diff --git a/alws/auth/dependencies.py b/alws/auth/dependencies.py index fe9d3a334..6decec672 100644 --- a/alws/auth/dependencies.py +++ b/alws/auth/dependencies.py @@ -9,7 +9,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from alws.config import settings -# from alws.dependencies import get_async_session +from alws.dependencies import get_async_db_key from alws.models import User, UserAccessToken, UserOauthAccount from fastapi_sqla import AsyncSessionDependency @@ -22,14 +22,14 @@ async def get_user_db( - session: AsyncSession = Depends(AsyncSessionDependency(key="async")) + session: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())) ): yield SQLAlchemyUserDatabase( session, User, oauth_account_table=UserOauthAccount) async def get_access_token_db( - session: AsyncSession = Depends(AsyncSessionDependency(key="async")), + session: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): yield SQLAlchemyAccessTokenDatabase(session, UserAccessToken) diff --git a/alws/config.py b/alws/config.py index 66d94cc28..0426544b8 100644 --- a/alws/config.py +++ b/alws/config.py @@ -43,11 +43,11 @@ class Settings(BaseSettings): ) fastapi_sqla__async__sqlalchemy_echo_pool: bool = True - fastapi_sqla__sync__sqlalchemy_url: str = ( + sqlalchemy_url: str = ( 'postgresql+psycopg2://postgres:password@db/almalinux-bs' ) - fastapi_sqla__sync__sqlalchemy_pool_pre_ping: bool = True - fastapi_sqla__sync__sqlalchemy_pool_recycle: int = 3600 + sqlalchemy_pool_pre_ping: bool = True + sqlalchemy_pool_recycle: int = 3600 fastapi_sqla__pulp__sqlalchemy_url: str = ( 'postgresql+psycopg2://postgres:password@pulp:5432/pulp' @@ -55,6 +55,10 @@ class Settings(BaseSettings): fastapi_sqla__pulp__sqlalchemy_pool_pre_ping: bool = True fastapi_sqla__pulp__sqlalchemy_pool_recycle: int = 3600 + fastapi_sqla__test__sqlalchemy_url: str = ( + 'postgresql+asyncpg://postgres:password@db/test-almalinux-bs' + ) + github_client: str github_client_secret: str diff --git a/alws/crud/errata.py b/alws/crud/errata.py index 7f82bf92b..1bd33f57c 100644 --- a/alws/crud/errata.py +++ b/alws/crud/errata.py @@ -32,7 +32,7 @@ ErrataReleaseStatus, GitHubIssueStatus, ) -# from alws.dependencies import get_db, get_pulp_db +from alws.dependencies import get_async_db_key from alws.pulp_models import ( RpmPackage, UpdateCollection, @@ -1374,7 +1374,7 @@ async def release_errata_record(record_id: str, platform_id: int, force: bool): settings.pulp_password, ) # async with asynccontextmanager(get_db)() as session: - async with open_async_session(key="async") as session: + async with open_async_session(key=get_async_db_key()) as session: session: AsyncSession query = generate_query_for_release([record_id]) query = query.filter(models.NewErrataRecord.platform_id == platform_id) @@ -1447,7 +1447,7 @@ async def bulk_errata_records_release(records_ids: List[str]): release_tasks = [] repos_to_publish = [] # async with asynccontextmanager(get_db)() as session: - async with open_async_session(key="async") as session: + async with open_async_session(key=get_async_db_key()) as session: await session.execute( update(models.NewErrataRecord) .where(models.NewErrataRecord.id.in_(records_ids)) @@ -1459,7 +1459,7 @@ async def bulk_errata_records_release(records_ids: List[str]): # await session.commit() # auto commit on exit of the contextmanager # async with asynccontextmanager(get_db)() as session: - async with open_async_session(key="async") as session: + async with open_async_session(key=get_async_db_key()) as session: session: AsyncSession db_records = await session.execute( generate_query_for_release(records_ids), diff --git a/alws/crud/sign_task.py b/alws/crud/sign_task.py index fc21bda1d..896e7154a 100644 --- a/alws/crud/sign_task.py +++ b/alws/crud/sign_task.py @@ -16,6 +16,7 @@ from alws.config import settings from alws.constants import GenKeyStatus, SignStatus from alws.crud.user import get_user +from alws.dependencies import get_async_db_key # from alws.database import Session from alws.errors import ( BuildAlreadySignedError, @@ -506,7 +507,7 @@ async def __failed_post_processing( srpms_mapping = defaultdict(list) logging.info("Start processing task %s", sign_task_id) - async with open_async_session(key="async") as db: + async with open_async_session(key=get_async_db_key()) as db: builds = await db.execute( select(models.Build) .where(models.Build.id == payload.build_id) diff --git a/alws/dependencies.py b/alws/dependencies.py index 1e78e7e8f..fb0952cd8 100644 --- a/alws/dependencies.py +++ b/alws/dependencies.py @@ -57,3 +57,6 @@ async def get_redis() -> aioredis.Redis: yield client finally: await client.close() + +def get_async_db_key() -> str: + return "async" \ No newline at end of file diff --git a/alws/dramatiq/build.py b/alws/dramatiq/build.py index 781c05378..f1c06e97d 100644 --- a/alws/dramatiq/build.py +++ b/alws/dramatiq/build.py @@ -22,7 +22,7 @@ from alws.crud import build_node as build_node_crud from alws.crud import test from sqlalchemy.orm import Session -# from alws.dependencies import get_db +from alws.dependencies import get_async_db_key from alws.dramatiq import event_loop from alws.errors import ( ArtifactConversionError, @@ -74,7 +74,6 @@ async def _start_build(build_id: int, build_request: build_schema.BuildCreate): )) module_build_index = {} - await setup_all() if has_modules: # with SyncSession() as db, db.begin(): @@ -106,7 +105,7 @@ async def _start_build(build_id: int, build_request: build_schema.BuildCreate): # db.commit() # db.close() - async with open_async_session(key="async") as db: + async with open_async_session(key=get_async_db_key()) as db: # async with db.begin(): build = await fetch_build(db, build_id) planner = BuildPlanner( @@ -166,8 +165,7 @@ async def _start_build(build_id: int, build_request: build_schema.BuildCreate): async def _build_done(request: build_node_schema.BuildDone): # async for db in get_db(): - await setup_all() - async with open_async_session(key="async") as db: + async with open_async_session(key=get_async_db_key()) as db: try: await build_node_crud.safe_build_done(db, request) except Exception as e: @@ -293,6 +291,7 @@ async def _all_build_tasks_completed( ) def start_build(build_id: int, build_request: Dict[str, Any]): parsed_build = build_schema.BuildCreate(**build_request) + event_loop.run_until_complete(setup_all()) event_loop.run_until_complete(_start_build(build_id, parsed_build)) @@ -311,4 +310,5 @@ def start_build(build_id: int, build_request: Dict[str, Any]): ) def build_done(request: Dict[str, Any]): parsed_build = build_node_schema.BuildDone(**request) + event_loop.run_until_complete(setup_all()) event_loop.run_until_complete(_build_done(parsed_build)) diff --git a/alws/dramatiq/errata.py b/alws/dramatiq/errata.py index 4d64a8135..feffe5ece 100644 --- a/alws/dramatiq/errata.py +++ b/alws/dramatiq/errata.py @@ -10,7 +10,6 @@ __all__ = ["release_errata"] async def _release_errata_record(record_id: str, platform_id: int, force: bool): - await setup_all() await release_errata_record( record_id, platform_id, @@ -18,7 +17,6 @@ async def _release_errata_record(record_id: str, platform_id: int, force: bool): ) async def _bulk_errata_records_release(records_ids: typing.List[str]): - await setup_all() await bulk_errata_records_release(records_ids) @dramatiq.actor( @@ -28,6 +26,7 @@ async def _bulk_errata_records_release(records_ids: typing.List[str]): time_limit=DRAMATIQ_TASK_TIMEOUT, ) def release_errata(record_id: str, platform_id: int, force: bool): + event_loop.run_until_complete(setup_all()) event_loop.run_until_complete( _release_errata_record( record_id, @@ -44,4 +43,5 @@ def release_errata(record_id: str, platform_id: int, force: bool): time_limit=DRAMATIQ_TASK_TIMEOUT, ) def bulk_errata_release(records_ids: typing.List[str]): + event_loop.run_until_complete(setup_all()) event_loop.run_until_complete(_bulk_errata_records_release(records_ids)) diff --git a/alws/dramatiq/products.py b/alws/dramatiq/products.py index 82f34f9aa..e537e3d68 100644 --- a/alws/dramatiq/products.py +++ b/alws/dramatiq/products.py @@ -13,6 +13,7 @@ from alws.config import settings from alws.constants import DRAMATIQ_TASK_TIMEOUT, BuildTaskStatus # from alws.database import Session +from alws.dependencies import get_async_db_key from alws.dramatiq import event_loop from alws.utils.log_utils import setup_logger from alws.utils.pulp_client import PulpClient @@ -276,9 +277,8 @@ async def _perform_product_modification( build_id, product_id, ) - await setup_all() # async with Session() as db, db.begin(): - async with open_async_session(key="async") as db: + async with open_async_session(key=get_async_db_key()) as db: db_product = ( ( await db.execute( @@ -388,6 +388,7 @@ def perform_product_modification( product_id: int, modification: str, ): + event_loop.run_until_complete(setup_all()) event_loop.run_until_complete( _perform_product_modification(build_id, product_id, modification) ) diff --git a/alws/dramatiq/releases.py b/alws/dramatiq/releases.py index a3a4a7536..ba123bcf2 100644 --- a/alws/dramatiq/releases.py +++ b/alws/dramatiq/releases.py @@ -5,8 +5,8 @@ from fastapi_sqla import open_async_session from alws.constants import DRAMATIQ_TASK_TIMEOUT from alws.crud import release as r_crud +from alws.dependencies import get_async_db_key from alws.dramatiq import event_loop -# from alws.dependencies import get_db from alws.utils.fastapi_sqla_setup import setup_all @@ -14,15 +14,13 @@ async def _commit_release(release_id, user_id): - await setup_all() # async with asynccontextmanager(get_db)() as db: - async with open_async_session(key="async") as db: + async with open_async_session(key=get_async_db_key()) as db: await r_crud.commit_release(db, release_id, user_id) async def _revert_release(release_id, user_id): - await setup_all() - async with open_async_session(key="async") as db: + async with open_async_session(key=get_async_db_key()) as db: await r_crud.revert_release(db, release_id, user_id) @@ -33,6 +31,7 @@ async def _revert_release(release_id, user_id): time_limit=DRAMATIQ_TASK_TIMEOUT, ) def execute_release_plan(release_id: int, user_id: int): + event_loop.run_until_complete(setup_all()) event_loop.run_until_complete(_commit_release(release_id, user_id)) @@ -43,4 +42,5 @@ def execute_release_plan(release_id: int, user_id: int): time_limit=DRAMATIQ_TASK_TIMEOUT, ) def revert_release(release_id: int, user_id: int): + event_loop.run_until_complete(setup_all()) event_loop.run_until_complete(_revert_release(release_id, user_id)) diff --git a/alws/dramatiq/sign_task.py b/alws/dramatiq/sign_task.py index 885fc3e54..15d2feba9 100644 --- a/alws/dramatiq/sign_task.py +++ b/alws/dramatiq/sign_task.py @@ -15,7 +15,6 @@ async def _complete_sign_task( task_id: int, payload: typing.Dict[str, typing.Any] ): - await setup_all() await sign_task.complete_sign_task( task_id, sign_schema.SignTaskComplete(**payload) ) @@ -30,4 +29,5 @@ async def _complete_sign_task( time_limit=DRAMATIQ_TASK_TIMEOUT, ) def complete_sign_task(task_id: int, payload: typing.Dict[str, typing.Any]): + event_loop.run_until_complete(setup_all()) event_loop.run_until_complete(_complete_sign_task(task_id, payload)) diff --git a/alws/dramatiq/tests.py b/alws/dramatiq/tests.py index 2dff0e78f..abdf7a7b1 100644 --- a/alws/dramatiq/tests.py +++ b/alws/dramatiq/tests.py @@ -7,6 +7,7 @@ from alws.constants import DRAMATIQ_TASK_TIMEOUT, TestTaskStatus from alws.crud import test as t_crud # from alws.database import Session +from alws.dependencies import get_async_db_key from alws.dramatiq import event_loop from alws.schemas.test_schema import TestTaskResult from alws.utils.fastapi_sqla_setup import setup_all @@ -16,8 +17,7 @@ async def _complete_test_task(task_id: int, task_result: TestTaskResult): - await setup_all() - async with open_async_session(key="async") as db: + async with open_async_session(key=get_async_db_key()) as db: try: logging.info('Start processing test task %s', task_id) await t_crud.complete_test_task(db, task_id, task_result) @@ -41,5 +41,6 @@ async def _complete_test_task(task_id: int, task_result: TestTaskResult): time_limit=DRAMATIQ_TASK_TIMEOUT ) def complete_test_task(task_id: int, payload: typing.Dict[str, typing.Any]): + event_loop.run_until_complete(setup_all()) event_loop.run_until_complete( _complete_test_task(task_id, TestTaskResult(**payload))) diff --git a/alws/dramatiq/user.py b/alws/dramatiq/user.py index b7640125f..336b25080 100644 --- a/alws/dramatiq/user.py +++ b/alws/dramatiq/user.py @@ -7,7 +7,7 @@ from alws import models from alws.constants import DRAMATIQ_TASK_TIMEOUT from alws.crud import build as build_crud -# from alws.dependencies import get_db +from alws.dependencies import get_async_db_key from alws.dramatiq import event_loop from alws.utils.fastapi_sqla_setup import setup_all @@ -15,8 +15,7 @@ async def _perform_user_removal(user_id: int): - await setup_all() - async with open_async_session(key="async") as db: + async with open_async_session(key=get_async_db_key()) as db: # Remove builds build_ids = (await db.execute( select(models.Build.id).where( @@ -35,4 +34,5 @@ async def _perform_user_removal(user_id: int): @dramatiq.actor(max_retries=0, priority=0, time_limit=DRAMATIQ_TASK_TIMEOUT) def perform_user_removal(user_id: int): + event_loop.run_until_complete(setup_all()) event_loop.run_until_complete(_perform_user_removal(user_id)) diff --git a/alws/models.py b/alws/models.py index 20b2bef1e..0e56d5812 100644 --- a/alws/models.py +++ b/alws/models.py @@ -35,6 +35,7 @@ ) # from alws.database import Base, engine from alws.database import Base +from alws.dependencies import get_async_db_key __all__ = [ "Build", @@ -2115,7 +2116,7 @@ class PerformanceStats(Base): async def create_tables(): # async with engine.begin() as conn: - async with open_async_session(key="async") as conn: + async with open_async_session(key=get_async_db_key()) as conn: await conn.run_sync(Base.metadata.create_all) diff --git a/alws/routers/build_node.py b/alws/routers/build_node.py index 4bb9d346e..19034a5ea 100644 --- a/alws/routers/build_node.py +++ b/alws/routers/build_node.py @@ -12,20 +12,20 @@ from alws.config import settings from alws.constants import BuildTaskRefType, BuildTaskStatus from alws.crud import build_node -# from alws.dependencies import get_db +from alws.dependencies import get_async_db_key from alws.schemas import build_node_schema router = APIRouter( prefix="/build_node", tags=["builds"], - dependencies=[Depends(get_current_user)], + # dependencies=[Depends(get_current_user)], ) @router.post("/ping") async def ping( node_status: build_node_schema.Ping, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): if not node_status.active_tasks: return {} @@ -37,7 +37,7 @@ async def ping( async def build_done( build_done_: build_node_schema.BuildDone, response: Response, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): build_task = await build_node.get_build_task(db, build_done_.task_id) if BuildTaskStatus.is_finished(build_task.status): @@ -60,7 +60,7 @@ async def build_done( ) async def get_task( request: build_node_schema.RequestTask, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): task = await build_node.get_available_build_task(db, request) if not task: diff --git a/alws/routers/builds.py b/alws/routers/builds.py index 2a22914be..7c3f43a21 100644 --- a/alws/routers/builds.py +++ b/alws/routers/builds.py @@ -11,7 +11,7 @@ from alws.crud import build_node from alws.crud import platform as platform_crud from alws.crud import platform_flavors as flavors_crud -# from alws.dependencies import get_db +from alws.dependencies import get_async_db_key from alws.errors import BuildError, DataNotFoundError from alws.schemas import build_schema @@ -31,7 +31,7 @@ async def create_build( build: build_schema.BuildCreate, user: models.User = Depends(get_current_user), - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): return await build_crud.create_build(db, build, user.id) @@ -58,7 +58,7 @@ async def get_builds_per_page( released: typing.Optional[bool] = None, signed: typing.Optional[bool] = None, is_running: typing.Optional[bool] = None, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): return await build_crud.get_builds( db=db, @@ -82,7 +82,7 @@ async def get_builds_per_page( @router.post('/get_module_preview/', response_model=build_schema.ModulePreview) async def get_module_preview( module_request: build_schema.ModulePreviewRequest, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): platform = await platform_crud.get_platform( db, @@ -104,7 +104,7 @@ async def get_module_preview( @public_router.get('/{build_id}/', response_model=build_schema.Build) async def get_build( build_id: int, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")) + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())) ): db_build = await build_crud.get_builds(db, build_id) if db_build is None: @@ -118,7 +118,7 @@ async def get_build( @router.patch('/{build_id}/restart-failed', status_code=status.HTTP_200_OK) async def restart_failed_build_items( build_id: int, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): return await build_node.update_failed_build_items(db, build_id) @@ -126,7 +126,7 @@ async def restart_failed_build_items( @router.patch("/{build_id}/cancel", status_code=status.HTTP_200_OK) async def cancel_idle_build_items( build_id: int, - session: AsyncSession = Depends(AsyncSessionDependency(key="async")), + session: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): await build_node.mark_build_tasks_as_cancelled(session, build_id) @@ -137,7 +137,7 @@ async def cancel_idle_build_items( ) async def parallel_restart_failed_build_items( build_id: int, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): return await build_node.update_failed_build_items_in_parallel(db, build_id) @@ -145,7 +145,7 @@ async def parallel_restart_failed_build_items( @router.delete('/{build_id}/remove', status_code=status.HTTP_204_NO_CONTENT) async def remove_build( build_id: int, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")) + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())) ): try: await build_crud.remove_build_job(db, build_id) diff --git a/alws/routers/coprs.py b/alws/routers/coprs.py index 9679d71c0..b5c65fb5f 100644 --- a/alws/routers/coprs.py +++ b/alws/routers/coprs.py @@ -10,7 +10,7 @@ # from alws import database from alws import models -# from alws.dependencies import get_db +from alws.dependencies import get_async_db_key from alws.utils.copr import ( generate_repo_config, get_clean_copr_chroot, @@ -25,7 +25,7 @@ @copr_router.get('/api_3/project/search') async def search_repos( query: str, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ) -> typing.Dict: query = select(models.Product).where( models.Product.name == query, @@ -40,7 +40,7 @@ async def search_repos( @copr_router.get('/api_3/project/list') async def list_repos( ownername: str, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ) -> typing.Dict: query = select(models.Product).where( models.Product.owner.has(username=ownername), @@ -61,7 +61,7 @@ async def get_dnf_repo_config( name: str, platform: str, arch: str, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): chroot = f'{platform}-{arch}' clean_chroot = get_clean_copr_chroot(chroot) diff --git a/alws/routers/errata.py b/alws/routers/errata.py index 0eecebfb0..0b5c135c4 100644 --- a/alws/routers/errata.py +++ b/alws/routers/errata.py @@ -10,7 +10,7 @@ # from alws.config import settings from alws.constants import ErrataReleaseStatus from alws.crud import errata as errata_crud -# from alws.dependencies import get_db +from alws.dependencies import get_async_db_key from alws.dramatiq import bulk_errata_release, release_errata from alws.schemas import errata_schema @@ -29,7 +29,7 @@ @router.post("/", response_model=errata_schema.CreateErrataResponse) async def create_errata_record( errata: errata_schema.BaseErrataRecord, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): record = await errata_crud.create_errata_record( db, @@ -42,7 +42,7 @@ async def create_errata_record( async def get_errata_record( errata_id: str, errata_platform_id: int, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): errata_record = await errata_crud.get_errata_record( db, @@ -61,7 +61,7 @@ async def get_errata_record( async def get_oval_xml( platform_name: str, only_released: bool = False, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): return await errata_crud.get_oval_xml(db, platform_name, only_released) @@ -75,7 +75,7 @@ async def list_errata_records( platformId: Optional[int] = None, cveId: Optional[str] = None, status: Optional[ErrataReleaseStatus] = None, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): return await errata_crud.list_errata_records( db, @@ -115,7 +115,7 @@ async def get_updateinfo_xml( @router.post("/update/", response_model=errata_schema.ErrataRecord) async def update_errata_record( errata: errata_schema.UpdateErrataRequest, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): return await errata_crud.update_errata_record(db, errata) @@ -125,7 +125,7 @@ async def update_errata_record( # See https://github.com/AlmaLinux/build-system/issues/207 @router.get("/all/", response_model=List[errata_schema.CompactErrataRecord]) async def list_all_errata_records( - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): records = await errata_crud.list_errata_records(db, compact=True) return [ @@ -143,7 +143,7 @@ async def list_all_errata_records( ) async def update_package_status( packages: List[errata_schema.ChangeErrataPackageStatusRequest], - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): try: return { @@ -161,7 +161,7 @@ async def release_errata_record( record_id: str, platform_id: int, force: bool = False, - session: AsyncSession = Depends(AsyncSessionDependency(key="async")), + session: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): db_record = await errata_crud.get_errata_record( session, @@ -198,7 +198,7 @@ async def bulk_release_errata_records(records_ids: List[str]): @router.post('/reset-matched-packages') async def reset_matched_packages( record_id: str, - session: AsyncSession = Depends(AsyncSessionDependency(key="async")), + session: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): await errata_crud.reset_matched_errata_packages(record_id, session) return {'message': f'Packages for record {record_id} have been matched'} diff --git a/alws/routers/platform_flavors.py b/alws/routers/platform_flavors.py index 3296aee79..aa559fa88 100644 --- a/alws/routers/platform_flavors.py +++ b/alws/routers/platform_flavors.py @@ -6,7 +6,7 @@ from fastapi_sqla import AsyncSessionDependency from alws.auth import get_current_user -# from alws.dependencies import get_db +from alws.dependencies import get_async_db_key from alws.schemas import platform_flavors_schema as pf_schema from alws.crud import platform_flavors as pf_crud @@ -20,7 +20,7 @@ @router.post('/', response_model=pf_schema.FlavourResponse) async def create_flavour( flavour: pf_schema.CreateFlavour, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")) + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())) ): return await pf_crud.create_flavour(db, flavour) @@ -28,13 +28,13 @@ async def create_flavour( @router.patch('/', response_model=pf_schema.FlavourResponse) async def update_flavour( flavour: pf_schema.UpdateFlavour, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")) + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())) ): return await pf_crud.update_flavour(db, flavour) @router.get('/', response_model=List[pf_schema.FlavourResponse]) async def get_flavours( - db: AsyncSession = Depends(AsyncSessionDependency(key="async")) + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())) ): return await pf_crud.list_flavours(db) diff --git a/alws/routers/platforms.py b/alws/routers/platforms.py index 127423c44..1c1323747 100644 --- a/alws/routers/platforms.py +++ b/alws/routers/platforms.py @@ -7,7 +7,7 @@ # from alws import database from alws.auth import get_current_user from alws.crud import platform as pl_crud, repository -# from alws.dependencies import get_db +from alws.dependencies import get_async_db_key from alws.schemas import platform_schema @@ -26,7 +26,7 @@ @router.post('/', response_model=platform_schema.Platform) async def create_platform( platform: platform_schema.PlatformCreate, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")) + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())) ): return await pl_crud.create_platform(db, platform) @@ -34,14 +34,14 @@ async def create_platform( @router.put('/', response_model=platform_schema.Platform) async def modify_platform( platform: platform_schema.PlatformModify, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")) + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())) ): return await pl_crud.modify_platform(db, platform) @public_router.get('/', response_model=typing.List[platform_schema.Platform]) async def get_platforms( - db: AsyncSession = Depends(AsyncSessionDependency(key="async")) + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())) ): return await pl_crud.get_platforms(db) @@ -50,7 +50,7 @@ async def get_platforms( response_model=platform_schema.Platform) async def add_repositories_to_platform( platform_id: int, repositories_ids: typing.List[int], - db: AsyncSession = Depends(AsyncSessionDependency(key="async"))): + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key()))): return await repository.add_to_platform( db, platform_id, repositories_ids) @@ -59,6 +59,6 @@ async def add_repositories_to_platform( response_model=platform_schema.Platform) async def remove_repositories_to_platform( platform_id: int, repositories_ids: typing.List[int], - db: AsyncSession = Depends(AsyncSessionDependency(key="async"))): + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key()))): return await repository.remove_from_platform( db, platform_id, repositories_ids) diff --git a/alws/routers/products.py b/alws/routers/products.py index b6ec31233..00447a1ff 100644 --- a/alws/routers/products.py +++ b/alws/routers/products.py @@ -11,7 +11,7 @@ from alws.auth import get_current_user from alws.crud import products, sign_task -# from alws.dependencies import get_db +from alws.dependencies import get_async_db_key from alws.models import User from alws.schemas import ( product_schema, @@ -40,7 +40,7 @@ async def get_products( pageNumber: Optional[int] = None, search_string: Optional[str] = None, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): return await products.get_products( db, page_number=pageNumber, search_string=search_string @@ -50,7 +50,7 @@ async def get_products( @public_router.post("/", response_model=product_schema.Product) async def create_product( product: product_schema.ProductCreate, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), user: User = Depends(get_current_user), ): # async with db.begin(): @@ -69,7 +69,7 @@ async def create_product( @public_router.get("/{product_id}/", response_model=product_schema.Product) async def get_product( product_id: int, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): db_product = await products.get_products(db, product_id=product_id) if db_product is None: @@ -87,7 +87,7 @@ async def get_product( async def add_to_product( product: str, build_id: int, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), user: User = Depends(get_current_user), ): try: @@ -110,7 +110,7 @@ async def add_to_product( async def remove_from_product( product: str, build_id: int, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), user: User = Depends(get_current_user), ): try: @@ -134,7 +134,7 @@ async def remove_from_product( ) async def remove_product( product_id: int, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), user: User = Depends(get_current_user), ): try: @@ -156,7 +156,7 @@ async def remove_product( ) async def create_gen_key_task( product_id: int, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), user: User = Depends(get_current_user), ): product = await products.get_products(db=db, product_id=product_id) diff --git a/alws/routers/releases.py b/alws/routers/releases.py index 47636919c..25dec6ed1 100644 --- a/alws/routers/releases.py +++ b/alws/routers/releases.py @@ -9,7 +9,7 @@ from alws.auth import get_current_user from alws.constants import ReleaseStatus from alws.crud import release as r_crud -# from alws.dependencies import get_db +from alws.dependencies import get_async_db_key from alws.dramatiq import execute_release_plan, revert_release from alws.schemas import release_schema @@ -38,7 +38,7 @@ async def get_releases( platform_id: typing.Optional[int] = None, status: typing.Optional[int] = None, package_name: typing.Optional[str] = None, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), # db: AsyncSession = Depends(get_db), ): return await r_crud.get_releases( @@ -54,7 +54,7 @@ async def get_releases( @public_router.get("/{release_id}/", response_model=release_schema.Release) async def get_release( release_id: int, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), # db: AsyncSession = Depends(get_db), ): return await r_crud.get_releases(db, release_id=release_id) @@ -63,7 +63,7 @@ async def get_release( @router.post("/new/", response_model=release_schema.Release) async def create_new_release( payload: release_schema.ReleaseCreate, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), # db: AsyncSession = Depends(get_db), user: models.User = Depends(get_current_user), ): @@ -75,7 +75,7 @@ async def create_new_release( async def update_release( release_id: int, payload: release_schema.ReleaseUpdate, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), # db: AsyncSession = Depends(get_db), user: models.User = Depends(get_current_user), ): @@ -89,7 +89,7 @@ async def update_release( ) async def commit_release( release_id: int, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), # db: AsyncSession = Depends(get_db), user: models.User = Depends(get_current_user), ): @@ -113,7 +113,7 @@ async def commit_release( ) async def revert_db_release( release_id: int, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), # db: AsyncSession = Depends(get_db), user: models.User = Depends(get_current_user), ): @@ -137,7 +137,7 @@ async def revert_db_release( ) async def delete_release( release_id: int, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), # db: AsyncSession = Depends(get_db), user: models.User = Depends(get_current_user), ): diff --git a/alws/routers/repositories.py b/alws/routers/repositories.py index 39fe14553..aec9a45bc 100644 --- a/alws/routers/repositories.py +++ b/alws/routers/repositories.py @@ -7,7 +7,7 @@ # from alws import database from alws.auth import get_current_user from alws.crud import repository -# from alws.dependencies import get_db +from alws.dependencies import get_async_db_key from alws.utils.exporter import fs_export_repository from alws.schemas import repository_schema @@ -21,7 +21,7 @@ @router.get('/', response_model=typing.List[repository_schema.Repository]) async def get_repositories( - db: AsyncSession = Depends(AsyncSessionDependency(key="async")) + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())) ): return await repository.get_repositories(db) @@ -30,7 +30,7 @@ async def get_repositories( response_model=typing.Union[None, repository_schema.Repository]) async def get_repository( repository_id: int, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")) + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())) ): result = await repository.get_repositories(db, repository_id=repository_id) if result: @@ -41,6 +41,6 @@ async def get_repository( @router.post('/exports/', response_model=typing.List[str]) async def filesystem_export_repository( repository_ids: typing.List[int], - db: AsyncSession = Depends(AsyncSessionDependency(key="async")) + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())) ): return await fs_export_repository(repository_ids, db) diff --git a/alws/routers/roles.py b/alws/routers/roles.py index bf50aaf34..53a1ace01 100644 --- a/alws/routers/roles.py +++ b/alws/routers/roles.py @@ -7,7 +7,7 @@ # from alws import database from alws.auth import get_current_user from alws.crud import roles -# from alws.dependencies import get_db +from alws.dependencies import get_async_db_key from alws.schemas import role_schema @@ -20,6 +20,6 @@ @router.get('/', response_model=typing.List[role_schema.Role]) async def get_roles( - db: AsyncSession = Depends(AsyncSessionDependency(key="async")) + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())) ): return await roles.get_roles(db) diff --git a/alws/routers/sign_key.py b/alws/routers/sign_key.py index 0b0080349..8608f7905 100644 --- a/alws/routers/sign_key.py +++ b/alws/routers/sign_key.py @@ -12,7 +12,7 @@ # from alws import database from alws.auth import get_current_user from alws.crud import sign_key -# from alws.dependencies import get_db +from alws.dependencies import get_async_db_key from alws.errors import PlatformMissingError, SignKeyAlreadyExistsError from alws.schemas import sign_schema @@ -25,7 +25,7 @@ @router.get('/', response_model=typing.List[sign_schema.SignKey]) async def get_sign_keys( - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), user=Depends(get_current_user), ): return await sign_key.get_sign_keys(db, user) @@ -38,7 +38,7 @@ async def get_sign_keys( ) async def create_sign_key( payload: sign_schema.SignKeyCreate, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")) + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())) ): try: return await sign_key.create_sign_key(db, payload) @@ -50,6 +50,6 @@ async def create_sign_key( async def modify_sign_key( sign_key_id: int, payload: sign_schema.SignKeyUpdate, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): return await sign_key.update_sign_key(db, sign_key_id, payload) diff --git a/alws/routers/sign_task.py b/alws/routers/sign_task.py index f5195e2a6..dd2965a3a 100644 --- a/alws/routers/sign_task.py +++ b/alws/routers/sign_task.py @@ -13,7 +13,7 @@ from alws.auth import get_current_user from alws.crud import sign_task from alws.dependencies import get_redis -# from alws.dependencies import get_db, get_redis +from alws.dependencies import get_async_db_key from alws.schemas import sign_schema router = APIRouter( @@ -31,7 +31,7 @@ @public_router.get('/', response_model=typing.List[sign_schema.SignTask]) async def get_sign_tasks( build_id: int = None, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")) + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())) ): return await sign_task.get_sign_tasks(db, build_id=build_id) @@ -39,7 +39,7 @@ async def get_sign_tasks( @router.post('/', response_model=sign_schema.SignTask) async def create_sign_task( payload: sign_schema.SignTaskCreate, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), user=Depends(get_current_user), ): return await sign_task.create_sign_task(db, payload, user.id) @@ -51,7 +51,7 @@ async def create_sign_task( ) async def get_available_sign_task( payload: sign_schema.SignTaskGet, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")) + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())) ): result = await sign_task.get_available_sign_task(db, payload.key_ids) if any( @@ -71,7 +71,7 @@ async def get_available_sign_task( async def complete_sign_task( sign_task_id: int, payload: sign_schema.SignTaskComplete, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): task = await sign_task.get_sign_task(db, sign_task_id) task.ts = datetime.datetime.utcnow() + datetime.timedelta(hours=2) @@ -138,7 +138,7 @@ async def iter_sync_sign_tasks( response_model=typing.Union[dict, sign_schema.AvailableGenKeyTask], ) async def get_avaiable_gen_key_task( - db: AsyncSession = Depends(AsyncSessionDependency(key="async")) + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())) ): gen_key_task = await sign_task.get_available_gen_key_task(db) if gen_key_task: @@ -159,7 +159,7 @@ async def get_avaiable_gen_key_task( async def complete_gen_key_task( gen_key_task_id: int, payload: sign_schema.GenKeyTaskComplete, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): sign_key = await sign_task.complete_gen_key_task( gen_key_task_id=gen_key_task_id, diff --git a/alws/routers/teams.py b/alws/routers/teams.py index 3531f2cc8..1efb35dca 100644 --- a/alws/routers/teams.py +++ b/alws/routers/teams.py @@ -12,7 +12,7 @@ # from alws import database from alws.auth import get_current_superuser from alws.crud import teams -# from alws.dependencies import get_db +from alws.dependencies import get_async_db_key from alws.errors import TeamError from alws.schemas import team_schema @@ -33,7 +33,7 @@ typing.List[team_schema.Team], team_schema.TeamResponse]) async def get_teams( pageNumber: int = None, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): return await teams.get_teams(db, page_number=pageNumber) @@ -41,7 +41,7 @@ async def get_teams( @public_router.get('/{team_id}/', response_model=team_schema.Team) async def get_team( team_id: int, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): return await teams.get_teams(db, team_id=team_id) @@ -50,7 +50,7 @@ async def get_team( async def add_members( team_id: int, payload: team_schema.TeamMembersUpdate, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): try: db_team = await teams.update_members(db, payload, team_id, 'add') @@ -66,7 +66,7 @@ async def add_members( async def remove_members( team_id: int, payload: team_schema.TeamMembersUpdate, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): try: db_team = await teams.update_members(db, payload, team_id, 'remove') @@ -81,7 +81,7 @@ async def remove_members( @router.post('/create/', response_model=team_schema.Team) async def create_team( payload: team_schema.TeamCreate, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): try: db_team = await teams.create_team(db, payload) @@ -96,7 +96,7 @@ async def create_team( @router.delete('/{team_id}/remove/', status_code=status.HTTP_202_ACCEPTED) async def remove_team( team_id: int, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")) + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())) ): try: await teams.remove_team(db, team_id) diff --git a/alws/routers/test_repositories.py b/alws/routers/test_repositories.py index 516f573c7..4edc2e796 100644 --- a/alws/routers/test_repositories.py +++ b/alws/routers/test_repositories.py @@ -6,7 +6,7 @@ from alws.auth import get_current_user from alws.crud import test_repository -# from alws.dependencies import get_db +from alws.dependencies import get_async_db_key from alws.errors import DataNotFoundError, TestRepositoryError from alws.schemas import test_repository_schema @@ -27,7 +27,7 @@ async def get_repositories( pageNumber: typing.Optional[int] = None, name: typing.Optional[str] = None, - session: AsyncSession = Depends(AsyncSessionDependency(key="async")), + session: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): return await test_repository.get_repositories( session, @@ -42,7 +42,7 @@ async def get_repositories( ) async def get_repository( repository_id: int, - session: AsyncSession = Depends(AsyncSessionDependency(key="async")), + session: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): return await test_repository.get_repositories( session, @@ -53,7 +53,7 @@ async def get_repository( @router.post('/create/', response_model=test_repository_schema.TestRepository) async def create_repository( payload: test_repository_schema.TestRepositoryCreate, - session: AsyncSession = Depends(AsyncSessionDependency(key="async")), + session: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): try: db_repo = await test_repository.create_repository(session, payload) @@ -76,7 +76,7 @@ async def create_repository( async def update_test_repository( repository_id: int, payload: test_repository_schema.TestRepositoryUpdate, - session: AsyncSession = Depends(AsyncSessionDependency(key="async")), + session: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): try: await test_repository.update_repository( @@ -97,7 +97,7 @@ async def update_test_repository( ) async def remove_test_repository( repository_id: int, - session: AsyncSession = Depends(AsyncSessionDependency(key="async")), + session: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): try: await test_repository.delete_repository(session, repository_id) @@ -115,7 +115,7 @@ async def remove_test_repository( async def create_package_mapping( repository_id: int, payload: test_repository_schema.PackageTestRepositoryCreate, - session: AsyncSession = Depends(AsyncSessionDependency(key="async")), + session: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): try: await test_repository.create_package_mapping( @@ -137,7 +137,7 @@ async def create_package_mapping( async def bulk_create_package_mapping( repository_id: int, payload: typing.List[test_repository_schema.PackageTestRepositoryCreate], - session: AsyncSession = Depends(AsyncSessionDependency(key="async")), + session: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): try: await test_repository.bulk_create_package_mapping( @@ -158,7 +158,7 @@ async def bulk_create_package_mapping( ) async def remove_package_mapping( package_id: int, - session: AsyncSession = Depends(AsyncSessionDependency(key="async")), + session: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): try: await test_repository.delete_package_mapping(session, package_id) @@ -176,7 +176,7 @@ async def remove_package_mapping( async def bulk_delete_package_mapping( repository_id: int, package_ids: typing.List[int], - session: AsyncSession = Depends(AsyncSessionDependency(key="async")), + session: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): await test_repository.bulk_delete_package_mapping( session=session, diff --git a/alws/routers/tests.py b/alws/routers/tests.py index 9b41e5c3f..4fc907fd4 100644 --- a/alws/routers/tests.py +++ b/alws/routers/tests.py @@ -7,7 +7,7 @@ from alws import dramatiq from alws.auth import get_current_user from alws.crud import test -# from alws.dependencies import get_db +from alws.dependencies import get_async_db_key from alws.schemas import test_schema router = APIRouter( @@ -36,7 +36,7 @@ async def update_test_task_result( response_model=List[test_schema.TestTaskPayload], ) async def get_test_tasks( - session: AsyncSession = Depends(AsyncSessionDependency(key="async")) + session: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())) ): return await test.get_available_test_tasks(session) @@ -44,7 +44,7 @@ async def get_test_tasks( @router.put('/build/{build_id}/restart') async def restart_build_tests( build_id: int, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): await test.restart_build_tests(db, build_id) return {'ok': True} @@ -53,7 +53,7 @@ async def restart_build_tests( @router.put('/build_task/{build_task_id}/restart') async def restart_build_task_tests( build_task_id: int, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): await test.restart_build_task_tests(db, build_task_id) return {'ok': True} @@ -62,7 +62,7 @@ async def restart_build_task_tests( @router.put('/build/{build_id}/cancel') async def cancel_build_tests( build_id: int, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): await test.cancel_build_tests(db, build_id) return {'ok': True} @@ -74,7 +74,7 @@ async def cancel_build_tests( ) async def get_latest_test_results( build_task_id: int, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): return await test.get_test_tasks_by_build_task(db, build_task_id) @@ -85,7 +85,7 @@ async def get_latest_test_results( ) async def get_test_logs( build_task_id: int, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): return await test.get_test_logs(build_task_id, db) @@ -97,7 +97,7 @@ async def get_test_logs( async def get_latest_test_results_by_revision( build_task_id: int, revision: int, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): return await test.get_test_tasks_by_build_task( db, diff --git a/alws/routers/uploads.py b/alws/routers/uploads.py index db086a20b..390f9da29 100644 --- a/alws/routers/uploads.py +++ b/alws/routers/uploads.py @@ -5,7 +5,7 @@ from fastapi_sqla import AsyncSessionDependency from alws.auth import get_current_user -# from alws.dependencies import get_db +from alws.dependencies import get_async_db_key from alws.utils.uploader import MetadataUploader router = APIRouter( @@ -20,7 +20,7 @@ async def upload_repometada( modules: typing.Optional[UploadFile] = None, comps: typing.Optional[UploadFile] = None, repository: str = Form(...), - session: AsyncSession = Depends(AsyncSessionDependency(key="async")), + session: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), ): # Temporary disable modules.yaml upload msg = "" diff --git a/alws/routers/users.py b/alws/routers/users.py index 7e36253a6..1f3867046 100644 --- a/alws/routers/users.py +++ b/alws/routers/users.py @@ -12,7 +12,7 @@ # from alws import database from alws.auth import get_current_superuser, get_current_user from alws.crud import user as user_crud -# from alws.dependencies import get_db +from alws.dependencies import get_async_db_key from alws.errors import UserError, PermissionDenied from alws.schemas import user_schema, role_schema from alws.models import User @@ -29,7 +29,7 @@ response_model=typing.List[user_schema.User], ) async def get_all_users( - db: AsyncSession = Depends(AsyncSessionDependency(key="async")) + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())) ): return await user_crud.get_all_users(db) @@ -40,7 +40,7 @@ async def get_all_users( ) async def modify_user( user_id: int, payload: user_schema.UserUpdate, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), _=Depends(get_current_superuser) ) -> user_schema.UserOpResult: try: @@ -56,7 +56,7 @@ async def modify_user( @router.delete('/{user_id}/remove') async def remove_user( user_id: int, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), _=Depends(get_current_superuser) ) -> user_schema.UserOpResult: try: @@ -77,7 +77,7 @@ async def remove_user( ) async def get_user_roles( user_id: int, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), _=Depends(get_current_user) ): return await user_crud.get_user_roles(db, user_id) @@ -89,7 +89,7 @@ async def get_user_roles( ) async def add_roles( user_id: int, roles_ids: typing.List[int], - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), current_user: User = Depends(get_current_user) ) -> user_schema.UserOpResult: try: @@ -110,7 +110,7 @@ async def add_roles( ) async def remove_roles( user_id: int, roles_ids: typing.List[int], - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), current_user: User = Depends(get_current_user) ) -> user_schema.UserOpResult: try: @@ -131,7 +131,7 @@ async def remove_roles( ) async def get_user_teams( user_id: int, - db: AsyncSession = Depends(AsyncSessionDependency(key="async")), + db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), _=Depends(get_current_user) ): return await user_crud.get_user_teams(db, user_id) diff --git a/alws/utils/fastapi_sqla_setup.py b/alws/utils/fastapi_sqla_setup.py index 5da783528..bb688e602 100644 --- a/alws/utils/fastapi_sqla_setup.py +++ b/alws/utils/fastapi_sqla_setup.py @@ -3,22 +3,21 @@ from fastapi_sqla.sqla import startup, _DEFAULT_SESSION_KEY from fastapi_sqla.async_sqla import startup as async_startup -DEFAULT_SESSION_KEY = _DEFAULT_SESSION_KEY app = FastAPI() setup(app) -sync_keys = ['pulp', DEFAULT_SESSION_KEY] -async_keys = ['async'] +sync_keys = ['pulp', _DEFAULT_SESSION_KEY] +async_keys = ['async', 'test'] async def setup_all(): sync_setup() await async_setup() -async def async_setup(*args): +async def async_setup(): for key in async_keys: await async_startup(key) -def sync_setup(*args): +def sync_setup(): for key in sync_keys: startup(key) \ No newline at end of file diff --git a/scripts/add_release_status_errata.py b/scripts/add_release_status_errata.py index 7b319ecbb..fdea543f3 100644 --- a/scripts/add_release_status_errata.py +++ b/scripts/add_release_status_errata.py @@ -14,7 +14,7 @@ from alws.models import ErrataRecord from alws.pulp_models import UpdateRecord -from alws.utils.fastapi_sqla_setup import sync_setup, DEFAULT_SESSION_KEY +from alws.utils.fastapi_sqla_setup import sync_setup logging.basicConfig( level=logging.INFO, @@ -26,7 +26,7 @@ def main(): logging.info("Start checking release status for ALBS errata records") - sync_setup("pulp", DEFAULT_SESSION_KEY) + sync_setup() # with PulpSession() as pulp_db, SyncSession() as albs_db, albs_db.begin(): with open_session(key="pulp") as pulp_db, open_session() as albs_db: pulp_records: typing.List[UpdateRecord.id] = ( diff --git a/scripts/add_teams_to_community_products.py b/scripts/add_teams_to_community_products.py index aaf9629d2..3fd792f4a 100644 --- a/scripts/add_teams_to_community_products.py +++ b/scripts/add_teams_to_community_products.py @@ -9,15 +9,16 @@ sys.path.append(os.path.dirname(os.path.dirname(__file__))) from alws import models +from alws.dependencies import get_async_db_key # from alws import database, models from alws.crud.teams import create_team, create_team_roles from alws.schemas.team_schema import TeamCreate -from alws.utils.fastapi_sqla_setup import async_setup +from alws.utils.fastapi_sqla_setup import setup_all async def main(): - await async_setup('async') - async with open_async_session(key="async") as session: + await setup_all() + async with open_async_session(get_async_db_key()) as session: products = (await session.execute(select(models.Product).where( models.Product.is_community.is_(True)).options( selectinload(models.Product.team).selectinload( diff --git a/scripts/albs-1003-fixes.py b/scripts/albs-1003-fixes.py index a51122418..fa1067e22 100644 --- a/scripts/albs-1003-fixes.py +++ b/scripts/albs-1003-fixes.py @@ -11,10 +11,11 @@ # from alws.dependencies import get_db, get_pulp_db from fastapi_sqla import open_session, open_async_session +from alws.dependencies import get_async_db_key from alws.models import ErrataRecord from alws.pulp_models import UpdateRecord from alws.utils.errata import debrand_description_and_title -from alws.utils.fastapi_sqla_setup import sync_setup, async_setup +from alws.utils.fastapi_sqla_setup import setup_all logging.basicConfig( format="%(asctime)s %(levelname)-8s %(message)s", @@ -29,8 +30,8 @@ async def main(): affected_updateinfos = {} - await async_setup('async') - async with open_async_session(key="async") as session: + await setup_all() + async with open_async_session(get_async_db_key()) as session: records = await session.execute( select(ErrataRecord).where( ErrataRecord.original_description.like("%[rhel%") @@ -44,7 +45,6 @@ async def main(): affected_updateinfos[record.id] = debranded_description # await session.commit() # auto commit on the exit of the contextmanager - sync_setup('pulp') with open_session(key="pulp") as pulp_session: records = pulp_session.execute( select(UpdateRecord).where( diff --git a/scripts/albs-1147.py b/scripts/albs-1147.py index 0b8ce3257..ccd5a19a9 100644 --- a/scripts/albs-1147.py +++ b/scripts/albs-1147.py @@ -12,7 +12,7 @@ from fastapi_sqla import open_session, open_async_session from alws.config import settings -# from alws.dependencies import get_pulp_db, get_db +from alws.dependencies import get_async_db_key from alws.utils.errata import debrand_description_and_title from alws.utils import pulp_client @@ -24,7 +24,7 @@ UpdateRecord, ) -from alws.utils.fastapi_sqla_setup import sync_setup, async_setup +from alws.utils.fastapi_sqla_setup import setup_all logging.basicConfig( format="%(message)s", @@ -95,8 +95,7 @@ async def main(write=False): pulp = pulp_client.PulpClient( settings.pulp_host, settings.pulp_user, settings.pulp_password ) - sync_setup('pulp') - await async_setup('async') + await setup_all() with open_session(key="pulp") as session: result = session.execute( @@ -160,7 +159,7 @@ async def main(write=False): log.info(f'{os.linesep * 2}Looking for records in almalinux\'s ' f'\'errata_records\' table...') - async with open_async_session(key="async") as session: + async with open_async_session(get_async_db_key()) as session: result = await session.execute( select(ErrataRecord).where( ErrataRecord.id.in_(list(affected_records.keys())) diff --git a/scripts/albs-682-fixes.py b/scripts/albs-682-fixes.py index b6e22ab05..841ffdcda 100644 --- a/scripts/albs-682-fixes.py +++ b/scripts/albs-682-fixes.py @@ -40,7 +40,7 @@ from alws.pulp_models import UpdateRecord, UpdatePackage from alws.models import ErrataRecord, Platform, Repository -from alws.utils.fastapi_sqla_setup import sync_setup, DEFAULT_SESSION_KEY +from alws.utils.fastapi_sqla_setup import setup_all logging.basicConfig( @@ -330,8 +330,7 @@ async def update_pulp_repos(): async def main(): - sync_setup('pulp', DEFAULT_SESSION_KEY) - + await setup_all() pulp_client.PULP_SEMAPHORE = asyncio.Semaphore(10) await prepare_albs_errata_cache() diff --git a/scripts/albs-952_fill_errata_module.py b/scripts/albs-952_fill_errata_module.py index 9f443482c..824dc8dfa 100644 --- a/scripts/albs-952_fill_errata_module.py +++ b/scripts/albs-952_fill_errata_module.py @@ -9,15 +9,15 @@ sys.path.append(os.path.dirname(os.path.dirname(__file__))) from alws import models -# from alws.dependencies import get_db -from alws.utils.fastapi_sqla_setup import async_setup +from alws.dependencies import get_async_db_key +from alws.utils.fastapi_sqla_setup import setup_all async def main(): module_regex = re.compile('Module ([\d\w\-\_]+:[\d\.\w]+) is enabled') updated_records = [] - await async_setup('async') - async with open_async_session(key="async") as db: + await setup_all() + async with open_async_session(key=get_async_db_key()) as db: errata_records = (await db.execute( select(models.ErrataRecord))).scalars().all() for record in errata_records: diff --git a/scripts/bootstrap_permissions.py b/scripts/bootstrap_permissions.py index 780359683..8c7b1b718 100644 --- a/scripts/bootstrap_permissions.py +++ b/scripts/bootstrap_permissions.py @@ -18,10 +18,11 @@ ) from alws.crud.products import create_product from alws.crud.teams import create_team +from alws.dependencies import get_async_db_key from alws.schemas.product_schema import ProductCreate from alws.schemas.team_schema import TeamCreate -from alws.utils.fastapi_sqla_setup import async_setup +from alws.utils.fastapi_sqla_setup import setup_all async def ensure_system_user_exists(session: AsyncSession) -> models.User: @@ -51,8 +52,8 @@ async def ensure_system_user_exists(session: AsyncSession) -> models.User: async def main(): - await async_setup('async') - async with open_async_session(key="async") as db: + await setup_all() + async with open_async_session(get_async_db_key()) as db: system_user = await ensure_system_user_exists(db) alma_team = await create_team( session=db, diff --git a/scripts/bootstrap_repositories.py b/scripts/bootstrap_repositories.py index cbda7db46..86ae2665d 100644 --- a/scripts/bootstrap_repositories.py +++ b/scripts/bootstrap_repositories.py @@ -12,12 +12,13 @@ from fastapi_sqla import open_async_session # from alws import database +from alws.dependencies import get_async_db_key from alws.crud import platform as pl_crud from alws.crud import repository as repo_crud from alws.schemas import platform_schema, remote_schema, repository_schema from alws.utils.pulp_client import PulpClient -from alws.utils.fastapi_sqla_setup import async_setup +from alws.utils.fastapi_sqla_setup import setup_all REPO_CACHE = {} @@ -77,7 +78,7 @@ async def get_repository( production: bool, logger: logging.Logger, ): - async with open_async_session(key="async") as db: + async with open_async_session(key=get_async_db_key()) as db: if production: repo_payload = repo_info.copy() repo_payload.pop("remote_url") @@ -145,7 +146,7 @@ async def get_repository( async def get_remote(repo_info: dict, remote_sync_policy: str): - async with open_async_session(key="async") as db: + async with open_async_session(key=get_async_db_key()) as db: remote_payload = repo_info.copy() remote_payload["name"] = f'{repo_info["name"]}-{repo_info["arch"]}' remote_payload.pop("type", None) @@ -160,7 +161,7 @@ async def get_remote(repo_info: dict, remote_sync_policy: str): async def update_remote(remote_id, remote_data: dict): - async with open_async_session(key="async") as db: + async with open_async_session(key=get_async_db_key()) as db: return await repo_crud.update_repository_remote( db=db, remote_id=remote_id, @@ -169,21 +170,21 @@ async def update_remote(remote_id, remote_data: dict): async def update_platform(platform_data: dict): - async with open_async_session(key="async") as db: + async with open_async_session(key=get_async_db_key()) as db: await pl_crud.modify_platform( db, platform_schema.PlatformModify(**platform_data) ) async def update_repository(repo_id: int, repo_data: dict): - async with open_async_session(key="async") as db: + async with open_async_session(key=get_async_db_key()) as db: await repo_crud.update_repository( db, repo_id, repository_schema.RepositoryUpdate(**repo_data) ) async def get_repositories_for_update(platform_name: str): - async with open_async_session(key="async") as db: + async with open_async_session(key=get_async_db_key()) as db: return await repo_crud.get_repositories_by_platform_name( db, platform_name ) @@ -194,7 +195,7 @@ async def add_repositories_to_platform( ): platform_name = platform_data.get("name") platform_instance = None - async with open_async_session(key="async") as db: + async with open_async_session(key=get_async_db_key()) as db: for platform in await pl_crud.get_platforms( db, is_reference=platform_data.get("is_reference", False) ): @@ -227,7 +228,7 @@ def main(): pulp_client = PulpClient(pulp_host, pulp_user, pulp_password) - sync(async_setup('async')) + sync(setup_all()) for platform_data in platforms_data: if args.only_update: diff --git a/scripts/compare_beta_to_stable.py b/scripts/compare_beta_to_stable.py index c3f7be02b..e41a7230a 100644 --- a/scripts/compare_beta_to_stable.py +++ b/scripts/compare_beta_to_stable.py @@ -22,10 +22,11 @@ ) from alws.config import settings # from alws.database import Session +from alws.dependencies import get_async_db_key from alws.utils import pulp_client from alws.utils.debuginfo import is_debuginfo -from alws.utils.fastapi_sqla_setup import async_setup +from alws.utils.fastapi_sqla_setup import setup_all SUPPORTED_ARCHES = ("src", "aarch64", "i686", "ppc64le", @@ -165,7 +166,7 @@ async def get_packages_list( async def run(self, args): # async with Session() as session: - async with open_async_session(key="async") as session: + async with open_async_session(key=get_async_db_key()) as session: stable_repositories = await get_repositories( session, args.stable_platform, arch=args.arch ) @@ -206,7 +207,7 @@ async def main(): logging.FileHandler(f"stable_beta_comparator.{current_ts}.log"), ], ) - await async_setup('async') + await setup_all() comparator = PackagesComparator() await comparator.run(args) diff --git a/scripts/errata_fix_script.py b/scripts/errata_fix_script.py index 6d75c720e..a6cada6a7 100644 --- a/scripts/errata_fix_script.py +++ b/scripts/errata_fix_script.py @@ -14,11 +14,12 @@ # from alws.database import PulpSession, SyncSession from alws.config import settings from alws.constants import ErrataReferenceType, ErrataPackageStatus +from alws.dependencies import get_async_db_key from alws.pulp_models import UpdateRecord, UpdatePackage from alws.models import ErrataRecord, ErrataReference # from alws.dependencies import get_db -from alws.utils.fastapi_sqla_setup import sync_setup, async_setup, DEFAULT_SESSION_KEY +from alws.utils.fastapi_sqla_setup import setup_all logging.basicConfig( @@ -184,7 +185,7 @@ async def update_albs_db(): ErrataReference.title == '', ) ) - async with open_async_session(key="async") as db: + async with open_async_session(key=get_async_db_key()) as db: for reference in (await db.execute(query)).scalars().all(): reference.title = reference.ref_id logging.info("Update albs db is done") @@ -196,8 +197,7 @@ async def main(): update_pulp_db(), update_albs_db(), ] - sync_setup('pulp', DEFAULT_SESSION_KEY) - await async_setup('async') + await setup_all() await asyncio.gather(*tasks) diff --git a/scripts/errata_pkgs_matcher.py b/scripts/errata_pkgs_matcher.py index 422997dfb..b22f936a1 100644 --- a/scripts/errata_pkgs_matcher.py +++ b/scripts/errata_pkgs_matcher.py @@ -13,6 +13,7 @@ from alws.constants import ErrataPackageStatus # from alws.dependencies import get_db +from alws.dependencies import get_async_db_key from alws.models import ( BuildTask, BuildTaskArtifact, @@ -27,7 +28,7 @@ get_uuid_from_pulp_href, ) -from alws.utils.fastapi_sqla_setup import async_setup +from alws.utils.fastapi_sqla_setup import setup_all logging.basicConfig( format="%(message)s", @@ -50,8 +51,8 @@ async def main( advisory_id, ) added = not_found = 0 - await async_setup('async') - async with open_async_session(key="async") as db: + await setup_all() + async with open_async_session(key=get_async_db_key()) as db: advisory = ( ( await db.execute( diff --git a/scripts/fix_releases_products.py b/scripts/fix_releases_products.py index c03b33992..008252015 100644 --- a/scripts/fix_releases_products.py +++ b/scripts/fix_releases_products.py @@ -12,14 +12,14 @@ from alws import models from alws.constants import DEFAULT_PRODUCT -from alws.dependencies import get_db -from alws.utils.fastapi_sqla_setup import async_setup +from alws.dependencies import get_async_db_key +from alws.utils.fastapi_sqla_setup import setup_all async def main(): - await async_setup('async') + await setup_all() # async with asynccontextmanager(get_db)() as db, db.begin(): - async with open_async_session('async') as db: + async with open_async_session(get_async_db_key()) as db: product_id = (await db.execute(select(models.Product.id).where( models.Product.name == DEFAULT_PRODUCT))).scalar() # Assign all previous releases to AlmaLinux product diff --git a/scripts/generate_errata_title.py b/scripts/generate_errata_title.py index c1d679a73..4238c79e1 100644 --- a/scripts/generate_errata_title.py +++ b/scripts/generate_errata_title.py @@ -9,17 +9,17 @@ sys.path.append(os.path.dirname(os.path.dirname(__file__))) from alws import models -# from alws.dependencies import get_db +from alws.dependencies import get_async_db_key -from alws.utils.fastapi_sqla_setup import async_setup +from alws.utils.fastapi_sqla_setup import setup_all async def main(): severity_regex = re.compile('^(Important|Critical|Moderate|Low): ') updated_records = [] - await async_setup('async') - async with open_async_session(key="async") as db: + await setup_all() + async with open_async_session(key=get_async_db_key()) as db: errata_records = (await db.execute( select(models.ErrataRecord))).scalars().all() for record in errata_records: diff --git a/scripts/generate_token.py b/scripts/generate_token.py index fa7e82a22..bd7f1b6ee 100644 --- a/scripts/generate_token.py +++ b/scripts/generate_token.py @@ -11,11 +11,11 @@ sys.path.append(os.path.dirname(os.path.dirname(__file__))) # from alws.dependencies import get_db -from alws.auth.dependencies import get_user_db +from alws.auth.dependencies import get_user_db, get_async_db_key from alws.auth.user_manager import get_user_manager from alws.models import User -from alws.utils.fastapi_sqla_setup import async_setup +from alws.utils.fastapi_sqla_setup import setup_all def parse_args(): @@ -32,10 +32,10 @@ def parse_args(): async def gen_token(secret: str, email: str = None, username: str = None): - await async_setup('async') + await setup_all strategy = JWTStrategy(secret, lifetime_seconds=1 * 31557600) - get_async_session_context = open_async_session(key='async') + get_async_session_context = open_async_session(key=get_async_db_key()) get_user_db_context = contextlib.asynccontextmanager(get_user_db) get_user_manager_context = contextlib.asynccontextmanager(get_user_manager) async with get_async_session_context() as session: diff --git a/scripts/manage_flavours.py b/scripts/manage_flavours.py index 1e1b153ab..253ec0108 100644 --- a/scripts/manage_flavours.py +++ b/scripts/manage_flavours.py @@ -13,9 +13,11 @@ # from alws import database from alws.crud import platform_flavors as pf_crud from alws.crud import repository as repo_crud +from alws.dependencies import get_async_db_key + from alws.schemas import platform_flavors_schema, repository_schema -from alws.utils.fastapi_sqla_setup import async_setup +from alws.utils.fastapi_sqla_setup import setup_all def parse_args(): @@ -65,7 +67,7 @@ def parse_args(): async def update_flavour(flavour_data: dict, logger: logging.Logger): - async with open_async_session(key="async") as db: + async with open_async_session(key=get_async_db_key()) as db: data = platform_flavors_schema.UpdateFlavour(**flavour_data) flavor = await pf_crud.update_flavour(db, data) if not flavor: @@ -77,7 +79,7 @@ async def update_flavour(flavour_data: dict, logger: logging.Logger): async def prune_flavours( flavours_data: [], logger: logging.Logger, confirmation_yes ): - async with open_async_session(key="async") as db: + async with open_async_session(key=get_async_db_key()) as db: flavour_names_in_config = [e.get('name') for e in flavours_data] flavours_in_db = await pf_crud.list_flavours(db) @@ -119,7 +121,7 @@ async def prune_flavours( async def add_flavor(flavor_data: dict, logger: logging.Logger): - async with open_async_session(key="async") as db: + async with open_async_session(key=get_async_db_key()) as db: flavour = await pf_crud.find_flavour_by_name(db, flavor_data["name"]) if flavour: logger.error("Flavor %s is already added", flavor_data["name"]) @@ -142,7 +144,7 @@ def main(): loader = yaml.Loader(f) flavours_data = loader.get_data() - sync(async_setup('async')) + sync(setup_all()) if args.prune: logger.info("Start to prune") diff --git a/scripts/manage_users.py b/scripts/manage_users.py index 7e43d25f6..4805341c3 100644 --- a/scripts/manage_users.py +++ b/scripts/manage_users.py @@ -17,10 +17,10 @@ roles as role_crud, teams as team_crud, user as user_crud, ) -# from alws.dependencies import get_db +from alws.dependencies import get_async_db_key from alws.perms.roles import RolesList -from alws.utils.fastapi_sqla_setup import async_setup +from alws.utils.fastapi_sqla_setup import setup_all def parse_args(): @@ -53,8 +53,8 @@ def parse_args(): async def main() -> int: arguments = parse_args() - await async_setup("async") - async with open_async_session(key="async") as db: + await setup_all() + async with open_async_session(key=get_async_db_key()) as db: user = (await db.execute(select(models.User).where( models.User.email == arguments.email).options( selectinload(models.User.roles), diff --git a/scripts/migrate_old_distros.py b/scripts/migrate_old_distros.py index 6d159302d..cf4db67d6 100644 --- a/scripts/migrate_old_distros.py +++ b/scripts/migrate_old_distros.py @@ -10,7 +10,7 @@ from alws import models -from alws.utils.fastapi_sqla_setup import sync_setup, DEFAULT_SESSION_KEY +from alws.utils.fastapi_sqla_setup import sync_setup # from alws.database import SyncSession @@ -60,5 +60,5 @@ def migrate_old_records(): if __name__ == '__main__': - sync_setup(DEFAULT_SESSION_KEY) + sync_setup() migrate_old_records() diff --git a/scripts/migrate_pulp_modularity.py b/scripts/migrate_pulp_modularity.py index 6fba061fd..db5b28ca9 100644 --- a/scripts/migrate_pulp_modularity.py +++ b/scripts/migrate_pulp_modularity.py @@ -25,7 +25,7 @@ from alws.utils.parsing import parse_rpm_nevra from alws.utils.pulp_client import PulpClient -from alws.utils.fastapi_sqla_setup import sync_setup +from alws.utils.fastapi_sqla_setup import setup_all ROOT_FOLDER = '/srv/pulp/media/' @@ -353,7 +353,7 @@ async def main(): logging.StreamHandler(), ], ) - sync_setup('pulp') + await setup_all() time1 = time.time() step = 100 pulp_host = os.environ["PULP_HOST"] diff --git a/scripts/move_logs_to_new_repos.py b/scripts/move_logs_to_new_repos.py index 2e6eb9fd9..6210aa612 100644 --- a/scripts/move_logs_to_new_repos.py +++ b/scripts/move_logs_to_new_repos.py @@ -17,6 +17,7 @@ from alws.config import settings # from alws.dependencies import get_db, get_pulp_db +from alws.dependencies import get_async_db_key from alws.models import ( Build, BuildTask, @@ -32,7 +33,7 @@ from alws.utils import pulp_client from alws.utils.file_utils import hash_content -from alws.utils.fastapi_sqla_setup import sync_setup, async_setup +from alws.utils.fastapi_sqla_setup import setup_all logging.basicConfig( format="%(asctime)s %(levelname)-8s %(message)s", @@ -55,8 +56,7 @@ async def main(): settings.pulp_password, ) - await async_setup('async') - sync_setup('pulp') + await setup_all() def get_log_names_from_repo(repo: Repository): result = {} @@ -99,7 +99,7 @@ async def safe_delete(href: str): href, ) - async with open_async_session(key="async") as session: + async with open_async_session(key=get_async_db_key()) as session: with open_session(key="pulp") as pulp_session: builds = ( ( diff --git a/scripts/noarch_checker.py b/scripts/noarch_checker.py index 9eeda92bb..added606f 100644 --- a/scripts/noarch_checker.py +++ b/scripts/noarch_checker.py @@ -18,10 +18,10 @@ from alws.models import Platform, Product, Repository from alws.config import settings -# from alws.dependencies import get_db +from alws.dependencies import get_async_db_key from alws.utils import pulp_client -from alws.utils.fastapi_sqla_setup import async_setup +from alws.utils.fastapi_sqla_setup import setup_all class NoarchProcessor: @@ -331,9 +331,9 @@ def parse_args(): async def main(): args = parse_args() pulp_client.PULP_SEMAPHORE = asyncio.Semaphore(10) - await async_setup('async') + await setup_all() # async with asynccontextmanager(get_db)() as session: - async with open_async_session('async') as session: + async with open_async_session(key=get_async_db_key()) as session: processor = NoarchProcessor( session=session, source_obj_name=args.source, diff --git a/scripts/packages_exporter.py b/scripts/packages_exporter.py index 731ed9c5e..5bd47f174 100644 --- a/scripts/packages_exporter.py +++ b/scripts/packages_exporter.py @@ -37,6 +37,7 @@ from alws import models from alws.config import settings from alws.constants import SignStatusEnum +from alws.dependencies import get_async_db_key from alws.utils.errata import ( extract_errata_metadata, extract_errata_metadata_modern, @@ -50,7 +51,7 @@ from alws.utils.osv import export_errata_to_osv from alws.utils.pulp_client import PulpClient -from alws.utils.fastapi_sqla_setup import async_setup +from alws.utils.fastapi_sqla_setup import setup_all KNOWN_SUBKEYS_CONFIG = os.path.abspath( os.path.expanduser("~/config/known_subkeys.json") @@ -357,7 +358,7 @@ async def get_exporter_data( repo_exporter_dict["publication_href"] = publication_href return fs_exporter_href, repo_exporter_dict - async with open_async_session(key="async") as db: + async with open_async_session(key=get_async_db_key()) as db: query = select(models.Repository).where( models.Repository.id.in_(repository_ids) ) @@ -645,7 +646,7 @@ async def export_repos_from_pulp( selectinload(models.Platform.sign_keys), ) ) - async with open_async_session(key="async") as db: + async with open_async_session(key=get_async_db_key()) as db: db_platforms = await db.execute(query) db_platforms = db_platforms.scalars().all() @@ -687,7 +688,7 @@ async def export_repos_from_release( self.logger.info( "Start exporting packages from release id=%s", release_id ) - async with open_async_session(key="async") as db: + async with open_async_session(key=get_async_db_key()) as db: db_release = await db.execute( select(models.Release).where(models.Release.id == release_id) ) @@ -867,7 +868,7 @@ def repo_post_processing(exporter: Exporter, repo_path: str): def main(): args = parse_args() init_sentry() - sync(async_setup('async')) + sync(setup_all()) platforms_dict = {} key_id_by_platform = None diff --git a/scripts/remove_unnecessary_versions_of_repositories.py b/scripts/remove_unnecessary_versions_of_repositories.py index e8edf1700..dba861e82 100644 --- a/scripts/remove_unnecessary_versions_of_repositories.py +++ b/scripts/remove_unnecessary_versions_of_repositories.py @@ -33,13 +33,14 @@ # from alws import database from alws import models from alws.crud import build as build_crud +from alws.dependencies import get_async_db_key from alws.utils.pulp_client import PulpClient from alws.errors import ( BuildError, DataNotFoundError, ) -from alws.utils.fastapi_sqla_setup import async_setup +from alws.utils.fastapi_sqla_setup import setup_all def parse_args(): @@ -60,7 +61,7 @@ def parse_args(): # Get old unsigned, unreleased and unrelated builds async def get_old_unsigned_builds(logger: logging.Logger): - async with open_async_session(key="async") as db: + async with open_async_session(key=get_async_db_key()) as db: build_dependency = select( models.BuildDependency.c.build_dependency ).scalar_subquery() @@ -86,7 +87,7 @@ async def get_old_unsigned_builds(logger: logging.Logger): async def remove_builds(builds: list, logger: logging.Logger): for build in builds: - async with open_async_session(key="async") as db: + async with open_async_session(key=get_async_db_key()) as db: try: logger.debug("Delete build with id: %s", build.id) await build_crud.remove_build_job(db, build.id) @@ -168,7 +169,7 @@ async def main(): logging.basicConfig(level=logging.DEBUG) else: logging.basicConfig(level=logging.INFO) - await async_setup('async') + await setup_all() pulp_client = PulpClient(pulp_host, pulp_user, pulp_password) logger.info("Get old unsigned builds") builds = await get_old_unsigned_builds(logger) diff --git a/scripts/wrong_cas_hash_fixes.py b/scripts/wrong_cas_hash_fixes.py index d15ea89a0..061b9ea9f 100644 --- a/scripts/wrong_cas_hash_fixes.py +++ b/scripts/wrong_cas_hash_fixes.py @@ -13,11 +13,11 @@ from alws.models import Build, BuildTask, BuildTaskArtifact from alws.pulp_models import CoreArtifact, CoreContentArtifact -from alws.utils.fastapi_sqla_setup import sync_setup, DEFAULT_SESSION_KEY +from alws.utils.fastapi_sqla_setup import sync_setup def main(): - sync_setup('pulp', DEFAULT_SESSION_KEY) + sync_setup() first_subq = ( select(Build.id) diff --git a/tests/fixtures/builds.py b/tests/fixtures/builds.py index 498d49b16..eb00c08d0 100644 --- a/tests/fixtures/builds.py +++ b/tests/fixtures/builds.py @@ -4,6 +4,8 @@ import pytest from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm.session import Session +from fastapi_sqla import open_async_session from alws.crud.build import create_build, get_builds from alws.dramatiq.build import _start_build @@ -948,66 +950,77 @@ def build_payload() -> typing.Dict[str, typing.Any]: @pytest.mark.anyio @pytest.fixture async def modular_build( - session: AsyncSession, + async_session: AsyncSession, modular_build_payload: dict, ) -> typing.AsyncIterable[Build]: - yield await create_build( - session, + build = await create_build( + async_session, BuildCreate(**modular_build_payload), user_id=ADMIN_USER_ID, ) + await async_session.commit() + yield build @pytest.mark.anyio @pytest.fixture async def virt_modular_build( - session: AsyncSession, + async_session: AsyncSession, virt_build_payload: dict, ) -> typing.AsyncIterable: - yield await create_build( - session, + build = await create_build( + async_session, BuildCreate(**virt_build_payload), user_id=ADMIN_USER_ID, ) + await async_session.commit() + yield build @pytest.mark.anyio @pytest.fixture async def ruby_modular_build( - session: AsyncSession, + async_session: AsyncSession, ruby_build_payload: dict, ) -> typing.AsyncIterable: - yield await create_build( - session, + # async with open_async_session('test') as async_session: + build = await create_build( + async_session, BuildCreate(**ruby_build_payload), user_id=ADMIN_USER_ID, ) + await async_session.commit() + yield build @pytest.mark.anyio @pytest.fixture async def subversion_modular_build( - session: AsyncSession, + async_session: AsyncSession, subversion_build_payload: dict, ) -> typing.AsyncIterable: - yield await create_build( - session, + build = await create_build( + async_session, BuildCreate(**subversion_build_payload), user_id=ADMIN_USER_ID, ) + await async_session.commit() + yield build @pytest.mark.anyio @pytest.fixture async def llvm_modular_build( - session: AsyncSession, + async_session: AsyncSession, llvm_build_payload: dict, ) -> typing.AsyncIterable: - yield await create_build( - session, + build = await create_build( + async_session, BuildCreate(**llvm_build_payload), user_id=ADMIN_USER_ID, ) + await async_session.commit() + yield build @pytest.mark.anyio @@ -1015,20 +1028,22 @@ async def llvm_modular_build( async def regular_build( base_platform, base_product, - session: AsyncSession, + async_session: AsyncSession, build_payload: dict, ) -> typing.AsyncIterable[Build]: - yield await create_build( - session, + build = await create_build( + async_session, BuildCreate(**build_payload), user_id=ADMIN_USER_ID, ) + await async_session.commit() + yield build @pytest.mark.anyio @pytest.fixture async def regular_build_with_user_product( - session: AsyncSession, + async_session: AsyncSession, build_payload: dict, create_build_rpm_repo, create_log_repo, @@ -1037,7 +1052,7 @@ async def regular_build_with_user_product( payload = copy.deepcopy(build_payload) user_product_id = ( ( - await session.execute( + await async_session.execute( select(Product.id).where(Product.is_community.is_(True)) ) ) @@ -1046,12 +1061,13 @@ async def regular_build_with_user_product( ) payload['product_id'] = user_product_id build = await create_build( - session, + async_session, BuildCreate(**payload), user_id=ADMIN_USER_ID, ) + await async_session.commit() await _start_build(build.id, BuildCreate(**payload)) - yield await get_builds(session, build_id=build.id) + yield await get_builds(async_session, build_id=build.id) @pytest.fixture @@ -1077,16 +1093,16 @@ async def func(arg, arg2): @pytest.mark.anyio @pytest.fixture async def build_for_release( - session: AsyncSession, + async_session: AsyncSession, regular_build: Build, ) -> typing.AsyncIterable[Build]: - yield await get_builds(session, build_id=regular_build.id) + yield await get_builds(async_session, build_id=regular_build.id) @pytest.mark.anyio @pytest.fixture async def modular_build_for_release( - session: AsyncSession, + async_session: AsyncSession, modular_build: Build, ) -> typing.AsyncIterable[Build]: - yield await get_builds(session, build_id=modular_build.id) + yield await get_builds(async_session, build_id=modular_build.id) diff --git a/tests/fixtures/database.py b/tests/fixtures/database.py index f61acb948..b0a37ca66 100644 --- a/tests/fixtures/database.py +++ b/tests/fixtures/database.py @@ -1,16 +1,24 @@ import typing from contextlib import asynccontextmanager +from unittest.mock import patch import pytest -from sqlalchemy import insert, select +from sqlalchemy import insert, select, delete from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +# from sqlalchemy.orm.session import Session +from sqlalchemy.orm.session import sessionmaker from sqlalchemy.pool import NullPool +from fastapi_sqla import open_async_session from alws import models from alws.config import settings from alws.database import Base +from alws.utils.fastapi_sqla_setup import setup_all from tests.constants import ADMIN_USER_ID, CUSTOM_USER_ID + +TEST_DB_KEY = 'test' + engine = create_async_engine( settings.test_database_url, poolclass=NullPool, @@ -18,22 +26,31 @@ ) -async def get_session(): - async with AsyncSession( - engine, - expire_on_commit=False, - ) as sess: - try: - yield sess - finally: - await sess.close() +@pytest.fixture +def async_session_factory(): + """Fastapi-sqla async_session_factory() fixture overload, disabling expire_on_commit.""" + return sessionmaker(class_=AsyncSession, expire_on_commit=False) + + +@pytest.fixture(autouse=True) +def patch_db_key() -> str: + with patch("alws.dependencies.get_async_db_key") as get_async_db_key: + get_async_db_key.return_value = TEST_DB_KEY + yield @pytest.mark.anyio @pytest.fixture -async def session() -> typing.AsyncIterator[AsyncSession]: - async with asynccontextmanager(get_session)() as db_session: - yield db_session +async def async_session( + async_sqla_connection, + async_session_factory, + async_sqla_reflection, + # patch_new_engine +): + """Fastapi-sqla async_session() fixture overload.""" + session = async_session_factory(bind=async_sqla_connection) + yield session + await session.close() def get_user_data(): @@ -55,7 +72,7 @@ def get_user_data(): ] -async def create_user(data: dict): +async def create_user(async_session: AsyncSession, data: dict): data = { "id": data["id"], "username": data["username"], @@ -63,14 +80,13 @@ async def create_user(data: dict): "is_superuser": data["is_superuser"], "is_verified": data["is_verified"], } - async with asynccontextmanager(get_session)() as db_session: - user = await db_session.execute( - select(models.User).where(models.User.id == data["id"]), - ) - if user.scalars().first(): - return - await db_session.execute(insert(models.User).values(**data)) - await db_session.commit() + user = await async_session.execute( + select(models.User).where(models.User.id == data["id"]), + ) + if user.scalars().first(): + return + await async_session.execute(insert(models.User).values(**data)) + await async_session.commit() @pytest.mark.anyio @@ -78,8 +94,46 @@ async def create_user(data: dict): async def create_tables(): async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) - for user_data in get_user_data(): - await create_user(user_data) + + await setup_all() + async with open_async_session('test') as async_session: + for user_data in get_user_data(): + await create_user(async_session, user_data) yield + async with engine.begin() as conn: await conn.run_sync(Base.metadata.drop_all) + +@pytest.fixture +def sqla_modules(): + from alws.models import ( + Build, + BuildTask, + ErrataRecord, + NewErrataRecord, + Platform, + SignKey, + SignTask, + User, + UserAccessToken, + UserAction, + UserOauthAccount, + UserRole, + Team, + TestRepository, + ) # noqa + +@pytest.fixture(scope="session") +def db_url(): + """Fastapi-sqla fixture. Sync database url.""" + return settings.sqlalchemy_url + +@pytest.fixture(scope="session") +def async_sqlalchemy_url(): + """Fastapi-sqla fixture. Async database url.""" + return settings.test_database_url + +@pytest.fixture(scope="session") +def alembic_ini_path(): + """Fastapi-sqla fixture. Path for alembic.ini file.""" + return "./alws/alembic.ini" diff --git a/tests/fixtures/dramatiq.py b/tests/fixtures/dramatiq.py index 667f075d1..5f2b39cb8 100644 --- a/tests/fixtures/dramatiq.py +++ b/tests/fixtures/dramatiq.py @@ -136,7 +136,7 @@ def prepare_build_done_payload( @pytest.mark.anyio @pytest.fixture async def build_done( - session: AsyncSession, + async_session: AsyncSession, regular_build: Build, start_build, create_entity, @@ -144,11 +144,11 @@ async def build_done( mock_get_pulp_packages, get_packages_info_from_pulp, ): - build = await get_builds(db=session, build_id=regular_build.id) - await session.close() + build = await get_builds(db=async_session, build_id=regular_build.id) + await async_session.close() for build_task in build.tasks: await safe_build_done( - session, + async_session, BuildDone( **prepare_build_done_payload( build_task.id, @@ -159,17 +159,18 @@ async def build_done( ) ), ) - build = await get_builds(db=session, build_id=regular_build.id) + await async_session.commit() + build = await get_builds(db=async_session, build_id=regular_build.id) for build_task in build.tasks: assert build_task.status == BuildTaskStatus.COMPLETED - await session.close() - await test.create_test_tasks_for_build_id(session, build.id) + await async_session.close() + await test.create_test_tasks_for_build_id(async_session, build.id) @pytest.mark.anyio @pytest.fixture async def modular_build_done( - session: AsyncSession, + async_session: AsyncSession, modular_build: Build, start_modular_build, create_entity, @@ -177,11 +178,11 @@ async def modular_build_done( get_repo_modules_yaml, get_repo_modules, ): - build = await get_builds(db=session, build_id=modular_build.id) - await session.close() + build = await get_builds(db=async_session, build_id=modular_build.id) + await async_session.close() for build_task in build.tasks: await safe_build_done( - session, + async_session, BuildDone( **prepare_build_done_payload( build_task.id, @@ -197,7 +198,7 @@ async def modular_build_done( @pytest.mark.anyio @pytest.fixture async def virt_build_done( - session: AsyncSession, + async_session: AsyncSession, virt_modular_build: Build, modify_repository, start_modular_virt_build, @@ -206,8 +207,8 @@ async def virt_build_done( get_repo_virt_modules_yaml, get_repo_modules, ): - build = await get_builds(db=session, build_id=virt_modular_build.id) - await session.close() + build = await get_builds(db=async_session, build_id=virt_modular_build.id) + await async_session.close() for build_task in build.tasks: status = "done" packages = [] @@ -244,7 +245,7 @@ async def virt_build_done( status = "excluded" await safe_build_done( - session, + async_session, BuildDone( **prepare_build_done_payload( build_task.id, @@ -258,7 +259,7 @@ async def virt_build_done( @pytest.mark.anyio @pytest.fixture async def ruby_build_done( - session: AsyncSession, + async_session: AsyncSession, ruby_modular_build: Build, modify_repository, start_modular_ruby_build, @@ -267,8 +268,8 @@ async def ruby_build_done( get_repo_ruby_modules_yaml, get_repo_modules, ): - build = await get_builds(db=session, build_id=ruby_modular_build.id) - await session.close() + build = await get_builds(db=async_session, build_id=ruby_modular_build.id) + await async_session.close() for build_task in build.tasks: packages = [ "ruby-3.1.2-141.module_el8.1.0+8+503f6fbd.src.rpm", @@ -288,7 +289,7 @@ async def ruby_build_done( "rubygem-pg-doc-3.3.7-141.module_el8.1.0+8+503f6fbd.noarch.rpm", ] await safe_build_done( - session, + async_session, BuildDone(**prepare_build_done_payload(build_task.id, packages)), ) @@ -296,7 +297,7 @@ async def ruby_build_done( @pytest.mark.anyio @pytest.fixture async def subversion_build_done( - session: AsyncSession, + async_session: AsyncSession, subversion_modular_build: Build, modify_repository, start_modular_subversion_build, @@ -305,8 +306,8 @@ async def subversion_build_done( get_repo_subversion_modules_yaml, get_repo_modules, ): - build = await get_builds(db=session, build_id=subversion_modular_build.id) - await session.close() + build = await get_builds(db=async_session, build_id=subversion_modular_build.id) + await async_session.close() for build_task in build.tasks: packages = [ "subversion-1.10.2-5.module_el8.6.0+3347+66c1e1d6.src.rpm", @@ -317,7 +318,7 @@ async def subversion_build_done( f"subversion-ruby-1.10.2-5.module_el8.6.0+3347+66c1e1d6.{build_task.arch}.rpm", ] await safe_build_done( - session, + async_session, BuildDone(**prepare_build_done_payload(build_task.id, packages)), ) @@ -325,7 +326,7 @@ async def subversion_build_done( @pytest.mark.anyio @pytest.fixture async def llvm_build_done( - session: AsyncSession, + async_session: AsyncSession, llvm_modular_build: Build, modify_repository, start_modular_llvm_build, @@ -334,8 +335,8 @@ async def llvm_build_done( get_repo_llvm_modules_yaml, get_repo_modules, ): - build = await get_builds(db=session, build_id=llvm_modular_build.id) - await session.close() + build = await get_builds(db=async_session, build_id=llvm_modular_build.id) + await async_session.close() for build_task in build.tasks: packages = [] if "python" in build_task.ref.url: @@ -349,6 +350,6 @@ async def llvm_build_done( f"llvm-13.0.1-1.module+el8.6.0+14118+d530a951.{build_task.arch}.rpm", ] await safe_build_done( - session, + async_session, BuildDone(**prepare_build_done_payload(build_task.id, packages)), ) diff --git a/tests/fixtures/errata.py b/tests/fixtures/errata.py index 3fa0d3592..56ebc4060 100644 --- a/tests/fixtures/errata.py +++ b/tests/fixtures/errata.py @@ -107,13 +107,16 @@ def func(*args, **kwargs): @pytest.mark.anyio @pytest.fixture async def create_errata( - session: AsyncSession, + async_session: AsyncSession, errata_create_payload: typing.Dict[str, typing.Any], ): await create_errata_record( - session, + async_session, BaseErrataRecord(**errata_create_payload), ) + await async_session.commit() + yield + await async_session.rollback() @pytest.fixture diff --git a/tests/fixtures/platforms.py b/tests/fixtures/platforms.py index d13d91d7d..57f01bea2 100644 --- a/tests/fixtures/platforms.py +++ b/tests/fixtures/platforms.py @@ -13,7 +13,7 @@ @pytest.mark.anyio @pytest.fixture async def base_platform( - session: AsyncSession, + async_session: AsyncSession, ) -> AsyncIterable[models.Platform]: with open("reference_data/platforms.yaml", "rt") as file: loader = yaml.Loader(file) @@ -22,7 +22,7 @@ async def base_platform( schema["repos"] = [] platform = ( ( - await session.execute( + await async_session.execute( select(models.Platform).where( models.Platform.name == schema["name"], ) @@ -40,6 +40,6 @@ async def base_platform( **repository_schema.RepositoryCreate(**repo).model_dump() ) platform.repos.append(repository) - session.add(platform) - await session.commit() + async_session.add(platform) + await async_session.commit() yield platform diff --git a/tests/fixtures/products.py b/tests/fixtures/products.py index 33267bb71..48a38f680 100644 --- a/tests/fixtures/products.py +++ b/tests/fixtures/products.py @@ -75,11 +75,13 @@ def user_product_create_payload(request) -> dict: @pytest.mark.anyio @pytest.fixture async def base_product( - session: AsyncSession, product_create_payload: dict, create_repo + async_session: AsyncSession, + product_create_payload: dict, + create_repo ) -> AsyncIterable[Product]: product = ( ( - await session.execute( + await async_session.execute( select(Product).where( Product.name == product_create_payload["name"], ), @@ -90,23 +92,24 @@ async def base_product( ) if not product: product = await create_product( - session, + async_session, ProductCreate(**product_create_payload), ) + await async_session.commit() yield product @pytest.mark.anyio @pytest.fixture async def user_product( - session: AsyncSession, + async_session: AsyncSession, user_product_create_payload: dict, create_repo, create_file_repository, ) -> AsyncIterable[Product]: product = ( ( - await session.execute( + await async_session.execute( select(Product).where( Product.name == user_product_create_payload["name"], ), @@ -117,9 +120,8 @@ async def user_product( ) if not product: product = await create_product( - session, + async_session, ProductCreate(**user_product_create_payload), ) - session.add(product) - await session.commit() + await async_session.commit() yield product diff --git a/tests/fixtures/sign_keys.py b/tests/fixtures/sign_keys.py index 0df831f98..501900edd 100644 --- a/tests/fixtures/sign_keys.py +++ b/tests/fixtures/sign_keys.py @@ -3,6 +3,7 @@ import pytest from sqlalchemy import delete, select from sqlalchemy.ext.asyncio import AsyncSession +from fastapi_sqla import open_async_session from alws.crud.sign_key import create_sign_key from alws.models import SignKey @@ -20,9 +21,9 @@ def basic_sign_key_payload() -> dict: } -async def __create_sign_key(session: AsyncSession, payload: dict) -> SignKey: - await create_sign_key(session, SignKeyCreate(**payload)) - sign_key_cursor = await session.execute( +async def __create_sign_key(async_session: AsyncSession, payload: dict) -> SignKey: + await create_sign_key(async_session, SignKeyCreate(**payload)) + sign_key_cursor = await async_session.execute( select(SignKey).where(SignKey.keyid == payload['keyid']) ) sign_key = sign_key_cursor.scalars().first() @@ -32,10 +33,11 @@ async def __create_sign_key(session: AsyncSession, payload: dict) -> SignKey: @pytest.mark.anyio @pytest.fixture async def sign_key( - session: AsyncSession, + async_session: AsyncSession, basic_sign_key_payload, ) -> typing.AsyncIterable[SignKey]: - sign_key = await __create_sign_key(session, basic_sign_key_payload) + sign_key = await __create_sign_key(async_session, basic_sign_key_payload) + await async_session.commit() yield sign_key - await session.execute(delete(SignKey)) - await session.commit() + await async_session.execute(delete(SignKey)) + await async_session.commit() diff --git a/tests/mock_classes.py b/tests/mock_classes.py index 19e9dc40f..f375a5074 100644 --- a/tests/mock_classes.py +++ b/tests/mock_classes.py @@ -10,7 +10,7 @@ from alws.dependencies import get_db from alws.utils import jwt_utils from tests.constants import ADMIN_USER_ID -from tests.fixtures.database import get_session +# from tests.fixtures.database import get_session @pytest.mark.anyio @@ -59,7 +59,7 @@ def generate_jwt_token( @classmethod def setup_class(cls): - app.dependency_overrides[get_db] = get_session + # app.dependency_overrides[get_db] = get_session # not get_db, we're using fastapi_sqla dependency cls.token = cls.generate_jwt_token(str(cls.user_id)) cls.headers.update( { diff --git a/tests/test_api/test_products.py b/tests/test_api/test_products.py index 369d8921d..6919258a7 100644 --- a/tests/test_api/test_products.py +++ b/tests/test_api/test_products.py @@ -33,7 +33,7 @@ async def test_add_to_product( self, regular_build: Build, user_product: Product, - session: AsyncSession, + async_session: AsyncSession, ): product_id = user_product.id product_name = user_product.name @@ -54,7 +54,7 @@ async def test_add_to_product( await _perform_product_modification(build_id, product_id, "add") db_product = ( ( - await session.execute( + await async_session.execute( select(Product) .where(Product.id == product_id) .options(selectinload(Product.builds)) @@ -69,7 +69,7 @@ async def test_add_to_product( async def test_remove_from_product( self, user_product: Product, - session: AsyncSession, + async_session: AsyncSession, ): product_id = user_product.id product_name = user_product.name @@ -86,7 +86,7 @@ async def test_remove_from_product( await _perform_product_modification(build_id, product_id, "remove") db_product = ( ( - await session.execute( + await async_session.execute( select(Product) .where(Product.id == product_id) .options(selectinload(Product.builds)) @@ -101,7 +101,7 @@ async def test_remove_from_product( async def test_user_product_remove_when_build_is_running( self, - session: AsyncSession, + async_session: AsyncSession, user_product: Product, regular_build_with_user_product: Build, ): @@ -112,9 +112,9 @@ async def test_user_product_remove_when_build_is_running( ), response.text # we need to delete active build for further product deletion for task in regular_build_with_user_product.tasks: - await session.delete(task) - await session.delete(regular_build_with_user_product) - await session.commit() + await async_session.delete(task) + await async_session.delete(regular_build_with_user_product) + await async_session.commit() async def test_user_product_remove( self, diff --git a/tests/test_api/test_releases.py b/tests/test_api/test_releases.py index 97a1e4052..d448283b8 100644 --- a/tests/test_api/test_releases.py +++ b/tests/test_api/test_releases.py @@ -71,7 +71,7 @@ async def test_create_community_release( async def test_commit_release( self, - session: AsyncSession, + async_session: AsyncSession, base_product: models.Product, disable_packages_check_in_prod_repos, disable_sign_verify, @@ -95,7 +95,7 @@ async def test_commit_release( ) message = f"Cannot commit release:\n{response.text}" assert response.status_code == self.status_codes.HTTP_200_OK, message - await commit_release(session, release_id, self.user_id) + await commit_release(async_session, release_id, self.user_id) response = await self.make_request( "get", f"/api/v1/releases/{release_id}/", @@ -106,7 +106,7 @@ async def test_commit_release( async def test_commit_community_release( self, - session: AsyncSession, + async_session: AsyncSession, user_product: models.Product, modify_repository, create_rpm_publication, @@ -131,7 +131,7 @@ async def test_commit_community_release( ) message = f"Cannot commit release:\n{response.text}" assert response.status_code == self.status_codes.HTTP_200_OK, message - await commit_release(session, release_id, self.user_id) + await commit_release(async_session, release_id, self.user_id) response = await self.make_request( "get", f"/api/v1/releases/{release_id}/", @@ -158,7 +158,7 @@ async def test_get_release( async def test_revert_release( self, - session: AsyncSession, + async_session: AsyncSession, base_product: models.Product, modify_repository, create_rpm_publication, @@ -174,7 +174,7 @@ async def test_revert_release( for row in response.json() if row["product"]["id"] == base_product.id )["id"] - await revert_release(session, release_id, self.user_id) + await revert_release(async_session, release_id, self.user_id) response = await self.make_request( "get", f"/api/v1/releases/{release_id}/", @@ -184,7 +184,7 @@ async def test_revert_release( assert release["status"] == ReleaseStatus.REVERTED, last_log builds = ( ( - await session.execute( + await async_session.execute( select(models.Build).where( models.Build.release_id == release_id, ), @@ -198,7 +198,7 @@ async def test_revert_release( pkg_dict.get("package", {}).get("artifact_href", "") for pkg_dict in release["plan"].get("packages", []) ] - errata_pkgs = await session.execute( + errata_pkgs = await async_session.execute( select(models.NewErrataToALBSPackage).where( models.NewErrataToALBSPackage.status == ErrataPackageStatus.released, @@ -221,7 +221,7 @@ async def test_revert_release( async def test_revert_community_release( self, - session: AsyncSession, + async_session: AsyncSession, user_product: models.Product, modify_repository, create_rpm_publication, @@ -237,7 +237,7 @@ async def test_revert_community_release( for row in response.json() if row["product"]["id"] == user_product.id )["id"] - await revert_release(session, release_id, self.user_id) + await revert_release(async_session, release_id, self.user_id) response = await self.make_request( "get", f"/api/v1/releases/{release_id}/", @@ -247,7 +247,7 @@ async def test_revert_community_release( assert release["status"] == ReleaseStatus.REVERTED, last_log builds = ( ( - await session.execute( + await async_session.execute( select(models.Build).where( models.Build.release_id == release_id, ), diff --git a/tests/test_api/test_uploads.py b/tests/test_api/test_uploads.py index 6b2d25492..2b27bb9af 100644 --- a/tests/test_api/test_uploads.py +++ b/tests/test_api/test_uploads.py @@ -43,7 +43,7 @@ async def test_module_upload_prod_repo( async def test_module_upload_build_repo( self, - session: AsyncSession, + async_session: AsyncSession, modules_yaml: bytes, base_platform, base_product, @@ -61,7 +61,7 @@ async def test_module_upload_build_repo( rpm_modules = ( ( - await session.execute( + await async_session.execute( select(RpmModule).where( RpmModule.id.in_( select(BuildTask.rpm_module_id) diff --git a/tests/test_unit/test_products.py b/tests/test_unit/test_products.py index 479382242..b8ee8dd1c 100644 --- a/tests/test_unit/test_products.py +++ b/tests/test_unit/test_products.py @@ -189,7 +189,7 @@ async def test_group_tasks_by_ref_id(self, build_tasks, expected): @pytest.fixture() async def create_build_and_artifacts( self, - session: AsyncSession, + async_session: AsyncSession, base_platform, base_product, create_build_rpm_repo, @@ -197,13 +197,14 @@ async def create_build_and_artifacts( modify_repository, ) -> Build: created_build = await create_build( - session, BuildCreate(**build), user_id=ADMIN_USER_ID + async_session, BuildCreate(**build), user_id=ADMIN_USER_ID ) + await async_session.commit() await _start_build(created_build.id, BuildCreate(**build)) db_build = ( ( - await session.execute( + await async_session.execute( select(Build) .where(Build.id == created_build.id) .options(selectinload(Build.tasks)) @@ -215,18 +216,18 @@ async def create_build_and_artifacts( for task, artifact in zip(db_build.tasks, build_task_artifacts): artifact["build_task_id"] = task.id - await session.execute(insert(BuildTaskArtifact).values(**artifact)) - await session.commit() + await async_session.execute(insert(BuildTaskArtifact).values(**artifact)) + await async_session.commit() return db_build @pytest.fixture async def tasks_and_expected_output( - self, session: AsyncSession, create_build_and_artifacts, request + self, async_session: AsyncSession, create_build_and_artifacts, request ) -> Tuple[List[BuildTask], List[str]]: db_build = ( ( - await session.execute( + await async_session.execute( select(Build) .where(Build.id == create_build_and_artifacts.id) .options(selectinload(Build.tasks))