From fcbb9187a242a06284a44b089f4bf04378917b74 Mon Sep 17 00:00:00 2001 From: Pamela Fox Date: Fri, 17 May 2024 04:07:30 -0700 Subject: [PATCH 1/2] Configure Azure Developer Pipeline From eb9ec6d642a65c48458d0800482f3327becac2aa Mon Sep 17 00:00:00 2001 From: Pamela Fox Date: Wed, 5 Jun 2024 05:33:21 +0000 Subject: [PATCH 2/2] Token improvements --- src/fastapi_app/__init__.py | 6 ++--- src/fastapi_app/openai_clients.py | 6 ++--- src/fastapi_app/postgres_engine.py | 36 ++++++++++++++---------------- src/fastapi_app/query_rewriter.py | 4 ++-- src/requirements.txt | 27 ++++++++++------------ 5 files changed, 37 insertions(+), 42 deletions(-) diff --git a/src/fastapi_app/__init__.py b/src/fastapi_app/__init__.py index 969a4271..de1d0fc8 100644 --- a/src/fastapi_app/__init__.py +++ b/src/fastapi_app/__init__.py @@ -2,7 +2,7 @@ import logging import os -import azure.identity.aio +import azure.identity from dotenv import load_dotenv from environs import Env from fastapi import FastAPI @@ -27,9 +27,9 @@ async def lifespan(app: FastAPI): "Using managed identity for client ID %s", client_id, ) - azure_credential = azure.identity.aio.ManagedIdentityCredential(client_id=client_id) + azure_credential = azure.identity.ManagedIdentityCredential(client_id=client_id) else: - azure_credential = azure.identity.aio.DefaultAzureCredential() + azure_credential = azure.identity.DefaultAzureCredential() except Exception as e: logger.warning("Failed to authenticate to Azure: %s", e) diff --git a/src/fastapi_app/openai_clients.py b/src/fastapi_app/openai_clients.py index c3d4fd2d..a3fcb1cd 100644 --- a/src/fastapi_app/openai_clients.py +++ b/src/fastapi_app/openai_clients.py @@ -1,7 +1,7 @@ import logging import os -import azure.identity.aio +import azure.identity import openai logger = logging.getLogger("ragapp") @@ -12,7 +12,7 @@ async def create_openai_chat_client(azure_credential): if OPENAI_CHAT_HOST == "azure": logger.info("Authenticating to OpenAI using Azure Identity...") - token_provider = azure.identity.aio.get_bearer_token_provider( + token_provider = azure.identity.get_bearer_token_provider( azure_credential, "https://cognitiveservices.azure.com/.default" ) openai_chat_client = openai.AsyncAzureOpenAI( @@ -40,7 +40,7 @@ async def create_openai_chat_client(azure_credential): async def create_openai_embed_client(azure_credential): OPENAI_EMBED_HOST = os.getenv("OPENAI_EMBED_HOST") if OPENAI_EMBED_HOST == "azure": - token_provider = azure.identity.aio.get_bearer_token_provider( + token_provider = azure.identity.get_bearer_token_provider( azure_credential, "https://cognitiveservices.azure.com/.default" ) openai_embed_client = openai.AsyncAzureOpenAI( diff --git a/src/fastapi_app/postgres_engine.py b/src/fastapi_app/postgres_engine.py index 54d159e5..0dcaf814 100644 --- a/src/fastapi_app/postgres_engine.py +++ b/src/fastapi_app/postgres_engine.py @@ -1,19 +1,25 @@ import logging import os -from azure.identity.aio import DefaultAzureCredential +from azure.identity import DefaultAzureCredential +from sqlalchemy import event from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine logger = logging.getLogger("ragapp") async def create_postgres_engine(*, host, username, database, password, sslmode, azure_credential) -> AsyncEngine: + def get_password_from_azure_credential(): + token = azure_credential.get_token("https://ossrdbms-aad.database.windows.net/.default") + return token.token + + token_based_password = False if host.endswith(".database.azure.com"): + token_based_password = True logger.info("Authenticating to Azure Database for PostgreSQL using Azure Identity...") if azure_credential is None: raise ValueError("Azure credential must be provided for Azure Database for PostgreSQL") - token = await azure_credential.get_token("https://ossrdbms-aad.database.windows.net/.default") - password = token.token + password = get_password_from_azure_credential() else: logger.info("Authenticating to PostgreSQL using password...") @@ -27,16 +33,20 @@ async def create_postgres_engine(*, host, username, database, password, sslmode, echo=False, ) + @event.listens_for(engine.sync_engine, "do_connect") + def update_password_token(dialect, conn_rec, cargs, cparams): + if token_based_password: + logger.info("Updating password token for Azure Database for PostgreSQL") + cparams["password"] = get_password_from_azure_credential() + return engine async def create_postgres_engine_from_env(azure_credential=None) -> AsyncEngine: - must_close = False if azure_credential is None and os.environ["POSTGRES_HOST"].endswith(".database.azure.com"): azure_credential = DefaultAzureCredential() - must_close = True - engine = await create_postgres_engine( + return await create_postgres_engine( host=os.environ["POSTGRES_HOST"], username=os.environ["POSTGRES_USERNAME"], database=os.environ["POSTGRES_DATABASE"], @@ -45,19 +55,12 @@ async def create_postgres_engine_from_env(azure_credential=None) -> AsyncEngine: azure_credential=azure_credential, ) - if must_close: - await azure_credential.close() - - return engine - async def create_postgres_engine_from_args(args, azure_credential=None) -> AsyncEngine: - must_close = False if azure_credential is None and args.host.endswith(".database.azure.com"): azure_credential = DefaultAzureCredential() - must_close = True - engine = await create_postgres_engine( + return await create_postgres_engine( host=args.host, username=args.username, database=args.database, @@ -65,8 +68,3 @@ async def create_postgres_engine_from_args(args, azure_credential=None) -> Async sslmode=args.sslmode, azure_credential=azure_credential, ) - - if must_close: - await azure_credential.close() - - return engine diff --git a/src/fastapi_app/query_rewriter.py b/src/fastapi_app/query_rewriter.py index 7eef8989..9cf4fffe 100644 --- a/src/fastapi_app/query_rewriter.py +++ b/src/fastapi_app/query_rewriter.py @@ -26,7 +26,7 @@ def build_search_function() -> list[ChatCompletionToolParam]: "properties": { "comparison_operator": { "type": "string", - "description": "Operator to compare the column value, either '>', '<', '>=', '<=', '=='", # noqa + "description": "Operator to compare the column value, either '>', '<', '>=', '<=', '='", # noqa }, "value": { "type": "number", @@ -40,7 +40,7 @@ def build_search_function() -> list[ChatCompletionToolParam]: "properties": { "comparison_operator": { "type": "string", - "description": "Operator to compare the column value, either '==' or '!='", + "description": "Operator to compare the column value, either '=' or '!='", }, "value": { "type": "string", diff --git a/src/requirements.txt b/src/requirements.txt index 4ae104d4..04daddd3 100644 --- a/src/requirements.txt +++ b/src/requirements.txt @@ -8,9 +8,9 @@ aiohttp==3.9.5 # via fastapi_app (pyproject.toml) aiosignal==1.3.1 # via aiohttp -annotated-types==0.6.0 +annotated-types==0.7.0 # via pydantic -anyio==4.3.0 +anyio==4.4.0 # via # httpx # openai @@ -53,10 +53,8 @@ email-validator==2.1.1 environs==11.0.0 # via fastapi_app (pyproject.toml) fastapi==0.111.0 - # via - # fastapi-cli - # fastapi_app (pyproject.toml) -fastapi-cli==0.0.3 + # via fastapi_app (pyproject.toml) +fastapi-cli==0.0.4 # via fastapi frozenlist==1.4.1 # via @@ -107,7 +105,7 @@ multidict==6.0.5 # yarl numpy==1.26.4 # via pgvector -openai==1.30.1 +openai==1.30.4 # via # fastapi_app (pyproject.toml) # openai-messages-token-helper @@ -128,11 +126,11 @@ portalocker==2.8.2 # via msal-extensions pycparser==2.22 # via cffi -pydantic==2.7.1 +pydantic==2.7.2 # via # fastapi # openai -pydantic-core==2.18.2 +pydantic-core==2.18.3 # via pydantic pygments==2.18.0 # via rich @@ -149,9 +147,9 @@ python-multipart==0.0.9 # via fastapi pyyaml==6.0.1 # via uvicorn -regex==2024.5.10 +regex==2024.5.15 # via tiktoken -requests==2.31.0 +requests==2.32.2 # via # azure-core # msal @@ -179,7 +177,7 @@ tqdm==4.66.4 # via openai typer==0.12.3 # via fastapi-cli -typing-extensions==4.11.0 +typing-extensions==4.12.0 # via # azure-core # fastapi @@ -192,14 +190,13 @@ ujson==5.10.0 # via fastapi urllib3==2.2.1 # via requests -uvicorn[standard]==0.29.0 +uvicorn[standard]==0.30.0 # via # fastapi - # fastapi-cli # fastapi_app (pyproject.toml) uvloop==0.19.0 # via uvicorn -watchfiles==0.21.0 +watchfiles==0.22.0 # via uvicorn websockets==12.0 # via uvicorn