diff --git a/diracx-routers/src/diracx/routers/__init__.py b/diracx-routers/src/diracx/routers/__init__.py index 43902ff2c..11e010b46 100644 --- a/diracx-routers/src/diracx/routers/__init__.py +++ b/diracx-routers/src/diracx/routers/__init__.py @@ -77,9 +77,8 @@ def create_app_inner( for entry_point in select_from_extension(group="diracx.access_policies") ] ) - + available_access_policy_names = [] for access_policy_name in available_access_policy_names: - access_policy_classes = BaseAccessPolicy.available_implementations( access_policy_name ) @@ -95,6 +94,11 @@ def create_app_inner( app.dependency_overrides[access_policy_class.check] = partial( check_permissions, access_policy ) + from diracx.routers.job_manager.access_policies import WMSAccessPolicy + + app.dependency_overrides[WMSAccessPolicy.check] = partial( + check_permissions, WMSAccessPolicy() + ) fail_startup = True # Add the SQL DBs to the application diff --git a/diracx-routers/src/diracx/routers/auth/access_policies.py b/diracx-routers/src/diracx/routers/auth/access_policies.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/diracx-routers/src/diracx/routers/job_manager/__init__.py b/diracx-routers/src/diracx/routers/job_manager/__init__.py index dae7275a6..c139702af 100644 --- a/diracx-routers/src/diracx/routers/job_manager/__init__.py +++ b/diracx-routers/src/diracx/routers/job_manager/__init__.py @@ -4,7 +4,7 @@ import logging from datetime import datetime, timezone from http import HTTPStatus -from typing import Annotated, Any, TypedDict +from typing import Annotated, Any, Callable, TypedDict from fastapi import BackgroundTasks, Body, Depends, HTTPException, Query from pydantic import BaseModel, root_validator @@ -32,9 +32,10 @@ from ..auth import AuthorizedUserInfo, verify_dirac_access_token from ..dependencies import JobDB, JobLoggingDB, SandboxMetadataDB, TaskQueueDB from ..fastapi_classes import DiracxRouter -from .access_policies import ActionType, WMSAccessPolicyCallable +from .access_policies import ActionType, WMSAccessPolicy from .sandboxes import router as sandboxes_router +WMSAccessPolicyCallable = Annotated[Callable, Depends(WMSAccessPolicy.check)] MAX_PARAMETRIC_JOBS = 20 logger = logging.getLogger(__name__) diff --git a/diracx-routers/src/diracx/routers/job_manager/access_policies copy.py b/diracx-routers/src/diracx/routers/job_manager/access_policies copy.py new file mode 100644 index 000000000..2729f003f --- /dev/null +++ b/diracx-routers/src/diracx/routers/job_manager/access_policies copy.py @@ -0,0 +1,186 @@ +# from __future__ import annotations + +# import contextlib +# import functools +# import os +# from enum import StrEnum, auto +# from typing import Annotated, AsyncIterator, Callable, Self + +# from fastapi import Depends, HTTPException, status + +# from diracx.core.extensions import select_from_extension +# from diracx.core.properties import JOB_ADMINISTRATOR, NORMAL_USER +# from diracx.db.sql import JobDB + +# from ..auth import AuthorizedUserInfo, verify_dirac_access_token + + +# class ActionType(StrEnum): +# CREATE = auto() +# READ = auto() +# MANAGE = auto() +# QUERY = auto() + + +# async def default_wms_policy( +# user_info: AuthorizedUserInfo, +# /, +# *, +# action: ActionType, +# job_db: JobDB, +# job_ids: list[int] | None = None, +# ): +# """Implement the JobPolicy""" +# if action == ActionType.CREATE: +# if job_ids is not None: +# raise NotImplementedError( +# "job_ids is not None with ActionType.CREATE. This shouldn't happen" +# ) +# if NORMAL_USER not in user_info.properties: +# raise HTTPException(status.HTTP_403_FORBIDDEN) +# return + +# if JOB_ADMINISTRATOR in user_info.properties: +# return + +# if NORMAL_USER not in user_info.properties: +# raise HTTPException(status.HTTP_403_FORBIDDEN) + +# if action == ActionType.QUERY: +# if job_ids is not None: +# raise NotImplementedError( +# "job_ids is not None with ActionType.QUERY. This shouldn't happen" +# ) +# return + +# if job_ids is None: +# raise NotImplementedError("job_ids is None. his shouldn't happen") + +# # TODO: check the CS global job monitoring flag + +# job_owners = await job_db.summary( +# ["Owner", "VO"], +# [{"parameter": "JobID", "operator": "in", "values": job_ids}], +# ) + +# expected_owner = { +# "Owner": user_info.preferred_username, +# "VO": user_info.vo, +# "count": len(set(job_ids)), +# } +# # All the jobs belong to the user doing the query +# # and all of them are present +# if job_owners == [expected_owner]: +# return + +# raise HTTPException(status.HTTP_403_FORBIDDEN) + + +# class BaseAccessPolicy: + +# policy: Callable + +# @classmethod +# def check(cls) -> Self: +# raise NotImplementedError("This should never be called") + +# @contextlib.asynccontextmanager +# async def lifetime_function(self) -> AsyncIterator[None]: +# """A context manager that can be used to run code at startup and shutdown.""" +# yield + +# @classmethod +# def available_implementations( +# cls, access_policy_name: str +# ) -> list[type[BaseAccessPolicy]]: +# """Return the available implementations of the AccessPolicy in reverse priority order.""" +# policy_classes: list[type[BaseAccessPolicy]] = [ +# entry_point.load() +# for entry_point in select_from_extension( +# group="diracx.access_policies", name=access_policy_name +# ) +# ] +# if not policy_classes: +# raise NotImplementedError( +# f"Could not find any matches for {access_policy_name=}" +# ) +# return policy_classes + + +# class WMSAccessPolicy(BaseAccessPolicy): +# policy = staticmethod(default_wms_policy) + + +# def check_permissions( +# access_policy_instance, +# user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], +# ): +# """ +# This is what every route should depend on to check user permissions. + +# It yield an access policy that needs to be checked. +# If this is declared as a dependency but not called +# """ +# has_been_called = False + +# # # TODO: query the CS to find the actual policy +# # policy = default_wms_policy + +# @functools.wraps(access_policy_instance.policy) +# async def wrapped_policy(**kwargs): +# """This wrapper is just to update the has_been_called flag""" +# nonlocal has_been_called +# has_been_called = True +# return await access_policy_instance.policy(user_info, **kwargs) + +# try: +# yield wrapped_policy +# finally: +# if not has_been_called: +# # TODO nice error message with inspect +# # That should really not happen +# print( +# "THIS SHOULD NOT HAPPEN, ALWAYS VERIFY PERMISSION", +# "(PS: I hope you are in a CI)", +# flush=True, +# ) +# os._exit(1) + + +# # def check_permissions_alone( +# # policy, +# # user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], +# # ): +# # """ +# # This is what every route should depend on to check user permissions. + +# # It yield an access policy that needs to be checked. +# # If this is declared as a dependency but not called +# # """ +# # has_been_called = False + +# # # # TODO: query the CS to find the actual policy +# # # policy = default_wms_policy + +# # @functools.wraps(policy) +# # async def wrapped_policy(**kwargs): +# # """This wrapper is just to update the has_been_called flag""" +# # nonlocal has_been_called +# # has_been_called = True +# # return await policy(user_info, **kwargs) + +# # try: +# # yield wrapped_policy +# # finally: +# # if not has_been_called: +# # # TODO nice error message with inspect +# # # That should really not happen +# # print( +# # "THIS SHOULD NOT HAPPEN, ALWAYS VERIFY PERMISSION", +# # "(PS: I hope you are in a CI)", +# # flush=True, +# # ) +# # os._exit(1) + + +# WMSAccessPolicyCallable = Annotated[Callable, Depends(WMSAccessPolicy.check)] diff --git a/diracx-routers/src/diracx/routers/job_manager/access_policies.py b/diracx-routers/src/diracx/routers/job_manager/access_policies.py index ad30018f1..01f9f9acd 100644 --- a/diracx-routers/src/diracx/routers/job_manager/access_policies.py +++ b/diracx-routers/src/diracx/routers/job_manager/access_policies.py @@ -1,18 +1,10 @@ from __future__ import annotations -import contextlib import functools -import os from enum import StrEnum, auto -from typing import Annotated, AsyncIterator, Callable, Self +from typing import Annotated, Self -from fastapi import Depends, HTTPException, status - -from diracx.core.extensions import select_from_extension -from diracx.core.properties import JOB_ADMINISTRATOR, NORMAL_USER -from diracx.db.sql import JobDB - -from ..auth import AuthorizedUserInfo, verify_dirac_access_token +from fastapi import Depends class ActionType(StrEnum): @@ -22,165 +14,40 @@ class ActionType(StrEnum): QUERY = auto() -async def default_wms_policy( - user_info: AuthorizedUserInfo, - /, - *, - action: ActionType, - job_db: JobDB, - job_ids: list[int] | None = None, -): - """Implement the JobPolicy""" - if action == ActionType.CREATE: - if job_ids is not None: - raise NotImplementedError( - "job_ids is not None with ActionType.CREATE. This shouldn't happen" - ) - if NORMAL_USER not in user_info.properties: - raise HTTPException(status.HTTP_403_FORBIDDEN) - return - - if JOB_ADMINISTRATOR in user_info.properties: - return - - if NORMAL_USER not in user_info.properties: - raise HTTPException(status.HTTP_403_FORBIDDEN) - - if action == ActionType.QUERY: - if job_ids is not None: - raise NotImplementedError( - "job_ids is not None with ActionType.QUERY. This shouldn't happen" - ) - return - - if job_ids is None: - raise NotImplementedError("job_ids is None. his shouldn't happen") - - # TODO: check the CS global job monitoring flag - - job_owners = await job_db.summary( - ["Owner", "VO"], - [{"parameter": "JobID", "operator": "in", "values": job_ids}], - ) - - expected_owner = { - "Owner": user_info.preferred_username, - "VO": user_info.vo, - "count": len(set(job_ids)), - } - # All the jobs belong to the user doing the query - # and all of them are present - if job_owners == [expected_owner]: - return - - raise HTTPException(status.HTTP_403_FORBIDDEN) +def get_user_name(): + return "toto" class BaseAccessPolicy: - - policy: Callable - @classmethod def check(cls) -> Self: raise NotImplementedError("This should never be called") - @contextlib.asynccontextmanager - async def lifetime_function(self) -> AsyncIterator[None]: - """A context manager that can be used to run code at startup and shutdown.""" - yield - @classmethod - def available_implementations( - cls, access_policy_name: str - ) -> list[type[BaseAccessPolicy]]: - """Return the available implementations of the AccessPolicy in reverse priority order.""" - policy_classes: list[type[BaseAccessPolicy]] = [ - entry_point.load() - for entry_point in select_from_extension( - group="diracx.access_policies", name=access_policy_name - ) - ] - if not policy_classes: - raise NotImplementedError( - f"Could not find any matches for {access_policy_name=}" - ) - return policy_classes +def policy_implementation(user_name): + print(f"Is {user_name} allowed ? ") class WMSAccessPolicy(BaseAccessPolicy): - policy = staticmethod(default_wms_policy) + # policy = staticmethod(policy_implementation) + @staticmethod + def policy(user_name): + print(f"Is {user_name} allowed ? ") + # def depend_real(sub: Annotated[str, Depends(get_user_name)]): + # return f"That's the real deal {sub}" -def check_permissions( - access_policy_instance, - user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], -): - """ - This is what every route should depend on to check user permissions. - It yield an access policy that needs to be checked. - If this is declared as a dependency but not called - """ - has_been_called = False +# def check_permissions_wrapper(an_int: int, sub: Annotated[str, Depends(get_user_name)]): +# return f"I am a fraud {an_int} {sub}" - # # TODO: query the CS to find the actual policy - # policy = default_wms_policy - @functools.wraps(access_policy_instance.policy) - async def wrapped_policy(**kwargs): - """This wrapper is just to update the has_been_called flag""" - nonlocal has_been_called - has_been_called = True - return await access_policy_instance.policy(user_info, **kwargs) - - try: - yield wrapped_policy - finally: - if not has_been_called: - # TODO nice error message with inspect - # That should really not happen - print( - "THIS SHOULD NOT HAPPEN, ALWAYS VERIFY PERMISSION", - "(PS: I hope you are in a CI)", - flush=True, - ) - os._exit(1) - - -def check_permissions_alone( - policy, - user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], +def check_permissions( + obj: BaseAccessPolicy, user_name: Annotated[str, Depends(get_user_name)] ): - """ - This is what every route should depend on to check user permissions. - - It yield an access policy that needs to be checked. - If this is declared as a dependency but not called - """ - has_been_called = False - - # # TODO: query the CS to find the actual policy - # policy = default_wms_policy - - @functools.wraps(policy) + @functools.wraps(obj.policy) async def wrapped_policy(**kwargs): - """This wrapper is just to update the has_been_called flag""" - nonlocal has_been_called - has_been_called = True - return await policy(user_info, **kwargs) - - try: - yield wrapped_policy - finally: - if not has_been_called: - # TODO nice error message with inspect - # That should really not happen - print( - "THIS SHOULD NOT HAPPEN, ALWAYS VERIFY PERMISSION", - "(PS: I hope you are in a CI)", - flush=True, - ) - os._exit(1) - - -WMSAccessPolicyCallable = Annotated[Callable, Depends(WMSAccessPolicy.check)] + return obj.policy(user_name, **kwargs) + + yield wrapped_policy + # return obj.policy(sub) diff --git a/diracx-routers/tests/test_config_manager.py b/diracx-routers/tests/test_config_manager.py index 92c181b02..70f59b2b6 100644 --- a/diracx-routers/tests/test_config_manager.py +++ b/diracx-routers/tests/test_config_manager.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest from fastapi import status diff --git a/diracx-routers/tests/test_job_manager.py b/diracx-routers/tests/test_job_manager.py index f58d32146..3a5d725e3 100644 --- a/diracx-routers/tests/test_job_manager.py +++ b/diracx-routers/tests/test_job_manager.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from datetime import datetime, timezone from http import HTTPStatus