Skip to content

Commit 082fc64

Browse files
author
nik
committed
Fix feedback table in server
1 parent 66406f4 commit 082fc64

File tree

3 files changed

+80
-77
lines changed

3 files changed

+80
-77
lines changed

adala/environments/servers/base.py

Lines changed: 23 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,11 @@
1010
STORAGE_DB = 'feedback.db'
1111

1212

13-
class GroundTruth(BaseModel):
13+
class Feedback(BaseModel):
1414
prediction_id: int
15-
skill_output: str
16-
gt_match: Optional[bool] = None
17-
gt_data: Optional[str] = None
18-
19-
20-
class Prediction(BaseModel):
21-
id: int
22-
input: Dict[str, Any]
23-
skill_name: str
24-
output: str
15+
prediction_column: str
16+
fb_match: Optional[bool] = None
17+
fb_message: Optional[str] = None
2518

2619

2720
router = APIRouter()
@@ -38,34 +31,34 @@ async def init_db(self):
3831
print(f'Initializing database {STORAGE_DB}...')
3932
async with aiosqlite.connect(STORAGE_DB) as db:
4033
await db.execute('''
41-
CREATE TABLE IF NOT EXISTS ground_truth (
34+
CREATE TABLE IF NOT EXISTS feedback (
4235
prediction_id INTEGER NOT NULL,
43-
skill_name TEXT NOT NULL,
44-
gt_match BOOLEAN,
45-
gt_data TEXT,
46-
PRIMARY KEY (prediction_id, skill_name)
36+
prediction_column TEXT NOT NULL,
37+
fb_match BOOLEAN,
38+
fb_message TEXT,
39+
PRIMARY KEY (prediction_id, prediction_column)
4740
)
4841
''')
4942
await db.commit()
5043

5144
async def request_feedback(
5245
self,
53-
predictions: List[Prediction],
46+
predictions: List[Dict[str, Any]],
5447
skills: List[Dict[str, Any]],
5548
db: aiosqlite.Connection
5649
):
5750
raise NotImplementedError
5851

59-
async def retrieve_ground_truth(self, db: aiosqlite.Connection):
60-
cursor = await db.execute('SELECT prediction_id, skill_name, gt_match, gt_data FROM ground_truth')
52+
async def retrieve_feedback(self, db: aiosqlite.Connection):
53+
cursor = await db.execute('SELECT prediction_id, prediction_column, fb_match, fb_message FROM feedback')
6154
rows = await cursor.fetchall()
62-
return [GroundTruth(prediction_id=row[0], skill_name=row[1], gt_match=row[2], gt_data=row[3]) for row in rows]
55+
return [Feedback(prediction_id=row[0], prediction_column=row[1], fb_match=row[2], fb_message=row[3]) for row in rows]
6356

