Skip to content

Update token in do_connect event #34

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/fastapi_app/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import os

import azure.identity.aio
import azure.identity
from dotenv import load_dotenv
from environs import Env
from fastapi import FastAPI
Expand All @@ -27,9 +27,9 @@ async def lifespan(app: FastAPI):
"Using managed identity for client ID %s",
client_id,
)
azure_credential = azure.identity.aio.ManagedIdentityCredential(client_id=client_id)
azure_credential = azure.identity.ManagedIdentityCredential(client_id=client_id)
else:
azure_credential = azure.identity.aio.DefaultAzureCredential()
azure_credential = azure.identity.DefaultAzureCredential()
except Exception as e:
logger.warning("Failed to authenticate to Azure: %s", e)

Expand Down
6 changes: 3 additions & 3 deletions src/fastapi_app/openai_clients.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import os

import azure.identity.aio
import azure.identity
import openai

logger = logging.getLogger("ragapp")
Expand All @@ -12,7 +12,7 @@ async def create_openai_chat_client(azure_credential):
if OPENAI_CHAT_HOST == "azure":
logger.info("Authenticating to OpenAI using Azure Identity...")

token_provider = azure.identity.aio.get_bearer_token_provider(
token_provider = azure.identity.get_bearer_token_provider(
azure_credential, "https://cognitiveservices.azure.com/.default"
)
openai_chat_client = openai.AsyncAzureOpenAI(
Expand Down Expand Up @@ -40,7 +40,7 @@ async def create_openai_chat_client(azure_credential):
async def create_openai_embed_client(azure_credential):
OPENAI_EMBED_HOST = os.getenv("OPENAI_EMBED_HOST")
if OPENAI_EMBED_HOST == "azure":
token_provider = azure.identity.aio.get_bearer_token_provider(
token_provider = azure.identity.get_bearer_token_provider(
azure_credential, "https://cognitiveservices.azure.com/.default"
)
openai_embed_client = openai.AsyncAzureOpenAI(
Expand Down
36 changes: 17 additions & 19 deletions src/fastapi_app/postgres_engine.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
import logging
import os

from azure.identity.aio import DefaultAzureCredential
from azure.identity import DefaultAzureCredential
from sqlalchemy import event
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine

logger = logging.getLogger("ragapp")


async def create_postgres_engine(*, host, username, database, password, sslmode, azure_credential) -> AsyncEngine:
def get_password_from_azure_credential():
token = azure_credential.get_token("https://ossrdbms-aad.database.windows.net/.default")
return token.token

token_based_password = False
if host.endswith(".database.azure.com"):
token_based_password = True
logger.info("Authenticating to Azure Database for PostgreSQL using Azure Identity...")
if azure_credential is None:
raise ValueError("Azure credential must be provided for Azure Database for PostgreSQL")
token = await azure_credential.get_token("https://ossrdbms-aad.database.windows.net/.default")
password = token.token
password = get_password_from_azure_credential()
else:
logger.info("Authenticating to PostgreSQL using password...")

Expand All @@ -27,16 +33,20 @@ async def create_postgres_engine(*, host, username, database, password, sslmode,
echo=False,
)

@event.listens_for(engine.sync_engine, "do_connect")
def update_password_token(dialect, conn_rec, cargs, cparams):
if token_based_password:
logger.info("Updating password token for Azure Database for PostgreSQL")
cparams["password"] = get_password_from_azure_credential()

return engine


async def create_postgres_engine_from_env(azure_credential=None) -> AsyncEngine:
must_close = False
if azure_credential is None and os.environ["POSTGRES_HOST"].endswith(".database.azure.com"):
azure_credential = DefaultAzureCredential()
must_close = True

engine = await create_postgres_engine(
return await create_postgres_engine(
host=os.environ["POSTGRES_HOST"],
username=os.environ["POSTGRES_USERNAME"],
database=os.environ["POSTGRES_DATABASE"],
Expand All @@ -45,28 +55,16 @@ async def create_postgres_engine_from_env(azure_credential=None) -> AsyncEngine:
azure_credential=azure_credential,
)

if must_close:
await azure_credential.close()

return engine


async def create_postgres_engine_from_args(args, azure_credential=None) -> AsyncEngine:
must_close = False
if azure_credential is None and args.host.endswith(".database.azure.com"):
azure_credential = DefaultAzureCredential()
must_close = True

engine = await create_postgres_engine(
return await create_postgres_engine(
host=args.host,
username=args.username,
database=args.database,
password=args.password,
sslmode=args.sslmode,
azure_credential=azure_credential,
)

if must_close:
await azure_credential.close()

return engine
4 changes: 2 additions & 2 deletions src/fastapi_app/query_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def build_search_function() -> list[ChatCompletionToolParam]:
"properties": {
"comparison_operator": {
"type": "string",
"description": "Operator to compare the column value, either '>', '<', '>=', '<=', '=='", # noqa
"description": "Operator to compare the column value, either '>', '<', '>=', '<=', '='", # noqa
},
"value": {
"type": "number",
Expand All @@ -40,7 +40,7 @@ def build_search_function() -> list[ChatCompletionToolParam]:
"properties": {
"comparison_operator": {
"type": "string",
"description": "Operator to compare the column value, either '==' or '!='",
"description": "Operator to compare the column value, either '=' or '!='",
},
"value": {
"type": "string",
Expand Down
27 changes: 12 additions & 15 deletions src/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ aiohttp==3.9.5
# via fastapi_app (pyproject.toml)
aiosignal==1.3.1
# via aiohttp
annotated-types==0.6.0
annotated-types==0.7.0
# via pydantic
anyio==4.3.0
anyio==4.4.0
# via
# httpx
# openai
Expand Down Expand Up @@ -53,10 +53,8 @@ email-validator==2.1.1
environs==11.0.0
# via fastapi_app (pyproject.toml)
fastapi==0.111.0
# via
# fastapi-cli
# fastapi_app (pyproject.toml)
fastapi-cli==0.0.3
# via fastapi_app (pyproject.toml)
fastapi-cli==0.0.4
# via fastapi
frozenlist==1.4.1
# via
Expand Down Expand Up @@ -107,7 +105,7 @@ multidict==6.0.5
# yarl
numpy==1.26.4
# via pgvector
openai==1.30.1
openai==1.30.4
# via
# fastapi_app (pyproject.toml)
# openai-messages-token-helper
Expand All @@ -128,11 +126,11 @@ portalocker==2.8.2
# via msal-extensions
pycparser==2.22
# via cffi
pydantic==2.7.1
pydantic==2.7.2
# via
# fastapi
# openai
pydantic-core==2.18.2
pydantic-core==2.18.3
# via pydantic
pygments==2.18.0
# via rich
Expand All @@ -149,9 +147,9 @@ python-multipart==0.0.9
# via fastapi
pyyaml==6.0.1
# via uvicorn
regex==2024.5.10
regex==2024.5.15
# via tiktoken
requests==2.31.0
requests==2.32.2
# via
# azure-core
# msal
Expand Down Expand Up @@ -179,7 +177,7 @@ tqdm==4.66.4
# via openai
typer==0.12.3
# via fastapi-cli
typing-extensions==4.11.0
typing-extensions==4.12.0
# via
# azure-core
# fastapi
Expand All @@ -192,14 +190,13 @@ ujson==5.10.0
# via fastapi
urllib3==2.2.1
# via requests
uvicorn[standard]==0.29.0
uvicorn[standard]==0.30.0
# via
# fastapi
# fastapi-cli
# fastapi_app (pyproject.toml)
uvloop==0.19.0
# via uvicorn
watchfiles==0.21.0
watchfiles==0.22.0
# via uvicorn
websockets==12.0
# via uvicorn
Expand Down