Skip to content

Commit

Permalink
Integrate fastapi-sqla: tests setup and minor fixes (AlmaLinux/build-…
Browse files Browse the repository at this point in the history
…system#230)

- Test fixtures changed for fastapi_sqla
- Dynamic configuring of async_db key with get_async_db_key() dependency
- Dramatiq tasks' fastapi_sqla_setup in event_loop
  • Loading branch information
bklvsky committed Apr 10, 2024
1 parent 4e1e268 commit 9094fa6
Show file tree
Hide file tree
Showing 66 changed files with 454 additions and 358 deletions.
6 changes: 3 additions & 3 deletions alws/auth/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from sqlalchemy.ext.asyncio import AsyncSession

from alws.config import settings
# from alws.dependencies import get_async_session
from alws.dependencies import get_async_db_key
from alws.models import User, UserAccessToken, UserOauthAccount
from fastapi_sqla import AsyncSessionDependency

Expand All @@ -22,14 +22,14 @@


async def get_user_db(
session: AsyncSession = Depends(AsyncSessionDependency(key="async"))
session: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key()))
):
yield SQLAlchemyUserDatabase(
session, User, oauth_account_table=UserOauthAccount)


async def get_access_token_db(
session: AsyncSession = Depends(AsyncSessionDependency(key="async")),
session: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())),
):
yield SQLAlchemyAccessTokenDatabase(session, UserAccessToken)

Expand Down
10 changes: 7 additions & 3 deletions alws/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,22 @@ class Settings(BaseSettings):
)
fastapi_sqla__async__sqlalchemy_echo_pool: bool = True