64-
async def store_ground_truths(self, ground_truths: List[GroundTruth], db: aiosqlite.Connection):
57+
async def store_feedback(self, feedbacks: List[Feedback], db: aiosqlite.Connection):
6558
await db.executemany('''
66-
INSERT OR REPLACE INTO ground_truth (prediction_id, skill_name, gt_match, gt_data)
59+
INSERT OR REPLACE INTO feedback (prediction_id, prediction_column, fb_match, fb_message)
6760
VALUES (?, ?, ?, ?)
68-
''', [(gt.prediction_id, gt.skill_name, gt.gt_match, gt.gt_data) for gt in ground_truths])
61+
''', [(fb.prediction_id, fb.prediction_column, fb.fb_match, fb.fb_message) for fb in feedbacks])
6962
await db.commit()
7063

7164

@@ -79,23 +72,23 @@ async def get_db() -> aiosqlite.Connection:
7972
app = BaseAPI()
8073

8174

82-
@router.post("/feedback")
83-
async def create_feedback(
75+
@router.post("/request-feedback")
76+
async def request_feedback(
8477
request: Request,
8578
predictions: List[Dict[str, Any]],
8679
skills: List[Dict[str, Any]],
8780
db: aiosqlite.Connection = Depends(get_db)
8881
):
8982
app = request.app
9083
await app.request_feedback(predictions, skills, db)
91-
return {"message": "Feedback received successfully"}
84+
return {"message": "Feedback requested successfully"}
9285

9386

94-
@router.get("/ground-truth", response_model=List[GroundTruth])
95-
async def get_ground_truth(request: Request, db: aiosqlite.Connection = Depends(get_db)):
87+
@router.get("/feedback", response_model=List[Feedback])
88+
async def get_feedback(request: Request, db: aiosqlite.Connection = Depends(get_db)):
9689
app = request.app
97-
ground_truths = await app.retrieve_ground_truth(db)
98-
return ground_truths
90+
fb = await app.retrieve_feedback(db)
91+
return fb
9992

10093

10194
@app.on_event("startup")

adala/environments/servers/discord_bot.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import List, Dict, Any
88
from discord.ext import commands
99
from discord.ui import View
10-
from adala.environments.servers.base import BaseAPI, Prediction, GroundTruth, STORAGE_DB
10+
from adala.environments.servers.base import BaseAPI, Feedback, STORAGE_DB
1111

1212

1313
intents = discord.Intents.default()
@@ -43,17 +43,17 @@ async def on_message(message):
4343
initial_message_id = message.reference.message_id
4444
print(f'Got reply in thread for message id: {initial_message_id}', message)
4545
async with aiosqlite.connect(STORAGE_DB) as db:
46-
async with db.execute('SELECT * FROM discord_gt_message WHERE message_id = ?', (initial_message_id,)) as cursor:
46+
async with db.execute('SELECT * FROM discord_fb_message WHERE message_id = ?', (initial_message_id,)) as cursor:
4747
row = await cursor.fetchone()
4848
if row is None:
4949
print(f'No ground truth message found for message id: {initial_message_id}')
5050
return
51-
prediction_id, skill_name = int(row[1]), row[2]
51+
prediction_id, prediction_column = int(row[1]), row[2]
5252

5353
# update ground truth with reply
54-
await db.execute('UPDATE ground_truth SET gt_data = ? '
55-
'WHERE prediction_id = ? AND skill_name = ?',
56-
(message.content, prediction_id, skill_name))
54+
await db.execute('UPDATE feedback SET fb_message = ? '
55+
'WHERE prediction_id = ? AND prediction_column = ?',
56+
(message.content, prediction_id, prediction_column))
5757
await db.commit()
5858

5959
# Process other messages normally
@@ -63,46 +63,46 @@ async def on_message(message):
6363
@bot.event
6464
async def on_interaction(interaction: discord.Interaction):
6565

66-
async def update_ground_truth_match(prediction_id: int, skill_name: str, match: bool):
66+
async def update_feedback_match(prediction_id: int, prediction_column: str, match: bool):
6767
async with aiosqlite.connect(STORAGE_DB) as db:
68-
await db.execute('UPDATE ground_truth SET gt_match = ?, gt_data = NULL '
69-
'WHERE prediction_id = ? AND skill_name = ?',
70-
(match, prediction_id, skill_name))
68+
await db.execute('UPDATE feedback SET fb_match = ?, fb_message = NULL '
69+
'WHERE prediction_id = ? AND prediction_column = ?',
70+
(match, prediction_id, prediction_column))
7171
await db.commit()
7272
print(f'Updated ground truth for prediction id: {prediction_id} with match: {match}')
7373

7474
if interaction.type == discord.InteractionType.component:
7575
await interaction.response.defer(ephemeral=True)
7676

7777
custom_id = interaction.data['custom_id']
78-
action, prediction_id_str, skill_name = custom_id.split(':')
78+
action, prediction_id_str, prediction_column = custom_id.split(':')
7979
prediction_id = int(prediction_id_str) # Convert prediction_id to int
8080

8181
if action == 'accept':
8282
# Handle the accept action
83-
await update_ground_truth_match(prediction_id, skill_name, True)
83+
await update_feedback_match(prediction_id, prediction_column, True)
8484
# React with a checkmark emoji to the message
8585
await interaction.message.add_reaction('✅')
8686
elif action == 'reject':
8787
# Handle the reject action
88-
await update_ground_truth_match(prediction_id, skill_name, False)
88+
await update_feedback_match(prediction_id, prediction_column, False)
8989
# React with a cross mark emoji to the message
9090
await interaction.message.add_reaction('❌')
9191

9292

9393
class AcceptRejectView(View):
9494

95-
def __init__(self, prediction_id: int, skill_name: str, *args, **kwargs):
95+
def __init__(self, prediction_id: int, prediction_column: str, *args, **kwargs):
9696
super().__init__(*args, **kwargs)
9797
self.add_item(discord.ui.Button(
9898
label='Accept',
9999
style=discord.ButtonStyle.success,
100-
custom_id=f'accept:{prediction_id}:{skill_name}'
100+
custom_id=f'accept:{prediction_id}:{prediction_column}'
101101
))
102102
self.add_item(discord.ui.Button(
103103
label='Reject',
104104
style=discord.ButtonStyle.danger,
105-
custom_id=f'reject:{prediction_id}:{skill_name}'
105+
custom_id=f'reject:{prediction_id}:{prediction_column}'
106106
))
107107

108108
async def interaction_check(self, interaction: discord.Interaction) -> bool:
@@ -115,10 +115,10 @@ class DiscordAPI(BaseAPI):
115115
async def init_db_gt_message(self):
116116
async with aiosqlite.connect(STORAGE_DB) as db:
117117
await db.execute('''
118-
CREATE TABLE IF NOT EXISTS discord_gt_message (
118+
CREATE TABLE IF NOT EXISTS discord_fb_message (
119119
id INTEGER PRIMARY KEY AUTOINCREMENT,
120120
prediction_id INTEGER NOT NULL,
121-
skill_name TEXT NOT NULL,
121+
prediction_column TEXT NOT NULL,
122122
message_id INTEGER NOT NULL
123123
)
124124
''')
@@ -140,29 +140,29 @@ async def request_feedback(
140140
channel = bot.get_channel(CHANNEL_ID)
141141
if not channel:
142142
raise Exception(f'Channel with id {CHANNEL_ID} not found')
143-
ground_truths = []
143+
fbs = []
144144
skill_outputs = sum([skill['outputs'] for skill in skills], [])
145145
for skill_output in skill_outputs:
146146
for prediction in predictions:
147147
text = '========================\n'
148148
text += '\n'.join(f'**{k}**: {v}' for k, v in prediction.items() if k not in skill_outputs + ['index'])
149149
text += f'\n\n__**{skill_output}**__: {prediction[skill_output]}'
150-
ground_truth = GroundTruth(prediction_id=prediction['index'], skill_name=skill_output)
151-
150+
text = text[:2000]
151+
fb = Feedback(prediction_id=prediction['index'], prediction_column=skill_output)
152152
message = await channel.send(
153153
text, view=AcceptRejectView(
154-
prediction_id=ground_truth.prediction_id,
155-
skill_name=ground_truth.skill_name)
154+
prediction_id=fb.prediction_id,
155+
prediction_column=fb.prediction_column)
156156
)
157-
ground_truths.append(ground_truth)
157+
fbs.append(fb)
158158
await db.execute('''
159-
INSERT INTO discord_gt_message (prediction_id, skill_name, message_id)
159+
INSERT INTO discord_fb_message (prediction_id, prediction_column, message_id)
160160
VALUES (?, ?, ?)
161-
''', (ground_truth.prediction_id, ground_truth.skill_name, message.id))
161+
''', (fb.prediction_id, fb.prediction_column, message.id))
162162
await db.commit()
163163

