Skip to content

Commit e397acb

Browse files
pierrotsmnrdpmeiernenb
authored
Fix #399 : eager loading of Chats docs, msgs, srcs (#401)
Co-authored-by: Philip Meier <github.pmeier@posteo.de> Co-authored-by: Nick Byrne <byrnen8@tcd.ie>
1 parent 04168ab commit e397acb

File tree

4 files changed

+36
-12
lines changed

4 files changed

+36
-12
lines changed

ragna/deploy/_api/database.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from urllib.parse import urlsplit
77

88
from sqlalchemy import create_engine, select
9-
from sqlalchemy.orm import Session
9+
from sqlalchemy.orm import Session, joinedload
1010
from sqlalchemy.orm import sessionmaker as _sessionmaker
1111

1212
from ragna.core import RagnaException
@@ -136,30 +136,49 @@ def _orm_to_schema_chat(chat: orm.Chat) -> schemas.Chat:
136136
)
137137

138138

139+
def _select_chat(*, eager: bool = False) -> Any:
140+
selector = select(orm.Chat)
141+
if eager:
142+
selector = selector.options( # type: ignore[attr-defined]
143+
joinedload(orm.Chat.messages).joinedload(orm.Message.sources),
144+
joinedload(orm.Chat.documents),
145+
)
146+
return selector
147+
148+
139149
def get_chats(session: Session, *, user: str) -> list[schemas.Chat]:
140150
return [
141151
_orm_to_schema_chat(chat)
142152
for chat in session.execute(
143-
select(orm.Chat).where(orm.Chat.user_id == _get_user_id(session, user))
153+
_select_chat(eager=True).where(
154+
orm.Chat.user_id == _get_user_id(session, user)
155+
)
144156
)
145157
.scalars()
158+
.unique()
146159
.all()
147160
]
148161

149162

150-
def _get_orm_chat(session: Session, *, user: str, id: uuid.UUID) -> orm.Chat:
151-
chat: Optional[orm.Chat] = session.execute(
152-
select(orm.Chat).where(
153-
(orm.Chat.id == id) & (orm.Chat.user_id == _get_user_id(session, user))
163+
def _get_orm_chat(
164+
session: Session, *, user: str, id: uuid.UUID, eager: bool = False
165+
) -> orm.Chat:
166+
chat: Optional[orm.Chat] = (
167+
session.execute(
168+
_select_chat(eager=eager).where(
169+
(orm.Chat.id == id) & (orm.Chat.user_id == _get_user_id(session, user))
170+
)
154171
)
155-
).scalar_one_or_none()
172+
.unique()
173+
.scalar_one_or_none()
174+
)
156175
if chat is None:
157176
raise RagnaException()
158177
return chat
159178

160179

161180
def get_chat(session: Session, *, user: str, id: uuid.UUID) -> schemas.Chat:
162-
return _orm_to_schema_chat(_get_orm_chat(session, user=user, id=id))
181+
return _orm_to_schema_chat(_get_orm_chat(session, user=user, id=id, eager=True))
163182

164183

165184
def _schema_to_orm_source(session: Session, source: schemas.Source) -> orm.Source:

ragna/deploy/_api/orm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,9 @@ class Chat(Base):
8686
source_storage = Column(types.String)
8787
assistant = Column(types.String)
8888
params = Column(Json)
89-
messages = relationship("Message", cascade="all, delete")
89+
messages = relationship(
90+
"Message", cascade="all, delete", order_by="Message.timestamp"
91+
)
9092
prepared = Column(types.Boolean)
9193

9294

ragna/deploy/_api/schemas.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,7 @@ class Message(BaseModel):
5656
content: str
5757
role: ragna.core.MessageRole
5858
sources: list[Source] = Field(default_factory=list)
59-
timestamp: datetime.datetime = Field(
60-
default_factory=lambda: datetime.datetime.utcnow()
61-
)
59+
timestamp: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
6260

6361
@classmethod
6462
def from_core(cls, message: ragna.core.Message) -> Message:

tests/deploy/api/test_e2e.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import time
23

34
import pytest
45
from fastapi.testclient import TestClient
@@ -12,6 +13,10 @@
1213

1314
class TestAssistant(RagnaDemoAssistant):
1415
def answer(self, prompt, sources, *, multiple_answer_chunks: bool):
16+
# Simulate a "real" assistant through a small delay. See
17+
# https://github.com/Quansight/ragna/pull/401#issuecomment-2095851440
18+
# for why this is needed.
19+
time.sleep(1e-3)
1520
content = next(super().answer(prompt, sources))
1621

1722
if multiple_answer_chunks:

0 commit comments

Comments
 (0)