Skip to content

Commit ae0871d

Browse files
committed
Make mypy happy
1 parent 7632736 commit ae0871d

File tree

7 files changed

+20
-12
lines changed

7 files changed

+20
-12
lines changed

src/backend/fastapi_app/embeddings.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ async def compute_text_embedding(
1010
openai_client: AsyncOpenAI | AsyncAzureOpenAI,
1111
embed_model: str,
1212
embed_deployment: str | None = None,
13-
embedding_dimensions: int = 1536,
13+
embedding_dimensions: int | None = None,
1414
) -> list[float]:
1515
SUPPORTED_DIMENSIONS_MODEL = {
1616
"text-embedding-ada-002": False,
@@ -21,9 +21,12 @@ async def compute_text_embedding(
2121
class ExtraArgs(TypedDict, total=False):
2222
dimensions: int
2323

24-
dimensions_args: ExtraArgs = (
25-
{"dimensions": embedding_dimensions} if SUPPORTED_DIMENSIONS_MODEL.get(embed_model) else {}
26-
)
24+
dimensions_args: ExtraArgs = {}
25+
if SUPPORTED_DIMENSIONS_MODEL.get(embed_model):
26+
if embedding_dimensions is None:
27+
raise ValueError(f"Model {embed_model} requires embedding dimensions")
28+
else:
29+
dimensions_args = {"dimensions": embedding_dimensions}
2730

2831
embedding = await openai_client.embeddings.create(
2932
# Azure OpenAI takes the deployment name as the model name

src/backend/fastapi_app/postgres_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def get_password_from_azure_credential():
3030

3131
engine = create_async_engine(
3232
DATABASE_URI,
33-
echo=False,
33+
echo=True,
3434
)
3535

3636
@event.listens_for(engine.sync_engine, "do_connect")

src/backend/fastapi_app/postgres_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class Item(Base):
2121
description: Mapped[str] = mapped_column()
2222
price: Mapped[float] = mapped_column()
2323
embedding_ada002: Mapped[Vector] = mapped_column(Vector(1536)) # ada-002
24-
embedding_nomic: Mapped[Vector] | None = mapped_column(Vector(768), nullable=True) # nomic-embed-text
24+
embedding_nomic: Mapped[Vector] = mapped_column(Vector(768)) # nomic-embed-text
2525

2626
def to_dict(self, include_embedding: bool = False):
2727
model_dict = asdict(self)

src/backend/fastapi_app/postgres_searcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def __init__(
1414
openai_embed_client: AsyncOpenAI | AsyncAzureOpenAI,
1515
embed_deployment: str | None, # Not needed for non-Azure OpenAI or for retrieval_mode="text"
1616
embed_model: str,
17-
embed_dimensions: int,
17+
embed_dimensions: int | None,
1818
embedding_column: str,
1919
):
2020
self.db_session = db_session

src/backend/fastapi_app/routes/api_routes.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,18 @@ async def item_handler(database_session: DBSession, id: int) -> ItemPublic:
4545

4646

4747
@router.get("/similar", response_model=list[ItemWithDistance])
48-
async def similar_handler(database_session: DBSession, id: int, n: int = 5) -> list[ItemWithDistance]:
48+
async def similar_handler(
49+
context: CommonDeps, database_session: DBSession, id: int, n: int = 5
50+
) -> list[ItemWithDistance]:
4951
"""A similarity API to find items similar to items with given ID."""
5052
item = (await database_session.scalars(select(Item).where(Item.id == id))).first()
5153
if not item:
5254
raise HTTPException(detail=f"Item with ID {id} not found.", status_code=404)
55+
5356
closest = await database_session.execute(
54-
select(Item, Item.embedding.l2_distance(item.embedding))
57+
select(Item, Item.embedding_ada002.l2_distance(item.embedding_ada002))
5558
.filter(Item.id != id)
56-
.order_by(Item.embedding.l2_distance(item.embedding))
59+
.order_by(Item.embedding_ada002.l2_distance(item.embedding_ada002))
5760
.limit(n)
5861
)
5962
return [
@@ -78,6 +81,7 @@ async def search_handler(
7881
embed_deployment=context.openai_embed_deployment,
7982
embed_model=context.openai_embed_model,
8083
embed_dimensions=context.openai_embed_dimensions,
84+
embedding_column=context.embedding_column,
8185
)
8286
results = await searcher.search_and_embed(
8387
query, top=top, enable_vector_search=enable_vector_search, enable_text_search=enable_text_search

src/backend/fastapi_app/update_embeddings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,9 @@ async def update_embeddings(in_seed_data=False):
6363

6464
async with async_sessionmaker(engine, expire_on_commit=False)() as session:
6565
async with session.begin():
66-
items = (await session.scalars(select(Item))).all()
66+
items_to_update = (await session.scalars(select(Item))).all()
6767

68-
for item in items:
68+
for item in items_to_update:
6969
setattr(
7070
item,
7171
embedding_column,

tests/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,4 +263,5 @@ async def postgres_searcher(mock_session_env, mock_default_azure_credential, db_
263263
embed_deployment="text-embedding-ada-002",
264264
embed_model="text-embedding-ada-002",
265265
embed_dimensions=1536,
266+
embedding_column="embedding_ada002",
266267
)

0 commit comments

Comments
 (0)