164164
# TODO: do we need to store it in advance?
165-
await self.store_ground_truths(ground_truths, db)
165+
await self.store_feedback(fbs, db)
166166

167167

168168
app = DiscordAPI()

adala/environments/web.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import requests
22
import time
33
from typing import Optional
4-
from .base import StaticEnvironment
5-
from .servers.base import GroundTruth
4+
from .base import StaticEnvironment, EnvironmentFeedback
5+
from .servers.base import Feedback
66
from adala.skills import SkillSet
77
from adala.utils.internal_data import InternalDataFrame, InternalSeries
88
from collections import defaultdict
@@ -13,17 +13,22 @@ class WebStaticEnvironment(StaticEnvironment):
1313
"""
1414
Web environment interacts with server API to request feedback and retrieve ground truth.
1515
Following endpoints are expected:
16-
- POST /feedback
17-
- GET /ground-truth
16+
- POST /request-feedback
17+
- GET /feedback
1818
"""
1919
url: str
2020

21+
def _get_fb_records(self):
22+
fb_records = requests.get(f'{self.url}/feedback', timeout=3).json()
23+
fb_records = [Feedback(**r) for r in fb_records]
24+
return fb_records
25+
2126
def get_feedback(
2227
self,
2328
skills: SkillSet,
2429
predictions: InternalDataFrame,
2530
num_feedbacks: Optional[int] = None,
26-
):
31+
) -> EnvironmentFeedback:
2732
"""
2833
Request feedback for the predictions.
2934
@@ -45,30 +50,35 @@ def get_feedback(
4550
'predictions': predictions.reset_index().to_dict(orient='records')
4651
}
4752

48-
requests.post(f'{self.url}/feedback', json=payload, timeout=3)
53+
requests.post(f'{self.url}/request-feedback', json=payload, timeout=3)
4954

5055
# wait for feedback
5156
with Progress() as progress:
5257
task = progress.add_task(f"Waiting for feedback...", total=3600)
53-
gt_records = []
54-
while len(gt_records) < num_feedbacks:
58+
fb_records = []
59+
while len(fb_records) < num_feedbacks:
5560
progress.advance(task, 10)
5661
time.sleep(10)
57-
gt_records = self.get_gt_records()
62+
fb_records = self._get_fb_records()
63+
print('ZZZZ', fb_records)
5864

59-
if not gt_records:
65+
if not fb_records:
6066
raise RuntimeError('No ground truth found.')
6167

62-
gt = defaultdict(dict)
63-
for g in gt_records:
64-
gt[g.skill_output][g.prediction_id] = g.gt_data or True
68+
match = defaultdict(list)
69+
feedback = defaultdict(list)
70+
index = []
71+
for f in fb_records:
72+
match[f.prediction_column].append(f.fb_match)
73+
feedback[f.prediction_column].append(f.fb_message)
74+
index.append(f.prediction_id)
75+
76+
print(11111, match)
77+
print(2222, feedback)
78+
print(3333, fb_records)
6579

66-
df = InternalDataFrame({skill: InternalSeries(g) for skill, g in gt.items()})
80+
match = InternalDataFrame(match, index=index)
81+
feedback = InternalDataFrame(feedback, index=index)
6782

68-
return df
6983

70-
def get_gt_records(self):
71-
gt_records = requests.get(f'{self.url}/ground-truth', timeout=3).json()
72-
gt_records = [GroundTruth(**r) for r in gt_records]
73-
gt_records = [r for r in gt_records if r.gt_data or r.gt_match]
74-
return gt_records
84+
return EnvironmentFeedback(match=match, feedback=feedback)

0 commit comments

Comments
 (0)