fastapi_sqla__sync__sqlalchemy_url: str = (
sqlalchemy_url: str = (
'postgresql+psycopg2://postgres:password@db/almalinux-bs'
)
fastapi_sqla__sync__sqlalchemy_pool_pre_ping: bool = True
fastapi_sqla__sync__sqlalchemy_pool_recycle: int = 3600
sqlalchemy_pool_pre_ping: bool = True
sqlalchemy_pool_recycle: int = 3600

fastapi_sqla__pulp__sqlalchemy_url: str = (
'postgresql+psycopg2://postgres:password@pulp:5432/pulp'
)
fastapi_sqla__pulp__sqlalchemy_pool_pre_ping: bool = True
fastapi_sqla__pulp__sqlalchemy_pool_recycle: int = 3600

fastapi_sqla__test__sqlalchemy_url: str = (
'postgresql+asyncpg://postgres:password@db/test-almalinux-bs'
)

github_client: str
github_client_secret: str

Expand Down
8 changes: 4 additions & 4 deletions alws/crud/errata.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
ErrataReleaseStatus,
GitHubIssueStatus,
)
# from alws.dependencies import get_db, get_pulp_db
from alws.dependencies import get_async_db_key
from alws.pulp_models import (
RpmPackage,
UpdateCollection,
Expand Down Expand Up @@ -1374,7 +1374,7 @@ async def release_errata_record(record_id: str, platform_id: int, force: bool):
settings.pulp_password,
)
# async with asynccontextmanager(get_db)() as session:
async with open_async_session(key="async") as session:
async with open_async_session(key=get_async_db_key()) as session:
session: AsyncSession
query = generate_query_for_release([record_id])
query = query.filter(models.NewErrataRecord.platform_id == platform_id)
Expand Down Expand Up @@ -1447,7 +1447,7 @@ async def bulk_errata_records_release(records_ids: List[str]):
release_tasks = []
repos_to_publish = []
# async with asynccontextmanager(get_db)() as session:
async with open_async_session(key="async") as session:
async with open_async_session(key=get_async_db_key()) as session:
await session.execute(
update(models.NewErrataRecord)
.where(models.NewErrataRecord.id.in_(records_ids))
Expand All @@ -1459,7 +1459,7 @@ async def bulk_errata_records_release(records_ids: List[str]):
# await session.commit() # auto commit on exit of the contextmanager

# async with asynccontextmanager(get_db)() as session:
async with open_async_session(key="async") as session:
async with open_async_session(key=get_async_db_key()) as session:
session: AsyncSession
db_records = await session.execute(
generate_query_for_release(records_ids),
Expand Down
3 changes: 2 additions & 1 deletion alws/crud/sign_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from alws.config import settings
from alws.constants import GenKeyStatus, SignStatus
from alws.crud.user import get_user
from alws.dependencies import get_async_db_key
# from alws.database import Session
from alws.errors import (
BuildAlreadySignedError,
Expand Down Expand Up @@ -506,7 +507,7 @@ async def __failed_post_processing(
srpms_mapping = defaultdict(list)

logging.info("Start processing task %s", sign_task_id)
async with open_async_session(key="async") as db:
async with open_async_session(key=get_async_db_key()) as db:
builds = await db.execute(
select(models.Build)
.where(models.Build.id == payload.build_id)
Expand Down
3 changes: 3 additions & 0 deletions alws/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,6 @@ async def get_redis() -> aioredis.Redis:
yield client
finally:
await client.close()

def get_async_db_key() -> str:
return "async"
10 changes: 5 additions & 5 deletions alws/dramatiq/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from alws.crud import build_node as build_node_crud
from alws.crud import test
from sqlalchemy.orm import Session
# from alws.dependencies import get_db
from alws.dependencies import get_async_db_key
from alws.dramatiq import event_loop
from alws.errors import (
ArtifactConversionError,
Expand Down Expand Up @@ -74,7 +74,6 @@ async def _start_build(build_id: int, build_request: build_schema.BuildCreate):
))
module_build_index = {}

await setup_all()
if has_modules:
# with SyncSession() as db, db.begin():

Expand Down Expand Up @@ -106,7 +105,7 @@ async def _start_build(build_id: int, build_request: build_schema.BuildCreate):
# db.commit()
# db.close()

async with open_async_session(key="async") as db:
async with open_async_session(key=get_async_db_key()) as db:
# async with db.begin():
build = await fetch_build(db, build_id)
planner = BuildPlanner(
Expand Down Expand Up @@ -166,8 +165,7 @@ async def _start_build(build_id: int, build_request: build_schema.BuildCreate):

async def _build_done(request: build_node_schema.BuildDone):
# async for db in get_db():
await setup_all()
async with open_async_session(key="async") as db:
async with open_async_session(key=get_async_db_key()) as db:
try:
await build_node_crud.safe_build_done(db, request)
except Exception as e:
Expand Down Expand Up @@ -293,6 +291,7 @@ async def _all_build_tasks_completed(
)
def start_build(build_id: int, build_request: Dict[str, Any]):
parsed_build = build_schema.BuildCreate(**build_request)
event_loop.run_until_complete(setup_all())
event_loop.run_until_complete(_start_build(build_id, parsed_build))


Expand All @@ -311,4 +310,5 @@ def start_build(build_id: int, build_request: Dict[str, Any]):
)
def build_done(request: Dict[str, Any]):
parsed_build = build_node_schema.BuildDone(**request)
event_loop.run_until_complete(setup_all())
event_loop.run_until_complete(_build_done(parsed_build))
4 changes: 2 additions & 2 deletions alws/dramatiq/errata.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,13 @@
__all__ = ["release_errata"]

async def _release_errata_record(record_id: str, platform_id: int, force: bool):
await setup_all()
await release_errata_record(
record_id,
platform_id,
force,
)

async def _bulk_errata_records_release(records_ids: typing.List[str]):
await setup_all()
await bulk_errata_records_release(records_ids)

@dramatiq.actor(
Expand All @@ -28,6 +26,7 @@ async def _bulk_errata_records_release(records_ids: typing.List[str]):
time_limit=DRAMATIQ_TASK_TIMEOUT,
)
def release_errata(record_id: str, platform_id: int, force: bool):
event_loop.run_until_complete(setup_all())
event_loop.run_until_complete(
_release_errata_record(
record_id,
Expand All @@ -44,4 +43,5 @@ def release_errata(record_id: str, platform_id: int, force: bool):
time_limit=DRAMATIQ_TASK_TIMEOUT,
)
def bulk_errata_release(records_ids: typing.List[str]):
event_loop.run_until_complete(setup_all())
event_loop.run_until_complete(_bulk_errata_records_release(records_ids))
5 changes: 3 additions & 2 deletions alws/dramatiq/products.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from alws.config import settings
from alws.constants import DRAMATIQ_TASK_TIMEOUT, BuildTaskStatus
# from alws.database import Session
from alws.dependencies import get_async_db_key
from alws.dramatiq import event_loop
from alws.utils.log_utils import setup_logger
from alws.utils.pulp_client import PulpClient
Expand Down Expand Up @@ -276,9 +277,8 @@ async def _perform_product_modification(
build_id,
product_id,
)
await setup_all()
# async with Session() as db, db.begin():
async with open_async_session(key="async") as db:
async with open_async_session(key=get_async_db_key()) as db:
db_product = (
(
await db.execute(
Expand Down Expand Up @@ -388,6 +388,7 @@ def perform_product_modification(
product_id: int,
modification: str,
):
event_loop.run_until_complete(setup_all())
event_loop.run_until_complete(
_perform_product_modification(build_id, product_id, modification)
)
10 changes: 5 additions & 5 deletions alws/dramatiq/releases.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,22 @@
from fastapi_sqla import open_async_session
from alws.constants import DRAMATIQ_TASK_TIMEOUT
from alws.crud import release as r_crud
from alws.dependencies import get_async_db_key
from alws.dramatiq import event_loop
# from alws.dependencies import get_db
from alws.utils.fastapi_sqla_setup import setup_all


__all__ = ["execute_release_plan"]


async def _commit_release(release_id, user_id):
await setup_all()
# async with asynccontextmanager(get_db)() as db:
async with open_async_session(key="async") as db:
async with open_async_session(key=get_async_db_key()) as db:
await r_crud.commit_release(db, release_id, user_id)


async def _revert_release(release_id, user_id):
await setup_all()
async with open_async_session(key="async") as db:
async with open_async_session(key=get_async_db_key()) as db:
await r_crud.revert_release(db, release_id, user_id)


Expand All @@ -33,6 +31,7 @@ async def _revert_release(release_id, user_id):
time_limit=DRAMATIQ_TASK_TIMEOUT,
)
def execute_release_plan(release_id: int, user_id: int):
event_loop.run_until_complete(setup_all())
event_loop.run_until_complete(_commit_release(release_id, user_id))


Expand All @@ -43,4 +42,5 @@ def execute_release_plan(release_id: int, user_id: int):
time_limit=DRAMATIQ_TASK_TIMEOUT,
)
def revert_release(release_id: int, user_id: int):
event_loop.run_until_complete(setup_all())
event_loop.run_until_complete(_revert_release(release_id, user_id))
2 changes: 1 addition & 1 deletion alws/dramatiq/sign_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
async def _complete_sign_task(
task_id: int, payload: typing.Dict[str, typing.Any]
):
await setup_all()
await sign_task.complete_sign_task(
task_id, sign_schema.SignTaskComplete(**payload)
)
Expand All @@ -30,4 +29,5 @@ async def _complete_sign_task(
time_limit=DRAMATIQ_TASK_TIMEOUT,
)
def complete_sign_task(task_id: int, payload: typing.Dict[str, typing.Any]):
event_loop.run_until_complete(setup_all())
event_loop.run_until_complete(_complete_sign_task(task_id, payload))
5 changes: 3 additions & 2 deletions alws/dramatiq/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from alws.constants import DRAMATIQ_TASK_TIMEOUT, TestTaskStatus
from alws.crud import test as t_crud
# from alws.database import Session
from alws.dependencies import get_async_db_key
from alws.dramatiq import event_loop
from alws.schemas.test_schema import TestTaskResult
from alws.utils.fastapi_sqla_setup import setup_all
Expand All @@ -16,8 +17,7 @@


async def _complete_test_task(task_id: int, task_result: TestTaskResult):
await setup_all()
async with open_async_session(key="async") as db:
async with open_async_session(key=get_async_db_key()) as db:
try:
logging.info('Start processing test task %s', task_id)
await t_crud.complete_test_task(db, task_id, task_result)
Expand All @@ -41,5 +41,6 @@ async def _complete_test_task(task_id: int, task_result: TestTaskResult):
time_limit=DRAMATIQ_TASK_TIMEOUT
)
def complete_test_task(task_id: int, payload: typing.Dict[str, typing.Any]):
event_loop.run_until_complete(setup_all())
event_loop.run_until_complete(
_complete_test_task(task_id, TestTaskResult(**payload)))
6 changes: 3 additions & 3 deletions alws/dramatiq/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,15 @@
from alws import models
from alws.constants import DRAMATIQ_TASK_TIMEOUT
from alws.crud import build as build_crud
# from alws.dependencies import get_db
from alws.dependencies import get_async_db_key
from alws.dramatiq import event_loop
from alws.utils.fastapi_sqla_setup import setup_all

__all__ = ['perform_user_removal']


async def _perform_user_removal(user_id: int):
await setup_all()
async with open_async_session(key="async") as db:
async with open_async_session(key=get_async_db_key()) as db:
# Remove builds
build_ids = (await db.execute(
select(models.Build.id).where(
Expand All @@ -35,4 +34,5 @@ async def _perform_user_removal(user_id: int):

@dramatiq.actor(max_retries=0, priority=0, time_limit=DRAMATIQ_TASK_TIMEOUT)
def perform_user_removal(user_id: int):
event_loop.run_until_complete(setup_all())
event_loop.run_until_complete(_perform_user_removal(user_id))
3 changes: 2 additions & 1 deletion alws/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
)
# from alws.database import Base, engine
from alws.database import Base
from alws.dependencies import get_async_db_key

__all__ = [
"Build",
Expand Down Expand Up @@ -2115,7 +2116,7 @@ class PerformanceStats(Base):

async def create_tables():
# async with engine.begin() as conn:
async with open_async_session(key="async") as conn:
async with open_async_session(key=get_async_db_key()) as conn:
await conn.run_sync(Base.metadata.create_all)


Expand Down
10 changes: 5 additions & 5 deletions alws/routers/build_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,20 @@
from alws.config import settings
from alws.constants import BuildTaskRefType, BuildTaskStatus
from alws.crud import build_node
# from alws.dependencies import get_db
from alws.dependencies import get_async_db_key
from alws.schemas import build_node_schema

router = APIRouter(
prefix="/build_node",
tags=["builds"],
dependencies=[Depends(get_current_user)],
# dependencies=[Depends(get_current_user)],
)


@router.post("/ping")
async def ping(
node_status: build_node_schema.Ping,
db: AsyncSession = Depends(AsyncSessionDependency(key="async")),
db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())),
):
if not node_status.active_tasks:
return {}
Expand All @@ -37,7 +37,7 @@ async def ping(
async def build_done(
build_done_: build_node_schema.BuildDone,
response: Response,
db: AsyncSession = Depends(AsyncSessionDependency(key="async")),
db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())),
):
build_task = await build_node.get_build_task(db, build_done_.task_id)
if BuildTaskStatus.is_finished(build_task.status):
Expand All @@ -60,7 +60,7 @@ async def build_done(
)
async def get_task(
request: build_node_schema.RequestTask,
db: AsyncSession = Depends(AsyncSessionDependency(key="async")),
db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())),
):
task = await build_node.get_available_build_task(db, request)
if not task:
Expand Down
Loading

0 comments on commit 9094fa6

Please sign in to comment.