diff --git a/.coveragerc b/.coveragerc index 9d44cc8..dbaa183 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,3 +1,3 @@ [run] include = */app/* -omit = */tests/* \ No newline at end of file +omit = */tests/* diff --git a/.dockerignore b/.dockerignore index 0e0da93..0afb397 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,2 +1,2 @@ **/app/model -**/docker/*/.env* \ No newline at end of file +**/docker/*/.env* diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 09a1f8c..4bc4b0e 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -32,4 +32,4 @@ jobs: uses: astral-sh/ruff-action@v1 - name: Test run: | - pytest --cov --cov-report=html:coverage_reports #--random-order + pytest --cov --cov-report=html:coverage_reports #--random-order diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..5ba06e7 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,23 @@ +default_install_hook_types: +- pre-commit +- post-checkout +- post-merge + +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-added-large-files + - id: check-json + - id: check-toml + +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.8.4 + hooks: + - id: ruff + entry: ruff check --force-exclude --fix + - id: ruff-format + entry: ruff format --force-exclude diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..975c9f3 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,26 @@ +# Contributing + +Thank you for your interest in contributing to the project! Here are some useful instructions for +getting you started. + +## Pre-commit Hooks + +We use pre-commit hooks to ensure code quality. To set them up, follow these steps: + +1. Install pre-commit if you haven't already: + + ```shell + pip install pre-commit + ``` + +2. Install the pre-commit hooks to your local Git repository: + + ```shell + pre-commit install + ``` + +3. (optional) To run the pre-commit hooks manually on all files, use: + + ```shell + pre-commit run --all-files + ``` diff --git a/README.md b/README.md index 55f2328..35c1c00 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ python app/cli/cli.py train --help ## Download models: -CMS runs the NLP model packaged in a single ZIP file. To download the GA models, please follow the [instructions](https://github.com/CogStack/MedCAT#available-models). Contact [Cogstack](contact@cogstack.org) +CMS runs the NLP model packaged in a single ZIP file. To download the GA models, please follow the [instructions](https://github.com/CogStack/MedCAT#available-models). Contact [Cogstack](contact@cogstack.org) if you are interested in trying out Alpha release such as the de-identification model. To serve or train existing Hugging Face NER models, you can package the model, either downloaded from the Hugging Face Hub or cached locally, as a ZIP file by running: ```commandline @@ -193,4 +193,4 @@ You can also "chat" with the running model using the `/stream/ws` endpoint. For event.preventDefault(); }; -``` \ No newline at end of file +``` diff --git a/app/api/api.py b/app/api/api.py index 7c25b20..b1b9cbe 100644 --- a/app/api/api.py +++ b/app/api/api.py @@ -1,28 +1,28 @@ -import logging import asyncio import importlib +import logging import os.path -import api.globals as cms_globals - -from typing import Dict, Any, Optional from concurrent.futures import ThreadPoolExecutor -from anyio.lowlevel import RunVar +from typing import Any, Dict, Optional + from anyio import CapacityLimiter +from anyio.lowlevel import RunVar from fastapi import FastAPI, Request +from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html from fastapi.openapi.utils import get_openapi -from fastapi.responses import RedirectResponse, HTMLResponse +from fastapi.responses import HTMLResponse, RedirectResponse from fastapi.staticfiles import StaticFiles -from fastapi.openapi.docs import get_swagger_ui_html, get_redoc_html from prometheus_fastapi_instrumentator import Instrumentator +from domain import Tags, TagsStreamable +from utils import get_settings + +import api.globals as cms_globals from api.auth.db import make_sure_db_and_tables from api.auth.users import Props from api.dependencies import ModelServiceDep from api.utils import add_exception_handlers, add_rate_limiter -from domain import Tags, TagsStreamable from management.tracker_client import TrackerClient -from utils import get_settings - logging.getLogger("asyncio").setLevel(logging.ERROR) logger = logging.getLogger("cms") @@ -87,25 +87,37 @@ def get_stream_server(msd_overwritten: Optional[ModelServiceDep] = None) -> Fast return app -def _get_app(msd_overwritten: Optional[ModelServiceDep] = None, streamable: bool = False) -> FastAPI: - tags_metadata = [{"name": tag.name, "description": tag.value} for tag in (Tags if not streamable else TagsStreamable)] +def _get_app( + msd_overwritten: Optional[ModelServiceDep] = None, streamable: bool = False +) -> FastAPI: + tags_metadata = [ + {"name": tag.name, "description": tag.value} + for tag in (Tags if not streamable else TagsStreamable) + ] config = get_settings() - app = FastAPI(title="CogStack ModelServe", - summary="A model serving and governance system for CogStack NLP solutions", - docs_url=None, - redoc_url=None, - debug=(config.DEBUG == "true"), - openapi_tags=tags_metadata) + app = FastAPI( + title="CogStack ModelServe", + summary="A model serving and governance system for CogStack NLP solutions", + docs_url=None, + redoc_url=None, + debug=(config.DEBUG == "true"), + openapi_tags=tags_metadata, + ) add_exception_handlers(app) instrumentator = Instrumentator( - excluded_handlers=["/docs", "/redoc", "/metrics", "/openapi.json", "/favicon.ico", "none"]).instrument(app) + excluded_handlers=["/docs", "/redoc", "/metrics", "/openapi.json", "/favicon.ico", "none"] + ).instrument(app) if msd_overwritten is not None: cms_globals.model_service_dep = msd_overwritten cms_globals.props = Props(config.AUTH_USER_ENABLED == "true") - app.mount("/static", StaticFiles(directory=os.path.join(os.path.dirname(__file__), "static")), name="static") + app.mount( + "/static", + StaticFiles(directory=os.path.join(os.path.dirname(__file__), "static")), + name="static", + ) @app.on_event("startup") async def on_startup() -> None: @@ -160,8 +172,11 @@ def custom_openapi() -> Dict[str, Any]: openapi_schema = get_openapi( title=f"{cms_globals.model_service_dep().model_name} APIs", version=cms_globals.model_service_dep().api_version, - description="by CogStack ModelServe, a model serving and governance system for CogStack NLP solutions.", - routes=app.routes + description=( + "by CogStack ModelServe, a model serving and governance system for CogStack NLP" + " solutions." + ), + routes=app.routes, ) openapi_schema["info"]["x-logo"] = { "url": "https://avatars.githubusercontent.com/u/28688163?s=200&v=4" @@ -189,6 +204,7 @@ def custom_openapi() -> Dict[str, Any]: def _load_auth_router(app: FastAPI) -> FastAPI: from api.routers import authentication + importlib.reload(authentication) app.include_router(authentication.router) return app @@ -196,6 +212,7 @@ def _load_auth_router(app: FastAPI) -> FastAPI: def _load_model_card(app: FastAPI) -> FastAPI: from api.routers import model_card + importlib.reload(model_card) app.include_router(model_card.router) return app @@ -203,6 +220,7 @@ def _load_model_card(app: FastAPI) -> FastAPI: def _load_invocation_router(app: FastAPI) -> FastAPI: from api.routers import invocation + importlib.reload(invocation) app.include_router(invocation.router) return app @@ -210,6 +228,7 @@ def _load_invocation_router(app: FastAPI) -> FastAPI: def _load_supervised_training_router(app: FastAPI) -> FastAPI: from api.routers import supervised_training + importlib.reload(supervised_training) app.include_router(supervised_training.router) return app @@ -217,6 +236,7 @@ def _load_supervised_training_router(app: FastAPI) -> FastAPI: def _load_evaluation_router(app: FastAPI) -> FastAPI: from api.routers import evaluation + importlib.reload(evaluation) app.include_router(evaluation.router) return app @@ -224,6 +244,7 @@ def _load_evaluation_router(app: FastAPI) -> FastAPI: def _load_preview_router(app: FastAPI) -> FastAPI: from api.routers import preview + importlib.reload(preview) app.include_router(preview.router) return app @@ -231,6 +252,7 @@ def _load_preview_router(app: FastAPI) -> FastAPI: def _load_unsupervised_training_router(app: FastAPI) -> FastAPI: from api.routers import unsupervised_training + importlib.reload(unsupervised_training) app.include_router(unsupervised_training.router) return app @@ -238,6 +260,7 @@ def _load_unsupervised_training_router(app: FastAPI) -> FastAPI: def _load_metacat_training_router(app: FastAPI) -> FastAPI: from api.routers import metacat_training + importlib.reload(metacat_training) app.include_router(metacat_training.router) return app @@ -245,6 +268,7 @@ def _load_metacat_training_router(app: FastAPI) -> FastAPI: def _load_health_check_router(app: FastAPI) -> FastAPI: from api.routers import health_check + importlib.reload(health_check) app.include_router(health_check.router) return app @@ -252,6 +276,7 @@ def _load_health_check_router(app: FastAPI) -> FastAPI: def _load_stream_router(app: FastAPI) -> FastAPI: from api.routers import stream + importlib.reload(stream) app.include_router(stream.router, prefix="/stream") return app diff --git a/app/api/auth/README.md b/app/api/auth/README.md index 926ded2..90f23a8 100644 --- a/app/api/auth/README.md +++ b/app/api/auth/README.md @@ -26,4 +26,4 @@ Among the above arguments, `` can be calculated using the `fast from fastapi_users.password import PasswordHelper helper = PasswordHelper() print(helper.hash("RAW_PASSWORD")) -``` \ No newline at end of file +``` diff --git a/app/api/auth/backends.py b/app/api/auth/backends.py index c18bf96..c6a7acc 100644 --- a/app/api/auth/backends.py +++ b/app/api/auth/backends.py @@ -1,17 +1,27 @@ from functools import lru_cache from typing import List -from fastapi_users.authentication.transport.base import Transport + +from fastapi_users.authentication import ( + AuthenticationBackend, + BearerTransport, + CookieTransport, + JWTStrategy, +) from fastapi_users.authentication.strategy.base import Strategy -from fastapi_users.authentication import BearerTransport, JWTStrategy -from fastapi_users.authentication import AuthenticationBackend, CookieTransport +from fastapi_users.authentication.transport.base import Transport + from utils import get_settings @lru_cache def get_backends() -> List[AuthenticationBackend]: return [ - AuthenticationBackend(name="jwt", transport=_get_bearer_transport(), get_strategy=_get_strategy), - AuthenticationBackend(name="cookie", transport=_get_cookie_transport(), get_strategy=_get_strategy), + AuthenticationBackend( + name="jwt", transport=_get_bearer_transport(), get_strategy=_get_strategy + ), + AuthenticationBackend( + name="cookie", transport=_get_cookie_transport(), get_strategy=_get_strategy + ), ] @@ -24,4 +34,7 @@ def _get_cookie_transport() -> Transport: def _get_strategy() -> Strategy: - return JWTStrategy(secret=get_settings().AUTH_JWT_SECRET, lifetime_seconds=get_settings().AUTH_ACCESS_TOKEN_EXPIRE_SECONDS) + return JWTStrategy( + secret=get_settings().AUTH_JWT_SECRET, + lifetime_seconds=get_settings().AUTH_ACCESS_TOKEN_EXPIRE_SECONDS, + ) diff --git a/app/api/auth/db.py b/app/api/auth/db.py index 9d0701c..d6825d6 100644 --- a/app/api/auth/db.py +++ b/app/api/auth/db.py @@ -4,6 +4,7 @@ from fastapi_users.db import SQLAlchemyBaseUserTableUUID, SQLAlchemyUserDatabase from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.orm import DeclarativeBase + from utils import get_settings @@ -29,5 +30,7 @@ async def make_sure_db_and_tables() -> None: await conn.run_sync(Base.metadata.create_all) -async def get_user_db(session: AsyncSession = Depends(_get_async_session)) -> AsyncGenerator[SQLAlchemyUserDatabase, None]: +async def get_user_db( + session: AsyncSession = Depends(_get_async_session), +) -> AsyncGenerator[SQLAlchemyUserDatabase, None]: yield SQLAlchemyUserDatabase(session, User) diff --git a/app/api/auth/users.py b/app/api/auth/users.py index e9833ef..4e37a77 100644 --- a/app/api/auth/users.py +++ b/app/api/auth/users.py @@ -1,14 +1,17 @@ -import uuid import logging -from typing import Optional, AsyncGenerator, List, Callable +import uuid +from typing import AsyncGenerator, Callable, List, Optional + from fastapi import Depends, Request from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin -from fastapi_users.db import SQLAlchemyUserDatabase from fastapi_users.authentication import AuthenticationBackend -from api.auth.db import User, get_user_db -from api.auth.backends import get_backends +from fastapi_users.db import SQLAlchemyUserDatabase + from utils import get_settings +from api.auth.backends import get_backends +from api.auth.db import User, get_user_db + logger = logging.getLogger("cms") @@ -19,26 +22,33 @@ class CmsUserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): async def on_after_register(self, user: User, request: Optional[Request] = None) -> None: logger.info("User %s has registered.", user.id) - async def on_after_forgot_password(self, user: User, token: str, request: Optional[Request] = None) -> None: + async def on_after_forgot_password( + self, user: User, token: str, request: Optional[Request] = None + ) -> None: logger.info("User %s has forgot their password. Reset token: %s", user.id, token) - async def on_after_request_verify(self, user: User, token: str, request: Optional[Request] = None) -> None: + async def on_after_request_verify( + self, user: User, token: str, request: Optional[Request] = None + ) -> None: logger.info("Verification requested for user %s. Verification token: %s", user.id, token) -async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)) -> AsyncGenerator: +async def get_user_manager( + user_db: SQLAlchemyUserDatabase = Depends(get_user_db), +) -> AsyncGenerator: yield CmsUserManager(user_db) class Props(object): - def __init__(self, auth_user_enabled: bool) -> None: self._auth_backends: List = [] self._fastapi_users = None self._current_active_user = lambda: None if auth_user_enabled: self._auth_backends = get_backends() - self._fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, self.auth_backends) + self._fastapi_users = FastAPIUsers[User, uuid.UUID]( + get_user_manager, self.auth_backends + ) self._current_active_user = self._fastapi_users.current_user(active=True) @property diff --git a/app/api/dependencies.py b/app/api/dependencies.py index 1fc1082..d82a5f9 100644 --- a/app/api/dependencies.py +++ b/app/api/dependencies.py @@ -1,16 +1,16 @@ import logging import re -from typing import Union -from typing_extensions import Annotated +from typing import Optional, Union from fastapi import HTTPException, Query from starlette.status import HTTP_400_BAD_REQUEST +from typing_extensions import Annotated -from typing import Optional from config import Settings from registry import model_service_registry -from model_services.base import AbstractModelService + from management.model_manager import ModelManager +from model_services.base import AbstractModelService TRACKING_ID_REGEX = re.compile(r"^[a-zA-Z0-9][\w\-]{0,255}$") @@ -18,7 +18,6 @@ class ModelServiceDep(object): - @property def model_service(self) -> AbstractModelService: return self._model_sevice @@ -41,12 +40,11 @@ def __call__(self) -> AbstractModelService: self._model_sevice = model_service_registry[self._model_type](self._config) else: logger.error("Unknown model type: %s", self._model_type) - exit(1) # throw an exception? + exit(1) # throw an exception? return self._model_sevice class ModelManagerDep(object): - def __init__(self, model_service: AbstractModelService) -> None: self._model_manager = ModelManager(model_service.__class__, model_service.service_config) self._model_manager.model_service = model_service @@ -56,11 +54,16 @@ def __call__(self) -> ModelManager: def validate_tracking_id( - tracking_id: Annotated[Union[str, None], Query(description="The tracking ID of the requested task")] = None, + tracking_id: Annotated[ + Union[str, None], Query(description="The tracking ID of the requested task") + ] = None, ) -> Union[str, None]: if tracking_id is not None and TRACKING_ID_REGEX.match(tracking_id) is None: raise HTTPException( status_code=HTTP_400_BAD_REQUEST, - detail=f"Invalid tracking ID '{tracking_id}', must be an alphanumeric string of length 1 to 256", + detail=( + f"Invalid tracking ID '{tracking_id}'," + " must be an alphanumeric string of length 1 to 256" + ), ) return tracking_id diff --git a/app/api/routers/authentication.py b/app/api/routers/authentication.py index 5e9ad35..a039c7a 100644 --- a/app/api/routers/authentication.py +++ b/app/api/routers/authentication.py @@ -1,7 +1,11 @@ import logging -import api.globals as cms_globals + from fastapi import APIRouter + from domain import Tags + +import api.globals as cms_globals + router = APIRouter() logger = logging.getLogger("cms") diff --git a/app/api/routers/evaluation.py b/app/api/routers/evaluation.py index 97ff305..bc65b78 100644 --- a/app/api/routers/evaluation.py +++ b/app/api/routers/evaluation.py @@ -1,42 +1,49 @@ import io import json import sys -import uuid import tempfile - +import uuid from typing import List, Union + +from fastapi import APIRouter, Depends, File, Query, Request, UploadFile +from fastapi.responses import JSONResponse, StreamingResponse from starlette.status import HTTP_202_ACCEPTED, HTTP_503_SERVICE_UNAVAILABLE from typing_extensions import Annotated -from fastapi import APIRouter, Query, Depends, UploadFile, Request, File -from fastapi.responses import StreamingResponse, JSONResponse + +from domain import Scope, Tags +from exception import AnnotationException +from utils import filter_by_concept_ids import api.globals as cms_globals from api.dependencies import validate_tracking_id -from domain import Tags, Scope from model_services.base import AbstractModelService from processors.metrics_collector import ( - sanity_check_model_with_trainer_export, + concat_trainer_exports, get_iaa_scores_per_concept, get_iaa_scores_per_doc, get_iaa_scores_per_span, - concat_trainer_exports, get_stats_from_trainer_export, + sanity_check_model_with_trainer_export, ) -from exception import AnnotationException -from utils import filter_by_concept_ids router = APIRouter() -@router.post("/evaluate", - tags=[Tags.Evaluating.name], - response_class=JSONResponse, - dependencies=[Depends(cms_globals.props.current_active_user)], - description="Evaluate the model being served with a trainer export") -async def get_evaluation_with_trainer_export(request: Request, - trainer_export: Annotated[List[UploadFile], File(description="One or more trainer export files to be uploaded")], - tracking_id: Union[str, None] = Depends(validate_tracking_id), - model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> JSONResponse: +@router.post( + "/evaluate", + tags=[Tags.Evaluating.name], + response_class=JSONResponse, + dependencies=[Depends(cms_globals.props.current_active_user)], + description="Evaluate the model being served with a trainer export", +) +async def get_evaluation_with_trainer_export( + request: Request, + trainer_export: Annotated[ + List[UploadFile], File(description="One or more trainer export files to be uploaded") + ], + tracking_id: Union[str, None] = Depends(validate_tracking_id), + model_service: AbstractModelService = Depends(cms_globals.model_service_dep), +) -> JSONResponse: files = [] file_names = [] for te in trainer_export: @@ -47,7 +54,9 @@ async def get_evaluation_with_trainer_export(request: Request, files.append(temp_te) file_names.append("" if te.filename is None else te.filename) try: - concatenated = concat_trainer_exports([file.name for file in files], allow_recurring_doc_ids=False) + concatenated = concat_trainer_exports( + [file.name for file in files], allow_recurring_doc_ids=False + ) finally: for file in files: file.close() @@ -57,22 +66,44 @@ async def get_evaluation_with_trainer_export(request: Request, data_file.flush() data_file.seek(0) evaluation_id = tracking_id or str(uuid.uuid4()) - evaluation_accepted = model_service.train_supervised(data_file, 0, sys.maxsize, evaluation_id, ",".join(file_names)) + evaluation_accepted = model_service.train_supervised( + data_file, 0, sys.maxsize, evaluation_id, ",".join(file_names) + ) if evaluation_accepted: - return JSONResponse(content={"message": "Your evaluation started successfully.", "evaluation_id": evaluation_id}, status_code=HTTP_202_ACCEPTED) + return JSONResponse( + content={ + "message": "Your evaluation started successfully.", + "evaluation_id": evaluation_id, + }, + status_code=HTTP_202_ACCEPTED, + ) else: - return JSONResponse(content={"message": "Another training or evaluation on this model is still active. Please retry later."}, status_code=HTTP_503_SERVICE_UNAVAILABLE) + return JSONResponse( + content={ + "message": ( + "Another training or evaluation on this model is still active." + " Please retry later." + ) + }, + status_code=HTTP_503_SERVICE_UNAVAILABLE, + ) -@router.post("/sanity-check", - tags=[Tags.Evaluating.name], - response_class=StreamingResponse, - dependencies=[Depends(cms_globals.props.current_active_user)], - description="Sanity check the model being served with a trainer export") -def get_sanity_check_with_trainer_export(request: Request, - trainer_export: Annotated[List[UploadFile], File(description="One or more trainer export files to be uploaded")], - tracking_id: Union[str, None] = Depends(validate_tracking_id), - model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> StreamingResponse: +@router.post( + "/sanity-check", + tags=[Tags.Evaluating.name], + response_class=StreamingResponse, + dependencies=[Depends(cms_globals.props.current_active_user)], + description="Sanity check the model being served with a trainer export", +) +def get_sanity_check_with_trainer_export( + request: Request, + trainer_export: Annotated[ + List[UploadFile], File(description="One or more trainer export files to be uploaded") + ], + tracking_id: Union[str, None] = Depends(validate_tracking_id), + model_service: AbstractModelService = Depends(cms_globals.model_service_dep), +) -> StreamingResponse: files = [] file_names = [] for te in trainer_export: @@ -83,31 +114,54 @@ def get_sanity_check_with_trainer_export(request: Request, files.append(temp_te) file_names.append("" if te.filename is None else te.filename) try: - concatenated = concat_trainer_exports([file.name for file in files], allow_recurring_doc_ids=False) + concatenated = concat_trainer_exports( + [file.name for file in files], allow_recurring_doc_ids=False + ) finally: for file in files: file.close() concatenated = filter_by_concept_ids(concatenated, model_service.info().model_type) - metrics = sanity_check_model_with_trainer_export(concatenated, model_service, return_df=True, include_anchors=False) + metrics = sanity_check_model_with_trainer_export( + concatenated, model_service, return_df=True, include_anchors=False + ) stream = io.StringIO() metrics.to_csv(stream, index=False) tracking_id = tracking_id or str(uuid.uuid4()) response = StreamingResponse(iter([stream.getvalue()]), media_type="text/csv") - response.headers["Content-Disposition"] = f'attachment ; filename="sanity_check_{tracking_id}.csv"' + response.headers["Content-Disposition"] = ( + f'attachment ; filename="sanity_check_{tracking_id}.csv"' + ) return response -@router.post("/iaa-scores", - tags=[Tags.Evaluating.name], - response_class=StreamingResponse, - dependencies=[Depends(cms_globals.props.current_active_user)], - description="Calculate inter annotator agreement scores between two projects") -def get_inter_annotator_agreement_scores(request: Request, - trainer_export: Annotated[List[UploadFile], File(description="A list of trainer export files to be uploaded")], - annotator_a_project_id: Annotated[int, Query(description="The project ID from one annotator")], - annotator_b_project_id: Annotated[int, Query(description="The project ID from another annotator")], - scope: Annotated[str, Query(enum=[s.value for s in Scope], description="The scope for which the score will be calculated, e.g., per_concept, per_document or per_span")], - tracking_id: Union[str, None] = Depends(validate_tracking_id)) -> StreamingResponse: +@router.post( + "/iaa-scores", + tags=[Tags.Evaluating.name], + response_class=StreamingResponse, + dependencies=[Depends(cms_globals.props.current_active_user)], + description="Calculate inter annotator agreement scores between two projects", +) +def get_inter_annotator_agreement_scores( + request: Request, + trainer_export: Annotated[ + List[UploadFile], File(description="A list of trainer export files to be uploaded") + ], + annotator_a_project_id: Annotated[int, Query(description="The project ID from one annotator")], + annotator_b_project_id: Annotated[ + int, Query(description="The project ID from another annotator") + ], + scope: Annotated[ + str, + Query( + enum=[s.value for s in Scope], + description=( + "The scope for which the score will be calculated, e.g., per_concept, per_document" + " or per_span" + ), + ), + ], + tracking_id: Union[str, None] = Depends(validate_tracking_id), +) -> StreamingResponse: files = [] for te in trainer_export: temp_te = tempfile.NamedTemporaryFile() @@ -122,11 +176,17 @@ def get_inter_annotator_agreement_scores(request: Request, json.dump(concatenated, combined) combined.seek(0) if scope == Scope.PER_CONCEPT.value: - iaa_scores = get_iaa_scores_per_concept(combined, annotator_a_project_id, annotator_b_project_id, return_df=True) + iaa_scores = get_iaa_scores_per_concept( + combined, annotator_a_project_id, annotator_b_project_id, return_df=True + ) elif scope == Scope.PER_DOCUMENT.value: - iaa_scores = get_iaa_scores_per_doc(combined, annotator_a_project_id, annotator_b_project_id, return_df=True) + iaa_scores = get_iaa_scores_per_doc( + combined, annotator_a_project_id, annotator_b_project_id, return_df=True + ) elif scope == Scope.PER_SPAN.value: - iaa_scores = get_iaa_scores_per_span(combined, annotator_a_project_id, annotator_b_project_id, return_df=True) + iaa_scores = get_iaa_scores_per_span( + combined, annotator_a_project_id, annotator_b_project_id, return_df=True + ) else: raise AnnotationException(f'Unknown scope: "{scope}"') stream = io.StringIO() @@ -137,14 +197,20 @@ def get_inter_annotator_agreement_scores(request: Request, return response -@router.post("/concat_trainer_exports", - tags=[Tags.Evaluating.name], - response_class=JSONResponse, - dependencies=[Depends(cms_globals.props.current_active_user)], - description="Concatenate multiple trainer export files into a single file for download") -def get_concatenated_trainer_exports(request: Request, - trainer_export: Annotated[List[UploadFile], File(description="A list of trainer export files to be uploaded")], - tracking_id: Union[str, None] = Depends(validate_tracking_id)) -> JSONResponse: +@router.post( + "/concat_trainer_exports", + tags=[Tags.Evaluating.name], + response_class=JSONResponse, + dependencies=[Depends(cms_globals.props.current_active_user)], + description="Concatenate multiple trainer export files into a single file for download", +) +def get_concatenated_trainer_exports( + request: Request, + trainer_export: Annotated[ + List[UploadFile], File(description="A list of trainer export files to be uploaded") + ], + tracking_id: Union[str, None] = Depends(validate_tracking_id), +) -> JSONResponse: files = [] for te in trainer_export: temp_te = tempfile.NamedTemporaryFile() @@ -152,23 +218,33 @@ def get_concatenated_trainer_exports(request: Request, temp_te.write(line) temp_te.flush() files.append(temp_te) - concatenated = concat_trainer_exports([file.name for file in files], allow_recurring_doc_ids=False) + concatenated = concat_trainer_exports( + [file.name for file in files], allow_recurring_doc_ids=False + ) for file in files: file.close() tracking_id = tracking_id or str(uuid.uuid4()) response = JSONResponse(concatenated, media_type="application/json; charset=utf-8") - response.headers["Content-Disposition"] = f'attachment ; filename="concatenated_{tracking_id}.json"' + response.headers["Content-Disposition"] = ( + f'attachment ; filename="concatenated_{tracking_id}.json"' + ) return response -@router.post("/annotation-stats", - tags=[Tags.Evaluating.name], - response_class=StreamingResponse, - dependencies=[Depends(cms_globals.props.current_active_user)], - description="Get annotation stats of trainer export files") -def get_annotation_stats(request: Request, - trainer_export: Annotated[List[UploadFile], File(description="One or more trainer export files to be uploaded")], - tracking_id: Union[str, None] = Depends(validate_tracking_id)) -> StreamingResponse: +@router.post( + "/annotation-stats", + tags=[Tags.Evaluating.name], + response_class=StreamingResponse, + dependencies=[Depends(cms_globals.props.current_active_user)], + description="Get annotation stats of trainer export files", +) +def get_annotation_stats( + request: Request, + trainer_export: Annotated[ + List[UploadFile], File(description="One or more trainer export files to be uploaded") + ], + tracking_id: Union[str, None] = Depends(validate_tracking_id), +) -> StreamingResponse: files = [] file_names = [] for te in trainer_export: @@ -179,7 +255,9 @@ def get_annotation_stats(request: Request, files.append(temp_te) file_names.append("" if te.filename is None else te.filename) try: - concatenated = concat_trainer_exports([file.name for file in files], allow_recurring_doc_ids=False) + concatenated = concat_trainer_exports( + [file.name for file in files], allow_recurring_doc_ids=False + ) finally: for file in files: file.close() diff --git a/app/api/routers/health_check.py b/app/api/routers/health_check.py index 489e53f..089c71d 100644 --- a/app/api/routers/health_check.py +++ b/app/api/routers/health_check.py @@ -1,20 +1,19 @@ -import api.globals as cms_globals from fastapi import APIRouter, Depends from fastapi.responses import PlainTextResponse + +import api.globals as cms_globals from model_services.base import AbstractModelService router = APIRouter() -@router.get("/healthz", - description="Health check endpoint", - include_in_schema=False) +@router.get("/healthz", description="Health check endpoint", include_in_schema=False) async def is_healthy() -> PlainTextResponse: return PlainTextResponse(content="OK", status_code=200) -@router.get("/readyz", - description="Readiness check endpoint", - include_in_schema=False) -async def is_ready(model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> PlainTextResponse: +@router.get("/readyz", description="Readiness check endpoint", include_in_schema=False) +async def is_ready( + model_service: AbstractModelService = Depends(cms_globals.model_service_dep), +) -> PlainTextResponse: return PlainTextResponse(content=model_service.info().model_type, status_code=200) diff --git a/app/api/routers/invocation.py b/app/api/routers/invocation.py index 0a8114e..fa53841 100644 --- a/app/api/routers/invocation.py +++ b/app/api/routers/invocation.py @@ -1,34 +1,36 @@ -import statistics -import tempfile +import hashlib import itertools import json -import ijson -import uuid -import hashlib import logging -import pandas as pd -import api.globals as cms_globals - -from typing import Dict, List, Union, Iterator, Any +import statistics +import tempfile +import uuid from collections import defaultdict from io import BytesIO +from typing import Any, Dict, Iterator, List, Union + +import ijson +import pandas as pd +from fastapi import APIRouter, Body, Depends, File, Query, Request, Response, UploadFile +from fastapi.responses import JSONResponse, PlainTextResponse, StreamingResponse +from pydantic import ValidationError from starlette.status import HTTP_400_BAD_REQUEST from typing_extensions import Annotated -from fastapi import APIRouter, Depends, Body, UploadFile, File, Request, Query, Response -from fastapi.responses import StreamingResponse, PlainTextResponse, JSONResponse -from pydantic import ValidationError -from domain import TextWithAnnotations, TextWithPublicKey, TextStreamItem, ModelCard, Tags -from model_services.base import AbstractModelService + +from domain import ModelCard, Tags, TextStreamItem, TextWithAnnotations, TextWithPublicKey from utils import get_settings + +import api.globals as cms_globals from api.dependencies import validate_tracking_id -from api.utils import get_rate_limiter, encrypt +from api.utils import encrypt, get_rate_limiter from management.prometheus_metrics import ( - cms_doc_annotations, - cms_avg_anno_acc_per_doc, cms_avg_anno_acc_per_concept, + cms_avg_anno_acc_per_doc, cms_avg_meta_anno_conf_per_doc, cms_bulk_processed_docs, + cms_doc_annotations, ) +from model_services.base import AbstractModelService from processors.data_batcher import mini_batch PATH_INFO = "/info" @@ -46,27 +48,37 @@ logger = logging.getLogger("cms") -@router.get(PATH_INFO, - response_model=ModelCard, - tags=[Tags.Metadata.name], - dependencies=[Depends(cms_globals.props.current_active_user)], - description="Get information about the model being served") -async def get_model_card(request: Request, - model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> ModelCard: +@router.get( + PATH_INFO, + response_model=ModelCard, + tags=[Tags.Metadata.name], + dependencies=[Depends(cms_globals.props.current_active_user)], + description="Get information about the model being served", +) +async def get_model_card( + request: Request, model_service: AbstractModelService = Depends(cms_globals.model_service_dep) +) -> ModelCard: return model_service.info() -@router.post(PATH_PROCESS, - response_model=TextWithAnnotations, - response_model_exclude_none=True, - response_class=JSONResponse, - tags=[Tags.Annotations.name], - dependencies=[Depends(cms_globals.props.current_active_user)], - description="Extract the NER entities from a single piece of plain text") +@router.post( + PATH_PROCESS, + response_model=TextWithAnnotations, + response_model_exclude_none=True, + response_class=JSONResponse, + tags=[Tags.Annotations.name], + dependencies=[Depends(cms_globals.props.current_active_user)], + description="Extract the NER entities from a single piece of plain text", +) @limiter.limit(config.PROCESS_RATE_LIMIT) -def get_entities_from_text(request: Request, - text: Annotated[str, Body(description="The plain text to be sent to the model for NER", media_type="text/plain")], - model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> TextWithAnnotations: +def get_entities_from_text( + request: Request, + text: Annotated[ + str, + Body(description="The plain text to be sent to the model for NER", media_type="text/plain"), + ], + model_service: AbstractModelService = Depends(cms_globals.model_service_dep), +) -> TextWithAnnotations: annotations = model_service.annotate(text) _send_annotation_num_metric(len(annotations), PATH_PROCESS) @@ -77,14 +89,27 @@ def get_entities_from_text(request: Request, return TextWithAnnotations(text=text, annotations=annotations) -@router.post(PATH_PROCESS_JSON_LINES, - response_class=StreamingResponse, - tags=[Tags.Annotations.name], - dependencies=[Depends(cms_globals.props.current_active_user)], - description="Extract the NER entities from texts in the JSON Lines format") +@router.post( + PATH_PROCESS_JSON_LINES, + response_class=StreamingResponse, + tags=[Tags.Annotations.name], + dependencies=[Depends(cms_globals.props.current_active_user)], + description="Extract the NER entities from texts in the JSON Lines format", +) @limiter.limit(config.PROCESS_RATE_LIMIT) -def get_entities_from_jsonlines_text(request: Request, - json_lines: Annotated[str, Body(description="The texts in the jsonlines format and each line contains {\"text\": \"\"[, \"name\": \"\"]}", media_type="application/x-ndjson")]) -> Response: +def get_entities_from_jsonlines_text( + request: Request, + json_lines: Annotated[ + str, + Body( + description=( + "The texts in the jsonlines format and each line contains" + ' {"text": ""[, "name": ""]}' + ), + media_type="application/x-ndjson", + ), + ], +) -> Response: model_manager = cms_globals.model_manager_dep() stream: Iterator[Dict[str, Any]] = itertools.chain() @@ -93,23 +118,47 @@ def get_entities_from_jsonlines_text(request: Request, predicted_stream = model_manager.predict_stream(context=None, model_input=chunked_input) stream = itertools.chain(stream, predicted_stream) - return StreamingResponse(_get_jsonlines_stream(stream), media_type="application/x-ndjson; charset=utf-8") + return StreamingResponse( + _get_jsonlines_stream(stream), media_type="application/x-ndjson; charset=utf-8" + ) except json.JSONDecodeError: - return JSONResponse(status_code=HTTP_400_BAD_REQUEST, content={"message": "Invalid JSON Lines."}) + return JSONResponse( + status_code=HTTP_400_BAD_REQUEST, content={"message": "Invalid JSON Lines."} + ) except ValidationError: - return JSONResponse(status_code=HTTP_400_BAD_REQUEST, content={"message": f"Invalid JSON properties found. The schema should be {TextStreamItem.schema_json()}"}) - - -@router.post(PATH_PROCESS_BULK, - response_model=List[TextWithAnnotations], - response_model_exclude_none=True, - tags=[Tags.Annotations.name], - dependencies=[Depends(cms_globals.props.current_active_user)], - description="Extract the NER entities from multiple plain texts") + return JSONResponse( + status_code=HTTP_400_BAD_REQUEST, + content={ + "message": ( + "Invalid JSON properties found." + f" The schema should be {TextStreamItem.schema_json()}" + ) + }, + ) + + +@router.post( + PATH_PROCESS_BULK, + response_model=List[TextWithAnnotations], + response_model_exclude_none=True, + tags=[Tags.Annotations.name], + dependencies=[Depends(cms_globals.props.current_active_user)], + description="Extract the NER entities from multiple plain texts", +) @limiter.limit(config.PROCESS_BULK_RATE_LIMIT) -def get_entities_from_multiple_texts(request: Request, - texts: Annotated[List[str], Body(description="A list of plain texts to be sent to the model for NER, in the format of [\"text_1\", \"text_2\", ..., \"text_n\"]")], - model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> List[TextWithAnnotations]: +def get_entities_from_multiple_texts( + request: Request, + texts: Annotated[ + List[str], + Body( + description=( + "A list of plain texts to be sent to the model for NER, in the format of" + ' ["text_1", "text_2", ..., "text_n"]' + ) + ), + ], + model_service: AbstractModelService = Depends(cms_globals.model_service_dep), +) -> List[TextWithAnnotations]: annotations_list = model_service.batch_annotate(texts) body: List[TextWithAnnotations] = [] annotation_sum = 0 @@ -126,15 +175,29 @@ def get_entities_from_multiple_texts(request: Request, return body -@router.post(PATH_PROCESS_BULK_FILE, - tags=[Tags.Annotations.name], - response_class=StreamingResponse, - dependencies=[Depends(cms_globals.props.current_active_user)], - description="Upload a file containing a list of plain text and extract the NER entities in JSON") -def extract_entities_from_multi_text_file(request: Request, - multi_text_file: Annotated[UploadFile, File(description="A file containing a list of plain texts, in the format of [\"text_1\", \"text_2\", ..., \"text_n\"]")], - tracking_id: Union[str, None] = Depends(validate_tracking_id), - model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> StreamingResponse: +@router.post( + PATH_PROCESS_BULK_FILE, + tags=[Tags.Annotations.name], + response_class=StreamingResponse, + dependencies=[Depends(cms_globals.props.current_active_user)], + description=( + "Upload a file containing a list of plain text and extract the NER entities in JSON" + ), +) +def extract_entities_from_multi_text_file( + request: Request, + multi_text_file: Annotated[ + UploadFile, + File( + description=( + "A file containing a list of plain texts, in the format of" + ' ["text_1", "text_2", ..., "text_n"]' + ) + ), + ], + tracking_id: Union[str, None] = Depends(validate_tracking_id), + model_service: AbstractModelService = Depends(cms_globals.model_service_dep), +) -> StreamingResponse: with tempfile.NamedTemporaryFile() as data_file: for line in multi_text_file.file: data_file.write(line) @@ -164,22 +227,54 @@ def extract_entities_from_multi_text_file(request: Request, json_file = BytesIO(output.encode()) tracking_id = tracking_id or str(uuid.uuid4()) response = StreamingResponse(json_file, media_type="application/json") - response.headers["Content-Disposition"] = f'attachment ; filename="concatenated_{tracking_id}.json"' + response.headers["Content-Disposition"] = ( + f'attachment ; filename="concatenated_{tracking_id}.json"' + ) return response -@router.post(PATH_REDACT, - tags=[Tags.Redaction.name], - dependencies=[Depends(cms_globals.props.current_active_user)], - description="Extract and redact NER entities from a single piece of plain text") +@router.post( + PATH_REDACT, + tags=[Tags.Redaction.name], + dependencies=[Depends(cms_globals.props.current_active_user)], + description="Extract and redact NER entities from a single piece of plain text", +) @limiter.limit(config.PROCESS_RATE_LIMIT) -def get_redacted_text(request: Request, - text: Annotated[str, Body(description="The plain text to be sent to the model for NER and redaction", media_type="text/plain")], - concepts_to_keep: Annotated[List[str], Query(description="List of concepts (Label IDs) that should not be removedd during the redaction process. List should be in the format ['label1','label2'...]")] = [], - warn_on_no_redaction: Annotated[Union[bool, None], Query(description="Return warning when no entities were detected for redaction to prevent potential info leaking")] = False, - mask: Annotated[Union[str, None], Query(description="The custom symbols used for masking detected spans")] = None, - hash: Annotated[Union[bool, None], Query(description="Whether or not to hash detected spans")] = False, - model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> PlainTextResponse: +def get_redacted_text( + request: Request, + text: Annotated[ + str, + Body( + description="The plain text to be sent to the model for NER and redaction", + media_type="text/plain", + ), + ], + concepts_to_keep: Annotated[ + List[str], + Query( + description=( + "List of concepts (Label IDs) that should not be removedd during the redaction" + " process. List should be in the format ['label1','label2'...]" + ) + ), + ] = [], + warn_on_no_redaction: Annotated[ + Union[bool, None], + Query( + description=( + "Return warning when no entities were detected for redaction to prevent potential" + " info leaking" + ) + ), + ] = False, + mask: Annotated[ + Union[str, None], Query(description="The custom symbols used for masking detected spans") + ] = None, + hash: Annotated[ + Union[bool, None], Query(description="Whether or not to hash detected spans") + ] = False, + model_service: AbstractModelService = Depends(cms_globals.model_service_dep), +) -> PlainTextResponse: annotations = model_service.annotate(text) _send_annotation_num_metric(len(annotations), PATH_REDACT) @@ -189,35 +284,50 @@ def get_redacted_text(request: Request, redacted_text = "" start_index = 0 if not annotations and warn_on_no_redaction: - return PlainTextResponse(content="WARNING: No entities were detected for redaction.", status_code=200) + return PlainTextResponse( + content="WARNING: No entities were detected for redaction.", status_code=200 + ) else: for annotation in annotations: - if annotation["label_id"] in concepts_to_keep: continue if hash: - label = hashlib.sha256(text[annotation["start"]:annotation["end"]].encode()).hexdigest() + label = hashlib.sha256( + text[annotation["start"] : annotation["end"]].encode() + ).hexdigest() elif mask is None or len(mask) == 0: label = f"[{annotation['label_name']}]" else: label = mask - redacted_text += text[start_index:annotation["start"]] + label + redacted_text += text[start_index : annotation["start"]] + label start_index = annotation["end"] redacted_text += text[start_index:] logger.debug(redacted_text) return PlainTextResponse(content=redacted_text, status_code=200) -@router.post(PATH_REDACT_WITH_ENCRYPTION, - tags=[Tags.Redaction.name], - dependencies=[Depends(cms_globals.props.current_active_user)], - description="Redact and encrypt NER entities from a single piece of plain text") +@router.post( + PATH_REDACT_WITH_ENCRYPTION, + tags=[Tags.Redaction.name], + dependencies=[Depends(cms_globals.props.current_active_user)], + description="Redact and encrypt NER entities from a single piece of plain text", +) @limiter.limit(config.PROCESS_RATE_LIMIT) -def get_redacted_text_with_encryption(request: Request, - text_with_public_key: Annotated[TextWithPublicKey, Body()], - warn_on_no_redaction: Annotated[Union[bool, None], Query(description="Return warning when no entities were detected for redaction to prevent potential info leaking")] = False, - model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> JSONResponse: +def get_redacted_text_with_encryption( + request: Request, + text_with_public_key: Annotated[TextWithPublicKey, Body()], + warn_on_no_redaction: Annotated[ + Union[bool, None], + Query( + description=( + "Return warning when no entities were detected for redaction to prevent potential" + " info leaking" + ) + ), + ] = False, + model_service: AbstractModelService = Depends(cms_globals.model_service_dep), +) -> JSONResponse: annotations = model_service.annotate(text_with_public_key.text) _send_annotation_num_metric(len(annotations), PATH_REDACT_WITH_ENCRYPTION) @@ -228,12 +338,17 @@ def get_redacted_text_with_encryption(request: Request, start_index = 0 encryptions = [] if not annotations and warn_on_no_redaction: - return JSONResponse(content={"message": "WARNING: No entities were detected for redaction."}) + return JSONResponse( + content={"message": "WARNING: No entities were detected for redaction."} + ) else: for idx, annotation in enumerate(annotations): label = f"[REDACTED_{idx}]" - encrypted = encrypt(text_with_public_key.text[annotation["start"]:annotation["end"]], text_with_public_key.public_key_pem) - redacted_text += text_with_public_key.text[start_index:annotation["start"]] + label + encrypted = encrypt( + text_with_public_key.text[annotation["start"] : annotation["end"]], + text_with_public_key.public_key_pem, + ) + redacted_text += text_with_public_key.text[start_index : annotation["start"]] + label encryptions.append({"label": label, "encryption": encrypted}) start_index = annotation["end"] redacted_text += text_with_public_key.text[start_index:] @@ -260,12 +375,20 @@ def _send_accuracy_metric(annotations: List[Dict], handler: str) -> None: concept_count[annotation["label_id"]] += 1 for concept, accumulated_accuracy in accumulated_concept_accuracy.items(): concept_avg_acc = accumulated_accuracy / concept_count[concept] - cms_avg_anno_acc_per_concept.labels(handler=handler, concept=concept).set(concept_avg_acc) + cms_avg_anno_acc_per_concept.labels(handler=handler, concept=concept).set( + concept_avg_acc + ) def _send_meta_confidence_metric(annotations: List[Dict], handler: str) -> None: if annotations and annotations[0].get("meta_anns", None): - avg_conf = statistics.mean([meta_value["confidence"] for annotation in annotations for _, meta_value in annotation["meta_anns"].items()]) + avg_conf = statistics.mean( + [ + meta_value["confidence"] + for annotation in annotations + for _, meta_value in annotation["meta_anns"].items() + ] + ) cms_avg_meta_anno_conf_per_doc.labels(handler=handler).set(avg_conf) diff --git a/app/api/routers/metacat_training.py b/app/api/routers/metacat_training.py index 1e16e20..ad4cfa6 100644 --- a/app/api/routers/metacat_training.py +++ b/app/api/routers/metacat_training.py @@ -1,17 +1,18 @@ -import tempfile -import uuid import json import logging +import tempfile +import uuid from typing import List, Union -from typing_extensions import Annotated -from fastapi import APIRouter, Depends, UploadFile, Query, Request, File +from fastapi import APIRouter, Depends, File, Query, Request, UploadFile from fastapi.responses import JSONResponse from starlette.status import HTTP_202_ACCEPTED, HTTP_503_SERVICE_UNAVAILABLE +from typing_extensions import Annotated + +from domain import Tags import api.globals as cms_globals from api.dependencies import validate_tracking_id -from domain import Tags from model_services.base import AbstractModelService from processors.metrics_collector import concat_trainer_exports @@ -19,19 +20,35 @@ logger = logging.getLogger("cms") -@router.post("/train_metacat", - status_code=HTTP_202_ACCEPTED, - response_class=JSONResponse, - tags=[Tags.Training.name], - dependencies=[Depends(cms_globals.props.current_active_user)], - description="Upload one or more trainer export files and trigger the metacat training") -async def train_metacat(request: Request, - trainer_export: Annotated[List[UploadFile], File(description="One or more trainer export files to be uploaded")], - epochs: Annotated[int, Query(description="The number of training epochs", ge=0)] = 1, - log_frequency: Annotated[int, Query(description="The number of processed documents after which training metrics will be logged", ge=1)] = 1, - description: Annotated[Union[str, None], Query(description="The description on the training or change logs")] = None, - tracking_id: Union[str, None] = Depends(validate_tracking_id), - model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> JSONResponse: +@router.post( + "/train_metacat", + status_code=HTTP_202_ACCEPTED, + response_class=JSONResponse, + tags=[Tags.Training.name], + dependencies=[Depends(cms_globals.props.current_active_user)], + description="Upload one or more trainer export files and trigger the metacat training", +) +async def train_metacat( + request: Request, + trainer_export: Annotated[ + List[UploadFile], File(description="One or more trainer export files to be uploaded") + ], + epochs: Annotated[int, Query(description="The number of training epochs", ge=0)] = 1, + log_frequency: Annotated[ + int, + Query( + description=( + "The number of processed documents after which training metrics will be logged" + ), + ge=1, + ), + ] = 1, + description: Annotated[ + Union[str, None], Query(description="The description on the training or change logs") + ] = None, + tracking_id: Union[str, None] = Depends(validate_tracking_id), + model_service: AbstractModelService = Depends(cms_globals.model_service_dep), +) -> JSONResponse: files = [] file_names = [] for te in trainer_export: @@ -42,7 +59,9 @@ async def train_metacat(request: Request, files.append(temp_te) file_names.append("" if te.filename is None else te.filename) try: - concatenated = concat_trainer_exports([file.name for file in files], allow_recurring_doc_ids=False) + concatenated = concat_trainer_exports( + [file.name for file in files], allow_recurring_doc_ids=False + ) logger.debug("Training exports concatenated") finally: for file in files: @@ -53,14 +72,16 @@ async def train_metacat(request: Request, data_file.seek(0) training_id = tracking_id or str(uuid.uuid4()) try: - training_accepted = model_service.train_metacat(data_file, - epochs, - log_frequency, - training_id, - ",".join(file_names), - raw_data_files=files, - synchronised=False, - description=description) + training_accepted = model_service.train_metacat( + data_file, + epochs, + log_frequency, + training_id, + ",".join(file_names), + raw_data_files=files, + synchronised=False, + description=description, + ) finally: for file in files: file.close() @@ -71,7 +92,18 @@ async def train_metacat(request: Request, def _get_training_response(training_accepted: bool, training_id: str) -> JSONResponse: if training_accepted: logger.debug("Training accepted with ID: %s", training_id) - return JSONResponse(content={"message": "Your training started successfully.", "training_id": training_id}, status_code=HTTP_202_ACCEPTED) + return JSONResponse( + content={"message": "Your training started successfully.", "training_id": training_id}, + status_code=HTTP_202_ACCEPTED, + ) else: logger.debug("Training refused due to another active training or evaluation on this model") - return JSONResponse(content={"message": "Another training or evaluation on this model is still active. Please retry your training later."}, status_code=HTTP_503_SERVICE_UNAVAILABLE) + return JSONResponse( + content={ + "message": ( + "Another training or evaluation on this model is still active." + " Please retry your training later." + ) + }, + status_code=HTTP_503_SERVICE_UNAVAILABLE, + ) diff --git a/app/api/routers/model_card.py b/app/api/routers/model_card.py index 94b2f0c..68100a7 100644 --- a/app/api/routers/model_card.py +++ b/app/api/routers/model_card.py @@ -1,21 +1,25 @@ -import api.globals as cms_globals - from fastapi import APIRouter, Depends, Request + from domain import ModelCard, Tags -from model_services.base import AbstractModelService from utils import get_settings + +import api.globals as cms_globals from api.utils import get_rate_limiter +from model_services.base import AbstractModelService router = APIRouter() config = get_settings() limiter = get_rate_limiter(config) -@router.get("/info", - response_model=ModelCard, - tags=[Tags.Metadata.name], - dependencies=[Depends(cms_globals.props.current_active_user)], - description="Get information about the model being served") -async def get_model_card(request: Request, - model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> ModelCard: +@router.get( + "/info", + response_model=ModelCard, + tags=[Tags.Metadata.name], + dependencies=[Depends(cms_globals.props.current_active_user)], + description="Get information about the model being served", +) +async def get_model_card( + request: Request, model_service: AbstractModelService = Depends(cms_globals.model_service_dep) +) -> ModelCard: return model_service.info() diff --git a/app/api/routers/preview.py b/app/api/routers/preview.py index 2def28b..9f626ba 100644 --- a/app/api/routers/preview.py +++ b/app/api/routers/preview.py @@ -1,35 +1,43 @@ -import uuid import json -import tempfile import logging +import tempfile +import uuid from io import BytesIO from typing import Union -from typing_extensions import Annotated, Dict, List -from fastapi import APIRouter, Depends, Body, UploadFile, Request, Response, File, Form, Query -from fastapi.responses import StreamingResponse, JSONResponse + +from fastapi import APIRouter, Body, Depends, File, Form, Query, Request, Response, UploadFile +from fastapi.responses import JSONResponse, StreamingResponse from spacy import displacy from starlette.status import HTTP_404_NOT_FOUND +from typing_extensions import Annotated, Dict, List + +from domain import Doc, Tags +from utils import annotations_to_entities import api.globals as cms_globals from api.dependencies import validate_tracking_id -from domain import Doc, Tags from model_services.base import AbstractModelService from processors.metrics_collector import concat_trainer_exports -from utils import annotations_to_entities router = APIRouter() logger = logging.getLogger("cms") -@router.post("/preview", - tags=[Tags.Rendering.name], - response_class=StreamingResponse, - dependencies=[Depends(cms_globals.props.current_active_user)], - description="Extract the NER entities in HTML for preview") -async def get_rendered_entities_from_text(request: Request, - text: Annotated[str, Body(description="The text to be sent to the model for NER", media_type="text/plain")], - tracking_id: Union[str, None] = Depends(validate_tracking_id), - model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> StreamingResponse: +@router.post( + "/preview", + tags=[Tags.Rendering.name], + response_class=StreamingResponse, + dependencies=[Depends(cms_globals.props.current_active_user)], + description="Extract the NER entities in HTML for preview", +) +async def get_rendered_entities_from_text( + request: Request, + text: Annotated[ + str, Body(description="The text to be sent to the model for NER", media_type="text/plain") + ], + tracking_id: Union[str, None] = Depends(validate_tracking_id), + model_service: AbstractModelService = Depends(cms_globals.model_service_dep), +) -> StreamingResponse: annotations = model_service.annotate(text) entities = annotations_to_entities(annotations, model_service.model_name) logger.debug("Entities extracted for previewing %s", entities) @@ -41,17 +49,38 @@ async def get_rendered_entities_from_text(request: Request, return response -@router.post("/preview_trainer_export", - tags=[Tags.Rendering.name], - response_class=StreamingResponse, - dependencies=[Depends(cms_globals.props.current_active_user)], - description="Get existing entities in HTML from a trainer export for preview") -def get_rendered_entities_from_trainer_export(request: Request, - trainer_export: Annotated[List[UploadFile], File(description="One or more trainer export files to be uploaded")] = [], - trainer_export_str: Annotated[str, Form(description="The trainer export raw JSON string")] = "{\"projects\": []}", - project_id: Annotated[Union[int, None], Query(description="The target project ID, and if not provided, all projects will be included")] = None, - document_id: Annotated[Union[int, None], Query(description="The target document ID, and if not provided, all documents of the target project(s) will be included")] = None, - tracking_id: Union[str, None] = Depends(validate_tracking_id)) -> Response: +@router.post( + "/preview_trainer_export", + tags=[Tags.Rendering.name], + response_class=StreamingResponse, + dependencies=[Depends(cms_globals.props.current_active_user)], + description="Get existing entities in HTML from a trainer export for preview", +) +def get_rendered_entities_from_trainer_export( + request: Request, + trainer_export: Annotated[ + List[UploadFile], File(description="One or more trainer export files to be uploaded") + ] = [], + trainer_export_str: Annotated[ + str, Form(description="The trainer export raw JSON string") + ] = '{"projects": []}', + project_id: Annotated[ + Union[int, None], + Query( + description="The target project ID, and if not provided, all projects will be included" + ), + ] = None, + document_id: Annotated[ + Union[int, None], + Query( + description=( + "The target document ID, and if not provided, all documents of the target" + " project(s) will be included" + ) + ), + ] = None, + tracking_id: Union[str, None] = Depends(validate_tracking_id), +) -> Response: data: Dict = {"projects": []} if trainer_export is not None: files = [] @@ -62,7 +91,11 @@ def get_rendered_entities_from_trainer_export(request: Request, temp_te.write(line) temp_te.flush() files.append(temp_te) - concatenated = concat_trainer_exports([file.name for file in files], allow_recurring_project_ids=True, allow_recurring_doc_ids=True) + concatenated = concat_trainer_exports( + [file.name for file in files], + allow_recurring_project_ids=True, + allow_recurring_doc_ids=True, + ) logger.debug("Training exports concatenated") finally: for file in files: @@ -79,23 +112,37 @@ def get_rendered_entities_from_trainer_export(request: Request, continue entities = [] for annotation in document["annotations"]: - entities.append({ - "start": annotation["start"], - "end": annotation["end"], - "label": f"{annotation['cui']} ({'correct' if annotation.get('correct', True) else 'incorrect'}{'; terminated' if annotation.get('deleted', False) and annotation.get('killed', False) else ''})", - "kb_id": annotation["cui"], - "kb_url": "#", - }) + correct_label = "correct" if annotation.get("correct", True) else "incorrect" + is_terminated = annotation.get("deleted", False) and annotation.get("killed", False) + terminated_label = "; terminated" if is_terminated else "" + entities.append( + { + "start": annotation["start"], + "end": annotation["end"], + "label": f"{annotation['cui']} ({correct_label}{terminated_label})", + "kb_id": annotation["cui"], + "kb_url": "#", + } + ) # Displacy cannot handle annotations out of appearance order so be this entities = sorted(entities, key=lambda e: e["start"]) logger.debug("Entities extracted for previewing %s", entities) - doc = Doc(text=document["text"], ents=entities, title=f"P{project['id']}/D{document['id']}") + doc = Doc( + text=document["text"], ents=entities, title=f"P{project['id']}/D{document['id']}" + ) htmls.append(displacy.render(doc.dict(), style="ent", manual=True)) if htmls: tracking_id = tracking_id or str(uuid.uuid4()) - response = StreamingResponse(BytesIO("
".join(htmls).encode()), media_type="application/octet-stream") - response.headers["Content-Disposition"] = f'attachment ; filename="preview_{tracking_id}.html"' + response = StreamingResponse( + BytesIO("
".join(htmls).encode()), media_type="application/octet-stream" + ) + response.headers["Content-Disposition"] = ( + f'attachment ; filename="preview_{tracking_id}.html"' + ) else: logger.debug("Cannot find any matching documents to preview") - return JSONResponse(content={"message": "Cannot find any matching documents to preview"}, status_code=HTTP_404_NOT_FOUND) + return JSONResponse( + content={"message": "Cannot find any matching documents to preview"}, + status_code=HTTP_404_NOT_FOUND, + ) return response diff --git a/app/api/routers/stream.py b/app/api/routers/stream.py index 3ed5699..5e4194d 100644 --- a/app/api/routers/stream.py +++ b/app/api/routers/stream.py @@ -1,22 +1,23 @@ +import asyncio import json import logging -import asyncio -from starlette.status import WS_1008_POLICY_VIOLATION -from starlette.websockets import WebSocketDisconnect -from starlette.requests import ClientDisconnect +from typing import Any, AsyncGenerator, Mapping, Optional -import api.globals as cms_globals - -from typing import Any, Mapping, Optional, AsyncGenerator -from starlette.types import Receive, Scope, Send -from starlette.background import BackgroundTask from fastapi import APIRouter, Depends, Request, Response, WebSocket, WebSocketException from pydantic import ValidationError +from starlette.background import BackgroundTask +from starlette.requests import ClientDisconnect +from starlette.status import WS_1008_POLICY_VIOLATION +from starlette.types import Receive, Scope, Send +from starlette.websockets import WebSocketDisconnect + from domain import Annotation, Tags, TextStreamItem -from model_services.base import AbstractModelService from utils import get_settings + +import api.globals as cms_globals +from api.auth.users import CmsUserManager, get_user_manager from api.utils import get_rate_limiter -from api.auth.users import get_user_manager, CmsUserManager +from model_services.base import AbstractModelService PATH_STREAM_PROCESS = "/process" PATH_WS_PROCESS = "/ws" @@ -27,31 +28,46 @@ logger = logging.getLogger("cms") -@router.post(PATH_STREAM_PROCESS, - tags=[Tags.Annotations.name], - dependencies=[Depends(cms_globals.props.current_active_user)], - description="Extract the NER entities from a stream of texts in the JSON Lines format") +@router.post( + PATH_STREAM_PROCESS, + tags=[Tags.Annotations.name], + dependencies=[Depends(cms_globals.props.current_active_user)], + description="Extract the NER entities from a stream of texts in the JSON Lines format", +) @limiter.limit(config.PROCESS_BULK_RATE_LIMIT) -async def get_entities_stream_from_jsonlines_stream(request: Request, - model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> Response: +async def get_entities_stream_from_jsonlines_stream( + request: Request, model_service: AbstractModelService = Depends(cms_globals.model_service_dep) +) -> Response: annotation_stream = _annotation_async_gen(request, model_service) - return _LocalStreamingResponse(annotation_stream, media_type="application/x-ndjson; charset=utf-8") + return _LocalStreamingResponse( + annotation_stream, media_type="application/x-ndjson; charset=utf-8" + ) @router.websocket(PATH_WS_PROCESS) # @limiter.limit(config.PROCESS_BULK_RATE_LIMIT) # Not supported yet -async def get_inline_annotations_from_websocket(websocket: WebSocket, - user_manager: CmsUserManager = Depends(get_user_manager), - model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> None: +async def get_inline_annotations_from_websocket( + websocket: WebSocket, + user_manager: CmsUserManager = Depends(get_user_manager), + model_service: AbstractModelService = Depends(cms_globals.model_service_dep), +) -> None: monitor_idle_task = None try: if get_settings().AUTH_USER_ENABLED == "true": cookie = websocket.cookies.get("fastapiusersauth") if cookie is None: - raise WebSocketException(code=WS_1008_POLICY_VIOLATION, reason="Authentication cookie not found") - user = await cms_globals.props.auth_backends[1].get_strategy().read_token(cookie, user_manager) + raise WebSocketException( + code=WS_1008_POLICY_VIOLATION, reason="Authentication cookie not found" + ) + user = ( + await cms_globals.props.auth_backends[1] + .get_strategy() + .read_token(cookie, user_manager) + ) if not user or not user.is_active: - raise WebSocketException(code=WS_1008_POLICY_VIOLATION, reason="User not found or not active") + raise WebSocketException( + code=WS_1008_POLICY_VIOLATION, reason="User not found or not active" + ) await websocket.accept() @@ -60,7 +76,9 @@ async def get_inline_annotations_from_websocket(websocket: WebSocket, async def _monitor_idle() -> None: while True: await asyncio.sleep(get_settings().WS_IDLE_TIMEOUT_SECONDS) - if (asyncio.get_event_loop().time() - time_of_last_seen_msg) >= get_settings().WS_IDLE_TIMEOUT_SECONDS: + if ( + asyncio.get_event_loop().time() - time_of_last_seen_msg + ) >= get_settings().WS_IDLE_TIMEOUT_SECONDS: await websocket.close() logger.debug("Connection closed due to inactivity") break @@ -75,7 +93,10 @@ async def _monitor_idle() -> None: annotated_text = "" start_index = 0 for annotation in annotations: - annotated_text += f'{text[start_index:annotation["start"]]}[{annotation["label_name"]}: {text[annotation["start"]:annotation["end"]]}]' + preface_slice = text[start_index : annotation["start"]] + annotation_slice = text[annotation["start"] : annotation["end"]] + label = annotation["label_name"] + annotated_text += f"{preface_slice}[{label}: {annotation_slice}]" start_index = annotation["end"] annotated_text += text[start_index:] except Exception as e: @@ -94,13 +115,14 @@ async def _monitor_idle() -> None: class _LocalStreamingResponse(Response): - - def __init__(self, - content: Any, - status_code: int = 200, - headers: Optional[Mapping[str, str]] = None, - media_type: Optional[str] = None, - background: Optional[BackgroundTask] = None) -> None: + def __init__( + self, + content: Any, + status_code: int = 200, + headers: Optional[Mapping[str, str]] = None, + media_type: Optional[str] = None, + background: Optional[BackgroundTask] = None, + ) -> None: self.content = content self.status_code = status_code self.media_type = self.media_type if media_type is None else media_type @@ -112,22 +134,42 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: max_chunk_size = 1024 async for line in self.content: if not response_started: - await send({"type": "http.response.start", "status": self.status_code, "headers": self.raw_headers}) + await send( + { + "type": "http.response.start", + "status": self.status_code, + "headers": self.raw_headers, + } + ) response_started = True line_bytes = line.encode("utf-8") for i in range(0, len(line_bytes), max_chunk_size): - chunk = line_bytes[i:i + max_chunk_size] + chunk = line_bytes[i : i + max_chunk_size] await send({"type": "http.response.body", "body": chunk, "more_body": True}) if not response_started: - await send({"type": "http.response.start", "status": self.status_code, "headers": self.raw_headers}) - await send({"type": "http.response.body", "body": '{"error": "Empty stream"}\n'.encode("utf-8"), "more_body": True}) + await send( + { + "type": "http.response.start", + "status": self.status_code, + "headers": self.raw_headers, + } + ) + await send( + { + "type": "http.response.body", + "body": '{"error": "Empty stream"}\n'.encode("utf-8"), + "more_body": True, + } + ) await send({"type": "http.response.body", "body": b"", "more_body": False}) if self.background is not None: await self.background() -async def _annotation_async_gen(request: Request, model_service: AbstractModelService) -> AsyncGenerator: +async def _annotation_async_gen( + request: Request, model_service: AbstractModelService +) -> AsyncGenerator: try: buffer = "" doc_idx = 0 @@ -151,7 +193,18 @@ async def _annotation_async_gen(request: Request, model_service: AbstractModelSe except json.JSONDecodeError: yield json.dumps({"error": "Invalid JSON Line", "content": line}) + "\n" except ValidationError: - yield json.dumps({"error": f"Invalid JSON properties found. The schema should be {TextStreamItem.schema_json()}", "content": line}) + "\n" + yield ( + json.dumps( + { + "error": ( + "Invalid JSON properties found." + f" The schema should be {TextStreamItem.schema_json()}" + ), + "content": line, + } + ) + + "\n" + ) finally: doc_idx += 1 if buffer.strip(): @@ -165,7 +218,18 @@ async def _annotation_async_gen(request: Request, model_service: AbstractModelSe except json.JSONDecodeError: yield json.dumps({"error": "Invalid JSON Line", "content": buffer}) + "\n" except ValidationError: - yield json.dumps({"error": f"Invalid JSON properties found. The schema should be {TextStreamItem.schema_json()}", "content": buffer}) + "\n" + yield ( + json.dumps( + { + "error": ( + "Invalid JSON properties found." + f" The schema should be {TextStreamItem.schema_json()}" + ), + "content": buffer, + } + ) + + "\n" + ) finally: doc_idx += 1 except ClientDisconnect: diff --git a/app/api/routers/supervised_training.py b/app/api/routers/supervised_training.py index fd66443..887723d 100644 --- a/app/api/routers/supervised_training.py +++ b/app/api/routers/supervised_training.py @@ -1,40 +1,69 @@ -import tempfile -import uuid import json import logging +import tempfile +import uuid from typing import List, Union -from typing_extensions import Annotated -from fastapi import APIRouter, Depends, UploadFile, Query, Request, File, Form +from fastapi import APIRouter, Depends, File, Form, Query, Request, UploadFile from fastapi.responses import JSONResponse from starlette.status import HTTP_202_ACCEPTED, HTTP_503_SERVICE_UNAVAILABLE +from typing_extensions import Annotated + +from domain import Tags +from utils import filter_by_concept_ids import api.globals as cms_globals from api.dependencies import validate_tracking_id -from domain import Tags from model_services.base import AbstractModelService from processors.metrics_collector import concat_trainer_exports -from utils import filter_by_concept_ids router = APIRouter() logger = logging.getLogger("cms") -@router.post("/train_supervised", - status_code=HTTP_202_ACCEPTED, - response_class=JSONResponse, - tags=[Tags.Training.name], - dependencies=[Depends(cms_globals.props.current_active_user)], - description="Upload one or more trainer export files and trigger the supervised training") -async def train_supervised(request: Request, - trainer_export: Annotated[List[UploadFile], File(description="One or more trainer export files to be uploaded")], - epochs: Annotated[int, Query(description="The number of training epochs", ge=0)] = 1, - lr_override: Annotated[Union[float, None], Query(description="The override of the initial learning rate", gt=0.0)] = None, - test_size: Annotated[Union[float, None], Query(description="The override of the test size in percentage. (For a 'huggingface-ner' model, a negative value can be used to apply the train-validation-test split if implicitly defined in trainer export: 'projects[0]' is used for training, 'projects[1]' for validation, and 'projects[2]' for testing)")] = 0.2, - log_frequency: Annotated[int, Query(description="The number of processed documents after which training metrics will be logged", ge=1)] = 1, - description: Annotated[Union[str, None], Form(description="The description of the training or change logs")] = None, - tracking_id: Union[str, None] = Depends(validate_tracking_id), - model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> JSONResponse: +@router.post( + "/train_supervised", + status_code=HTTP_202_ACCEPTED, + response_class=JSONResponse, + tags=[Tags.Training.name], + dependencies=[Depends(cms_globals.props.current_active_user)], + description="Upload one or more trainer export files and trigger the supervised training", +) +async def train_supervised( + request: Request, + trainer_export: Annotated[ + List[UploadFile], File(description="One or more trainer export files to be uploaded") + ], + epochs: Annotated[int, Query(description="The number of training epochs", ge=0)] = 1, + lr_override: Annotated[ + Union[float, None], Query(description="The override of the initial learning rate", gt=0.0) + ] = None, + test_size: Annotated[ + Union[float, None], + Query( + description=( + "The override of the test size in percentage. (For a 'huggingface-ner' model, a" + " negative value can be used to apply the train-validation-test split if implicitly" + " defined in trainer export: 'projects[0]' is used for training, 'projects[1]' for" + " validation, and 'projects[2]' for testing)" + ) + ), + ] = 0.2, + log_frequency: Annotated[ + int, + Query( + description=( + "The number of processed documents after which training metrics will be logged" + ), + ge=1, + ), + ] = 1, + description: Annotated[ + Union[str, None], Form(description="The description of the training or change logs") + ] = None, + tracking_id: Union[str, None] = Depends(validate_tracking_id), + model_service: AbstractModelService = Depends(cms_globals.model_service_dep), +) -> JSONResponse: files = [] file_names = [] for te in trainer_export: @@ -45,7 +74,9 @@ async def train_supervised(request: Request, files.append(temp_te) file_names.append("" if te.filename is None else te.filename) - concatenated = concat_trainer_exports([file.name for file in files], allow_recurring_doc_ids=False) + concatenated = concat_trainer_exports( + [file.name for file in files], allow_recurring_doc_ids=False + ) logger.debug("Training exports concatenated") data_file = tempfile.NamedTemporaryFile(mode="w") concatenated = filter_by_concept_ids(concatenated, model_service.info().model_type) @@ -55,16 +86,18 @@ async def train_supervised(request: Request, data_file.seek(0) training_id = tracking_id or str(uuid.uuid4()) try: - training_accepted = model_service.train_supervised(data_file, - epochs, - log_frequency, - training_id, - ",".join(file_names), - raw_data_files=files, - description=description, - synchronised=False, - lr_override=lr_override, - test_size=test_size) + training_accepted = model_service.train_supervised( + data_file, + epochs, + log_frequency, + training_id, + ",".join(file_names), + raw_data_files=files, + description=description, + synchronised=False, + lr_override=lr_override, + test_size=test_size, + ) finally: for file in files: file.close() @@ -75,7 +108,18 @@ async def train_supervised(request: Request, def _get_training_response(training_accepted: bool, training_id: str) -> JSONResponse: if training_accepted: logger.debug("Training accepted with ID: %s", training_id) - return JSONResponse(content={"message": "Your training started successfully.", "training_id": training_id}, status_code=HTTP_202_ACCEPTED) + return JSONResponse( + content={"message": "Your training started successfully.", "training_id": training_id}, + status_code=HTTP_202_ACCEPTED, + ) else: logger.debug("Training refused due to another active training or evaluation on this model") - return JSONResponse(content={"message": "Another training or evaluation on this model is still active. Please retry your training later."}, status_code=HTTP_503_SERVICE_UNAVAILABLE) + return JSONResponse( + content={ + "message": ( + "Another training or evaluation on this model is still active." + " Please retry your training later." + ) + }, + status_code=HTTP_503_SERVICE_UNAVAILABLE, + ) diff --git a/app/api/routers/unsupervised_training.py b/app/api/routers/unsupervised_training.py index c3925aa..122112f 100644 --- a/app/api/routers/unsupervised_training.py +++ b/app/api/routers/unsupervised_training.py @@ -1,41 +1,69 @@ import json +import logging import tempfile import uuid -import ijson -import logging -import datasets import zipfile from typing import List, Union -from typing_extensions import Annotated -from fastapi import APIRouter, Depends, UploadFile, Query, Request, File +import datasets +import ijson +from fastapi import APIRouter, Depends, File, Query, Request, UploadFile from fastapi.responses import JSONResponse from starlette.status import HTTP_202_ACCEPTED, HTTP_503_SERVICE_UNAVAILABLE +from typing_extensions import Annotated + +from domain import ModelType, Tags +from exception import ClientException, ConfigurationException +from utils import get_settings + import api.globals as cms_globals from api.dependencies import validate_tracking_id -from domain import Tags, ModelType from model_services.base import AbstractModelService -from utils import get_settings -from exception import ConfigurationException, ClientException router = APIRouter() logger = logging.getLogger("cms") -@router.post("/train_unsupervised", - status_code=HTTP_202_ACCEPTED, - response_class=JSONResponse, - tags=[Tags.Training.name], - dependencies=[Depends(cms_globals.props.current_active_user)]) -async def train_unsupervised(request: Request, - training_data: Annotated[List[UploadFile], File(description="One or more files to be uploaded and each contains a list of plain texts, in the format of [\"text_1\", \"text_2\", ..., \"text_n\"]")], - epochs: Annotated[int, Query(description="The number of training epochs", ge=0)] = 1, - lr_override: Annotated[Union[float, None], Query(description="The override of the initial learning rate", gt=0.0)] = None, - test_size: Annotated[Union[float, None], Query(description="The override of the test size in percentage", ge=0.0)] = 0.2, - log_frequency: Annotated[int, Query(description="The number of processed documents after which training metrics will be logged", ge=1)] = 1000, - description: Annotated[Union[str, None], Query(description="The description of the training or change logs")] = None, - tracking_id: Union[str, None] = Depends(validate_tracking_id), - model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> JSONResponse: +@router.post( + "/train_unsupervised", + status_code=HTTP_202_ACCEPTED, + response_class=JSONResponse, + tags=[Tags.Training.name], + dependencies=[Depends(cms_globals.props.current_active_user)], +) +async def train_unsupervised( + request: Request, + training_data: Annotated[ + List[UploadFile], + File( + description=( + "One or more files to be uploaded and each contains a list of plain texts, in the" + ' format of ["text_1", "text_2", ..., "text_n"]' + ) + ), + ], + epochs: Annotated[int, Query(description="The number of training epochs", ge=0)] = 1, + lr_override: Annotated[ + Union[float, None], Query(description="The override of the initial learning rate", gt=0.0) + ] = None, + test_size: Annotated[ + Union[float, None], Query(description="The override of the test size in percentage", ge=0.0) + ] = 0.2, + log_frequency: Annotated[ + int, + Query( + description=( + "The number of processed documents after which training metrics will be logged" + ), + ge=1, + ), + ] = 1000, + description: Annotated[ + Union[str, None], Query(description="The description of the training or change logs") + ] = None, + tracking_id: Union[str, None] = Depends(validate_tracking_id), + model_service: AbstractModelService = Depends(cms_globals.model_service_dep), +) -> JSONResponse: """ Upload one or more plain text files and trigger the unsupervised training """ @@ -65,16 +93,18 @@ async def train_unsupervised(request: Request, data_file.seek(0) training_id = tracking_id or str(uuid.uuid4()) try: - training_accepted = model_service.train_unsupervised(data_file, - epochs, - log_frequency, - training_id, - ",".join(file_names), - raw_data_files=files, - synchronised=False, - lr_override=lr_override, - test_size=test_size, - description=description) + training_accepted = model_service.train_unsupervised( + data_file, + epochs, + log_frequency, + training_id, + ",".join(file_names), + raw_data_files=files, + synchronised=False, + lr_override=lr_override, + test_size=test_size, + description=description, + ) finally: for file in files: file.close() @@ -82,32 +112,95 @@ async def train_unsupervised(request: Request, return _get_training_response(training_accepted, training_id) -@router.post("/train_unsupervised_with_hf_hub_dataset", - status_code=HTTP_202_ACCEPTED, - response_class=JSONResponse, - tags=[Tags.Training.name], - dependencies=[Depends(cms_globals.props.current_active_user)]) -async def train_unsupervised_with_hf_dataset(request: Request, - hf_dataset_repo_id: Annotated[Union[str, None], Query(description="The repository ID of the dataset to download from Hugging Face Hub, will be ignored when 'hf_dataset_package' is provided")] = None, - hf_dataset_config: Annotated[Union[str, None], Query(description="The name of the dataset configuration, will be ignored when 'hf_dataset_package' is provided")] = None, - hf_dataset_package: Annotated[Union[UploadFile, None], File(description="A ZIP file containing the dataset to be uploaded, will disable the download of 'hf_dataset_repo_id'")] = None, - trust_remote_code: Annotated[bool, Query(description="Whether to trust the remote code of the dataset")] = False, - text_column_name: Annotated[str, Query(description="The name of the text column in the dataset")] = "text", - epochs: Annotated[int, Query(description="The number of training epochs", ge=0)] = 1, - lr_override: Annotated[Union[float, None], Query(description="The override of the initial learning rate", gt=0.0)] = None, - test_size: Annotated[Union[float, None], Query(description="The override of the test size in percentage will only take effect if the dataset does not have predefined validation or test splits", ge=0.0)] = 0.2, - log_frequency: Annotated[int, Query(description="The number of processed documents after which training metrics will be logged", ge=1)] = 1000, - description: Annotated[Union[str, None], Query(description="The description of the training or change logs")] = None, - tracking_id: Union[str, None] = Depends(validate_tracking_id), - model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> JSONResponse: +@router.post( + "/train_unsupervised_with_hf_hub_dataset", + status_code=HTTP_202_ACCEPTED, + response_class=JSONResponse, + tags=[Tags.Training.name], + dependencies=[Depends(cms_globals.props.current_active_user)], +) +async def train_unsupervised_with_hf_dataset( + request: Request, + hf_dataset_repo_id: Annotated[ + Union[str, None], + Query( + description=( + "The repository ID of the dataset to download from Hugging Face Hub, will be" + " ignored when 'hf_dataset_package' is provided" + ) + ), + ] = None, + hf_dataset_config: Annotated[ + Union[str, None], + Query( + description=( + "The name of the dataset configuration, will be ignored when 'hf_dataset_package'" + " is provided" + ) + ), + ] = None, + hf_dataset_package: Annotated[ + Union[UploadFile, None], + File( + description=( + "A ZIP file containing the dataset to be uploaded, will disable the download of" + " 'hf_dataset_repo_id'" + ) + ), + ] = None, + trust_remote_code: Annotated[ + bool, Query(description="Whether to trust the remote code of the dataset") + ] = False, + text_column_name: Annotated[ + str, Query(description="The name of the text column in the dataset") + ] = "text", + epochs: Annotated[int, Query(description="The number of training epochs", ge=0)] = 1, + lr_override: Annotated[ + Union[float, None], Query(description="The override of the initial learning rate", gt=0.0) + ] = None, + test_size: Annotated[ + Union[float, None], + Query( + description=( + "The override of the test size in percentage will only take effect if the dataset" + " does not have predefined validation or test splits" + ), + ge=0.0, + ), + ] = 0.2, + log_frequency: Annotated[ + int, + Query( + description=( + "The number of processed documents after which training metrics will be logged" + ), + ge=1, + ), + ] = 1000, + description: Annotated[ + Union[str, None], Query(description="The description of the training or change logs") + ] = None, + tracking_id: Union[str, None] = Depends(validate_tracking_id), + model_service: AbstractModelService = Depends(cms_globals.model_service_dep), +) -> JSONResponse: """ Trigger the unsupervised training with a dataset from Hugging Face Hub """ if hf_dataset_repo_id is None and hf_dataset_package is None: - raise ClientException("Either 'hf_dataset_repo_id' or 'hf_dataset_package' must be provided") + raise ClientException( + "Either 'hf_dataset_repo_id' or 'hf_dataset_package' must be provided" + ) - if model_service.info().model_type not in [ModelType.HUGGINGFACE_NER, ModelType.MEDCAT_SNOMED, ModelType.MEDCAT_ICD10, ModelType.MEDCAT_UMLS]: - raise ConfigurationException(f"Currently this endpoint is not available for models of type: {model_service.info().model_type.value}") + model_type = model_service.info().model_type + if model_type not in [ + ModelType.HUGGINGFACE_NER, + ModelType.MEDCAT_SNOMED, + ModelType.MEDCAT_ICD10, + ModelType.MEDCAT_UMLS, + ]: + raise ConfigurationException( + f"Currently this endpoint is not available for models of type: {model_type.value}" + ) data_dir = tempfile.TemporaryDirectory() if hf_dataset_package is not None: @@ -119,37 +212,58 @@ async def train_unsupervised_with_hf_dataset(request: Request, logger.debug("Training dataset uploaded and extracted") else: input_file_name = hf_dataset_repo_id - hf_dataset = datasets.load_dataset(hf_dataset_repo_id, - cache_dir=get_settings().TRAINING_CACHE_DIR, - trust_remote_code=trust_remote_code, - name=hf_dataset_config) + hf_dataset = datasets.load_dataset( + hf_dataset_repo_id, + cache_dir=get_settings().TRAINING_CACHE_DIR, + trust_remote_code=trust_remote_code, + name=hf_dataset_config, + ) for split in hf_dataset.keys(): if text_column_name not in hf_dataset[split].column_names: - raise ClientException(f"The dataset does not contain a '{text_column_name}' column in the split(s)") + raise ClientException( + f"The dataset does not contain a '{text_column_name}' column in the split(s)" + ) if text_column_name != "text": - hf_dataset[split] = hf_dataset[split].map(lambda x: {"text": x[text_column_name]}, batched=True) - hf_dataset[split] = hf_dataset[split].remove_columns([col for col in hf_dataset[split].column_names if col != "text"]) + hf_dataset[split] = hf_dataset[split].map( + lambda x: {"text": x[text_column_name]}, batched=True + ) + hf_dataset[split] = hf_dataset[split].remove_columns( + [col for col in hf_dataset[split].column_names if col != "text"] + ) logger.debug("Training dataset downloaded and transformed") hf_dataset.save_to_disk(data_dir.name) training_id = tracking_id or str(uuid.uuid4()) - training_accepted = model_service.train_unsupervised(data_dir, - epochs, - log_frequency, - training_id, - input_file_name, - raw_data_files=None, - synchronised=False, - lr_override=lr_override, - test_size=test_size, - description=description) + training_accepted = model_service.train_unsupervised( + data_dir, + epochs, + log_frequency, + training_id, + input_file_name, + raw_data_files=None, + synchronised=False, + lr_override=lr_override, + test_size=test_size, + description=description, + ) return _get_training_response(training_accepted, training_id) def _get_training_response(training_accepted: bool, training_id: str) -> JSONResponse: if training_accepted: logger.debug("Training accepted with ID: %s", training_id) - return JSONResponse(content={"message": "Your training started successfully.", "training_id": training_id}, status_code=HTTP_202_ACCEPTED) + return JSONResponse( + content={"message": "Your training started successfully.", "training_id": training_id}, + status_code=HTTP_202_ACCEPTED, + ) else: logger.debug("Training refused due to another active training or evaluation on this model") - return JSONResponse(content={"message": "Another training or evaluation on this model is still active. Please retry later."}, status_code=HTTP_503_SERVICE_UNAVAILABLE) + return JSONResponse( + content={ + "message": ( + "Another training or evaluation on this model is still active." + " Please retry later." + ) + }, + status_code=HTTP_503_SERVICE_UNAVAILABLE, + ) diff --git a/app/api/static/images/cogstack_logo.svg b/app/api/static/images/cogstack_logo.svg index bee1402..a5047c4 100644 --- a/app/api/static/images/cogstack_logo.svg +++ b/app/api/static/images/cogstack_logo.svg @@ -1,3 +1,3 @@ - \ No newline at end of file + diff --git a/app/api/utils.py b/app/api/utils.py index 8f49784..7ab80c2 100644 --- a/app/api/utils.py +++ b/app/api/utils.py @@ -1,54 +1,78 @@ +import base64 +import hashlib import json import logging import re -import hashlib -import base64 from functools import lru_cache from typing import Optional -from fastapi import FastAPI, Request -from starlette.responses import JSONResponse -from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_400_BAD_REQUEST, HTTP_429_TOO_MANY_REQUESTS -from slowapi.middleware import SlowAPIMiddleware, SlowAPIASGIMiddleware -from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import padding +from fastapi import FastAPI, Request +from fastapi_users.jwt import decode_jwt from slowapi import Limiter -from slowapi.util import get_remote_address from slowapi.errors import RateLimitExceeded -from fastapi_users.jwt import decode_jwt +from slowapi.middleware import SlowAPIASGIMiddleware, SlowAPIMiddleware +from slowapi.util import get_remote_address +from starlette.responses import JSONResponse +from starlette.status import ( + HTTP_400_BAD_REQUEST, + HTTP_429_TOO_MANY_REQUESTS, + HTTP_500_INTERNAL_SERVER_ERROR, +) + from config import Settings -from exception import StartTrainingException, AnnotationException, ConfigurationException, ClientException +from exception import ( + AnnotationException, + ClientException, + ConfigurationException, + StartTrainingException, +) logger = logging.getLogger("cms") def add_exception_handlers(app: FastAPI) -> None: - @app.exception_handler(json.decoder.JSONDecodeError) - async def json_decoding_exception_handler(_: Request, exception: json.decoder.JSONDecodeError) -> JSONResponse: + async def json_decoding_exception_handler( + _: Request, exception: json.decoder.JSONDecodeError + ) -> JSONResponse: logger.exception(exception) return JSONResponse(status_code=HTTP_400_BAD_REQUEST, content={"message": str(exception)}) @app.exception_handler(RateLimitExceeded) async def rate_limit_exceeded_handler(_: Request, exception: RateLimitExceeded) -> JSONResponse: logger.exception(exception) - return JSONResponse(status_code=HTTP_429_TOO_MANY_REQUESTS, content={"message": "Too many requests. Please wait and try your request again."}) + return JSONResponse( + status_code=HTTP_429_TOO_MANY_REQUESTS, + content={"message": "Too many requests. Please wait and try your request again."}, + ) @app.exception_handler(StartTrainingException) - async def start_training_exception_handler(_: Request, exception: StartTrainingException) -> JSONResponse: + async def start_training_exception_handler( + _: Request, exception: StartTrainingException + ) -> JSONResponse: logger.exception(exception) - return JSONResponse(status_code=HTTP_500_INTERNAL_SERVER_ERROR, content={"message": str(exception)}) + return JSONResponse( + status_code=HTTP_500_INTERNAL_SERVER_ERROR, content={"message": str(exception)} + ) @app.exception_handler(AnnotationException) - async def annotation_exception_handler(_: Request, exception: AnnotationException) -> JSONResponse: + async def annotation_exception_handler( + _: Request, exception: AnnotationException + ) -> JSONResponse: logger.exception(exception) return JSONResponse(status_code=HTTP_400_BAD_REQUEST, content={"message": str(exception)}) @app.exception_handler(ConfigurationException) - async def configuration_exception_handler(_: Request, exception: ConfigurationException) -> JSONResponse: + async def configuration_exception_handler( + _: Request, exception: ConfigurationException + ) -> JSONResponse: logger.exception(exception) - return JSONResponse(status_code=HTTP_500_INTERNAL_SERVER_ERROR, content={"message": str(exception)}) + return JSONResponse( + status_code=HTTP_500_INTERNAL_SERVER_ERROR, content={"message": str(exception)} + ) @app.exception_handler(ClientException) async def client_exception_handler(_: Request, exception: ClientException) -> JSONResponse: @@ -58,7 +82,9 @@ async def client_exception_handler(_: Request, exception: ClientException) -> JS @app.exception_handler(Exception) async def unhandled_exception_handler(_: Request, exception: Exception) -> JSONResponse: logger.exception(exception) - return JSONResponse(status_code=HTTP_500_INTERNAL_SERVER_ERROR, content={"message": str(exception)}) + return JSONResponse( + status_code=HTTP_500_INTERNAL_SERVER_ERROR, content={"message": str(exception)} + ) def add_rate_limiter(app: FastAPI, config: Settings, streamable: bool = False) -> None: @@ -86,8 +112,14 @@ def get_user_auth(request: Request) -> str: limiter_key = re.sub(r":+", ":", re.sub(r"/+", ":", limiter_prefix + current_key)) return limiter_key - auth_user_enabled = config.AUTH_USER_ENABLED == "true" if auth_user_enabled is None else auth_user_enabled - return Limiter(key_func=get_user_auth, strategy="moving-window") if auth_user_enabled else Limiter(key_func=get_remote_address, strategy="moving-window") + auth_user_enabled = ( + config.AUTH_USER_ENABLED == "true" if auth_user_enabled is None else auth_user_enabled + ) + return ( + Limiter(key_func=get_user_auth, strategy="moving-window") + if auth_user_enabled + else Limiter(key_func=get_remote_address, strategy="moving-window") + ) def adjust_rate_limit_str(rate_limit: str) -> str: @@ -99,13 +131,21 @@ def adjust_rate_limit_str(rate_limit: str) -> str: def encrypt(raw: str, public_key_pem: str) -> str: public_key = serialization.load_pem_public_key(public_key_pem.encode(), backend=default_backend) - encrypted = public_key.encrypt(raw.encode(), # type: ignore - padding.OAEP(mgf=padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None)) + encrypted = public_key.encrypt( # type: ignore + raw.encode(), + padding.OAEP( + mgf=padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None + ), + ) return base64.b64encode(encrypted).decode() def decrypt(b64_encoded: str, private_key_pem: str) -> str: private_key = serialization.load_pem_private_key(private_key_pem.encode(), password=None) - decrypted = private_key.decrypt(base64.b64decode(b64_encoded), # type: ignore - padding.OAEP(mgf=padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None)) + decrypted = private_key.decrypt( # type: ignore + base64.b64decode(b64_encoded), + padding.OAEP( + mgf=padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None + ), + ) return decrypted.decode() diff --git a/app/cli/cli.py b/app/cli/cli.py index 605da05..2ecbdd0 100644 --- a/app/cli/cli.py +++ b/app/cli/cli.py @@ -1,11 +1,11 @@ +import inspect import json import logging.config import os +import subprocess import sys import uuid -import inspect import warnings -import subprocess current_frame = inspect.currentframe() if current_frame is None: # noqa @@ -32,33 +32,53 @@ from datasets import load_dataset # noqa from domain import ModelType, TrainingType, BuildBackend, Device # noqa from registry import model_service_registry # noqa -from api.api import get_model_server, get_stream_server # noqa +from api.api import get_model_server, get_stream_server # noqa from utils import get_settings, send_gelf_message # noqa from management.model_manager import ModelManager # noqa from api.dependencies import ModelServiceDep, ModelManagerDep # noqa from management.tracker_client import TrackerClient # noqa -cmd_app = typer.Typer(name="python cli.py", help="CLI for various CogStack ModelServe operations", add_completion=False) -stream_app = typer.Typer(name="python cli.py stream", help="This groups various stream operations", add_completion=False) +cmd_app = typer.Typer( + name="python cli.py", + help="CLI for various CogStack ModelServe operations", + add_completion=False, +) +stream_app = typer.Typer( + name="python cli.py stream", help="This groups various stream operations", add_completion=False +) cmd_app.add_typer(stream_app, name="stream") -package_app = typer.Typer(name="python cli.py package", help="This groups various package operations", add_completion=False) +package_app = typer.Typer( + name="python cli.py package", + help="This groups various package operations", + add_completion=False, +) cmd_app.add_typer(package_app, name="package") logging.config.fileConfig(os.path.join(parent_dir, "logging.ini"), disable_existing_loggers=False) @cmd_app.command("serve", help="This serves various CogStack NLP models") -def serve_model(model_type: ModelType = typer.Option(..., help="The type of the model to serve"), - model_path: str = typer.Option("", help="The file path to the model package"), - mlflow_model_uri: str = typer.Option("", help="The URI of the MLflow model to serve", metavar="models:/MODEL_NAME/ENV"), - host: str = typer.Option("127.0.0.1", help="The hostname of the server"), - port: str = typer.Option("8000", help="The port of the server"), - model_name: Optional[str] = typer.Option(None, help="The string representation of the model name"), - streamable: bool = typer.Option(False, help="Serve the streamable endpoints only"), - device: Device = typer.Option(Device.DEFAULT, help="The device to serve the model on"), - debug: Optional[bool] = typer.Option(None, help="Run in the debug mode")) -> None: +def serve_model( + model_type: ModelType = typer.Option(..., help="The type of the model to serve"), + model_path: str = typer.Option("", help="The file path to the model package"), + mlflow_model_uri: str = typer.Option( + "", help="The URI of the MLflow model to serve", metavar="models:/MODEL_NAME/ENV" + ), + host: str = typer.Option("127.0.0.1", help="The hostname of the server"), + port: str = typer.Option("8000", help="The port of the server"), + model_name: Optional[str] = typer.Option( + None, help="The string representation of the model name" + ), + streamable: bool = typer.Option(False, help="Serve the streamable endpoints only"), + device: Device = typer.Option(Device.DEFAULT, help="The device to serve the model on"), + debug: Optional[bool] = typer.Option(None, help="Run in the debug mode"), +) -> None: logger = _get_logger(debug, model_type, model_name) get_settings().DEVICE = device.value - if model_type in [ModelType.HUGGINGFACE_NER, ModelType.MEDCAT_DEID, ModelType.TRANSFORMERS_DEID]: + if model_type in [ + ModelType.HUGGINGFACE_NER, + ModelType.MEDCAT_DEID, + ModelType.TRANSFORMERS_DEID, + ]: get_settings().DISABLE_METACAT_TRAINING = "true" if "GELF_INPUT_URI" in os.environ and os.environ["GELF_INPUT_URI"]: @@ -69,7 +89,10 @@ def serve_model(model_type: ModelType = typer.Option(..., help="The type of the logger.addHandler(gelf_tcp_handler) logging.getLogger("uvicorn").addHandler(gelf_tcp_handler) except Exception: - logger.exception("$GELF_INPUT_URI is set to \"%s\" but it's not ready to receive logs", os.environ['GELF_INPUT_URI']) + logger.exception( + '$GELF_INPUT_URI is set to "%s" but it\'s not ready to receive logs', + os.environ["GELF_INPUT_URI"], + ) config = get_settings() @@ -90,7 +113,9 @@ def serve_model(model_type: ModelType = typer.Option(..., help="The type of the model_service.init_model() cms_globals.model_manager_dep = ModelManagerDep(model_service) elif mlflow_model_uri: - model_service = ModelManager.retrieve_model_service_from_uri(mlflow_model_uri, config, dst_model_path) + model_service = ModelManager.retrieve_model_service_from_uri( + mlflow_model_uri, config, dst_model_path + ) model_service.model_name = model_name if model_name is not None else "CMS model" model_service_dep.model_service = model_service cms_globals.model_manager_dep = ModelManagerDep(model_service) @@ -102,24 +127,43 @@ def serve_model(model_type: ModelType = typer.Option(..., help="The type of the logger.info('Start serving model "%s" on %s:%s', model_type, host, port) # interrupted = False # while not interrupted: - uvicorn.run(model_server_app if not streamable else get_stream_server(), host=host, port=int(port), log_config=None) + uvicorn.run( + model_server_app if not streamable else get_stream_server(), + host=host, + port=int(port), + log_config=None, + ) # interrupted = True typer.echo("Shutting down due to either keyboard interrupt or system exit") @cmd_app.command("train", help="This pretrains or fine-tunes various CogStack NLP models") -def train_model(model_type: ModelType = typer.Option(..., help="The type of the model to serve"), - base_model_path: str = typer.Option("", help="The file path to the base model package to be trained on"), - mlflow_model_uri: str = typer.Option("", help="The URI of the MLflow model to train", metavar="models:/MODEL_NAME/ENV"), - training_type: TrainingType = typer.Option(..., help="The type of training"), - data_file_path: str = typer.Option(..., help="The path to the training asset file"), - epochs: int = typer.Option(1, help="The number of training epochs"), - log_frequency: int = typer.Option(1, help="The number of processed documents after which training metrics will be logged"), - hyperparameters: str = typer.Option("{}", help="The overriding hyperparameters serialised as JSON string"), - description: Optional[str] = typer.Option(None, help="The description of the training or change logs"), - model_name: Optional[str] = typer.Option(None, help="The string representation of the model name"), - device: Device = typer.Option(Device.DEFAULT, help="The device to train the model on"), - debug: Optional[bool] = typer.Option(None, help="Run in the debug mode")) -> None: +def train_model( + model_type: ModelType = typer.Option(..., help="The type of the model to serve"), + base_model_path: str = typer.Option( + "", help="The file path to the base model package to be trained on" + ), + mlflow_model_uri: str = typer.Option( + "", help="The URI of the MLflow model to train", metavar="models:/MODEL_NAME/ENV" + ), + training_type: TrainingType = typer.Option(..., help="The type of training"), + data_file_path: str = typer.Option(..., help="The path to the training asset file"), + epochs: int = typer.Option(1, help="The number of training epochs"), + log_frequency: int = typer.Option( + 1, help="The number of processed documents after which training metrics will be logged" + ), + hyperparameters: str = typer.Option( + "{}", help="The overriding hyperparameters serialised as JSON string" + ), + description: Optional[str] = typer.Option( + None, help="The description of the training or change logs" + ), + model_name: Optional[str] = typer.Option( + None, help="The string representation of the model name" + ), + device: Device = typer.Option(Device.DEFAULT, help="The device to train the model on"), + debug: Optional[bool] = typer.Option(None, help="Run in the debug mode"), +) -> None: logger = _get_logger(debug, model_type, model_name) config = get_settings() @@ -140,7 +184,9 @@ def train_model(model_type: ModelType = typer.Option(..., help="The type of the model_service.model_name = model_name if model_name is not None else "CMS model" model_service.init_model() elif mlflow_model_uri: - model_service = ModelManager.retrieve_model_service_from_uri(mlflow_model_uri, config, dst_model_path) + model_service = ModelManager.retrieve_model_service_from_uri( + mlflow_model_uri, config, dst_model_path + ) model_service.model_name = model_name if model_name is not None else "CMS model" model_service_dep.model_service = model_service else: @@ -149,27 +195,61 @@ def train_model(model_type: ModelType = typer.Option(..., help="The type of the training_id = str(uuid.uuid4()) with open(data_file_path, "r") as data_file: - training_args = [data_file, epochs, log_frequency, training_id, data_file.name, [data_file], description, True] - if training_type == TrainingType.SUPERVISED and model_service._supervised_trainer is not None: + training_args = [ + data_file, + epochs, + log_frequency, + training_id, + data_file.name, + [data_file], + description, + True, + ] + if ( + training_type == TrainingType.SUPERVISED + and model_service._supervised_trainer is not None + ): model_service.train_supervised(*training_args, **json.loads(hyperparameters)) - elif training_type == TrainingType.UNSUPERVISED and model_service._unsupervised_trainer is not None: + elif ( + training_type == TrainingType.UNSUPERVISED + and model_service._unsupervised_trainer is not None + ): model_service.train_unsupervised(*training_args, **json.loads(hyperparameters)) - elif training_type == TrainingType.META_SUPERVISED and model_service._metacat_trainer is not None: + elif ( + training_type == TrainingType.META_SUPERVISED + and model_service._metacat_trainer is not None + ): model_service.train_metacat(*training_args, **json.loads(hyperparameters)) else: - logger.error("Training type %s is not supported or the corresponding trainer has not been enabled in the .env file.", training_type) + logger.error( + "Training type %s is not supported or the corresponding trainer has not been" + " enabled in the .env file.", + training_type, + ) typer.Exit(code=1) -@cmd_app.command("register", help="This pushes a pretrained NLP model to the CogStack ModelServe registry") -def register_model(model_type: ModelType = typer.Option(..., help="The type of the model to serve"), - model_path: str = typer.Option(..., help="The file path to the model package"), - model_name: str = typer.Option(..., help="The string representation of the registered model"), - training_type: Optional[str] = typer.Option(None, help="The type of training the model went through"), - model_config: Optional[str] = typer.Option(None, help="The string representation of a JSON object"), - model_metrics: Optional[str] = typer.Option(None, help="The string representation of a JSON array"), - model_tags: Optional[str] = typer.Option(None, help="The string representation of a JSON object"), - debug: Optional[bool] = typer.Option(None, help="Run in the debug mode")) -> None: +@cmd_app.command( + "register", help="This pushes a pretrained NLP model to the CogStack ModelServe registry" +) +def register_model( + model_type: ModelType = typer.Option(..., help="The type of the model to serve"), + model_path: str = typer.Option(..., help="The file path to the model package"), + model_name: str = typer.Option(..., help="The string representation of the registered model"), + training_type: Optional[str] = typer.Option( + None, help="The type of training the model went through" + ), + model_config: Optional[str] = typer.Option( + None, help="The string representation of a JSON object" + ), + model_metrics: Optional[str] = typer.Option( + None, help="The string representation of a JSON array" + ), + model_tags: Optional[str] = typer.Option( + None, help="The string representation of a JSON object" + ), + debug: Optional[bool] = typer.Option(None, help="Run in the debug mode"), +) -> None: logger = _get_logger(debug, model_type, model_name) config = get_settings() tracker_client = TrackerClient(config.MLFLOW_TRACKING_URI) @@ -186,22 +266,26 @@ def register_model(model_type: ModelType = typer.Option(..., help="The type of t t_type = training_type if training_type is not None else "" run_name = str(uuid.uuid4()) - tracker_client.save_pretrained_model(model_name=model_name, - model_path=model_path, - model_manager=ModelManager(model_service_type, config), - training_type=t_type, - run_name=run_name, - model_config=m_config, - model_metrics=m_metrics, - model_tags=m_tags) + tracker_client.save_pretrained_model( + model_name=model_name, + model_path=model_path, + model_manager=ModelManager(model_service_type, config), + training_type=t_type, + run_name=run_name, + model_config=m_config, + model_metrics=m_metrics, + model_tags=m_tags, + ) typer.echo(f"Pushed {model_path} as a new model version ({run_name})") @stream_app.command("json-lines", help="This gets NER entities as a JSON Lines stream") -def stream_jsonl_annotations(jsonl_file_path: str = typer.Option(..., help="The path to the JSON Lines file"), - base_url: str = typer.Option("http://127.0.0.1:8000", help="The CMS base url"), - timeout_in_secs: int = typer.Option(0, help="The max time to wait before disconnection"), - debug: Optional[bool] = typer.Option(None, help="Run in the debug mode")) -> None: +def stream_jsonl_annotations( + jsonl_file_path: str = typer.Option(..., help="The path to the JSON Lines file"), + base_url: str = typer.Option("http://127.0.0.1:8000", help="The CMS base url"), + timeout_in_secs: int = typer.Option(0, help="The max time to wait before disconnection"), + debug: Optional[bool] = typer.Option(None, help="Run in the debug mode"), +) -> None: logger = _get_logger(debug) async def get_jsonl_stream(base_url: str, jsonl_file_path: str) -> None: @@ -210,10 +294,12 @@ async def get_jsonl_stream(base_url: str, jsonl_file_path: str) -> None: try: async with aiohttp.ClientSession() as session: timeout = aiohttp.ClientTimeout(total=timeout_in_secs) - async with session.post(f"{base_url}/stream/process", - data=file.read().encode("utf-8"), - headers=headers, - timeout=timeout) as response: + async with session.post( + f"{base_url}/stream/process", + data=file.read().encode("utf-8"), + headers=headers, + timeout=timeout, + ) as response: response.raise_for_status() async for line in response.content: typer.echo(line.decode("utf-8"), nl=False) @@ -226,13 +312,17 @@ async def get_jsonl_stream(base_url: str, jsonl_file_path: str) -> None: @stream_app.command("chat", help="This gets NER entities by chatting with the model") -def chat_to_get_jsonl_annotations(base_url: str = typer.Option("ws://127.0.0.1:8000", help="The CMS base url"), - debug: Optional[bool] = typer.Option(None, help="Run in the debug mode")) -> None: +def chat_to_get_jsonl_annotations( + base_url: str = typer.Option("ws://127.0.0.1:8000", help="The CMS base url"), + debug: Optional[bool] = typer.Option(None, help="Run in the debug mode"), +) -> None: logger = _get_logger(debug) + async def chat_with_model(base_url: str) -> None: try: chat_endpoint = f"{base_url}/stream/ws" async with websockets.connect(chat_endpoint, ping_interval=None) as websocket: + async def keep_alive() -> None: while True: try: @@ -242,10 +332,14 @@ async def keep_alive() -> None: break keep_alive_task = asyncio.create_task(keep_alive()) - logging.info("Connected to CMS. Start typing you input and press to submit:") + logging.info( + "Connected to CMS. Start typing you input and press to submit:" + ) try: while True: - text = await asyncio.get_event_loop().run_in_executor(None, sys.stdin.readline) + text = await asyncio.get_event_loop().run_in_executor( + None, sys.stdin.readline + ) if text.strip() == "": continue try: @@ -270,14 +364,22 @@ async def keep_alive() -> None: @cmd_app.command("export-model-apis") -def generate_api_doc_per_model(model_type: ModelType = typer.Option(..., help="The type of the model to serve"), - add_training_apis: bool = typer.Option(False, help="Add training APIs to the doc"), - add_evaluation_apis: bool = typer.Option(False, help="Add evaluation APIs to the doc"), - add_previews_apis: bool = typer.Option(False, help="Add preview APIs to the doc"), - add_user_authentication: bool = typer.Option(False, help="Add user authentication APIs to the doc"), - exclude_unsupervised_training: bool = typer.Option(False, help="Exclude the unsupervised training API"), - exclude_metacat_training: bool = typer.Option(False, help="Exclude the metacat training API"), - model_name: Optional[str] = typer.Option(None, help="The string representation of the model name")) -> None: +def generate_api_doc_per_model( + model_type: ModelType = typer.Option(..., help="The type of the model to serve"), + add_training_apis: bool = typer.Option(False, help="Add training APIs to the doc"), + add_evaluation_apis: bool = typer.Option(False, help="Add evaluation APIs to the doc"), + add_previews_apis: bool = typer.Option(False, help="Add preview APIs to the doc"), + add_user_authentication: bool = typer.Option( + False, help="Add user authentication APIs to the doc" + ), + exclude_unsupervised_training: bool = typer.Option( + False, help="Exclude the unsupervised training API" + ), + exclude_metacat_training: bool = typer.Option(False, help="Exclude the metacat training API"), + model_name: Optional[str] = typer.Option( + None, help="The string representation of the model name" + ), +) -> None: """ This generates model-specific API docs for enabled endpoints """ @@ -303,15 +405,45 @@ def generate_api_doc_per_model(model_type: ModelType = typer.Option(..., help="T typer.echo(f"OpenAPI doc exported to {doc_name}") -@package_app.command("hf-model", help="This packages a remotely hosted or locally cached Hugging Face model into a model package") -def package_model(hf_repo_id: str = typer.Option("", help="The repository ID of the model to download from Hugging Face Hub, e.g., 'google-bert/bert-base-cased'"), - hf_repo_revision: str = typer.Option("", help="The revision of the model to download from Hugging Face Hub"), - cached_model_dir: str = typer.Option("", help="Path to the cached model directory, will only be used if --hf-repo-id is not provided"), - output_model_package: str = typer.Option("", help="Path to save the model package, minus any format-specific extension, e.g., './model_packages/bert-base-cased'"), - remove_cached: bool = typer.Option(False, help="Whether to remove the downloaded cache after the model package is saved"), +@package_app.command( + "hf-model", + help=( + "This packages a remotely hosted or locally cached Hugging Face model into a model package" + ), +) +def package_model( + hf_repo_id: str = typer.Option( + "", + help=( + "The repository ID of the model to download from Hugging Face Hub," + " e.g., 'google-bert/bert-base-cased'" + ), + ), + hf_repo_revision: str = typer.Option( + "", help="The revision of the model to download from Hugging Face Hub" + ), + cached_model_dir: str = typer.Option( + "", + help=( + "Path to the cached model directory, will only be used if --hf-repo-id is not provided" + ), + ), + output_model_package: str = typer.Option( + "", + help=( + "Path to save the model package, minus any format-specific extension," + " e.g., './model_packages/bert-base-cased'" + ), + ), + remove_cached: bool = typer.Option( + False, help="Whether to remove the downloaded cache after the model package is saved" + ), ) -> None: if hf_repo_id == "" and cached_model_dir == "": - typer.echo("ERROR: Neither the repository ID of the Hugging Face model nor the cached model directory is passed in.") + typer.echo( + "ERROR: Neither the repository ID of the Hugging Face model nor the cached model" + " directory is passed in." + ) raise typer.Exit(code=1) if output_model_package == "": @@ -339,16 +471,50 @@ def package_model(hf_repo_id: str = typer.Option("", help="The repository ID of typer.echo(f"Model package saved to {model_package_archive}.zip") -@package_app.command("hf-dataset", help="This packages a remotely hosted or locally cached Hugging Face dataset into a dataset package") -def package_dataset(hf_dataset_id: str = typer.Option("", help="The repository ID of the dataset to download from Hugging Face Hub, e.g., 'stanfordnlp/imdb'"), - hf_dataset_revision: str = typer.Option("", help="The revision of the dataset to download from Hugging Face Hub"), - cached_dataset_dir: str = typer.Option("", help="Path to the cached dataset directory, will only be used if --hf-dataset-id is not provided"), - output_dataset_package: str = typer.Option("", help="Path to save the dataset package, minus any format-specific extension, e.g., './dataset_packages/imdb'"), - remove_cached: bool = typer.Option(False, help="Whether to remove the downloaded cache after the dataset package is saved"), - trust_remote_code: bool = typer.Option(False, help="Whether to trust and use the remote script of the dataset"), +@package_app.command( + "hf-dataset", + help=( + "This packages a remotely hosted or locally cached Hugging Face dataset into a dataset" + " package" + ), +) +def package_dataset( + hf_dataset_id: str = typer.Option( + "", + help=( + "The repository ID of the dataset to download from Hugging Face Hub," + " e.g., 'stanfordnlp/imdb'" + ), + ), + hf_dataset_revision: str = typer.Option( + "", help="The revision of the dataset to download from Hugging Face Hub" + ), + cached_dataset_dir: str = typer.Option( + "", + help=( + "Path to the cached dataset directory, will only be used if --hf-dataset-id is not" + " provided" + ), + ), + output_dataset_package: str = typer.Option( + "", + help=( + "Path to save the dataset package, minus any format-specific extension," + " e.g., './dataset_packages/imdb'" + ), + ), + remove_cached: bool = typer.Option( + False, help="Whether to remove the downloaded cache after the dataset package is saved" + ), + trust_remote_code: bool = typer.Option( + False, help="Whether to trust and use the remote script of the dataset" + ), ) -> None: if hf_dataset_id == "" and cached_dataset_dir == "": - typer.echo("ERROR: Neither the repository ID of the Hugging Face dataset nor the cached dataset directory is passed in.") + typer.echo( + "ERROR: Neither the repository ID of the Hugging Face dataset nor the cached dataset" + " directory is passed in." + ) raise typer.Exit(code=1) if output_dataset_package == "": typer.echo("ERROR: The dataset package path is not passed in.") @@ -362,10 +528,16 @@ def package_dataset(hf_dataset_id: str = typer.Option("", help="The repository I try: if hf_dataset_revision == "": - dataset = load_dataset(path=hf_dataset_id, cache_dir=cache_dir, trust_remote_code=trust_remote_code) + dataset = load_dataset( + path=hf_dataset_id, cache_dir=cache_dir, trust_remote_code=trust_remote_code + ) else: - dataset = load_dataset(path=hf_dataset_id, cache_dir=cache_dir, revision=hf_dataset_revision, - trust_remote_code=trust_remote_code) + dataset = load_dataset( + path=hf_dataset_id, + cache_dir=cache_dir, + revision=hf_dataset_revision, + trust_remote_code=trust_remote_code, + ) dataset.save_to_disk(cached_dataset_path) shutil.make_archive(dataset_package_archive, "zip", cached_dataset_path) @@ -380,37 +552,61 @@ def package_dataset(hf_dataset_id: str = typer.Option("", help="The repository I @cmd_app.command("build", help="This builds an OCI-compliant image to containerise CMS") -def build_image(dockerfile_path: str = typer.Option(..., help="The path to the Dockerfile"), - context_dir: str = typer.Option(..., help="The directory containing the set of files accessible to the build"), - model_name: Optional[str] = typer.Option("cms_model", help="The string representation of the model name"), - user_id: Optional[int] = typer.Option(1000, help="The ID for the non-root user"), - group_id: Optional[int] = typer.Option(1000, help="The group ID for the non-root user"), - http_proxy: Optional[str] = typer.Option("", help="The string representation of the HTTP proxy"), - https_proxy: Optional[str] = typer.Option("", help="The string representation of the HTTPS proxy"), - no_proxy: Optional[str] = typer.Option("localhost,127.0.0.1", help="The string representation of addresses by-passing proxies"), - tag: str = typer.Option(None, help="The tag of the built image"), - backend: Optional[BuildBackend] = typer.Option(BuildBackend.DOCKER, help="The backend used for building the image")) -> None: +def build_image( + dockerfile_path: str = typer.Option(..., help="The path to the Dockerfile"), + context_dir: str = typer.Option( + ..., help="The directory containing the set of files accessible to the build" + ), + model_name: Optional[str] = typer.Option( + "cms_model", help="The string representation of the model name" + ), + user_id: Optional[int] = typer.Option(1000, help="The ID for the non-root user"), + group_id: Optional[int] = typer.Option(1000, help="The group ID for the non-root user"), + http_proxy: Optional[str] = typer.Option( + "", help="The string representation of the HTTP proxy" + ), + https_proxy: Optional[str] = typer.Option( + "", help="The string representation of the HTTPS proxy" + ), + no_proxy: Optional[str] = typer.Option( + "localhost,127.0.0.1", help="The string representation of addresses by-passing proxies" + ), + tag: str = typer.Option(None, help="The tag of the built image"), + backend: Optional[BuildBackend] = typer.Option( + BuildBackend.DOCKER, help="The backend used for building the image" + ), +) -> None: assert backend is not None cmd = [ *backend.value.split(), - '-f', dockerfile_path, - '--progress=plain', - '-t', f'{model_name}:{tag}', - '--build-arg', f'CMS_MODEL_NAME={model_name}', - '--build-arg', f'CMS_UID={str(user_id)}', - '--build-arg', f'CMS_GID={str(group_id)}', - '--build-arg', f'HTTP_PROXY={http_proxy}', - '--build-arg', f'HTTPS_PROXY={https_proxy}', - '--build-arg', f'NO_PROXY={no_proxy}', + "-f", + dockerfile_path, + "--progress=plain", + "-t", + f"{model_name}:{tag}", + "--build-arg", + f"CMS_MODEL_NAME={model_name}", + "--build-arg", + f"CMS_UID={str(user_id)}", + "--build-arg", + f"CMS_GID={str(group_id)}", + "--build-arg", + f"HTTP_PROXY={http_proxy}", + "--build-arg", + f"HTTPS_PROXY={https_proxy}", + "--build-arg", + f"NO_PROXY={no_proxy}", context_dir, ] - with subprocess.Popen(cmd, - shell=False, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - close_fds=True, - universal_newlines=True, - bufsize=1) as process: + with subprocess.Popen( + cmd, + shell=False, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + close_fds=True, + universal_newlines=True, + bufsize=1, + ) as process: assert process is not None try: while True: @@ -437,7 +633,11 @@ def build_image(dockerfile_path: str = typer.Option(..., help="The path to the D @cmd_app.command("export-openapi-spec") -def generate_api_doc(api_title: str = typer.Option("CogStack Model Serve APIs", help="The string representation of the API title")) -> None: +def generate_api_doc( + api_title: str = typer.Option( + "CogStack Model Serve APIs", help="The string representation of the API title" + ), +) -> None: """ This generates a single API doc for all endpoints """ @@ -465,9 +665,11 @@ def generate_api_doc(api_title: str = typer.Option("CogStack Model Serve APIs", typer.echo(f"OpenAPI doc exported to {doc_name}") -def _get_logger(debug: Optional[bool] = None, - model_type: Optional[ModelType] = None, - model_name: Optional[str] = None) -> logging.Logger: +def _get_logger( + debug: Optional[bool] = None, + model_type: Optional[ModelType] = None, + model_name: Optional[str] = None, +) -> logging.Logger: if debug is not None: get_settings().DEBUG = "true" if debug else "false" if get_settings().DEBUG != "true": @@ -481,6 +683,7 @@ def log_record_factory(*args: Tuple, **kwargs: Dict[str, Any]) -> LogRecord: record.model_type = model_type record.model_name = model_name if model_name is not None else "NULL" return record + logging.setLogRecordFactory(log_record_factory) return logger diff --git a/app/config.py b/app/config.py index 3ec86f8..3b0982b 100644 --- a/app/config.py +++ b/app/config.py @@ -1,38 +1,79 @@ -import os import json +import os + from pydantic import BaseSettings class Settings(BaseSettings): - BASE_MODEL_FILE: str = "model.zip" # the base name of the model file - BASE_MODEL_FULL_PATH: str = "" # the full path to the model file - DEVICE: str = "default" # the device literal, either "default", "cpu[:X]", "cuda[:X]" or "mps[:X]" - INCLUDE_SPAN_TEXT: str = "false" # if "true", include the text of the entity in the NER output - CONCAT_SIMILAR_ENTITIES: str = "true" # if "true", merge adjacent entities of the same type into one span - ENABLE_TRAINING_APIS: str = "false" # if "true", enable the APIs for model training - DISABLE_UNSUPERVISED_TRAINING: str = "false" # if "true", disable the API for unsupervised training - DISABLE_METACAT_TRAINING: str = "true" # if "true", disable the API for metacat training - ENABLE_EVALUATION_APIS: str = "false" # if "true", enable the APIs for evaluating the model being served - ENABLE_PREVIEWS_APIS: str = "false" # if "true", enable the APIs for previewing the NER output - MLFLOW_TRACKING_URI: str = f'file:{os.path.join(os.path.abspath(os.path.dirname(__file__)), "mlruns")}' # the mlflow tracking URI - REDEPLOY_TRAINED_MODEL: str = "false" # if "true", replace the running model with the newly trained one - SKIP_SAVE_MODEL: str = "false" # if "true", newly trained models won't be saved but training metrics will be collected - SKIP_SAVE_TRAINING_DATASET: str = "true" # if "true", the dataset used for training won't be saved - PROCESS_RATE_LIMIT: str = "180/minute" # the rate limit on the /process route - PROCESS_BULK_RATE_LIMIT: str = "90/minute" # the rate limit on the /process_bulk route - WS_IDLE_TIMEOUT_SECONDS: int = 60 # the timeout in seconds on the WebSocket connection being idle - TYPE_UNIQUE_ID_WHITELIST: str = "" # the comma-separated TUIs used for filtering and if set to "", all TUIs are whitelisted - AUTH_USER_ENABLED: str = "false" # if "true", enable user authentication on API access - AUTH_JWT_SECRET: str = "" # the JWT secret and will be ignored if AUTH_USER_ENABLED is not "true" - AUTH_ACCESS_TOKEN_EXPIRE_SECONDS: int = 3600 # the seconds after which the JWT will expire - AUTH_DATABASE_URL: str = "sqlite+aiosqlite:///./cms-users.db" # the URL of the authentication database - TRAINING_CONCEPT_ID_WHITELIST: str = "" # the comma-separated concept IDs used for filtering annotations of interest - TRAINING_METRICS_LOGGING_INTERVAL: int = 5 # the number of steps after which training metrics will be collected - TRAINING_SAFE_MODEL_SERIALISATION: str = "false" # if "true", serialise the trained model using safe tensors - TRAINING_CACHE_DIR: str = os.path.join(os.path.abspath(os.path.dirname(__file__)), "cms_cache") # the directory to cache the intermediate files created during training - HF_PIPELINE_AGGREGATION_STRATEGY: str = "simple" # the strategy used for aggregating the predictions of the Hugging Face NER model - LOG_PER_CONCEPT_ACCURACIES: str = "false" # if "true", per-concept accuracies will be exposed to the metrics scrapper. Switch this on with caution due to the potentially high number of concepts - DEBUG: str = "false" # if "true", the debug mode is switched on + BASE_MODEL_FILE: str = "model.zip" # the base name of the model file + BASE_MODEL_FULL_PATH: str = "" # the full path to the model file + DEVICE: str = ( + "default" # the device literal, either "default", "cpu[:X]", "cuda[:X]" or "mps[:X]" + ) + INCLUDE_SPAN_TEXT: str = "false" # if "true", include the text of the entity in the NER output + CONCAT_SIMILAR_ENTITIES: str = ( + "true" # if "true", merge adjacent entities of the same type into one span + ) + ENABLE_TRAINING_APIS: str = "false" # if "true", enable the APIs for model training + DISABLE_UNSUPERVISED_TRAINING: str = ( + "false" # if "true", disable the API for unsupervised training + ) + DISABLE_METACAT_TRAINING: str = "true" # if "true", disable the API for metacat training + ENABLE_EVALUATION_APIS: str = ( + "false" # if "true", enable the APIs for evaluating the model being served + ) + ENABLE_PREVIEWS_APIS: str = "false" # if "true", enable the APIs for previewing the NER output + MLFLOW_TRACKING_URI: str = ( + # the mlflow tracking URI + f'file:{os.path.join(os.path.abspath(os.path.dirname(__file__)), "mlruns")}' + ) + REDEPLOY_TRAINED_MODEL: str = ( + "false" # if "true", replace the running model with the newly trained one + ) + SKIP_SAVE_MODEL: str = ( + # if "true", newly trained models won't be saved but training metrics will be collected + "false" + ) + SKIP_SAVE_TRAINING_DATASET: str = ( + "true" # if "true", the dataset used for training won't be saved + ) + PROCESS_RATE_LIMIT: str = "180/minute" # the rate limit on the /process route + PROCESS_BULK_RATE_LIMIT: str = "90/minute" # the rate limit on the /process_bulk route + WS_IDLE_TIMEOUT_SECONDS: int = ( + 60 # the timeout in seconds on the WebSocket connection being idle + ) + TYPE_UNIQUE_ID_WHITELIST: str = ( + "" # the comma-separated TUIs used for filtering and if set to "", all TUIs are whitelisted + ) + AUTH_USER_ENABLED: str = "false" # if "true", enable user authentication on API access + AUTH_JWT_SECRET: str = ( + "" # the JWT secret and will be ignored if AUTH_USER_ENABLED is not "true" + ) + AUTH_ACCESS_TOKEN_EXPIRE_SECONDS: int = 3600 # the seconds after which the JWT will expire + AUTH_DATABASE_URL: str = ( + "sqlite+aiosqlite:///./cms-users.db" # the URL of the authentication database + ) + TRAINING_CONCEPT_ID_WHITELIST: str = ( + "" # the comma-separated concept IDs used for filtering annotations of interest + ) + TRAINING_METRICS_LOGGING_INTERVAL: int = ( + 5 # the number of steps after which training metrics will be collected + ) + TRAINING_SAFE_MODEL_SERIALISATION: str = ( + "false" # if "true", serialise the trained model using safe tensors + ) + TRAINING_CACHE_DIR: str = os.path.join( + os.path.abspath(os.path.dirname(__file__)), "cms_cache" + ) # the directory to cache the intermediate files created during training + HF_PIPELINE_AGGREGATION_STRATEGY: str = ( + "simple" # the strategy used for aggregating the predictions of the Hugging Face NER model + ) + LOG_PER_CONCEPT_ACCURACIES: str = ( + # if "true", per-concept accuracies will be exposed to the metrics scrapper. + # Switch this on with caution due to the potentially high number of concepts + "false" + ) + DEBUG: str = "false" # if "true", the debug mode is switched on class Config: env_file = os.path.join(os.path.dirname(__file__), "envs", ".env") diff --git a/app/data/anno_dataset.py b/app/data/anno_dataset.py index 2176f01..5286069 100644 --- a/app/data/anno_dataset.py +++ b/app/data/anno_dataset.py @@ -1,7 +1,9 @@ -import datasets import json from pathlib import Path -from typing import List, Iterable, Tuple, Dict +from typing import Dict, Iterable, List, Tuple + +import datasets + from utils import filter_by_concept_ids @@ -10,7 +12,6 @@ class AnnotationDatasetConfig(datasets.BuilderConfig): class AnnotationDatasetBuilder(datasets.GeneratorBasedBuilder): - BUILDER_CONFIGS = [ AnnotationDatasetConfig( name="json_annotation", @@ -21,27 +22,39 @@ class AnnotationDatasetBuilder(datasets.GeneratorBasedBuilder): def _info(self) -> datasets.DatasetInfo: return datasets.DatasetInfo( - description="Annotation Dataset. This is a dataset containing flattened MedCAT Trainer export", + description=( + "Annotation Dataset. This is a dataset containing flattened MedCAT Trainer export" + ), features=datasets.Features( { "project": datasets.Value("string"), - "name":datasets.Value("string"), + "name": datasets.Value("string"), "text": datasets.Value("string"), - "starts": datasets.Value("string"), # Mlflow ColSpec schema does not support HF Dataset Sequence - "ends": datasets.Value("string"), # Mlflow ColSpec schema does not support HF Dataset Sequence - "labels": datasets.Value("string"), # Mlflow ColSpec schema does not support HF Dataset Sequence + "starts": datasets.Value( + "string" + ), # Mlflow ColSpec schema does not support HF Dataset Sequence + "ends": datasets.Value( + "string" + ), # Mlflow ColSpec schema does not support HF Dataset Sequence + "labels": datasets.Value( + "string" + ), # Mlflow ColSpec schema does not support HF Dataset Sequence } - ) + ), ) def _split_generators(self, _: datasets.DownloadManager) -> List[datasets.SplitGenerator]: return [ - datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepaths": self.config.data_files["annotations"]}) + datasets.SplitGenerator( + name=datasets.Split.TRAIN, + gen_kwargs={"filepaths": self.config.data_files["annotations"]}, + ) ] def _generate_examples(self, filepaths: List[Path]) -> Iterable[Tuple[str, Dict]]: return generate_examples(filepaths) + def generate_examples(filepaths: List[Path]) -> Iterable[Tuple[str, Dict]]: id_ = 1 for filepath in filepaths: @@ -57,12 +70,15 @@ def generate_examples(filepaths: List[Path]) -> Iterable[Tuple[str, Dict]]: starts.append(str(annotation["start"])) ends.append(str(annotation["end"])) labels.append(annotation["cui"]) - yield str(id_), { - "project": project.get("name"), - "name": document.get("name"), - "text": document.get("text"), - "starts": ",".join(starts), - "ends": ",".join(ends), - "labels": ",".join(labels), - } + yield ( + str(id_), + { + "project": project.get("name"), + "name": document.get("name"), + "text": document.get("text"), + "starts": ",".join(starts), + "ends": ",".join(ends), + "labels": ",".join(labels), + }, + ) id_ += 1 diff --git a/app/data/doc_dataset.py b/app/data/doc_dataset.py index cc48842..b2703df 100644 --- a/app/data/doc_dataset.py +++ b/app/data/doc_dataset.py @@ -1,7 +1,8 @@ +from pathlib import Path +from typing import Dict, Iterable, List, Tuple + import datasets import ijson -from pathlib import Path -from typing import List, Iterable, Tuple, Dict class TextDatasetConfig(datasets.BuilderConfig): @@ -9,7 +10,6 @@ class TextDatasetConfig(datasets.BuilderConfig): class TextDatasetBuilder(datasets.GeneratorBasedBuilder): - BUILDER_CONFIGS = [ TextDatasetConfig( name="free_text", @@ -20,18 +20,24 @@ class TextDatasetBuilder(datasets.GeneratorBasedBuilder): def _info(self) -> datasets.DatasetInfo: return datasets.DatasetInfo( - description="Free text Dataset. This is a dataset containing document records each of which has 'doc_name' and 'text' attributes", + description=( + "Free text Dataset. This is a dataset containing document records each of which has" + " 'doc_name' and 'text' attributes" + ), features=datasets.Features( { "name": datasets.Value("string"), "text": datasets.Value("string"), } - ) + ), ) def _split_generators(self, _: datasets.DownloadManager) -> List[datasets.SplitGenerator]: return [ - datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepaths": self.config.data_files["documents"]}) + datasets.SplitGenerator( + name=datasets.Split.TRAIN, + gen_kwargs={"filepaths": self.config.data_files["documents"]}, + ) ] def _generate_examples(self, filepaths: List[Path]) -> Iterable[Tuple[str, Dict]]: diff --git a/app/domain.py b/app/domain.py index e20bc9b..f7bf58a 100644 --- a/app/domain.py +++ b/app/domain.py @@ -1,9 +1,9 @@ from enum import Enum -from typing import List, Optional, Dict, Any +from typing import Any, Dict, List, Optional from fastapi import HTTPException -from starlette.status import HTTP_400_BAD_REQUEST from pydantic import BaseModel, Field, root_validator +from starlette.status import HTTP_400_BAD_REQUEST class ModelType(str, Enum): @@ -83,21 +83,34 @@ class HfTransformerBackbone(Enum): class Annotation(BaseModel): - doc_name: Optional[str] = Field(description="The name of the document to which the annotation belongs") + doc_name: Optional[str] = Field( + description="The name of the document to which the annotation belongs" + ) start: int = Field(description="The start index of the annotation span") end: int = Field(description="The first index after the annotation span") label_name: str = Field(description="The pretty name of the annotation concept") label_id: str = Field(description="The code of the annotation concept") - categories: Optional[List[str]] = Field(default=None, description="The categories to which the annotation concept belongs") - accuracy: Optional[float] = Field(default=None, description="The confidence score of the annotation") - text: Optional[str] = Field(default=None, description="The string literal of the annotation span") + categories: Optional[List[str]] = Field( + default=None, description="The categories to which the annotation concept belongs" + ) + accuracy: Optional[float] = Field( + default=None, description="The confidence score of the annotation" + ) + text: Optional[str] = Field( + default=None, description="The string literal of the annotation span" + ) meta_anns: Optional[Dict] = Field(default=None, description="The meta annotations") - athena_ids: Optional[List[Dict]] = Field(default=None, description="The OHDSI Athena concept IDs") + athena_ids: Optional[List[Dict]] = Field( + default=None, description="The OHDSI Athena concept IDs" + ) @root_validator() def _validate(cls, values: Dict[str, Any]) -> Dict[str, Any]: if values["start"] >= values["end"]: - raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail="The start index should be lower than the end index") + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail="The start index should be lower than the end index", + ) return values @@ -136,7 +149,10 @@ class Entity(BaseModel): @root_validator() def _validate(cls, values: Dict[str, Any]) -> Dict[str, Any]: if values["start"] >= values["end"]: - raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail="The start index should be lower than the end index") + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail="The start index should be lower than the end index", + ) return values diff --git a/app/exception.py b/app/exception.py index 8d7bc44..cba428a 100644 --- a/app/exception.py +++ b/app/exception.py @@ -1,22 +1,22 @@ class StartTrainingException(Exception): - """ An exception raised due to training not started""" + """An exception raised due to training not started""" class TrainingFailedException(Exception): - """ An exception raised due to failure on training""" + """An exception raised due to failure on training""" class ConfigurationException(Exception): - """ An exception raised due to configuration errors""" + """An exception raised due to configuration errors""" class AnnotationException(Exception): - """ An exception raised due to annotation errors""" + """An exception raised due to annotation errors""" class ManagedModelException(Exception): - """ An exception raised due to erroneous models""" + """An exception raised due to erroneous models""" class ClientException(Exception): - """ An exception raised due to generic client errors""" + """An exception raised due to generic client errors""" diff --git a/app/logging.ini b/app/logging.ini index 1287baa..62850a2 100644 --- a/app/logging.ini +++ b/app/logging.ini @@ -33,4 +33,4 @@ args=(sys.stdout,) format=%(asctime)s %(levelname)-6s [%(name)s$%(funcName)s] %(message)s [formatter_detailedFormatter] -format=%(asctime)s %(levelname)-6s [%(name)s$%(funcName)s] %(message)s call_trace=%(pathname)s:%(lineno)-4d \ No newline at end of file +format=%(asctime)s %(levelname)-6s [%(name)s$%(funcName)s] %(message)s call_trace=%(pathname)s:%(lineno)-4d diff --git a/app/management/README.md b/app/management/README.md index d0bdbd4..9ecf8ac 100644 --- a/app/management/README.md +++ b/app/management/README.md @@ -7,4 +7,3 @@ To enable user authentication and authorisation in MLflow, you will need to prov * MLFLOW_AUTH_CONFIG_PATH=/opt/auth/basic_auth.ini Additionally, ensure you set the appropriate values in the default [basic auth file](./../../docker/mlflow/server/auth/basic_auth.ini) before buiding the image and firing up a container based off it. For detailed information on authentication, please refer to the [official documentation](https://mlflow.org/docs/2.6.0/auth/index.html). - diff --git a/app/management/log_captor.py b/app/management/log_captor.py index b5c34f2..6d6d7e0 100644 --- a/app/management/log_captor.py +++ b/app/management/log_captor.py @@ -3,7 +3,6 @@ @final class LogCaptor(object): - def __init__(self, processor: Callable[[str], None]): self.buffer = "" self.log_processor = processor @@ -15,7 +14,7 @@ def write(self, buffer: str) -> None: except ValueError: self.buffer += buffer break - log = self.buffer + buffer[:newline_idx+1] + log = self.buffer + buffer[: newline_idx + 1] self.buffer = "" - buffer = buffer[newline_idx+1:] + buffer = buffer[newline_idx + 1 :] self.log_processor(log) diff --git a/app/management/model_manager.py b/app/management/model_manager.py index d69c181..a4fa156 100644 --- a/app/management/model_manager.py +++ b/app/management/model_manager.py @@ -2,46 +2,54 @@ import os import shutil import tempfile +from typing import Any, Dict, Iterator, List, Optional, Type, Union, final + import mlflow -import toml import pandas as pd -from typing import Type, Optional, Dict, Any, List, Iterator, final, Union -from pandas import DataFrame -from mlflow.pyfunc import PythonModel, PythonModelContext -from mlflow.models.signature import ModelSignature -from mlflow.types import DataType, Schema, ColSpec +import toml from mlflow.models.model import ModelInfo -from model_services.base import AbstractModelService +from mlflow.models.signature import ModelSignature +from mlflow.pyfunc import PythonModel, PythonModelContext +from mlflow.types import ColSpec, DataType, Schema +from pandas import DataFrame + from config import Settings from exception import ManagedModelException from utils import func_deprecated +from model_services.base import AbstractModelService + @final class ModelManager(PythonModel): - - input_schema = Schema([ - ColSpec(DataType.string, "name", optional=True), - ColSpec(DataType.string, "text"), - ]) - - output_schema = Schema([ - ColSpec(DataType.string, "doc_name"), - ColSpec(DataType.integer, "start"), - ColSpec(DataType.integer, "end"), - ColSpec(DataType.string, "label_name"), - ColSpec(DataType.string, "label_id"), - ColSpec(DataType.string, "categories", optional=True), - ColSpec(DataType.float, "accuracy", optional=True), - ColSpec(DataType.string, "text", optional=True), - ColSpec(DataType.string, "meta_anns", optional=True) - ]) + input_schema = Schema( + [ + ColSpec(DataType.string, "name", optional=True), + ColSpec(DataType.string, "text"), + ] + ) + + output_schema = Schema( + [ + ColSpec(DataType.string, "doc_name"), + ColSpec(DataType.integer, "start"), + ColSpec(DataType.integer, "end"), + ColSpec(DataType.string, "label_name"), + ColSpec(DataType.string, "label_id"), + ColSpec(DataType.string, "categories", optional=True), + ColSpec(DataType.float, "accuracy", optional=True), + ColSpec(DataType.string, "text", optional=True), + ColSpec(DataType.string, "meta_anns", optional=True), + ] + ) def __init__(self, model_service_type: Type, config: Settings) -> None: self._model_service_type = model_service_type self._config = config self._model_service = None - self._model_signature = ModelSignature(inputs=ModelManager.input_schema, outputs=ModelManager.output_schema, params=None) + self._model_signature = ModelSignature( + inputs=ModelManager.input_schema, outputs=ModelManager.output_schema, params=None + ) @property def model_service(self) -> AbstractModelService: @@ -56,8 +64,7 @@ def model_signature(self) -> ModelSignature: return self._model_signature @staticmethod - def retrieve_python_model_from_uri(mlflow_model_uri: str, - config: Settings) -> PythonModel: + def retrieve_python_model_from_uri(mlflow_model_uri: str, config: Settings) -> PythonModel: mlflow.set_tracking_uri(config.MLFLOW_TRACKING_URI) pyfunc_model = mlflow.pyfunc.load_model(model_uri=mlflow_model_uri) # In case the load_model overwrote the tracking URI @@ -65,21 +72,25 @@ def retrieve_python_model_from_uri(mlflow_model_uri: str, return pyfunc_model._model_impl.python_model @staticmethod - def retrieve_model_service_from_uri(mlflow_model_uri: str, - config: Settings, - downloaded_model_path: Optional[str] = None) -> AbstractModelService: + def retrieve_model_service_from_uri( + mlflow_model_uri: str, config: Settings, downloaded_model_path: Optional[str] = None + ) -> AbstractModelService: model_manager = ModelManager.retrieve_python_model_from_uri(mlflow_model_uri, config) model_service = model_manager.model_service config.BASE_MODEL_FULL_PATH = mlflow_model_uri model_service._config = config if downloaded_model_path: - ModelManager.download_model_package(os.path.join(mlflow_model_uri, "artifacts"), downloaded_model_path) + ModelManager.download_model_package( + os.path.join(mlflow_model_uri, "artifacts"), downloaded_model_path + ) return model_service @staticmethod def download_model_package(model_artifact_uri: str, dst_file_path: str) -> Optional[str]: with tempfile.TemporaryDirectory() as dir_downloaded: - mlflow.artifacts.download_artifacts(artifact_uri=model_artifact_uri, dst_path=dir_downloaded) + mlflow.artifacts.download_artifacts( + artifact_uri=model_artifact_uri, dst_path=dir_downloaded + ) # This assumes the model package is the sole zip file in the artifacts directory file_path = None for file_path in glob.glob(os.path.join(dir_downloaded, "**", "*.zip")): @@ -88,12 +99,14 @@ def download_model_package(model_artifact_uri: str, dst_file_path: str) -> Optio shutil.copy(file_path, dst_file_path) return dst_file_path else: - raise ManagedModelException(f"Cannot find the model .zip file inside artifacts downloaded from {model_artifact_uri}") - - def log_model(self, - model_name: str, - model_path: str, - registered_model_name: Optional[str] = None) -> ModelInfo: + raise ManagedModelException( + "Cannot find the model .zip file inside artifacts downloaded from" + f" {model_artifact_uri}" + ) + + def log_model( + self, model_name: str, model_path: str, registered_model_name: Optional[str] = None + ) -> ModelInfo: return mlflow.pyfunc.log_model( artifact_path=model_name, python_model=self, @@ -116,13 +129,22 @@ def save_model(self, local_dir: str, model_path: str) -> None: def load_context(self, context: PythonModelContext) -> None: artifact_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) - model_service = self._model_service_type(self._config, - model_parent_dir=os.path.join(artifact_root, os.path.split(context.artifacts["model_path"])[0]), - base_model_file=os.path.split(context.artifacts["model_path"])[1]) + model_service = self._model_service_type( + self._config, + model_parent_dir=os.path.join( + artifact_root, os.path.split(context.artifacts["model_path"])[0] + ), + base_model_file=os.path.split(context.artifacts["model_path"])[1], + ) model_service.init_model() self._model_service = model_service - def predict(self, context: PythonModelContext, model_input: DataFrame, params: Optional[Dict[str, Any]] = None) -> pd.DataFrame: + def predict( + self, + context: PythonModelContext, + model_input: DataFrame, + params: Optional[Dict[str, Any]] = None, + ) -> pd.DataFrame: output = [] for idx, row in model_input.iterrows(): annotations = self._model_service.annotate(row["text"]) # type: ignore @@ -133,7 +155,12 @@ def predict(self, context: PythonModelContext, model_input: DataFrame, params: O df = df.iloc[:, df.columns.isin(ModelManager.output_schema.input_names())] return df - def predict_stream(self, context: PythonModelContext, model_input: DataFrame, params: Optional[Dict[str, Any]] = None) -> Iterator[Dict[str, Any]]: + def predict_stream( + self, + context: PythonModelContext, + model_input: DataFrame, + params: Optional[Dict[str, Any]] = None, + ) -> Iterator[Dict[str, Any]]: for idx, row in model_input.iterrows(): annotations = self._model_service.annotate(row["text"]) # type: ignore output = [] @@ -169,11 +196,22 @@ def _get_pip_requirements() -> str: @staticmethod def _get_pip_requirements_from_file() -> Union[List[str], str]: - if os.path.exists(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "pyproject.toml"))): - with open(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "pyproject.toml")), "r") as file: + if os.path.exists( + os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "pyproject.toml")) + ): + with open( + os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "..", "pyproject.toml") + ), + "r", + ) as file: pyproject = toml.load(file) return pyproject.get("project", {}).get("dependencies", []) - elif os.path.exists(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "requirements.txt"))): - return os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "requirements.txt")) + elif os.path.exists( + os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "requirements.txt")) + ): + return os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "requirements.txt") + ) else: raise ManagedModelException("Cannot find pip requirements.") diff --git a/app/management/prometheus_metrics.py b/app/management/prometheus_metrics.py index a3993d7..ef020ac 100644 --- a/app/management/prometheus_metrics.py +++ b/app/management/prometheus_metrics.py @@ -1,7 +1,23 @@ -from prometheus_client import Histogram, Gauge +from prometheus_client import Gauge, Histogram -cms_doc_annotations = Histogram("cms_doc_annotations", "Number of annotations extracted from a document", ["handler"]) -cms_avg_anno_acc_per_doc = Gauge("cms_avg_anno_acc_per_doc", "The average accuracy of annotations extracted from a document", ["handler"]) -cms_avg_anno_acc_per_concept = Gauge("cms_avg_anno_acc_per_concept", "The average accuracy of annotations for a specific concept", ["handler", "concept"]) -cms_avg_meta_anno_conf_per_doc = Gauge("cms_avg_meta_anno_conf_per_doc", "The average confidence of meta annotations extracted from a document", ["handler"]) -cms_bulk_processed_docs = Histogram("cms_bulk_processed_docs", "Number of bulk-processed documents", ["handler"]) +cms_doc_annotations = Histogram( + "cms_doc_annotations", "Number of annotations extracted from a document", ["handler"] +) +cms_avg_anno_acc_per_doc = Gauge( + "cms_avg_anno_acc_per_doc", + "The average accuracy of annotations extracted from a document", + ["handler"], +) +cms_avg_anno_acc_per_concept = Gauge( + "cms_avg_anno_acc_per_concept", + "The average accuracy of annotations for a specific concept", + ["handler", "concept"], +) +cms_avg_meta_anno_conf_per_doc = Gauge( + "cms_avg_meta_anno_conf_per_doc", + "The average confidence of meta annotations extracted from a document", + ["handler"], +) +cms_bulk_processed_docs = Histogram( + "cms_bulk_processed_docs", "Number of bulk-processed documents", ["handler"] +) diff --git a/app/management/tracker_client.py b/app/management/tracker_client.py index b0a259c..4b0b823 100644 --- a/app/management/tracker_client.py +++ b/app/management/tracker_client.py @@ -1,18 +1,21 @@ +import json +import logging import os import socket -import mlflow import tempfile -import json -import logging +from typing import Dict, List, Optional, Tuple, Union, final + import datasets +import mlflow import pandas as pd -from typing import Dict, Tuple, List, Optional, Union, final -from mlflow.utils.mlflow_tags import MLFLOW_SOURCE_NAME -from mlflow.entities import RunStatus, Metric +from mlflow.entities import Metric, RunStatus from mlflow.tracking import MlflowClient -from management.model_manager import ModelManager +from mlflow.utils.mlflow_tags import MLFLOW_SOURCE_NAME + from exception import StartTrainingException +from management.model_manager import ModelManager + logger = logging.getLogger("cms") urllib3_logger = logging.getLogger("urllib3") urllib3_logger.setLevel(logging.CRITICAL) @@ -20,20 +23,21 @@ @final class TrackerClient(object): - def __init__(self, mlflow_tracking_uri: str) -> None: mlflow.set_tracking_uri(mlflow_tracking_uri) self.mlflow_client = MlflowClient(mlflow_tracking_uri) @staticmethod - def start_tracking(model_name: str, - input_file_name: str, - base_model_original: str, - training_type: str, - training_params: Dict, - run_name: str, - log_frequency: int, - description: Optional[str] = None) -> Tuple[str, str]: + def start_tracking( + model_name: str, + input_file_name: str, + base_model_original: str, + training_type: str, + training_params: Dict, + run_name: str, + log_frequency: int, + description: Optional[str] = None, + ) -> Tuple[str, str]: experiment_name = TrackerClient.get_experiment_name(model_name, training_type) experiment_id = TrackerClient._get_experiment_id(experiment_name) try: @@ -41,16 +45,18 @@ def start_tracking(model_name: str, except Exception: logger.exception("Cannot start a new training") raise StartTrainingException("Cannot start a new training") - mlflow.set_tags({ - MLFLOW_SOURCE_NAME: socket.gethostname(), - "mlflow.runName": run_name, - "mlflow.note.content": description or "", - "training.mlflow.run_id": active_run.info.run_id, - "training.input_data.filename": input_file_name, - "training.base_model.origin": base_model_original, - "training.is.tracked": "True", - "training.metrics.log_frequency": log_frequency, - }) + mlflow.set_tags( + { + MLFLOW_SOURCE_NAME: socket.gethostname(), + "mlflow.runName": run_name, + "mlflow.note.content": description or "", + "training.mlflow.run_id": active_run.info.run_id, + "training.input_data.filename": input_file_name, + "training.base_model.origin": base_model_original, + "training.is.tracked": "True", + "training.metrics.log_frequency": log_frequency, + } + ) mlflow.log_params(training_params) return experiment_id, active_run.info.run_id @@ -76,28 +82,25 @@ def send_hf_metrics_logs(logs: Dict, step: int) -> None: mlflow.log_metrics(logs, step) @staticmethod - def save_model_local(local_dir: str, - filepath: str, - model_manager: ModelManager) -> None: + def save_model_local(local_dir: str, filepath: str, model_manager: ModelManager) -> None: model_manager.save_model(local_dir, filepath) @staticmethod - def save_model_artifact(filepath: str, - model_name: str) -> None: + def save_model_artifact(filepath: str, model_name: str) -> None: model_name = model_name.replace(" ", "_") mlflow.log_artifact(filepath, artifact_path=os.path.join(model_name, "artifacts")) @staticmethod - def save_raw_artifact(filepath: str, - model_name: str) -> None: + def save_raw_artifact(filepath: str, model_name: str) -> None: model_name = model_name.replace(" ", "_") mlflow.log_artifact(filepath, artifact_path=os.path.join(model_name, "artifacts", "raw")) @staticmethod - def save_processed_artifact(filepath: str, - model_name: str) -> None: + def save_processed_artifact(filepath: str, model_name: str) -> None: model_name = model_name.replace(" ", "_") - mlflow.log_artifact(filepath, artifact_path=os.path.join(model_name, "artifacts", "processed")) + mlflow.log_artifact( + filepath, artifact_path=os.path.join(model_name, "artifacts", "processed") + ) @staticmethod def save_dataframe_as_csv(file_name: str, data_frame: pd.DataFrame, model_name: str) -> None: @@ -125,7 +128,9 @@ def save_plot(file_name: str, model_name: str) -> None: @staticmethod def save_table_dict(table_dict: Dict, model_name: str, file_name: str) -> None: model_name = model_name.replace(" ", "_") - mlflow.log_table(data=table_dict, artifact_file=os.path.join(model_name, "tables", file_name)) + mlflow.log_table( + data=table_dict, artifact_file=os.path.join(model_name, "tables", file_name) + ) @staticmethod def save_train_dataset(dataset: datasets.Dataset) -> None: @@ -165,14 +170,16 @@ def log_model_config(config: Dict[str, str]) -> None: mlflow.log_params(config) @staticmethod - def save_pretrained_model(model_name: str, - model_path: str, - model_manager: ModelManager, - training_type: Optional[str] = "", - run_name: Optional[str] = "", - model_config: Optional[Dict] = None, - model_metrics: Optional[List[Dict]] = None, - model_tags: Optional[Dict] = None, ) -> None: + def save_pretrained_model( + model_name: str, + model_path: str, + model_manager: ModelManager, + training_type: Optional[str] = "", + run_name: Optional[str] = "", + model_config: Optional[Dict] = None, + model_metrics: Optional[List[Dict]] = None, + model_tags: Optional[Dict] = None, + ) -> None: experiment_name = TrackerClient.get_experiment_name(model_name, training_type) experiment_id = TrackerClient._get_experiment_id(experiment_name) active_run = mlflow.start_run(experiment_id=experiment_id) @@ -205,9 +212,15 @@ def save_pretrained_model(model_name: str, @staticmethod def get_experiment_name(model_name: str, training_type: Optional[str] = "") -> str: - return f"{model_name} {training_type}".replace(" ", "_") if training_type else model_name.replace(" ", "_") - - def send_batched_model_stats(self, aggregated_metrics: List[Dict], run_id: str, batch_size: int = 1000) -> None: + return ( + f"{model_name} {training_type}".replace(" ", "_") + if training_type + else model_name.replace(" ", "_") + ) + + def send_batched_model_stats( + self, aggregated_metrics: List[Dict], run_id: str, batch_size: int = 1000 + ) -> None: if batch_size <= 0: return batch = [] @@ -220,11 +233,13 @@ def send_batched_model_stats(self, aggregated_metrics: List[Dict], run_id: str, if batch: self.mlflow_client.log_batch(run_id=run_id, metrics=batch) - def save_model(self, - filepath: str, - model_name: str, - model_manager: ModelManager, - validation_status: str = "pending") -> str: + def save_model( + self, + filepath: str, + model_name: str, + model_manager: ModelManager, + validation_status: str = "pending", + ) -> str: model_name = model_name.replace(" ", "_") mlflow.set_tag("training.output.package", os.path.basename(filepath)) @@ -232,10 +247,12 @@ def save_model(self, if not mlflow.get_tracking_uri().startswith("file:/"): model_manager.log_model(model_name, filepath, model_name) versions = self.mlflow_client.search_model_versions(f"name='{model_name}'") - self.mlflow_client.set_model_version_tag(name=model_name, - version=versions[0].version, - key="validation_status", - value=validation_status) + self.mlflow_client.set_model_version_tag( + name=model_name, + version=versions[0].version, + key="validation_status", + value=validation_status, + ) else: model_manager.log_model(model_name, filepath) @@ -244,4 +261,8 @@ def save_model(self, @staticmethod def _get_experiment_id(experiment_name: str) -> str: experiment = mlflow.get_experiment_by_name(experiment_name) - return mlflow.create_experiment(name=experiment_name) if experiment is None else experiment.experiment_id + return ( + mlflow.create_experiment(name=experiment_name) + if experiment is None + else experiment.experiment_id + ) diff --git a/app/model_services/base.py b/app/model_services/base.py index fceb9a8..fb8d7eb 100644 --- a/app/model_services/base.py +++ b/app/model_services/base.py @@ -1,12 +1,12 @@ import asyncio from abc import ABC, abstractmethod -from typing import Any, List, Iterable, Tuple, Dict, final +from typing import Any, Dict, Iterable, List, Tuple, final + from config import Settings from domain import ModelCard class AbstractModelService(ABC): - @abstractmethod def __init__(self, config: Settings, *args: Tuple, **kwargs: Dict[str, Any]) -> None: self._config = config diff --git a/app/model_services/huggingface_ner_model.py b/app/model_services/huggingface_ner_model.py index 2c8b0d4..c6f3f93 100644 --- a/app/model_services/huggingface_ner_model.py +++ b/app/model_services/huggingface_ner_model.py @@ -1,10 +1,10 @@ -import os import logging +import os import zipfile -import pandas as pd - from functools import partial -from typing import Dict, List, Optional, Tuple, Any, TextIO +from typing import Any, Dict, List, Optional, TextIO, Tuple + +import pandas as pd from transformers import ( AutoModelForTokenClassification, AutoTokenizer, @@ -13,35 +13,49 @@ pipeline, ) from transformers.pipelines import Pipeline -from exception import ConfigurationException -from model_services.base import AbstractModelService -from trainers.huggingface_ner_trainer import HuggingFaceNerUnsupervisedTrainer, HuggingFaceNerSupervisedTrainer -from domain import ModelCard, ModelType + from config import Settings -from utils import get_settings, non_default_device_is_available, get_hf_pipeline_device_id +from domain import ModelCard, ModelType +from exception import ConfigurationException +from utils import get_hf_pipeline_device_id, get_settings, non_default_device_is_available +from model_services.base import AbstractModelService +from trainers.huggingface_ner_trainer import ( + HuggingFaceNerSupervisedTrainer, + HuggingFaceNerUnsupervisedTrainer, +) logger = logging.getLogger("cms") class HuggingFaceNerModel(AbstractModelService): - - def __init__(self, - config: Settings, - model_parent_dir: Optional[str] = None, - enable_trainer: Optional[bool] = None, - model_name: Optional[str] = None, - base_model_file: Optional[str] = None) -> None: + def __init__( + self, + config: Settings, + model_parent_dir: Optional[str] = None, + enable_trainer: Optional[bool] = None, + model_name: Optional[str] = None, + base_model_file: Optional[str] = None, + ) -> None: self._config = config - self._model_parent_dir = model_parent_dir or os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "model")) - self._model_pack_path = os.path.join(self._model_parent_dir, config.BASE_MODEL_FILE if base_model_file is None else base_model_file) - self._enable_trainer = enable_trainer if enable_trainer is not None else config.ENABLE_TRAINING_APIS == "true" + self._model_parent_dir = model_parent_dir or os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "model") + ) + self._model_pack_path = os.path.join( + self._model_parent_dir, + config.BASE_MODEL_FILE if base_model_file is None else base_model_file, + ) + self._enable_trainer = ( + enable_trainer if enable_trainer is not None else config.ENABLE_TRAINING_APIS == "true" + ) self._supervised_trainer = None self._unsupervised_trainer = None self._model: PreTrainedModel = None self._tokenizer: PreTrainedTokenizerBase = None self._ner_pipeline: Pipeline = None - self._whitelisted_tuis = set([tui.strip() for tui in config.TYPE_UNIQUE_ID_WHITELIST.split(",")]) + self._whitelisted_tuis = set( + [tui.strip() for tui in config.TYPE_UNIQUE_ID_WHITELIST.split(",")] + ) self.model_name = model_name or "Hugging Face NER model" @property @@ -73,30 +87,45 @@ def api_version(self) -> str: return "0.0.1" @classmethod - def from_model(cls, model: PreTrainedModel, tokenizer: PreTrainedTokenizerBase) -> "HuggingFaceNerModel": + def from_model( + cls, model: PreTrainedModel, tokenizer: PreTrainedTokenizerBase + ) -> "HuggingFaceNerModel": model_service = cls(get_settings(), enable_trainer=False) model_service.model = model model_service.tokenizer = tokenizer - _pipeline = partial(pipeline, - task="ner", - model=model_service.model, - tokenizer=model_service.tokenizer, - stride=10, - aggregation_strategy=get_settings().HF_PIPELINE_AGGREGATION_STRATEGY) + _pipeline = partial( + pipeline, + task="ner", + model=model_service.model, + tokenizer=model_service.tokenizer, + stride=10, + aggregation_strategy=get_settings().HF_PIPELINE_AGGREGATION_STRATEGY, + ) if non_default_device_is_available(get_settings().DEVICE): - model_service._ner_pipeline = _pipeline(device=get_hf_pipeline_device_id(get_settings().DEVICE)) + model_service._ner_pipeline = _pipeline( + device=get_hf_pipeline_device_id(get_settings().DEVICE) + ) else: model_service._ner_pipeline = _pipeline() return model_service @staticmethod - def load_model(model_file_path: str, *args: Tuple, **kwargs: Dict[str, Any]) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]: - model_path = os.path.join(os.path.dirname(model_file_path), os.path.basename(model_file_path).split(".")[0]) + def load_model( + model_file_path: str, *args: Tuple, **kwargs: Dict[str, Any] + ) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]: + model_path = os.path.join( + os.path.dirname(model_file_path), os.path.basename(model_file_path).split(".")[0] + ) with zipfile.ZipFile(model_file_path, "r") as f: f.extractall(model_path) try: model = AutoModelForTokenClassification.from_pretrained(model_path) - tokenizer = AutoTokenizer.from_pretrained(model_path, model_max_length=model.config.max_position_embeddings, add_special_tokens=False, do_lower_case=False) + tokenizer = AutoTokenizer.from_pretrained( + model_path, + model_max_length=model.config.max_position_embeddings, + add_special_tokens=False, + do_lower_case=False, + ) logger.info("Model package loaded from %s", os.path.normpath(model_file_path)) return model, tokenizer except ValueError as e: @@ -104,21 +133,29 @@ def load_model(model_file_path: str, *args: Tuple, **kwargs: Dict[str, Any]) -> raise ConfigurationException("Model package is not valid or not supported") def init_model(self) -> None: - if all([hasattr(self, "_model"), + if all( + [ + hasattr(self, "_model"), hasattr(self, "_tokenizer"), isinstance(self._model, PreTrainedModel), - isinstance(self._tokenizer, PreTrainedTokenizerBase)]): + isinstance(self._tokenizer, PreTrainedTokenizerBase), + ] + ): logger.warning("Model service is already initialised and can be initialised only once") else: self._model, self._tokenizer = self.load_model(self._model_pack_path) - _pipeline = partial(pipeline, - task="ner", - model=self._model, - tokenizer=self._tokenizer, - stride=10, - aggregation_strategy=self._config.HF_PIPELINE_AGGREGATION_STRATEGY) + _pipeline = partial( + pipeline, + task="ner", + model=self._model, + tokenizer=self._tokenizer, + stride=10, + aggregation_strategy=self._config.HF_PIPELINE_AGGREGATION_STRATEGY, + ) if non_default_device_is_available(get_settings().DEVICE): - self._ner_pipeline = _pipeline(device=get_hf_pipeline_device_id(get_settings().DEVICE)) + self._ner_pipeline = _pipeline( + device=get_hf_pipeline_device_id(get_settings().DEVICE) + ) else: self._ner_pipeline = _pipeline() if self._enable_trainer: @@ -126,10 +163,12 @@ def init_model(self) -> None: self._unsupervised_trainer = HuggingFaceNerUnsupervisedTrainer(self) def info(self) -> ModelCard: - return ModelCard(model_description=self.model_name, - model_type=ModelType.HUGGINGFACE_NER, - api_version=self.api_version, - model_card=self._model.config.to_dict()) + return ModelCard( + model_description=self.model_name, + model_type=ModelType.HUGGINGFACE_NER, + api_version=self.api_version, + model_card=self._model.config.to_dict(), + ) def annotate(self, text: str) -> Dict: entities = self._ner_pipeline(text) @@ -145,32 +184,58 @@ def annotate(self, text: str) -> Dict: return records def batch_annotate(self, texts: List[str]) -> List[Dict]: - raise NotImplementedError("Batch annotation is not yet implemented for Hugging Face NER models") - - def train_supervised(self, - data_file: TextIO, - epochs: int, - log_frequency: int, - training_id: str, - input_file_name: str, - raw_data_files: Optional[List[TextIO]] = None, - description: Optional[str] = None, - synchronised: bool = False, - **hyperparams: Dict[str, Any]) -> bool: + raise NotImplementedError( + "Batch annotation is not yet implemented for Hugging Face NER models" + ) + + def train_supervised( + self, + data_file: TextIO, + epochs: int, + log_frequency: int, + training_id: str, + input_file_name: str, + raw_data_files: Optional[List[TextIO]] = None, + description: Optional[str] = None, + synchronised: bool = False, + **hyperparams: Dict[str, Any], + ) -> bool: if self._supervised_trainer is None: raise ConfigurationException("The supervised trainer is not enabled") - return self._supervised_trainer.train(data_file, epochs, log_frequency, training_id, input_file_name, raw_data_files, description, synchronised, **hyperparams) - - def train_unsupervised(self, - data_file: TextIO, - epochs: int, - log_frequency: int, - training_id: str, - input_file_name: str, - raw_data_files: Optional[List[TextIO]] = None, - description: Optional[str] = None, - synchronised: bool = False, - **hyperparams: Dict[str, Any]) -> bool: + return self._supervised_trainer.train( + data_file, + epochs, + log_frequency, + training_id, + input_file_name, + raw_data_files, + description, + synchronised, + **hyperparams, + ) + + def train_unsupervised( + self, + data_file: TextIO, + epochs: int, + log_frequency: int, + training_id: str, + input_file_name: str, + raw_data_files: Optional[List[TextIO]] = None, + description: Optional[str] = None, + synchronised: bool = False, + **hyperparams: Dict[str, Any], + ) -> bool: if self._unsupervised_trainer is None: raise ConfigurationException("The unsupervised trainer is not enabled") - return self._unsupervised_trainer.train(data_file, epochs, log_frequency, training_id, input_file_name, raw_data_files, description, synchronised, **hyperparams) + return self._unsupervised_trainer.train( + data_file, + epochs, + log_frequency, + training_id, + input_file_name, + raw_data_files, + description, + synchronised, + **hyperparams, + ) diff --git a/app/model_services/medcat_model.py b/app/model_services/medcat_model.py index 214414a..8d8d1fb 100644 --- a/app/model_services/medcat_model.py +++ b/app/model_services/medcat_model.py @@ -1,38 +1,50 @@ -import os import logging -import pandas as pd - +import os from multiprocessing import cpu_count -from typing import Dict, List, Optional, TextIO, Tuple, Any +from typing import Any, Dict, List, Optional, TextIO, Tuple + +import pandas as pd from medcat.cat import CAT + +from config import Settings +from domain import ModelCard +from exception import ConfigurationException +from utils import TYPE_ID_TO_NAME_PATCH, get_settings, non_default_device_is_available + from model_services.base import AbstractModelService from trainers.medcat_trainer import MedcatSupervisedTrainer, MedcatUnsupervisedTrainer from trainers.metacat_trainer import MetacatTrainer -from domain import ModelCard -from config import Settings -from utils import get_settings, TYPE_ID_TO_NAME_PATCH, non_default_device_is_available -from exception import ConfigurationException logger = logging.getLogger("cms") class MedCATModel(AbstractModelService): - - def __init__(self, - config: Settings, - model_parent_dir: Optional[str] = None, - enable_trainer: Optional[bool] = None, - model_name: Optional[str] = None, - base_model_file: Optional[str] = None) -> None: + def __init__( + self, + config: Settings, + model_parent_dir: Optional[str] = None, + enable_trainer: Optional[bool] = None, + model_name: Optional[str] = None, + base_model_file: Optional[str] = None, + ) -> None: self._model: CAT = None self._config = config - self._model_parent_dir = model_parent_dir or os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "model")) - self._model_pack_path = os.path.join(self._model_parent_dir, config.BASE_MODEL_FILE if base_model_file is None else base_model_file) - self._enable_trainer = enable_trainer if enable_trainer is not None else config.ENABLE_TRAINING_APIS == "true" + self._model_parent_dir = model_parent_dir or os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "model") + ) + self._model_pack_path = os.path.join( + self._model_parent_dir, + config.BASE_MODEL_FILE if base_model_file is None else base_model_file, + ) + self._enable_trainer = ( + enable_trainer if enable_trainer is not None else config.ENABLE_TRAINING_APIS == "true" + ) self._supervised_trainer = None self._unsupervised_trainer = None self._metacat_trainer = None - self._whitelisted_tuis = set([tui.strip() for tui in config.TYPE_UNIQUE_ID_WHITELIST.split(",")]) + self._whitelisted_tuis = set( + [tui.strip() for tui in config.TYPE_UNIQUE_ID_WHITELIST.split(",")] + ) self.model_name = model_name or "MedCAT model" @property @@ -73,14 +85,19 @@ def _retrieve_meta_annotations(df: pd.DataFrame) -> pd.DataFrame: meta_annotations.append(meta_dict) df["new_meta_anns"] = meta_annotations - return pd.concat([df.drop(["new_meta_anns"], axis=1), df["new_meta_anns"].apply(pd.Series)], axis=1) + return pd.concat( + [df.drop(["new_meta_anns"], axis=1), df["new_meta_anns"].apply(pd.Series)], axis=1 + ) def init_model(self) -> None: if hasattr(self, "_model") and isinstance(self._model, CAT): logger.warning("Model service is already initialised and can be initialised only once") else: if non_default_device_is_available(get_settings().DEVICE): - self._model = self.load_model(self._model_pack_path, meta_cat_config_dict={"general": {"device": get_settings().DEVICE}}) + self._model = self.load_model( + self._model_pack_path, + meta_cat_config_dict={"general": {"device": get_settings().DEVICE}}, + ) self._model.config.general["device"] = get_settings().DEVICE else: self._model = self.load_model(self._model_pack_path) @@ -94,63 +111,102 @@ def info(self) -> ModelCard: raise NotImplementedError def annotate(self, text: str) -> Dict: - doc = self.model.get_entities(text, - addl_info=["cui2icd10", "cui2ontologies", "cui2snomed", "cui2athena_ids"]) + doc = self.model.get_entities( + text, addl_info=["cui2icd10", "cui2ontologies", "cui2snomed", "cui2athena_ids"] + ) return self.get_records_from_doc(doc) def batch_annotate(self, texts: List[str]) -> List[Dict]: batch_size_chars = 500000 - docs = self.model.multiprocessing(self._data_iterator(texts), - batch_size_chars=batch_size_chars, - nproc=max(int(cpu_count() / 2), 1), - addl_info=["cui2icd10", "cui2ontologies", "cui2snomed", "cui2athena_ids"]) + docs = self.model.multiprocessing( + self._data_iterator(texts), + batch_size_chars=batch_size_chars, + nproc=max(int(cpu_count() / 2), 1), + addl_info=["cui2icd10", "cui2ontologies", "cui2snomed", "cui2athena_ids"], + ) annotations_list = [] for _, doc in docs.items(): annotations_list.append(self.get_records_from_doc(doc)) return annotations_list - def train_supervised(self, - data_file: TextIO, - epochs: int, - log_frequency: int, - training_id: str, - input_file_name: str, - raw_data_files: Optional[List[TextIO]] = None, - description: Optional[str] = None, - synchronised: bool = False, - **hyperparams: Dict[str, Any]) -> bool: + def train_supervised( + self, + data_file: TextIO, + epochs: int, + log_frequency: int, + training_id: str, + input_file_name: str, + raw_data_files: Optional[List[TextIO]] = None, + description: Optional[str] = None, + synchronised: bool = False, + **hyperparams: Dict[str, Any], + ) -> bool: if self._supervised_trainer is None: raise ConfigurationException("The supervised trainer is not enabled") - return self._supervised_trainer.train(data_file, epochs, log_frequency, training_id, input_file_name, raw_data_files, description, synchronised, **hyperparams) - - def train_unsupervised(self, - data_file: TextIO, - epochs: int, - log_frequency: int, - training_id: str, - input_file_name: str, - raw_data_files: Optional[List[TextIO]] = None, - description: Optional[str] = None, - synchronised: bool = False, - **hyperparams: Dict[str, Any]) -> bool: + return self._supervised_trainer.train( + data_file, + epochs, + log_frequency, + training_id, + input_file_name, + raw_data_files, + description, + synchronised, + **hyperparams, + ) + + def train_unsupervised( + self, + data_file: TextIO, + epochs: int, + log_frequency: int, + training_id: str, + input_file_name: str, + raw_data_files: Optional[List[TextIO]] = None, + description: Optional[str] = None, + synchronised: bool = False, + **hyperparams: Dict[str, Any], + ) -> bool: if self._unsupervised_trainer is None: raise ConfigurationException("The unsupervised trainer is not enabled") - return self._unsupervised_trainer.train(data_file, epochs, log_frequency, training_id, input_file_name, raw_data_files, description, synchronised, **hyperparams) - - def train_metacat(self, - data_file: TextIO, - epochs: int, - log_frequency: int, - training_id: str, - input_file_name: str, - raw_data_files: Optional[List[TextIO]] = None, - description: Optional[str] = None, - synchronised: bool = False, - **hyperparams: Dict[str, Any]) -> bool: + return self._unsupervised_trainer.train( + data_file, + epochs, + log_frequency, + training_id, + input_file_name, + raw_data_files, + description, + synchronised, + **hyperparams, + ) + + def train_metacat( + self, + data_file: TextIO, + epochs: int, + log_frequency: int, + training_id: str, + input_file_name: str, + raw_data_files: Optional[List[TextIO]] = None, + description: Optional[str] = None, + synchronised: bool = False, + **hyperparams: Dict[str, Any], + ) -> bool: if self._metacat_trainer is None: raise ConfigurationException("The metacat trainer is not enabled") - return self._metacat_trainer.train(data_file, epochs, log_frequency, training_id, input_file_name, raw_data_files, description, synchronised, **hyperparams) + return self._metacat_trainer.train( + data_file, + epochs, + log_frequency, + training_id, + input_file_name, + raw_data_files, + description, + synchronised, + **hyperparams, + ) def get_records_from_doc(self, doc: Dict) -> Dict: df = pd.DataFrame(doc["entities"].values()) @@ -160,11 +216,32 @@ def get_records_from_doc(self, doc: Dict) -> Dict: else: for idx, row in df.iterrows(): if "athena_ids" in row and row["athena_ids"]: - df.loc[idx, "athena_ids"] = [athena_id["code"] for athena_id in row["athena_ids"]] + df.loc[idx, "athena_ids"] = [ + athena_id["code"] for athena_id in row["athena_ids"] + ] if self._config.INCLUDE_SPAN_TEXT == "true": - df.rename(columns={"pretty_name": "label_name", "cui": "label_id", "source_value": "text", "types": "categories", "acc": "accuracy", "athena_ids": "athena_ids"}, inplace=True) + df.rename( + columns={ + "pretty_name": "label_name", + "cui": "label_id", + "source_value": "text", + "types": "categories", + "acc": "accuracy", + "athena_ids": "athena_ids", + }, + inplace=True, + ) else: - df.rename(columns={"pretty_name": "label_name", "cui": "label_id", "types": "categories", "acc": "accuracy", "athena_ids": "athena_ids"}, inplace=True) + df.rename( + columns={ + "pretty_name": "label_name", + "cui": "label_id", + "types": "categories", + "acc": "accuracy", + "athena_ids": "athena_ids", + }, + inplace=True, + ) df = self._retrieve_meta_annotations(df) records = df.to_dict("records") return records @@ -178,7 +255,9 @@ def _set_tuis_filtering(self) -> None: model_tuis = set(tuis2cuis.keys()) if self._whitelisted_tuis == {""}: return - assert self._whitelisted_tuis.issubset(model_tuis), f"Unrecognisable Type Unique Identifier(s): {self._whitelisted_tuis - model_tuis}" + assert self._whitelisted_tuis.issubset( + model_tuis + ), f"Unrecognisable Type Unique Identifier(s): {self._whitelisted_tuis - model_tuis}" whitelisted_cuis = set() for tui in self._whitelisted_tuis: whitelisted_cuis.update(tuis2cuis.get(tui, {})) diff --git a/app/model_services/medcat_model_deid.py b/app/model_services/medcat_model_deid.py index deba5fe..635b9e2 100644 --- a/app/model_services/medcat_model_deid.py +++ b/app/model_services/medcat_model_deid.py @@ -1,34 +1,44 @@ -import logging import inspect +import logging import threading -import torch -from typing import Dict, List, TextIO, Optional, Any, final, Callable from functools import partial -from transformers import pipeline +from typing import Any, Callable, Dict, List, Optional, TextIO, final + +import torch from medcat.cat import CAT +from transformers import pipeline + from config import Settings -from model_services.medcat_model import MedCATModel -from trainers.medcat_deid_trainer import MedcatDeIdentificationSupervisedTrainer from domain import ModelCard, ModelType -from utils import non_default_device_is_available, get_hf_pipeline_device_id from exception import ConfigurationException +from utils import get_hf_pipeline_device_id, non_default_device_is_available + +from model_services.medcat_model import MedCATModel +from trainers.medcat_deid_trainer import MedcatDeIdentificationSupervisedTrainer logger = logging.getLogger("cms") @final class MedCATModelDeIdentification(MedCATModel): - CHUNK_SIZE = 500 LEFT_CONTEXT_WORDS = 5 - def __init__(self, - config: Settings, - model_parent_dir: Optional[str] = None, - enable_trainer: Optional[bool] = None, - model_name: Optional[str] = None, - base_model_file: Optional[str] = None) -> None: - super().__init__(config, model_parent_dir=model_parent_dir, enable_trainer=enable_trainer, model_name=model_name, base_model_file=base_model_file) + def __init__( + self, + config: Settings, + model_parent_dir: Optional[str] = None, + enable_trainer: Optional[bool] = None, + model_name: Optional[str] = None, + base_model_file: Optional[str] = None, + ) -> None: + super().__init__( + config, + model_parent_dir=model_parent_dir, + enable_trainer=enable_trainer, + model_name=model_name, + base_model_file=base_model_file, + ) self.model_name = model_name or "De-Identification MedCAT model" self._lock = threading.RLock() @@ -39,10 +49,12 @@ def api_version(self) -> str: def info(self) -> ModelCard: model_card = self.model.get_model_card(as_dict=True) model_card["Basic CDB Stats"]["Average training examples per concept"] = 0 - return ModelCard(model_description=self.model_name, - model_type=ModelType.ANONCAT, - api_version=self.api_version, - model_card=model_card) + return ModelCard( + model_description=self.model_name, + model_type=ModelType.ANONCAT, + api_version=self.api_version, + model_card=model_card, + ) def annotate(self, text: str) -> Dict: doc = self.model.get_entities(text) @@ -56,7 +68,9 @@ def annotate_with_local_chunking(self, text: str) -> Dict: tokenizer = self.model._addl_ner[0].tokenizer.hf_tokenizer leading_ws_len = len(text) - len(text.lstrip()) text = text.lstrip() - tokenized = self._with_lock(tokenizer, text, return_offsets_mapping=True, add_special_tokens=False) + tokenized = self._with_lock( + tokenizer, text, return_offsets_mapping=True, add_special_tokens=False + ) input_ids = tokenized["input_ids"] offset_mapping = tokenized["offset_mapping"] chunk = [] @@ -70,7 +84,7 @@ def annotate_with_local_chunking(self, text: str) -> Dict: last_token_start_idx = 0 window_overlap_start_idx = 0 number_of_seen_words = 0 - for i in range(MedCATModelDeIdentification.CHUNK_SIZE-1, -1, -1): + for i in range(MedCATModelDeIdentification.CHUNK_SIZE - 1, -1, -1): if " " in tokenizer.decode([chunk[i][0]], skip_special_tokens=True): if last_token_start_idx == 0: last_token_start_idx = i @@ -79,9 +93,15 @@ def annotate_with_local_chunking(self, text: str) -> Dict: else: break number_of_seen_words += 1 - c_text = text[chunk[:last_token_start_idx][0][1][0]:chunk[:last_token_start_idx][-1][1][1]] + c_text = text[ + chunk[:last_token_start_idx][0][1][0] : chunk[:last_token_start_idx][-1][1][1] + ] doc = self._with_lock(self.model.get_entities, c_text) - doc["entities"] = {_id: entity for _id, entity in doc["entities"].items() if (entity["end"] + processed_char_len) < chunk[window_overlap_start_idx][1][0]} + doc["entities"] = { + _id: entity + for _id, entity in doc["entities"].items() + if (entity["end"] + processed_char_len) < chunk[window_overlap_start_idx][1][0] + } for entity in doc["entities"].values(): entity["start"] += processed_char_len entity["end"] += processed_char_len @@ -91,7 +111,7 @@ def annotate_with_local_chunking(self, text: str) -> Dict: processed_char_len = chunk[:window_overlap_start_idx][-1][1][1] + leading_ws_len + 1 chunk = chunk[window_overlap_start_idx:] if chunk: - c_text = text[chunk[0][1][0]:chunk[-1][1][1]] + c_text = text[chunk[0][1][0] : chunk[-1][1][1]] doc = self.model.get_entities(c_text) if doc["entities"]: for entity in doc["entities"].values(): @@ -102,7 +122,10 @@ def annotate_with_local_chunking(self, text: str) -> Dict: ent_key += 1 processed_char_len += len(c_text) - assert processed_char_len == (len(text) + leading_ws_len), f"{len(text) + leading_ws_len - processed_char_len} characters were not processed:\n{text}" + total_char_len = len(text) + leading_ws_len + assert ( + processed_char_len == total_char_len + ), f"{total_char_len - processed_char_len} characters were not processed:\n{text}" return self.get_records_from_doc({"entities": aggregated_entities}) @@ -113,46 +136,75 @@ def batch_annotate(self, texts: List[str]) -> List[Dict]: return annotation_list def init_model(self) -> None: - if hasattr(self, "_model") and isinstance(self._model, CAT): # type: ignore + if hasattr(self, "_model") and isinstance(self._model, CAT): # type: ignore logger.warning("Model service is already initialised and can be initialised only once") else: self._model = self.load_model(self._model_pack_path) - self._model._addl_ner[0].tokenizer.hf_tokenizer._in_target_context_manager = getattr(self._model._addl_ner[0].tokenizer.hf_tokenizer, "_in_target_context_manager", False) - self._model._addl_ner[0].tokenizer.hf_tokenizer.clean_up_tokenization_spaces = getattr(self._model._addl_ner[0].tokenizer.hf_tokenizer, "clean_up_tokenization_spaces", None) - self._model._addl_ner[0].tokenizer.hf_tokenizer.split_special_tokens = getattr(self._model._addl_ner[0].tokenizer.hf_tokenizer, "split_special_tokens", False) + self._model._addl_ner[0].tokenizer.hf_tokenizer._in_target_context_manager = getattr( + self._model._addl_ner[0].tokenizer.hf_tokenizer, "_in_target_context_manager", False + ) + self._model._addl_ner[0].tokenizer.hf_tokenizer.clean_up_tokenization_spaces = getattr( + self._model._addl_ner[0].tokenizer.hf_tokenizer, + "clean_up_tokenization_spaces", + None, + ) + self._model._addl_ner[0].tokenizer.hf_tokenizer.split_special_tokens = getattr( + self._model._addl_ner[0].tokenizer.hf_tokenizer, "split_special_tokens", False + ) if non_default_device_is_available(self._config.DEVICE): self._model.config.general["device"] = self._config.DEVICE self._model._addl_ner[0].model.to(torch.device(self._config.DEVICE)) - self._model._addl_ner[0].ner_pipe = pipeline(model=self._model._addl_ner[0].model, - framework="pt", - task="ner", - tokenizer=self._model._addl_ner[0].tokenizer.hf_tokenizer, - device=get_hf_pipeline_device_id(self._config.DEVICE), - aggregation_strategy=self._config.HF_PIPELINE_AGGREGATION_STRATEGY) + self._model._addl_ner[0].ner_pipe = pipeline( + model=self._model._addl_ner[0].model, + framework="pt", + task="ner", + tokenizer=self._model._addl_ner[0].tokenizer.hf_tokenizer, + device=get_hf_pipeline_device_id(self._config.DEVICE), + aggregation_strategy=self._config.HF_PIPELINE_AGGREGATION_STRATEGY, + ) else: if self._config.DEVICE != "default": - logger.warning("DEVICE is set to '%s' but it is not available. Using 'default' instead.", self._config.DEVICE) + logger.warning( + "DEVICE is set to '%s' but it is not available. Using 'default' instead.", + self._config.DEVICE, + ) _save_pretrained = self._model._addl_ner[0].model.save_pretrained - if ("safe_serialization" in inspect.signature(_save_pretrained).parameters): - self._model._addl_ner[0].model.save_pretrained = partial(_save_pretrained, safe_serialization=(self._config.TRAINING_SAFE_MODEL_SERIALISATION == "true")) + if "safe_serialization" in inspect.signature(_save_pretrained).parameters: + self._model._addl_ner[0].model.save_pretrained = partial( + _save_pretrained, + safe_serialization=(self._config.TRAINING_SAFE_MODEL_SERIALISATION == "true"), + ) if self._enable_trainer: self._supervised_trainer = MedcatDeIdentificationSupervisedTrainer(self) - def train_supervised(self, - data_file: TextIO, - epochs: int, - log_frequency: int, - training_id: str, - input_file_name: str, - raw_data_files: Optional[List[TextIO]] = None, - description: Optional[str] = None, - synchronised: bool = False, - **hyperparams: Dict[str, Any]) -> bool: + def train_supervised( + self, + data_file: TextIO, + epochs: int, + log_frequency: int, + training_id: str, + input_file_name: str, + raw_data_files: Optional[List[TextIO]] = None, + description: Optional[str] = None, + synchronised: bool = False, + **hyperparams: Dict[str, Any], + ) -> bool: if self._supervised_trainer is None: raise ConfigurationException("Trainers are not enabled") - return self._supervised_trainer.train(data_file, epochs, log_frequency, training_id, input_file_name, raw_data_files, description, synchronised, **hyperparams) + return self._supervised_trainer.train( + data_file, + epochs, + log_frequency, + training_id, + input_file_name, + raw_data_files, + description, + synchronised, + **hyperparams, + ) def _with_lock(self, func: Callable, *args: Any, **kwargs: Any) -> Any: - # Temporarily tackle https://github.com/huggingface/tokenizers/issues/537 but it reduces parallelism + # Temporarily tackle https://github.com/huggingface/tokenizers/issues/537 + # but it reduces parallelism with self._lock: return func(*args, **kwargs) diff --git a/app/model_services/medcat_model_icd10.py b/app/model_services/medcat_model_icd10.py index 1873823..8668f1a 100644 --- a/app/model_services/medcat_model_icd10.py +++ b/app/model_services/medcat_model_icd10.py @@ -1,25 +1,35 @@ import logging -import pandas as pd from typing import Dict, Optional, final -from model_services.medcat_model import MedCATModel + +import pandas as pd + from config import Settings from domain import ModelCard, ModelType +from model_services.medcat_model import MedCATModel + logger = logging.getLogger("cms") @final class MedCATModelIcd10(MedCATModel): - ICD10_KEY = "icd10" - def __init__(self, - config: Settings, - model_parent_dir: Optional[str] = None, - enable_trainer: Optional[bool] = None, - model_name: Optional[str] = None, - base_model_file: Optional[str] = None) -> None: - super().__init__(config, model_parent_dir=model_parent_dir, enable_trainer=enable_trainer, model_name=model_name, base_model_file=base_model_file) + def __init__( + self, + config: Settings, + model_parent_dir: Optional[str] = None, + enable_trainer: Optional[bool] = None, + model_name: Optional[str] = None, + base_model_file: Optional[str] = None, + ) -> None: + super().__init__( + config, + model_parent_dir=model_parent_dir, + enable_trainer=enable_trainer, + model_name=model_name, + base_model_file=base_model_file, + ) self.model_name = model_name or "ICD-10 MedCAT model" @property @@ -27,10 +37,12 @@ def api_version(self) -> str: return "0.0.1" def info(self) -> ModelCard: - return ModelCard(model_description=self.model_name, - model_type=ModelType.MEDCAT_ICD10, - api_version=self.api_version, - model_card=self.model.get_model_card(as_dict=True)) + return ModelCard( + model_description=self.model_name, + model_type=ModelType.MEDCAT_ICD10, + api_version=self.api_version, + model_card=self.model.get_model_card(as_dict=True), + ) def get_records_from_doc(self, doc: Dict) -> Dict: df = pd.DataFrame(doc["entities"].values()) @@ -55,11 +67,22 @@ def get_records_from_doc(self, doc: Dict) -> Dict: else: logger.error("Unknown format for the ICD-10 code(s): %s", icd10) if "athena_ids" in output_row and output_row["athena_ids"]: - output_row["athena_ids"] = [athena_id["code"] for athena_id in output_row["athena_ids"]] + output_row["athena_ids"] = [ + athena_id["code"] for athena_id in output_row["athena_ids"] + ] new_rows.append(output_row) if new_rows: df = pd.DataFrame(new_rows) - df.rename(columns={"pretty_name": "label_name", self.ICD10_KEY: "label_id", "types": "categories", "acc": "accuracy", "athena_ids": "athena_ids"}, inplace=True) + df.rename( + columns={ + "pretty_name": "label_name", + self.ICD10_KEY: "label_id", + "types": "categories", + "acc": "accuracy", + "athena_ids": "athena_ids", + }, + inplace=True, + ) df = self._retrieve_meta_annotations(df) else: df = pd.DataFrame(columns=["label_name", "label_id", "start", "end", "accuracy"]) diff --git a/app/model_services/medcat_model_snomed.py b/app/model_services/medcat_model_snomed.py index 01c776d..3d64759 100644 --- a/app/model_services/medcat_model_snomed.py +++ b/app/model_services/medcat_model_snomed.py @@ -1,22 +1,31 @@ import logging from typing import Optional, final -from model_services.medcat_model import MedCATModel + from config import Settings from domain import ModelCard, ModelType +from model_services.medcat_model import MedCATModel + logger = logging.getLogger("cms") @final class MedCATModelSnomed(MedCATModel): - - def __init__(self, - config: Settings, - model_parent_dir: Optional[str] = None, - enable_trainer: Optional[bool] = None, - model_name: Optional[str] = None, - base_model_file: Optional[str] = None) -> None: - super().__init__(config, model_parent_dir=model_parent_dir, enable_trainer=enable_trainer, model_name=model_name, base_model_file=base_model_file) + def __init__( + self, + config: Settings, + model_parent_dir: Optional[str] = None, + enable_trainer: Optional[bool] = None, + model_name: Optional[str] = None, + base_model_file: Optional[str] = None, + ) -> None: + super().__init__( + config, + model_parent_dir=model_parent_dir, + enable_trainer=enable_trainer, + model_name=model_name, + base_model_file=base_model_file, + ) self.model_name = model_name or "SNOMED MedCAT model" @property @@ -24,7 +33,9 @@ def api_version(self) -> str: return "0.0.1" def info(self) -> ModelCard: - return ModelCard(model_description=self.model_name, - model_type=ModelType.MEDCAT_SNOMED, - api_version=self.api_version, - model_card=self.model.get_model_card(as_dict=True)) + return ModelCard( + model_description=self.model_name, + model_type=ModelType.MEDCAT_SNOMED, + api_version=self.api_version, + model_card=self.model.get_model_card(as_dict=True), + ) diff --git a/app/model_services/medcat_model_umls.py b/app/model_services/medcat_model_umls.py index 02d9674..abf09a6 100644 --- a/app/model_services/medcat_model_umls.py +++ b/app/model_services/medcat_model_umls.py @@ -1,19 +1,28 @@ from typing import Optional, final + from config import Settings -from model_services.medcat_model import MedCATModel from domain import ModelCard, ModelType +from model_services.medcat_model import MedCATModel + @final class MedCATModelUmls(MedCATModel): - - def __init__(self, - config: Settings, - model_parent_dir: Optional[str] = None, - enable_trainer: Optional[bool] = None, - model_name: Optional[str] = None, - base_model_file: Optional[str] = None) -> None: - super().__init__(config, model_parent_dir=model_parent_dir, enable_trainer=enable_trainer, model_name=model_name, base_model_file=base_model_file) + def __init__( + self, + config: Settings, + model_parent_dir: Optional[str] = None, + enable_trainer: Optional[bool] = None, + model_name: Optional[str] = None, + base_model_file: Optional[str] = None, + ) -> None: + super().__init__( + config, + model_parent_dir=model_parent_dir, + enable_trainer=enable_trainer, + model_name=model_name, + base_model_file=base_model_file, + ) self.model_name = model_name or "UMLS MedCAT model" @property @@ -21,7 +30,9 @@ def api_version(self) -> str: return "0.0.1" def info(self) -> ModelCard: - return ModelCard(model_description=self.model_name, - model_type=ModelType.MEDCAT_UMLS, - api_version=self.api_version, - model_card=self.model.get_model_card(as_dict=True)) + return ModelCard( + model_description=self.model_name, + model_type=ModelType.MEDCAT_UMLS, + api_version=self.api_version, + model_card=self.model.get_model_card(as_dict=True), + ) diff --git a/app/model_services/trf_model_deid.py b/app/model_services/trf_model_deid.py index a85b7f6..93acc4c 100644 --- a/app/model_services/trf_model_deid.py +++ b/app/model_services/trf_model_deid.py @@ -1,34 +1,48 @@ +import logging import os import shutil -import logging -import torch +from typing import Dict, Iterable, List, Optional, Tuple, final + import numpy as np -from typing import Tuple, List, Dict, Iterable, Optional, final +import torch +from medcat.tokenizers.transformers_ner import TransformersTokenizerNER from scipy.special import softmax from transformers import AutoModelForTokenClassification, PreTrainedModel -from medcat.tokenizers.transformers_ner import TransformersTokenizerNER -from model_services.base import AbstractModelService -from domain import ModelCard, ModelType + from config import Settings +from domain import ModelCard, ModelType from utils import cls_deprecated, non_default_device_is_available +from model_services.base import AbstractModelService + logger = logging.getLogger("cms") -@cls_deprecated("TransformersModelDeIdentification has been deprecated. Use MedCATModelDeIdentification instead.") +@cls_deprecated( + "TransformersModelDeIdentification has been deprecated." + " Use MedCATModelDeIdentification instead." +) @final class TransformersModelDeIdentification(AbstractModelService): - - def __init__(self, - config: Settings, - model_parent_dir: Optional[str] = None, - model_name: Optional[str] = None, - base_model_file: Optional[str] = None) -> None: + def __init__( + self, + config: Settings, + model_parent_dir: Optional[str] = None, + model_name: Optional[str] = None, + base_model_file: Optional[str] = None, + ) -> None: super().__init__(config) self._config = config - model_parent_dir = model_parent_dir or os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "model")) - self._model_parent_dir = model_parent_dir or os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "model")) - self._model_file_path = os.path.join(self._model_parent_dir, config.BASE_MODEL_FILE if base_model_file is None else base_model_file) + model_parent_dir = model_parent_dir or os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "model") + ) + self._model_parent_dir = model_parent_dir or os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "model") + ) + self._model_file_path = os.path.join( + self._model_parent_dir, + config.BASE_MODEL_FILE if base_model_file is None else base_model_file, + ) if non_default_device_is_available(config.DEVICE): self._device = torch.device(config.DEVICE) self.model_name = model_name or "De-identification model" @@ -53,9 +67,11 @@ def api_version(self) -> str: return "0.0.1" def info(self) -> ModelCard: - return ModelCard(model_description=self.model_name, - model_type=ModelType.TRANSFORMERS_DEID, - api_version=self.api_version) + return ModelCard( + model_description=self.model_name, + model_type=ModelType.TRANSFORMERS_DEID, + api_version=self.api_version, + ) @staticmethod def load_model(model_file_path: str) -> Tuple[TransformersTokenizerNER, PreTrainedModel]: @@ -98,8 +114,10 @@ def _get_annotations(self, text: str) -> List[Dict]: annotations: List[Dict] = [] for dataset, offset_mappings in self._get_chunked_tokens(text): - predictions = self._model(torch.tensor([dataset["input_ids"]]).to(device), - torch.tensor([dataset["attention_mask"]]).to(device)) + predictions = self._model( + torch.tensor([dataset["input_ids"]]).to(device), + torch.tensor([dataset["attention_mask"]]).to(device), + ) predictions = softmax(predictions.logits.detach().numpy()[0], axis=-1) predictions = np.argmax(predictions, axis=-1) @@ -119,11 +137,19 @@ def _get_annotations(self, text: str) -> List[Dict]: annotation["text"] = t_text if annotations: token_type = self._tokenizer.id2type.get(input_ids[t_idx]) - if any([self._should_expand_with_partial(cur_cui_id, token_type, annotation, annotations), - self._should_expand_with_whole(cas, annotation, annotations)]): + if any( + [ + self._should_expand_with_partial( + cur_cui_id, token_type, annotation, annotations + ), + self._should_expand_with_whole(cas, annotation, annotations), + ] + ): annotations[-1]["end"] = annotation["end"] if ist: - annotations[-1]["text"] = text[annotations[-1]["start"]:annotations[-1]["end"]] + annotations[-1]["text"] = text[ + annotations[-1]["start"] : annotations[-1]["end"] + ] del annotation continue elif cur_cui_id != 1: @@ -137,32 +163,52 @@ def _get_annotations(self, text: str) -> List[Dict]: return annotations def _get_chunked_tokens(self, text: str) -> Iterable[Tuple[Dict, List[Tuple]]]: - tokens = self._tokenizer.hf_tokenizer(text, return_offsets_mapping=True, add_special_tokens=False) + tokens = self._tokenizer.hf_tokenizer( + text, return_offsets_mapping=True, add_special_tokens=False + ) model_max_length = self._tokenizer.max_len pad_token_id = self._tokenizer.hf_tokenizer.pad_token_id partial = len(tokens["input_ids"]) % model_max_length for i in range(0, len(tokens["input_ids"]) - partial, model_max_length): dataset = { - "input_ids": tokens["input_ids"][i:i+model_max_length], - "attention_mask": tokens["attention_mask"][i:i+model_max_length], + "input_ids": tokens["input_ids"][i : i + model_max_length], + "attention_mask": tokens["attention_mask"][i : i + model_max_length], } - offset_mappings = tokens["offset_mapping"][i:i+model_max_length] + offset_mappings = tokens["offset_mapping"][i : i + model_max_length] yield dataset, offset_mappings if partial: dataset = { - "input_ids": tokens["input_ids"][-partial:] + [pad_token_id]*(model_max_length-partial), - "attention_mask": tokens["attention_mask"][-partial:] + [0]*(model_max_length-partial), + "input_ids": tokens["input_ids"][-partial:] + + [pad_token_id] * (model_max_length - partial), + "attention_mask": tokens["attention_mask"][-partial:] + + [0] * (model_max_length - partial), } - offset_mappings = (tokens["offset_mapping"][-partial:] + [(tokens["offset_mapping"][-1][1]+i, tokens["offset_mapping"][-1][1]+i+1) for i in range(model_max_length-partial)]) + offset_mappings = tokens["offset_mapping"][-partial:] + [ + (tokens["offset_mapping"][-1][1] + i, tokens["offset_mapping"][-1][1] + i + 1) + for i in range(model_max_length - partial) + ] yield dataset, offset_mappings @staticmethod - def _should_expand_with_partial(cur_cui_id: int, - cur_token_type: str, - annotation: Dict, - annotations: List[Dict]) -> bool: - return all([cur_cui_id == 1, cur_token_type == "sub", (annotation["start"] - annotations[-1]["end"]) in [0, 1]]) + def _should_expand_with_partial( + cur_cui_id: int, cur_token_type: str, annotation: Dict, annotations: List[Dict] + ) -> bool: + return all( + [ + cur_cui_id == 1, + cur_token_type == "sub", + (annotation["start"] - annotations[-1]["end"]) in [0, 1], + ] + ) @staticmethod - def _should_expand_with_whole(is_enabled: bool, annotation: Dict, annotations: List[Dict]) -> bool: - return all([is_enabled, annotation["label_id"] == annotations[-1]["label_id"], (annotation["start"] - annotations[-1]["end"]) in [0, 1]]) + def _should_expand_with_whole( + is_enabled: bool, annotation: Dict, annotations: List[Dict] + ) -> bool: + return all( + [ + is_enabled, + annotation["label_id"] == annotations[-1]["label_id"], + (annotation["start"] - annotations[-1]["end"]) in [0, 1], + ] + ) diff --git a/app/processors/data_batcher.py b/app/processors/data_batcher.py index fdbe410..d603310 100644 --- a/app/processors/data_batcher.py +++ b/app/processors/data_batcher.py @@ -1,4 +1,4 @@ -from typing import Iterable, List, Any +from typing import Any, Iterable, List def mini_batch(data: Iterable[Any], batch_size: Any) -> Iterable[List[Any]]: diff --git a/app/processors/metrics_collector.py b/app/processors/metrics_collector.py index d71edae..7a01430 100644 --- a/app/processors/metrics_collector.py +++ b/app/processors/metrics_collector.py @@ -1,13 +1,15 @@ -import json import hashlib -import pandas as pd -from typing import Tuple, Dict, List, Set, Union, Optional, TextIO +import json from collections import defaultdict +from typing import Dict, List, Optional, Set, TextIO, Tuple, Union + +import pandas as pd from sklearn.metrics import cohen_kappa_score from tqdm.autonotebook import tqdm -from model_services.base import AbstractModelService + from exception import AnnotationException +from model_services.base import AbstractModelService ANCHOR_DELIMITER = ";" DOC_SPAN_DELIMITER = "_" @@ -15,10 +17,12 @@ META_STATE_MISSING = hashlib.sha1("{}".encode("utf-8")).hexdigest() -def sanity_check_model_with_trainer_export(trainer_export: Union[str, TextIO, Dict], - model_service: AbstractModelService, - return_df: bool = False, - include_anchors: bool = False) -> Union[pd.DataFrame, Tuple[float, float, float, Dict, Dict, Dict, Dict, Optional[Dict]]]: +def sanity_check_model_with_trainer_export( + trainer_export: Union[str, TextIO, Dict], + model_service: AbstractModelService, + return_df: bool = False, + include_anchors: bool = False, +) -> Union[pd.DataFrame, Tuple[float, float, float, Dict, Dict, Dict, Dict, Optional[Dict]]]: if isinstance(trainer_export, str): with open(trainer_export, "r") as file: data = json.load(file) @@ -36,7 +40,9 @@ def sanity_check_model_with_trainer_export(trainer_export: Union[str, TextIO, Di if entry["correct"]: if document["id"] not in correct_cuis[project["id"]]: correct_cuis[project["id"]][document["id"]] = [] - correct_cuis[project["id"]][document["id"]].append([entry["start"], entry["end"], entry["cui"]]) + correct_cuis[project["id"]][document["id"]].append( + [entry["start"], entry["end"], entry["cui"]] + ) true_positives: Dict = {} false_positives: Dict = {} @@ -45,41 +51,60 @@ def sanity_check_model_with_trainer_export(trainer_export: Union[str, TextIO, Di concept_anchors: Dict = {} true_positive_count, false_positive_count, false_negative_count = 0, 0, 0 - for project in tqdm(data["projects"], desc="Evaluating projects", total=len(data["projects"]), leave=False): + for project in tqdm( + data["projects"], desc="Evaluating projects", total=len(data["projects"]), leave=False + ): predictions: Dict = {} documents = project["documents"] true_positives[project["id"]] = {} false_positives[project["id"]] = {} false_negatives[project["id"]] = {} - for document in tqdm(documents, desc="Evaluating documents", total=len(documents), leave=False): - true_positives[project["id"]][document["id"]] = {} - false_positives[project["id"]][document["id"]] = {} - false_negatives[project["id"]][document["id"]] = {} + for document in tqdm( + documents, desc="Evaluating documents", total=len(documents), leave=False + ): + p_id, d_id = project["id"], document["id"] + true_positives[p_id][d_id] = {} + false_positives[p_id][d_id] = {} + false_negatives[p_id][d_id] = {} annotations = model_service.annotate(document["text"]) - predictions[document["id"]] = [] + predictions[d_id] = [] for annotation in annotations: - predictions[document["id"]].append([annotation["start"], annotation["end"], annotation["label_id"]]) + predictions[d_id].append( + [annotation["start"], annotation["end"], annotation["label_id"]] + ) concept_names[annotation["label_id"]] = annotation["label_name"] - concept_anchors[annotation["label_id"]] = concept_anchors.get(annotation["label_id"], []) - concept_anchors[annotation["label_id"]].append(f"P{project['id']}/D{document['id']}/S{annotation['start']}/E{ annotation['end']}") - - predicted = {tuple(x) for x in predictions[document["id"]]} - actual = {tuple(x) for x in correct_cuis[project["id"]][document["id"]]} + concept_anchors[annotation["label_id"]] = concept_anchors.get( + annotation["label_id"], [] + ) + concept_anchors[annotation["label_id"]].append( + f"P{p_id}/D{document['id']}/S{annotation['start']}/E{ annotation['end']}" + ) + + predicted = {tuple(x) for x in predictions[d_id]} + actual = {tuple(x) for x in correct_cuis[p_id][d_id]} doc_tps = list(predicted.intersection(actual)) doc_fps = list(predicted.difference(actual)) doc_fns = list(actual.difference(predicted)) - true_positives[project["id"]][document["id"]] = doc_tps - false_positives[project["id"]][document["id"]] = doc_fps - false_negatives[project["id"]][document["id"]] = doc_fns + true_positives[p_id][d_id] = doc_tps + false_positives[p_id][d_id] = doc_fps + false_negatives[p_id][d_id] = doc_fns true_positive_count += len(doc_tps) false_positive_count += len(doc_fps) false_negative_count += len(doc_fns) - precision = true_positive_count / (true_positive_count + false_positive_count) if (true_positive_count + false_positive_count) != 0 else 0 - recall = true_positive_count / (true_positive_count + false_negative_count) if (true_positive_count + false_negative_count) != 0 else 0 - f1 = 2*((precision*recall) / (precision + recall)) if (precision + recall) != 0 else 0 + precision = ( + true_positive_count / (true_positive_count + false_positive_count) + if (true_positive_count + false_positive_count) != 0 + else 0 + ) + recall = ( + true_positive_count / (true_positive_count + false_negative_count) + if (true_positive_count + false_negative_count) != 0 + else 0 + ) + f1 = 2 * ((precision * recall) / (precision + recall)) if (precision + recall) != 0 else 0 fp_counts: Dict = defaultdict(int) fn_counts: Dict = defaultdict(int) @@ -108,29 +133,44 @@ def sanity_check_model_with_trainer_export(trainer_export: Union[str, TextIO, Di for cui in tp_counts.keys(): per_cui_prec[cui] = tp_counts[cui] / (tp_counts[cui] + fp_counts[cui]) per_cui_rec[cui] = tp_counts[cui] / (tp_counts[cui] + fn_counts[cui]) - per_cui_f1[cui] = 2*(per_cui_prec[cui]*per_cui_rec[cui]) / (per_cui_prec[cui] + per_cui_rec[cui]) + per_cui_f1[cui] = ( + 2 * (per_cui_prec[cui] * per_cui_rec[cui]) / (per_cui_prec[cui] + per_cui_rec[cui]) + ) per_cui_name[cui] = concept_names[cui] per_cui_anchors[cui] = ANCHOR_DELIMITER.join(concept_anchors[cui]) if return_df: - df = pd.DataFrame({ - "concept": per_cui_prec.keys(), - "name": per_cui_name.values(), - "precision": per_cui_prec.values(), - "recall": per_cui_rec.values(), - "f1": per_cui_f1.values(), - }) + df = pd.DataFrame( + { + "concept": per_cui_prec.keys(), + "name": per_cui_name.values(), + "precision": per_cui_prec.values(), + "recall": per_cui_rec.values(), + "f1": per_cui_f1.values(), + } + ) if include_anchors: df["anchors"] = per_cui_anchors.values() return df else: - return precision, recall, f1, per_cui_prec, per_cui_rec, per_cui_f1, per_cui_name, per_cui_anchors if include_anchors else None - - -def concat_trainer_exports(data_file_paths: List[str], - combined_data_file_path: Optional[str] = None, - allow_recurring_project_ids: bool = False, - allow_recurring_doc_ids: bool = True) -> Union[Dict, str]: + return ( + precision, + recall, + f1, + per_cui_prec, + per_cui_rec, + per_cui_f1, + per_cui_name, + per_cui_anchors if include_anchors else None, + ) + + +def concat_trainer_exports( + data_file_paths: List[str], + combined_data_file_path: Optional[str] = None, + allow_recurring_project_ids: bool = False, + allow_recurring_doc_ids: bool = True, +) -> Union[Dict, str]: combined: Dict = {"projects": []} project_ids = [] for path in data_file_paths: @@ -138,13 +178,17 @@ def concat_trainer_exports(data_file_paths: List[str], data = json.load(f) for project in data["projects"]: if project["id"] in project_ids and not allow_recurring_project_ids: - raise AnnotationException(f'Found multiple projects share the same ID: {project["id"]}') + raise AnnotationException( + f'Found multiple projects share the same ID: {project["id"]}' + ) project_ids.append(project["id"]) combined["projects"].extend(data["projects"]) document_ids = [doc["id"] for project in combined["projects"] for doc in project["documents"]] if not allow_recurring_doc_ids and len(document_ids) > len(set(document_ids)): - recurring_ids = list(set([doc_id for doc_id in document_ids if document_ids.count(doc_id) > 1])) - raise AnnotationException(f'Found multiple documents share the same ID(s): {recurring_ids}') + recurring_ids = list( + set([doc_id for doc_id in document_ids if document_ids.count(doc_id) > 1]) + ) + raise AnnotationException(f"Found multiple documents share the same ID(s): {recurring_ids}") if isinstance(combined_data_file_path, str): with open(combined_data_file_path, "w") as f: @@ -155,8 +199,9 @@ def concat_trainer_exports(data_file_paths: List[str], return combined -def get_stats_from_trainer_export(trainer_export: Union[str, TextIO, Dict], - return_df: bool = False) -> Union[pd.DataFrame, Tuple[Dict[str, int], Dict[str, int], Dict[str, int], int]]: +def get_stats_from_trainer_export( + trainer_export: Union[str, TextIO, Dict], return_df: bool = False +) -> Union[pd.DataFrame, Tuple[Dict[str, int], Dict[str, int], Dict[str, int], int]]: if isinstance(trainer_export, str): with open(trainer_export, "r") as file: data = json.load(file) @@ -177,32 +222,42 @@ def get_stats_from_trainer_export(trainer_export: Union[str, TextIO, Dict], elif isinstance(doc["annotations"], dict): annotations = list(doc["annotations"].values()) for annotation in annotations: - if any([not annotation.get("validated", True), - annotation.get("deleted", False), - annotation.get("killed", False), - annotation.get("irrelevant", False)]): + if any( + [ + not annotation.get("validated", True), + annotation.get("deleted", False), + annotation.get("killed", False), + annotation.get("irrelevant", False), + ] + ): cui_ignorance_counts[annotation["cui"]] += 1 - cui_values[annotation["cui"]].append(doc["text"][annotation["start"]:annotation["end"]].lower()) + cui_values[annotation["cui"]].append( + doc["text"][annotation["start"] : annotation["end"]].lower() + ) num_of_docs += 1 cui_counts = {cui: len(values) for cui, values in cui_values.items()} cui_unique_counts = {cui: len(set(values)) for cui, values in cui_values.items()} if return_df: - return pd.DataFrame({ - "concept": cui_counts.keys(), - "anno_count": cui_counts.values(), - "anno_unique_counts": cui_unique_counts.values(), - "anno_ignorance_counts": [cui_ignorance_counts[c] for c in cui_counts.keys()], - }) + return pd.DataFrame( + { + "concept": cui_counts.keys(), + "anno_count": cui_counts.values(), + "anno_unique_counts": cui_unique_counts.values(), + "anno_ignorance_counts": [cui_ignorance_counts[c] for c in cui_counts.keys()], + } + ) else: return cui_counts, cui_unique_counts, cui_ignorance_counts, num_of_docs -def get_iaa_scores_per_concept(export_file: Union[str, TextIO], - project_id: int, - another_project_id: int, - return_df: bool = False) -> Union[pd.DataFrame, Tuple[Dict, Dict]]: +def get_iaa_scores_per_concept( + export_file: Union[str, TextIO], + project_id: int, + another_project_id: int, + return_df: bool = False, +) -> Union[pd.DataFrame, Tuple[Dict, Dict]]: project_a, project_b = _extract_project_pair(export_file, project_id, another_project_id) filtered_projects = _filter_common_docs([project_a, project_b]) @@ -215,7 +270,9 @@ def get_iaa_scores_per_concept(export_file: Union[str, TextIO], docspan_key = _get_docspan_key(document, annotation) docspan2cui_a[docspan_key] = annotation["cui"] docspan2state_proj_a[docspan_key] = _get_hashed_annotation_state(annotation, state_keys) - docspan2metastate_proj_a[docspan_key] = _get_hashed_meta_annotation_state(annotation["meta_anns"]) + docspan2metastate_proj_a[docspan_key] = _get_hashed_meta_annotation_state( + annotation["meta_anns"] + ) docspan2cui_b = {} docspan2state_proj_b = {} @@ -225,47 +282,90 @@ def get_iaa_scores_per_concept(export_file: Union[str, TextIO], docspan_key = _get_docspan_key(document, annotation) docspan2cui_b[docspan_key] = annotation["cui"] docspan2state_proj_b[docspan_key] = _get_hashed_annotation_state(annotation, state_keys) - docspan2metastate_proj_b[docspan_key] = _get_hashed_meta_annotation_state(annotation["meta_anns"]) + docspan2metastate_proj_b[docspan_key] = _get_hashed_meta_annotation_state( + annotation["meta_anns"] + ) cui_states = {} cui_metastates = {} cuis = set(docspan2cui_a.values()).union(set(docspan2cui_b.values())) for cui in cuis: - docspans = set(_filter_docspan_by_value(docspan2cui_a, cui).keys()).union(set(_filter_docspan_by_value(docspan2cui_b, cui).keys())) - cui_states[cui] = [(docspan2state_proj_a.get(docspan, STATE_MISSING), docspan2state_proj_b.get(docspan, STATE_MISSING)) for docspan in docspans] - cui_metastates[cui] = [(docspan2metastate_proj_a.get(docspan, META_STATE_MISSING), docspan2metastate_proj_b.get(docspan, META_STATE_MISSING)) for docspan in docspans] + docspans = set(_filter_docspan_by_value(docspan2cui_a, cui).keys()).union( + set(_filter_docspan_by_value(docspan2cui_b, cui).keys()) + ) + cui_states[cui] = [ + ( + docspan2state_proj_a.get(docspan, STATE_MISSING), + docspan2state_proj_b.get(docspan, STATE_MISSING), + ) + for docspan in docspans + ] + cui_metastates[cui] = [ + ( + docspan2metastate_proj_a.get(docspan, META_STATE_MISSING), + docspan2metastate_proj_b.get(docspan, META_STATE_MISSING), + ) + for docspan in docspans + ] per_cui_anno_iia_pct = {} per_cui_anno_cohens_kappa = {} for cui, cui_state_pairs in cui_states.items(): - per_cui_anno_iia_pct[cui] = len([1 for csp in cui_state_pairs if csp[0] == csp[1]]) / len(cui_state_pairs) * 100 - per_cui_anno_cohens_kappa[cui] = _get_cohens_kappa_coefficient(*map(list, zip(*cui_state_pairs))) + per_cui_anno_iia_pct[cui] = ( + len([1 for csp in cui_state_pairs if csp[0] == csp[1]]) / len(cui_state_pairs) * 100 + ) + per_cui_anno_cohens_kappa[cui] = _get_cohens_kappa_coefficient( + *map(list, zip(*cui_state_pairs)) + ) per_cui_metaanno_iia_pct = {} per_cui_metaanno_cohens_kappa = {} for cui, cui_metastate_pairs in cui_metastates.items(): - per_cui_metaanno_iia_pct[cui] = len([1 for cmp in cui_metastate_pairs if cmp[0] == cmp[1]]) / len(cui_metastate_pairs) * 100 - per_cui_metaanno_cohens_kappa[cui] = _get_cohens_kappa_coefficient(*map(list, zip(*cui_metastate_pairs))) + per_cui_metaanno_iia_pct[cui] = ( + len([1 for cmp in cui_metastate_pairs if cmp[0] == cmp[1]]) + / len(cui_metastate_pairs) + * 100 + ) + per_cui_metaanno_cohens_kappa[cui] = _get_cohens_kappa_coefficient( + *map(list, zip(*cui_metastate_pairs)) + ) if return_df: - df = pd.DataFrame({ - "concept": per_cui_anno_iia_pct.keys(), - "iaa_percentage": per_cui_anno_iia_pct.values(), - "cohens_kappa": per_cui_anno_cohens_kappa.values(), - "iaa_percentage_meta": per_cui_metaanno_iia_pct.values(), - "cohens_kappa_meta": per_cui_metaanno_cohens_kappa.values() - }).sort_values(["concept"], ascending=True) + df = pd.DataFrame( + { + "concept": per_cui_anno_iia_pct.keys(), + "iaa_percentage": per_cui_anno_iia_pct.values(), + "cohens_kappa": per_cui_anno_cohens_kappa.values(), + "iaa_percentage_meta": per_cui_metaanno_iia_pct.values(), + "cohens_kappa_meta": per_cui_metaanno_cohens_kappa.values(), + } + ).sort_values(["concept"], ascending=True) return df.fillna("NaN") else: - return per_cui_anno_iia_pct, per_cui_anno_cohens_kappa, per_cui_metaanno_iia_pct, per_cui_metaanno_cohens_kappa - - -def get_iaa_scores_per_doc(export_file: Union[str, TextIO], - project_id: int, - another_project_id: int, - return_df: bool = False) -> Union[pd.DataFrame, Tuple[Dict, Dict]]: + return ( + per_cui_anno_iia_pct, + per_cui_anno_cohens_kappa, + per_cui_metaanno_iia_pct, + per_cui_metaanno_cohens_kappa, + ) + + +def get_iaa_scores_per_doc( + export_file: Union[str, TextIO], + project_id: int, + another_project_id: int, + return_df: bool = False, +) -> Union[pd.DataFrame, Tuple[Dict, Dict]]: project_a, project_b = _extract_project_pair(export_file, project_id, another_project_id) filtered_projects = _filter_common_docs([project_a, project_b]) - state_keys = {"validated", "correct", "deleted", "alternative", "killed", "manually_created", "cui"} + state_keys = { + "validated", + "correct", + "deleted", + "alternative", + "killed", + "manually_created", + "cui", + } docspan2doc_id_a = {} docspan2state_proj_a = {} @@ -275,7 +375,9 @@ def get_iaa_scores_per_doc(export_file: Union[str, TextIO], docspan_key = _get_docspan_key(document, annotation) docspan2doc_id_a[docspan_key] = document["id"] docspan2state_proj_a[docspan_key] = _get_hashed_annotation_state(annotation, state_keys) - docspan2metastate_proj_a[docspan_key] = _get_hashed_meta_annotation_state(annotation["meta_anns"]) + docspan2metastate_proj_a[docspan_key] = _get_hashed_meta_annotation_state( + annotation["meta_anns"] + ) docspan2doc_id_b = {} docspan2state_proj_b = {} @@ -285,48 +387,90 @@ def get_iaa_scores_per_doc(export_file: Union[str, TextIO], docspan_key = _get_docspan_key(document, annotation) docspan2doc_id_b[docspan_key] = document["id"] docspan2state_proj_b[docspan_key] = _get_hashed_annotation_state(annotation, state_keys) - docspan2metastate_proj_b[docspan_key] = _get_hashed_meta_annotation_state(annotation["meta_anns"]) + docspan2metastate_proj_b[docspan_key] = _get_hashed_meta_annotation_state( + annotation["meta_anns"] + ) doc_states = {} doc_metastates = {} doc_ids = sorted(set(docspan2doc_id_a.values()).union(set(docspan2doc_id_b.values()))) for doc_id in doc_ids: docspans = set(_filter_docspan_by_value(docspan2doc_id_a, doc_id).keys()).union( - set(_filter_docspan_by_value(docspan2doc_id_b, doc_id).keys())) - doc_states[doc_id] = [(docspan2state_proj_a.get(docspan, STATE_MISSING), docspan2state_proj_b.get(docspan, STATE_MISSING)) for docspan in docspans] - doc_metastates[doc_id] = [(docspan2metastate_proj_a.get(docspan, META_STATE_MISSING), docspan2metastate_proj_b.get(docspan, META_STATE_MISSING)) for docspan in docspans] + set(_filter_docspan_by_value(docspan2doc_id_b, doc_id).keys()) + ) + doc_states[doc_id] = [ + ( + docspan2state_proj_a.get(docspan, STATE_MISSING), + docspan2state_proj_b.get(docspan, STATE_MISSING), + ) + for docspan in docspans + ] + doc_metastates[doc_id] = [ + ( + docspan2metastate_proj_a.get(docspan, META_STATE_MISSING), + docspan2metastate_proj_b.get(docspan, META_STATE_MISSING), + ) + for docspan in docspans + ] per_doc_anno_iia_pct = {} per_doc_anno_cohens_kappa = {} for doc_id, doc_state_pairs in doc_states.items(): - per_doc_anno_iia_pct[str(doc_id)] = len([1 for dsp in doc_state_pairs if dsp[0] == dsp[1]]) / len(doc_state_pairs) * 100 - per_doc_anno_cohens_kappa[str(doc_id)] = _get_cohens_kappa_coefficient(*map(list, zip(*doc_state_pairs))) + per_doc_anno_iia_pct[str(doc_id)] = ( + len([1 for dsp in doc_state_pairs if dsp[0] == dsp[1]]) / len(doc_state_pairs) * 100 + ) + per_doc_anno_cohens_kappa[str(doc_id)] = _get_cohens_kappa_coefficient( + *map(list, zip(*doc_state_pairs)) + ) per_doc_metaanno_iia_pct = {} per_doc_metaanno_cohens_kappa = {} for doc_id, doc_metastate_pairs in doc_metastates.items(): - per_doc_metaanno_iia_pct[str(doc_id)] = len([1 for dmp in doc_metastate_pairs if dmp[0] == dmp[1]]) / len(doc_metastate_pairs) * 100 - per_doc_metaanno_cohens_kappa[str(doc_id)] = _get_cohens_kappa_coefficient(*map(list, zip(*doc_metastate_pairs))) + per_doc_metaanno_iia_pct[str(doc_id)] = ( + len([1 for dmp in doc_metastate_pairs if dmp[0] == dmp[1]]) + / len(doc_metastate_pairs) + * 100 + ) + per_doc_metaanno_cohens_kappa[str(doc_id)] = _get_cohens_kappa_coefficient( + *map(list, zip(*doc_metastate_pairs)) + ) if return_df: - df = pd.DataFrame({ - "doc_id": per_doc_anno_iia_pct.keys(), - "iaa_percentage": per_doc_anno_iia_pct.values(), - "cohens_kappa": per_doc_anno_cohens_kappa.values(), - "iaa_percentage_meta": per_doc_metaanno_iia_pct.values(), - "cohens_kappa_meta": per_doc_metaanno_cohens_kappa.values() - }).sort_values(["doc_id"], ascending=True) + df = pd.DataFrame( + { + "doc_id": per_doc_anno_iia_pct.keys(), + "iaa_percentage": per_doc_anno_iia_pct.values(), + "cohens_kappa": per_doc_anno_cohens_kappa.values(), + "iaa_percentage_meta": per_doc_metaanno_iia_pct.values(), + "cohens_kappa_meta": per_doc_metaanno_cohens_kappa.values(), + } + ).sort_values(["doc_id"], ascending=True) return df.fillna("NaN") else: - return per_doc_anno_iia_pct, per_doc_anno_cohens_kappa, per_doc_metaanno_iia_pct, per_doc_metaanno_cohens_kappa - - -def get_iaa_scores_per_span(export_file: Union[str, TextIO], - project_id: int, - another_project_id: int, - return_df: bool = False) -> Union[pd.DataFrame, Tuple[Dict, Dict]]: + return ( + per_doc_anno_iia_pct, + per_doc_anno_cohens_kappa, + per_doc_metaanno_iia_pct, + per_doc_metaanno_cohens_kappa, + ) + + +def get_iaa_scores_per_span( + export_file: Union[str, TextIO], + project_id: int, + another_project_id: int, + return_df: bool = False, +) -> Union[pd.DataFrame, Tuple[Dict, Dict]]: project_a, project_b = _extract_project_pair(export_file, project_id, another_project_id) filtered_projects = _filter_common_docs([project_a, project_b]) - state_keys = {"validated", "correct", "deleted", "alternative", "killed", "manually_created", "cui"} + state_keys = { + "validated", + "correct", + "deleted", + "alternative", + "killed", + "manually_created", + "cui", + } docspan2state_proj_a = {} docspan2statemeta_proj_a = {} @@ -334,7 +478,11 @@ def get_iaa_scores_per_span(export_file: Union[str, TextIO], for annotation in document["annotations"]: docspan_key = _get_docspan_key(document, annotation) docspan2state_proj_a[docspan_key] = [str(annotation.get(key)) for key in state_keys] - docspan2statemeta_proj_a[docspan_key] = [str(meta_ann) for meta_ann in annotation["meta_anns"].items()] if annotation["meta_anns"] else [META_STATE_MISSING] + docspan2statemeta_proj_a[docspan_key] = ( + [str(meta_ann) for meta_ann in annotation["meta_anns"].items()] + if annotation["meta_anns"] + else [META_STATE_MISSING] + ) docspan2state_proj_b = {} docspan2statemeta_proj_b = {} @@ -342,48 +490,106 @@ def get_iaa_scores_per_span(export_file: Union[str, TextIO], for annotation in document["annotations"]: docspan_key = _get_docspan_key(document, annotation) docspan2state_proj_b[docspan_key] = [str(annotation.get(key)) for key in state_keys] - docspan2statemeta_proj_b[docspan_key] = [str(meta_ann) for meta_ann in annotation["meta_anns"].items()] if annotation["meta_anns"] else [META_STATE_MISSING] + docspan2statemeta_proj_b[docspan_key] = ( + [str(meta_ann) for meta_ann in annotation["meta_anns"].items()] + if annotation["meta_anns"] + else [META_STATE_MISSING] + ) docspans = set(docspan2state_proj_a.keys()).union(set(docspan2state_proj_b.keys())) - docspan_states = {docspan: (docspan2state_proj_a.get(docspan, [STATE_MISSING]*len(state_keys)), docspan2state_proj_b.get(docspan, [STATE_MISSING]*len(state_keys))) for docspan in docspans} + docspan_states = { + docspan: ( + docspan2state_proj_a.get(docspan, [STATE_MISSING] * len(state_keys)), + docspan2state_proj_b.get(docspan, [STATE_MISSING] * len(state_keys)), + ) + for docspan in docspans + } docspan_metastates = {} for docspan in docspans: if docspan in docspan2statemeta_proj_a and docspan not in docspan2statemeta_proj_b: - docspan_metastates[docspan] = (docspan2statemeta_proj_a[docspan], [STATE_MISSING] * len(docspan2statemeta_proj_a[docspan])) + docspan_metastates[docspan] = ( + docspan2statemeta_proj_a[docspan], + [STATE_MISSING] * len(docspan2statemeta_proj_a[docspan]), + ) elif docspan not in docspan2statemeta_proj_a and docspan in docspan2statemeta_proj_b: - docspan_metastates[docspan] = ([STATE_MISSING] * len(docspan2statemeta_proj_b[docspan]), docspan2statemeta_proj_b[docspan]) + docspan_metastates[docspan] = ( + [STATE_MISSING] * len(docspan2statemeta_proj_b[docspan]), + docspan2statemeta_proj_b[docspan], + ) else: - docspan_metastates[docspan] = (docspan2statemeta_proj_a[docspan], docspan2statemeta_proj_b[docspan]) + docspan_metastates[docspan] = ( + docspan2statemeta_proj_a[docspan], + docspan2statemeta_proj_b[docspan], + ) per_span_anno_iia_pct = {} per_span_anno_cohens_kappa = {} for docspan, docspan_state_pairs in docspan_states.items(): - per_span_anno_iia_pct[docspan] = len([1 for state_a, state_b in zip(docspan_state_pairs[0], docspan_state_pairs[1]) if state_a == state_b]) / len(state_keys) * 100 - per_span_anno_cohens_kappa[docspan] = _get_cohens_kappa_coefficient(docspan_state_pairs[0], docspan_state_pairs[1]) + per_span_anno_iia_pct[docspan] = ( + len( + [ + 1 + for state_a, state_b in zip(docspan_state_pairs[0], docspan_state_pairs[1]) + if state_a == state_b + ] + ) + / len(state_keys) + * 100 + ) + per_span_anno_cohens_kappa[docspan] = _get_cohens_kappa_coefficient( + docspan_state_pairs[0], docspan_state_pairs[1] + ) per_doc_metaanno_iia_pct = {} per_doc_metaanno_cohens_kappa = {} for docspan, docspan_metastate_pairs in docspan_metastates.items(): - per_doc_metaanno_iia_pct[docspan] = len([1 for state_a, state_b in zip(docspan_metastate_pairs[0], docspan_metastate_pairs[1]) if state_a == state_b]) / len(docspan_metastate_pairs[0]) * 100 - per_doc_metaanno_cohens_kappa[docspan] = _get_cohens_kappa_coefficient(docspan_metastate_pairs[0], docspan_metastate_pairs[1]) + per_doc_metaanno_iia_pct[docspan] = ( + len( + [ + 1 + for state_a, state_b in zip( + docspan_metastate_pairs[0], docspan_metastate_pairs[1] + ) + if state_a == state_b + ] + ) + / len(docspan_metastate_pairs[0]) + * 100 + ) + per_doc_metaanno_cohens_kappa[docspan] = _get_cohens_kappa_coefficient( + docspan_metastate_pairs[0], docspan_metastate_pairs[1] + ) if return_df: - df = pd.DataFrame({ - "doc_id": [int(key.split(DOC_SPAN_DELIMITER)[0]) for key in per_span_anno_iia_pct.keys()], - "span_start": [int(key.split(DOC_SPAN_DELIMITER)[1]) for key in per_span_anno_iia_pct.keys()], - "span_end": [int(key.split(DOC_SPAN_DELIMITER)[2]) for key in per_span_anno_iia_pct.keys()], - "iaa_percentage": per_span_anno_iia_pct.values(), - "cohens_kappa": per_span_anno_cohens_kappa.values(), - "iaa_percentage_meta": per_doc_metaanno_iia_pct.values(), - "cohens_kappa_meta": per_doc_metaanno_cohens_kappa.values() - }).sort_values(["doc_id", "span_start", "span_end"], ascending=[True, True, True]) + df = pd.DataFrame( + { + "doc_id": [ + int(key.split(DOC_SPAN_DELIMITER)[0]) for key in per_span_anno_iia_pct.keys() + ], + "span_start": [ + int(key.split(DOC_SPAN_DELIMITER)[1]) for key in per_span_anno_iia_pct.keys() + ], + "span_end": [ + int(key.split(DOC_SPAN_DELIMITER)[2]) for key in per_span_anno_iia_pct.keys() + ], + "iaa_percentage": per_span_anno_iia_pct.values(), + "cohens_kappa": per_span_anno_cohens_kappa.values(), + "iaa_percentage_meta": per_doc_metaanno_iia_pct.values(), + "cohens_kappa_meta": per_doc_metaanno_cohens_kappa.values(), + } + ).sort_values(["doc_id", "span_start", "span_end"], ascending=[True, True, True]) return df.fillna("NaN") else: - return per_span_anno_iia_pct, per_span_anno_cohens_kappa, per_doc_metaanno_iia_pct, per_doc_metaanno_cohens_kappa + return ( + per_span_anno_iia_pct, + per_span_anno_cohens_kappa, + per_doc_metaanno_iia_pct, + per_doc_metaanno_cohens_kappa, + ) -def _extract_project_pair(export_file: Union[str, TextIO], - project_id: int, - another_project_id: int) -> Tuple[Dict, Dict]: +def _extract_project_pair( + export_file: Union[str, TextIO], project_id: int, another_project_id: int +) -> Tuple[Dict, Dict]: if isinstance(export_file, str): with open(export_file, "r") as file: data = json.load(file) @@ -405,7 +611,8 @@ def _extract_project_pair(export_file: Union[str, TextIO], def _get_docspan_key(document: Dict, annotation: Dict) -> str: - return f"{document['id']}{DOC_SPAN_DELIMITER}{annotation.get('start')}{DOC_SPAN_DELIMITER}{annotation.get('end')}" + start, end = annotation.get("start"), annotation.get("end") + return f"{document['id']}{DOC_SPAN_DELIMITER}{start}{DOC_SPAN_DELIMITER}{end}" def _filter_common_docs(projects: List[Dict]) -> List[Dict]: @@ -425,13 +632,21 @@ def _filter_docspan_by_value(docspan2value: Dict, value: str) -> Dict: def _get_hashed_annotation_state(annotation: Dict, state_keys: Set[str]) -> str: - return hashlib.sha1("_".join([str(annotation.get(key)) for key in state_keys]).encode("utf-8")).hexdigest() + return hashlib.sha1( + "_".join([str(annotation.get(key)) for key in state_keys]).encode("utf-8") + ).hexdigest() def _get_hashed_meta_annotation_state(meta_anno: Dict) -> str: - meta_anno = {key: val for key, val in sorted(meta_anno.items(), key=lambda item: item[0])} # may not be necessary + meta_anno = { + key: val for key, val in sorted(meta_anno.items(), key=lambda item: item[0]) + } # may not be necessary return hashlib.sha1(str(meta_anno).encode("utf=8")).hexdigest() def _get_cohens_kappa_coefficient(y1_labels: List, y2_labels: List) -> float: - return cohen_kappa_score(y1_labels, y2_labels) if len(set(y1_labels).union(set(y2_labels))) != 1 else 1.0 + return ( + cohen_kappa_score(y1_labels, y2_labels) + if len(set(y1_labels).union(set(y2_labels))) != 1 + else 1.0 + ) diff --git a/app/registry.py b/app/registry.py index e5e68bb..e82addf 100644 --- a/app/registry.py +++ b/app/registry.py @@ -1,10 +1,11 @@ from domain import ModelType -from model_services.trf_model_deid import TransformersModelDeIdentification + +from model_services.huggingface_ner_model import HuggingFaceNerModel +from model_services.medcat_model_deid import MedCATModelDeIdentification +from model_services.medcat_model_icd10 import MedCATModelIcd10 from model_services.medcat_model_snomed import MedCATModelSnomed from model_services.medcat_model_umls import MedCATModelUmls -from model_services.medcat_model_icd10 import MedCATModelIcd10 -from model_services.medcat_model_deid import MedCATModelDeIdentification -from model_services.huggingface_ner_model import HuggingFaceNerModel +from model_services.trf_model_deid import TransformersModelDeIdentification model_service_registry = { ModelType.MEDCAT_SNOMED.value: MedCATModelSnomed, diff --git a/app/trainers/base.py b/app/trainers/base.py index 2cc22fe..9264e3f 100644 --- a/app/trainers/base.py +++ b/app/trainers/base.py @@ -1,26 +1,27 @@ import asyncio -import threading -import shutil -import os import logging +import os +import shutil import tempfile -import datasets - +import threading from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor from functools import partial -from typing import TextIO, Callable, Dict, Optional, Any, List, Union, final +from typing import Any, Callable, Dict, List, Optional, TextIO, Union, final + +import datasets + from config import Settings -from management.tracker_client import TrackerClient -from data import doc_dataset, anno_dataset from domain import TrainingType +from data import anno_dataset, doc_dataset +from management.tracker_client import TrackerClient + logger = logging.getLogger("cms") logging.getLogger("asyncio").setLevel(logging.ERROR) class TrainerCommon(object): - def __init__(self, config: Settings, model_name: str) -> None: self._config = config self._model_name = model_name @@ -38,17 +39,19 @@ def model_name(self, model_name: str) -> None: self._model_name = model_name @final - def start_training(self, - run: Callable, - training_type: str, - training_params: Dict, - data_file: Union[TextIO, tempfile.TemporaryDirectory], - log_frequency: int, - training_id: str, - input_file_name: str, - raw_data_files: Optional[List[TextIO]] = None, - description: Optional[str] = None, - synchronised: bool = False) -> bool: + def start_training( + self, + run: Callable, + training_type: str, + training_params: Dict, + data_file: Union[TextIO, tempfile.TemporaryDirectory], + log_frequency: int, + training_id: str, + input_file_name: str, + raw_data_files: Optional[List[TextIO]] = None, + description: Optional[str] = None, + synchronised: bool = False, + ) -> bool: with self._training_lock: if self._training_in_progress: return False @@ -73,27 +76,33 @@ def start_training(self, self._tracker_client.save_processed_artifact(data_file.name, self._model_name) dataset = None - if training_type == TrainingType.UNSUPERVISED.value and isinstance(data_file, tempfile.TemporaryDirectory): + if training_type == TrainingType.UNSUPERVISED.value and isinstance( + data_file, tempfile.TemporaryDirectory + ): dataset = datasets.load_from_disk(data_file.name) self._tracker_client.save_train_dataset(dataset) elif training_type == TrainingType.UNSUPERVISED.value: try: - dataset = datasets.load_dataset(doc_dataset.__file__, - data_files={"documents": data_file.name}, - split="train", - cache_dir=self._config.TRAINING_CACHE_DIR, - trust_remote_code=True) + dataset = datasets.load_dataset( + doc_dataset.__file__, + data_files={"documents": data_file.name}, + split="train", + cache_dir=self._config.TRAINING_CACHE_DIR, + trust_remote_code=True, + ) self._tracker_client.save_train_dataset(dataset) finally: if dataset is not None: dataset.cleanup_cache_files() elif training_type == TrainingType.SUPERVISED.value: try: - dataset = datasets.load_dataset(anno_dataset.__file__, - data_files={"annotations": data_file.name}, - split="train", - cache_dir=self._config.TRAINING_CACHE_DIR, - trust_remote_code=True) + dataset = datasets.load_dataset( + anno_dataset.__file__, + data_files={"annotations": data_file.name}, + split="train", + cache_dir=self._config.TRAINING_CACHE_DIR, + trust_remote_code=True, + ) self._tracker_client.save_train_dataset(dataset) finally: if dataset is not None: @@ -101,10 +110,24 @@ def start_training(self, else: raise ValueError(f"Unknown training type: {training_type}") - logger.info("Starting training job: %s with experiment ID: %s", training_id, experiment_id) + logger.info( + "Starting training job: %s with experiment ID: %s", training_id, experiment_id + ) self._training_in_progress = True - training_task = asyncio.ensure_future(loop.run_in_executor(self._executor, - partial(run, self, training_params, data_file, log_frequency, run_id, description))) + training_task = asyncio.ensure_future( + loop.run_in_executor( + self._executor, + partial( + run, + self, + training_params, + data_file, + log_frequency, + run_id, + description, + ), + ) + ) if synchronised: loop.run_until_complete(training_task) @@ -148,85 +171,95 @@ def _clean_up_training_cache(self) -> None: class SupervisedTrainer(ABC, TrainerCommon): - def __init__(self, config: Settings, model_name: str) -> None: super().__init__(config, model_name) - def train(self, - data_file: TextIO, - epochs: int, - log_frequency: int, - training_id: str, - input_file_name: str, - raw_data_files: Optional[List[TextIO]] = None, - description: Optional[str] = None, - synchronised: bool = False, - **hyperparams: Dict[str, Any]) -> bool: + def train( + self, + data_file: TextIO, + epochs: int, + log_frequency: int, + training_id: str, + input_file_name: str, + raw_data_files: Optional[List[TextIO]] = None, + description: Optional[str] = None, + synchronised: bool = False, + **hyperparams: Dict[str, Any], + ) -> bool: training_type = TrainingType.SUPERVISED.value training_params = { "data_path": data_file.name, "nepochs": epochs, **hyperparams, } - return self.start_training(run=self.run, - training_type=training_type, - training_params=training_params, - data_file=data_file, - log_frequency=log_frequency, - training_id=training_id, - input_file_name=input_file_name, - raw_data_files=raw_data_files, - description=description, - synchronised=synchronised) + return self.start_training( + run=self.run, + training_type=training_type, + training_params=training_params, + data_file=data_file, + log_frequency=log_frequency, + training_id=training_id, + input_file_name=input_file_name, + raw_data_files=raw_data_files, + description=description, + synchronised=synchronised, + ) @staticmethod @abstractmethod - def run(trainer: "SupervisedTrainer", - training_params: Dict, - data_file: TextIO, - log_frequency: int, - run_id: str, - description: Optional[str] = None) -> None: + def run( + trainer: "SupervisedTrainer", + training_params: Dict, + data_file: TextIO, + log_frequency: int, + run_id: str, + description: Optional[str] = None, + ) -> None: raise NotImplementedError class UnsupervisedTrainer(ABC, TrainerCommon): - def __init__(self, config: Settings, model_name: str) -> None: super().__init__(config, model_name) - def train(self, - data_file: TextIO, - epochs: int, - log_frequency: int, - training_id: str, - input_file_name: str, - raw_data_files: Optional[List[TextIO]] = None, - description: Optional[str] = None, - synchronised: bool = False, - **hyperparams: Dict[str, Any]) -> bool: + def train( + self, + data_file: TextIO, + epochs: int, + log_frequency: int, + training_id: str, + input_file_name: str, + raw_data_files: Optional[List[TextIO]] = None, + description: Optional[str] = None, + synchronised: bool = False, + **hyperparams: Dict[str, Any], + ) -> bool: training_type = TrainingType.UNSUPERVISED.value training_params = { "nepochs": epochs, **hyperparams, } - return self.start_training(run=self.run, - training_type=training_type, - training_params=training_params, - data_file=data_file, - log_frequency=log_frequency, - training_id=training_id, - input_file_name=input_file_name, - raw_data_files=raw_data_files, - description=description, - synchronised=synchronised) + return self.start_training( + run=self.run, + training_type=training_type, + training_params=training_params, + data_file=data_file, + log_frequency=log_frequency, + training_id=training_id, + input_file_name=input_file_name, + raw_data_files=raw_data_files, + description=description, + synchronised=synchronised, + ) @staticmethod @abstractmethod - def run(trainer: "UnsupervisedTrainer", - training_params: Dict, - data_file: TextIO, - log_frequency: int, - run_id: str, - description: Optional[str] = None) -> None: + def run( + trainer: "UnsupervisedTrainer", + training_params: Dict, + data_file: TextIO, + log_frequency: int, + run_id: str, + description: Optional[str] = None, + ) -> None: raise NotImplementedError diff --git a/app/trainers/huggingface_ner_trainer.py b/app/trainers/huggingface_ner_trainer.py index 6aa8e6e..60dc1c8 100644 --- a/app/trainers/huggingface_ner_trainer.py +++ b/app/trainers/huggingface_ner_trainer.py @@ -1,69 +1,75 @@ -import os -import logging -import torch import gc import json -import shutil -import datasets +import logging +import os import random +import shutil import tempfile +from functools import partial +from typing import Any, Dict, Iterable, List, Optional, TextIO, Tuple, Union, final + +import datasets import numpy as np import pandas as pd -from functools import partial -from typing import final, Dict, TextIO, Optional, Any, List, Iterable, Tuple, Union -from torch import nn -from sklearn.metrics import precision_recall_fscore_support, accuracy_score +import torch +from evaluate.visualization import radar_plot from scipy.special import softmax -from transformers import __version__ as transformers_version +from sklearn.metrics import accuracy_score, precision_recall_fscore_support +from torch import nn from transformers import ( AutoModelForMaskedLM, DataCollatorForLanguageModeling, - TrainingArguments, - Trainer, + EvalPrediction, PreTrainedModel, PreTrainedTokenizerBase, PreTrainedTokenizerFast, + Trainer, TrainerCallback, - TrainerState, TrainerControl, - EvalPrediction, + TrainerState, + TrainingArguments, ) -from evaluate.visualization import radar_plot +from transformers import __version__ as transformers_version + +from domain import DatasetSplit, Device, HfTransformerBackbone, ModelType +from exception import AnnotationException +from utils import filter_by_concept_ids, non_default_device_is_available, reset_random_seed + from management.model_manager import ModelManager from management.tracker_client import TrackerClient from model_services.base import AbstractModelService -from processors.metrics_collector import get_stats_from_trainer_export, sanity_check_model_with_trainer_export -from utils import filter_by_concept_ids, reset_random_seed, non_default_device_is_available -from trainers.base import UnsupervisedTrainer, SupervisedTrainer -from domain import ModelType, DatasetSplit, HfTransformerBackbone, Device -from exception import AnnotationException - +from processors.metrics_collector import ( + get_stats_from_trainer_export, + sanity_check_model_with_trainer_export, +) +from trainers.base import SupervisedTrainer, UnsupervisedTrainer logger = logging.getLogger("cms") @final class HuggingFaceNerUnsupervisedTrainer(UnsupervisedTrainer): - def __init__(self, model_service: AbstractModelService) -> None: UnsupervisedTrainer.__init__(self, model_service._config, model_service.model_name) self._model_service = model_service self._model_name = model_service.model_name self._model_pack_path = model_service._model_pack_path - self._retrained_models_dir = os.path.join(model_service._model_parent_dir, "retrained", - self._model_name.replace(" ", "_")) + self._retrained_models_dir = os.path.join( + model_service._model_parent_dir, "retrained", self._model_name.replace(" ", "_") + ) self._model_manager = ModelManager(type(model_service), model_service._config) self._max_length = model_service.model.config.max_position_embeddings os.makedirs(self._retrained_models_dir, exist_ok=True) - @staticmethod - def run(trainer: "HuggingFaceNerUnsupervisedTrainer", - training_params: Dict, - data_file: Union[TextIO, tempfile.TemporaryDirectory], - log_frequency: int, - run_id: str, - description: Optional[str] = None) -> None: + def run( + trainer: "HuggingFaceNerUnsupervisedTrainer", + training_params: Dict, + data_file: Union[TextIO, tempfile.TemporaryDirectory], + log_frequency: int, + run_id: str, + description: Optional[str] = None, + ) -> None: copied_model_pack_path = None train_dataset = None eval_dataset = None @@ -81,7 +87,9 @@ def run(trainer: "HuggingFaceNerUnsupervisedTrainer", if non_default_device_is_available(trainer._config.DEVICE): mlm_model.to(trainer._config.DEVICE) - test_size = 0.2 if training_params.get("test_size") is None else training_params["test_size"] + test_size = ( + 0.2 if training_params.get("test_size") is None else training_params["test_size"] + ) if isinstance(data_file, tempfile.TemporaryDirectory): raw_dataset = datasets.load_from_disk(data_file.name) if DatasetSplit.VALIDATION.value in raw_dataset.keys(): @@ -93,37 +101,57 @@ def run(trainer: "HuggingFaceNerUnsupervisedTrainer", else: lines = raw_dataset[DatasetSplit.TRAIN.value]["text"] random.shuffle(lines) - train_texts = [line.strip() for line in lines[:int(len(lines) * (1 - test_size))]] - eval_texts = [line.strip() for line in lines[int(len(lines) * (1 - test_size)):]] + train_texts = [ + line.strip() for line in lines[: int(len(lines) * (1 - test_size))] + ] + eval_texts = [ + line.strip() for line in lines[int(len(lines) * (1 - test_size)) :] + ] else: with open(data_file.name, "r") as f: lines = json.load(f) random.shuffle(lines) - train_texts = [line.strip() for line in lines[:int(len(lines) * (1-test_size))]] - eval_texts = [line.strip() for line in lines[int(len(lines) * (1-test_size)):]] - - dataset_features = datasets.Features({ - "input_ids": datasets.Sequence(datasets.Value("int32")), - "attention_mask": datasets.Sequence(datasets.Value("int32")), - "special_tokens_mask": datasets.Sequence(datasets.Value("int32")), - "token_type_ids": datasets.Sequence(datasets.Value("int32")) - }) + train_texts = [ + line.strip() for line in lines[: int(len(lines) * (1 - test_size))] + ] + eval_texts = [ + line.strip() for line in lines[int(len(lines) * (1 - test_size)) :] + ] + + dataset_features = datasets.Features( + { + "input_ids": datasets.Sequence(datasets.Value("int32")), + "attention_mask": datasets.Sequence(datasets.Value("int32")), + "special_tokens_mask": datasets.Sequence(datasets.Value("int32")), + "token_type_ids": datasets.Sequence(datasets.Value("int32")), + } + ) train_dataset = datasets.Dataset.from_generator( trainer._tokenize_and_chunk, features=dataset_features, - gen_kwargs={"texts": train_texts, "tokenizer": tokenizer, "max_length": trainer._max_length}, - cache_dir=trainer._model_service._config.TRAINING_CACHE_DIR + gen_kwargs={ + "texts": train_texts, + "tokenizer": tokenizer, + "max_length": trainer._max_length, + }, + cache_dir=trainer._model_service._config.TRAINING_CACHE_DIR, ) eval_dataset = datasets.Dataset.from_generator( trainer._tokenize_and_chunk, features=dataset_features, - gen_kwargs={"texts": eval_texts, "tokenizer": tokenizer, "max_length": trainer._max_length}, - cache_dir = trainer._model_service._config.TRAINING_CACHE_DIR + gen_kwargs={ + "texts": eval_texts, + "tokenizer": tokenizer, + "max_length": trainer._max_length, + }, + cache_dir=trainer._model_service._config.TRAINING_CACHE_DIR, ) train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"]) eval_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"]) - data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.2) + data_collator = DataCollatorForLanguageModeling( + tokenizer=tokenizer, mlm=True, mlm_probability=0.2 + ) training_args = TrainingArguments( output_dir=results_path, @@ -139,7 +167,9 @@ def run(trainer: "HuggingFaceNerUnsupervisedTrainer", save_steps=1000, load_best_model_at_end=True, save_total_limit=3, - use_cpu=trainer._config.DEVICE.lower() == Device.CPU.value if non_default_device_is_available(trainer._config.DEVICE) else False, + use_cpu=trainer._config.DEVICE.lower() == Device.CPU.value + if non_default_device_is_available(trainer._config.DEVICE) + else False, ) if training_params.get("lr_override") is not None: @@ -151,7 +181,7 @@ def run(trainer: "HuggingFaceNerUnsupervisedTrainer", data_collator=data_collator, train_dataset=train_dataset, eval_dataset=eval_dataset, - callbacks=[MLflowLoggingCallback(trainer._tracker_client)] + callbacks=[MLflowLoggingCallback(trainer._tracker_client)], ) trainer._tracker_client.log_model_config(model.config.to_dict()) @@ -161,10 +191,21 @@ def run(trainer: "HuggingFaceNerUnsupervisedTrainer", model = trainer._get_final_model(model, mlm_model) if not skip_save_model: - retrained_model_pack_path = os.path.join(trainer._retrained_models_dir, f"{ModelType.HUGGINGFACE_NER.value}_{run_id}.zip") - model.save_pretrained(copied_model_directory, safe_serialization=(trainer._config.TRAINING_SAFE_MODEL_SERIALISATION == "true")) - shutil.make_archive(retrained_model_pack_path.replace(".zip", ""), "zip", copied_model_directory) - model_uri = trainer._tracker_client.save_model(retrained_model_pack_path, trainer._model_name, trainer._model_manager) + retrained_model_pack_path = os.path.join( + trainer._retrained_models_dir, f"{ModelType.HUGGINGFACE_NER.value}_{run_id}.zip" + ) + model.save_pretrained( + copied_model_directory, + safe_serialization=( + trainer._config.TRAINING_SAFE_MODEL_SERIALISATION == "true" + ), + ) + shutil.make_archive( + retrained_model_pack_path.replace(".zip", ""), "zip", copied_model_directory + ) + model_uri = trainer._tracker_client.save_model( + retrained_model_pack_path, trainer._model_name, trainer._model_manager + ) logger.info(f"Retrained model saved: {model_uri}") else: logger.info("Skipped saving on the retrained model") @@ -197,9 +238,11 @@ def run(trainer: "HuggingFaceNerUnsupervisedTrainer", trainer._housekeep_file(copied_model_pack_path) @staticmethod - def deploy_model(model_service: AbstractModelService, - model: PreTrainedModel, - tokenizer: PreTrainedTokenizerBase) -> None: + def deploy_model( + model_service: AbstractModelService, + model: PreTrainedModel, + tokenizer: PreTrainedTokenizerBase, + ) -> None: del model_service.model del model_service.tokenizer gc.collect() @@ -233,18 +276,23 @@ def _get_final_model(model: PreTrainedModel, mlm_model: PreTrainedModel) -> PreT return model @staticmethod - def _tokenize_and_chunk(texts: Iterable[str], tokenizer: PreTrainedTokenizerBase, max_length: int) -> Iterable[Dict[str, Any]]: + def _tokenize_and_chunk( + texts: Iterable[str], tokenizer: PreTrainedTokenizerBase, max_length: int + ) -> Iterable[Dict[str, Any]]: for text in texts: encoded = tokenizer(text, truncation=False, return_special_tokens_mask=True) for i in range(0, len(encoded["input_ids"]), max_length): - chunked_input_ids = encoded["input_ids"][i:i + max_length] + chunked_input_ids = encoded["input_ids"][i : i + max_length] padding_length = max(0, max_length - len(chunked_input_ids)) chunked_input_ids += [tokenizer.pad_token_id] * padding_length - chunked_attention_mask = encoded["attention_mask"][i:i + max_length] + [0] * padding_length - chunked_special_tokens = tokenizer.get_special_tokens_mask(chunked_input_ids, - already_has_special_tokens=True) + chunked_attention_mask = ( + encoded["attention_mask"][i : i + max_length] + [0] * padding_length + ) + chunked_special_tokens = tokenizer.get_special_tokens_mask( + chunked_input_ids, already_has_special_tokens=True + ) token_type_ids = [0] * len(chunked_input_ids) yield { @@ -257,7 +305,6 @@ def _tokenize_and_chunk(texts: Iterable[str], tokenizer: PreTrainedTokenizerBase @final class HuggingFaceNerSupervisedTrainer(SupervisedTrainer): - MIN_EXAMPLE_COUNT_FOR_TRAINABLE_CONCEPT = 5 MAX_CONCEPTS_TO_TRACK = 20 PAD_LABEL_ID = -100 @@ -271,23 +318,42 @@ def __init__(self, model_service: AbstractModelService) -> None: self._model_service = model_service self._model_name = model_service.model_name self._model_pack_path = model_service._model_pack_path - self._retrained_models_dir = os.path.join(model_service._model_parent_dir, "retrained", - self._model_name.replace(" ", "_")) + self._retrained_models_dir = os.path.join( + model_service._model_parent_dir, "retrained", self._model_name.replace(" ", "_") + ) self._model_manager = ModelManager(type(model_service), model_service._config) self._max_length = model_service.model.config.max_position_embeddings os.makedirs(self._retrained_models_dir, exist_ok=True) class _LocalDataCollator: - def __init__(self, max_length: int, pad_token_id: int) -> None: self.max_length = max_length self.pad_token_id = pad_token_id def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: return { - "input_ids": torch.tensor([self._add_padding(f["input_ids"], self.max_length, self.pad_token_id) for f in features], dtype=torch.long), - "labels": torch.tensor([self._add_padding(f["labels"], self.max_length, HuggingFaceNerSupervisedTrainer.PAD_LABEL_ID) for f in features], dtype=torch.long), - "attention_mask": torch.tensor([self._add_padding(f["attention_mask"], self.max_length, 0) for f in features], dtype=torch.long), + "input_ids": torch.tensor( + [ + self._add_padding(f["input_ids"], self.max_length, self.pad_token_id) + for f in features + ], + dtype=torch.long, + ), + "labels": torch.tensor( + [ + self._add_padding( + f["labels"], + self.max_length, + HuggingFaceNerSupervisedTrainer.PAD_LABEL_ID, + ) + for f in features + ], + dtype=torch.long, + ), + "attention_mask": torch.tensor( + [self._add_padding(f["attention_mask"], self.max_length, 0) for f in features], + dtype=torch.long, + ), } @staticmethod @@ -297,12 +363,14 @@ def _add_padding(target: List[int], max_length: int, pad_token_id: int) -> List[ return target + paddings @staticmethod - def run(trainer: "HuggingFaceNerSupervisedTrainer", - training_params: Dict, - data_file: TextIO, - log_frequency: int, - run_id: str, - description: Optional[str] = None) -> None: + def run( + trainer: "HuggingFaceNerSupervisedTrainer", + training_params: Dict, + data_file: TextIO, + log_frequency: int, + run_id: str, + description: Optional[str] = None, + ) -> None: copied_model_pack_path = None redeploy = trainer._config.REDEPLOY_TRAINED_MODEL == "true" skip_save_model = trainer._config.SKIP_SAVE_MODEL == "true" @@ -314,54 +382,99 @@ def run(trainer: "HuggingFaceNerSupervisedTrainer", if not eval_mode: try: logger.info("Loading a new model copy for training...") - copied_model_pack_path = trainer._make_model_file_copy(trainer._model_pack_path, run_id) + copied_model_pack_path = trainer._make_model_file_copy( + trainer._model_pack_path, run_id + ) model, tokenizer = trainer._model_service.load_model(copied_model_pack_path) copied_model_directory = copied_model_pack_path.replace(".zip", "") if non_default_device_is_available(trainer._config.DEVICE): model.to(trainer._config.DEVICE) - filtered_training_data, filtered_concepts = trainer._filter_training_data_and_concepts(data_file) + filtered_training_data, filtered_concepts = ( + trainer._filter_training_data_and_concepts(data_file) + ) logger.debug(f"Filtered concepts: {filtered_concepts}") model = trainer._update_model_with_concepts(model, filtered_concepts) - test_size = 0.2 if training_params.get("test_size") is None else training_params["test_size"] + test_size = ( + 0.2 + if training_params.get("test_size") is None + else training_params["test_size"] + ) if test_size < 0: - logger.info("Using pre-defined train-validation-test split in trainer export...") + logger.info( + "Using pre-defined train-validation-test split in trainer export..." + ) if len(filtered_training_data["projects"]) < 2: - raise AnnotationException("Not enough projects in the training data to provide a train-validation-test split") + raise AnnotationException( + "Not enough projects in the training data to provide a" + " train-validation-test split" + ) train_documents = filtered_training_data["projects"][0]["documents"] random.shuffle(train_documents) eval_documents = filtered_training_data["projects"][1]["documents"] else: - documents = [document for project in filtered_training_data["projects"] for document in project["documents"]] + documents = [ + document + for project in filtered_training_data["projects"] + for document in project["documents"] + ] random.shuffle(documents) - test_size = 0.2 if training_params.get("test_size") is None else training_params["test_size"] - train_documents = [document for document in documents[:int(len(documents) * (1 - test_size))]] - eval_documents = [document for document in documents[int(len(documents) * (1 - test_size)):]] - - dataset_features = datasets.Features({ - "input_ids": datasets.Sequence(datasets.Value("int32")), - "labels": datasets.Sequence(datasets.Value("int32")), - "attention_mask": datasets.Sequence(datasets.Value("int32")), - }) + test_size = ( + 0.2 + if training_params.get("test_size") is None + else training_params["test_size"] + ) + train_documents = [ + document for document in documents[: int(len(documents) * (1 - test_size))] + ] + eval_documents = [ + document for document in documents[int(len(documents) * (1 - test_size)) :] + ] + + dataset_features = datasets.Features( + { + "input_ids": datasets.Sequence(datasets.Value("int32")), + "labels": datasets.Sequence(datasets.Value("int32")), + "attention_mask": datasets.Sequence(datasets.Value("int32")), + } + ) train_dataset = datasets.Dataset.from_generator( trainer._tokenize_and_chunk, features=dataset_features, - gen_kwargs={"documents": train_documents, "tokenizer": tokenizer, "max_length": trainer._max_length, "model": model}, - cache_dir=trainer._config.TRAINING_CACHE_DIR + gen_kwargs={ + "documents": train_documents, + "tokenizer": tokenizer, + "max_length": trainer._max_length, + "model": model, + }, + cache_dir=trainer._config.TRAINING_CACHE_DIR, ) eval_dataset = datasets.Dataset.from_generator( trainer._tokenize_and_chunk, features=dataset_features, - gen_kwargs={"documents": eval_documents, "tokenizer": tokenizer, "max_length": trainer._max_length, "model": model}, - cache_dir = trainer._config.TRAINING_CACHE_DIR + gen_kwargs={ + "documents": eval_documents, + "tokenizer": tokenizer, + "max_length": trainer._max_length, + "model": model, + }, + cache_dir=trainer._config.TRAINING_CACHE_DIR, + ) + train_dataset.set_format( + type=None, columns=["input_ids", "labels", "attention_mask"] + ) + eval_dataset.set_format( + type=None, columns=["input_ids", "labels", "attention_mask"] ) - train_dataset.set_format(type=None, columns=["input_ids", "labels", "attention_mask"]) - eval_dataset.set_format(type=None, columns=["input_ids", "labels", "attention_mask"]) - data_collator = trainer._LocalDataCollator(max_length=trainer._max_length, pad_token_id=tokenizer.pad_token_id) - training_args = trainer._get_training_args(results_path, logs_path, training_params, log_frequency) + data_collator = trainer._LocalDataCollator( + max_length=trainer._max_length, pad_token_id=tokenizer.pad_token_id + ) + training_args = trainer._get_training_args( + results_path, logs_path, training_params, log_frequency + ) if training_params.get("lr_override") is not None: training_args.learning_rate = training_params["lr_override"] @@ -371,8 +484,13 @@ def run(trainer: "HuggingFaceNerSupervisedTrainer", data_collator=data_collator, train_dataset=train_dataset, eval_dataset=eval_dataset, - compute_metrics=partial(trainer._compute_token_level_metrics, id2label=model.config.id2label, tracker_client=trainer._tracker_client, model_name=trainer._model_name), - callbacks=[MLflowLoggingCallback(trainer._tracker_client)] + compute_metrics=partial( + trainer._compute_token_level_metrics, + id2label=model.config.id2label, + tracker_client=trainer._tracker_client, + model_name=trainer._model_name, + ), + callbacks=[MLflowLoggingCallback(trainer._tracker_client)], ) trainer._tracker_client.log_model_config(model.config.to_dict()) @@ -381,19 +499,30 @@ def run(trainer: "HuggingFaceNerSupervisedTrainer", logger.info("Performing supervised training...") hf_trainer.train() - cui_counts, cui_unique_counts, cui_ignorance_counts, num_of_docs = get_stats_from_trainer_export(data_file.name) + cui_counts, cui_unique_counts, cui_ignorance_counts, num_of_docs = ( + get_stats_from_trainer_export(data_file.name) + ) trainer._tracker_client.log_document_size(num_of_docs) - trainer._save_trained_concepts(cui_counts, cui_unique_counts, cui_ignorance_counts, model) + trainer._save_trained_concepts( + cui_counts, cui_unique_counts, cui_ignorance_counts, model + ) trainer._tracker_client.log_classes_and_names(model.config.id2label) - trainer._sanity_check_model_and_save_results(data_file.name, trainer._model_service.from_model(model, tokenizer)) + trainer._sanity_check_model_and_save_results( + data_file.name, trainer._model_service.from_model(model, tokenizer) + ) if not skip_save_model: - retrained_model_pack_path = os.path.join(trainer._retrained_models_dir, - f"{ModelType.HUGGINGFACE_NER.value}_{run_id}.zip") + retrained_model_pack_path = os.path.join( + trainer._retrained_models_dir, + f"{ModelType.HUGGINGFACE_NER.value}_{run_id}.zip", + ) model.save_pretrained(copied_model_directory) - shutil.make_archive(retrained_model_pack_path.replace(".zip", ""), "zip", copied_model_directory) - model_uri = trainer._tracker_client.save_model(retrained_model_pack_path, trainer._model_name, - trainer._model_manager) + shutil.make_archive( + retrained_model_pack_path.replace(".zip", ""), "zip", copied_model_directory + ) + model_uri = trainer._tracker_client.save_model( + retrained_model_pack_path, trainer._model_name, trainer._model_manager + ) logger.info(f"Retrained model saved: {model_uri}") else: logger.info("Skipped saving on the retrained model") @@ -419,38 +548,65 @@ def run(trainer: "HuggingFaceNerSupervisedTrainer", else: try: logger.info("Evaluating the running model...") - trainer._tracker_client.log_model_config(trainer._model_service._model.config.to_dict()) + trainer._tracker_client.log_model_config( + trainer._model_service._model.config.to_dict() + ) trainer._tracker_client.log_trainer_version(transformers_version) with open(data_file.name, "r") as f: eval_data = json.load(f) - eval_documents = [document for project in eval_data["projects"] for document in project["documents"]] - dataset_features = datasets.Features({ - "input_ids": datasets.Sequence(datasets.Value("int32")), - "labels": datasets.Sequence(datasets.Value("int32")), - "attention_mask": datasets.Sequence(datasets.Value("int32")), - }) + eval_documents = [ + document + for project in eval_data["projects"] + for document in project["documents"] + ] + dataset_features = datasets.Features( + { + "input_ids": datasets.Sequence(datasets.Value("int32")), + "labels": datasets.Sequence(datasets.Value("int32")), + "attention_mask": datasets.Sequence(datasets.Value("int32")), + } + ) eval_dataset = datasets.Dataset.from_generator( trainer._tokenize_and_chunk, features=dataset_features, - gen_kwargs={"documents": eval_documents, "tokenizer": trainer._model_service.tokenizer, "max_length": trainer._max_length, "model": trainer._model_service._model}, - cache_dir=trainer._config.TRAINING_CACHE_DIR + gen_kwargs={ + "documents": eval_documents, + "tokenizer": trainer._model_service.tokenizer, + "max_length": trainer._max_length, + "model": trainer._model_service._model, + }, + cache_dir=trainer._config.TRAINING_CACHE_DIR, + ) + eval_dataset.set_format( + type=None, columns=["input_ids", "labels", "attention_mask"] + ) + data_collator = trainer._LocalDataCollator( + max_length=trainer._max_length, + pad_token_id=trainer._model_service.tokenizer.pad_token_id, + ) + training_args = trainer._get_training_args( + results_path, logs_path, training_params, log_frequency ) - eval_dataset.set_format(type=None, columns=["input_ids", "labels", "attention_mask"]) - data_collator = trainer._LocalDataCollator(max_length=trainer._max_length, pad_token_id=trainer._model_service.tokenizer.pad_token_id) - training_args = trainer._get_training_args(results_path, logs_path, training_params, log_frequency) hf_trainer = Trainer( model=trainer._model_service.model, args=training_args, data_collator=data_collator, train_dataset=None, eval_dataset=None, - compute_metrics=partial(trainer._compute_token_level_metrics, id2label=trainer._model_service.model.config.id2label, tracker_client=trainer._tracker_client, model_name=trainer._model_name), + compute_metrics=partial( + trainer._compute_token_level_metrics, + id2label=trainer._model_service.model.config.id2label, + tracker_client=trainer._tracker_client, + model_name=trainer._model_name, + ), tokenizer=None, ) eval_metrics = hf_trainer.evaluate(eval_dataset) logger.debug("Evaluation metrics: %s", eval_metrics) trainer._tracker_client.send_hf_metrics_logs(eval_metrics, 0) - cui_counts, cui_unique_counts, cui_ignorance_counts, num_of_docs = get_stats_from_trainer_export(data_file.name) + cui_counts, cui_unique_counts, cui_ignorance_counts, num_of_docs = ( + get_stats_from_trainer_export(data_file.name) + ) trainer._tracker_client.log_document_size(num_of_docs) trainer._sanity_check_model_and_save_results(data_file.name, trainer._model_service) trainer._tracker_client.end_with_success() @@ -469,47 +625,76 @@ def _filter_training_data_and_concepts(data_file: TextIO) -> Tuple[Dict, List]: with open(data_file.name, "r") as f: training_data = json.load(f) te_stats_df = get_stats_from_trainer_export(training_data, return_df=True) - rear_concept_ids = te_stats_df[(te_stats_df["anno_count"] - te_stats_df["anno_ignorance_counts"]) < HuggingFaceNerSupervisedTrainer.MIN_EXAMPLE_COUNT_FOR_TRAINABLE_CONCEPT]["concept"].unique() - logger.debug(f"The following concept(s) will be excluded due to the low example count(s): {rear_concept_ids}") - filtered_training_data = filter_by_concept_ids(training_data, ModelType.HUGGINGFACE_NER, extra_excluded=rear_concept_ids) - filtered_concepts = get_stats_from_trainer_export(filtered_training_data, return_df=True)["concept"].unique() + rear_concept_ids = te_stats_df[ + (te_stats_df["anno_count"] - te_stats_df["anno_ignorance_counts"]) + < HuggingFaceNerSupervisedTrainer.MIN_EXAMPLE_COUNT_FOR_TRAINABLE_CONCEPT + ]["concept"].unique() + logger.debug( + "The following concept(s) will be excluded due to the low example count(s): %s", + rear_concept_ids, + ) + filtered_training_data = filter_by_concept_ids( + training_data, ModelType.HUGGINGFACE_NER, extra_excluded=rear_concept_ids + ) + filtered_concepts = get_stats_from_trainer_export( + filtered_training_data, return_df=True + )["concept"].unique() return filtered_training_data, filtered_concepts @staticmethod def _update_model_with_concepts(model: PreTrainedModel, concepts: List[str]) -> PreTrainedModel: if model.config.label2id == {"LABEL_0": 0, "LABEL_1": 1}: logger.debug("Cannot find existing labels and IDs, creating new ones...") - model.config.label2id = {"O": HuggingFaceNerSupervisedTrainer.DEFAULT_LABEL_ID, "X": HuggingFaceNerSupervisedTrainer.CONTINUING_TOKEN_LABEL_ID} - model.config.id2label = {HuggingFaceNerSupervisedTrainer.DEFAULT_LABEL_ID: "O", HuggingFaceNerSupervisedTrainer.CONTINUING_TOKEN_LABEL_ID: "X"} + model.config.label2id = { + "O": HuggingFaceNerSupervisedTrainer.DEFAULT_LABEL_ID, + "X": HuggingFaceNerSupervisedTrainer.CONTINUING_TOKEN_LABEL_ID, + } + model.config.id2label = { + HuggingFaceNerSupervisedTrainer.DEFAULT_LABEL_ID: "O", + HuggingFaceNerSupervisedTrainer.CONTINUING_TOKEN_LABEL_ID: "X", + } avg_weight = torch.mean(model.classifier.weight, dim=0, keepdim=True) avg_bias = torch.mean(model.classifier.bias, dim=0, keepdim=True) for concept in concepts: if concept not in model.config.label2id.keys(): model.config.label2id[concept] = len(model.config.label2id) model.config.id2label[len(model.config.id2label)] = concept - model.classifier.weight = nn.Parameter(torch.cat((model.classifier.weight, avg_weight), 0)) - model.classifier.bias = nn.Parameter(torch.cat((model.classifier.bias, avg_bias), 0)) + model.classifier.weight = nn.Parameter( + torch.cat((model.classifier.weight, avg_weight), 0) + ) + model.classifier.bias = nn.Parameter( + torch.cat((model.classifier.bias, avg_bias), 0) + ) model.classifier.out_features += 1 model.num_labels += 1 return model @staticmethod - def _tokenize_and_chunk(documents: List[Dict], tokenizer: PreTrainedTokenizerBase, max_length: int, model: PreTrainedModel) -> Iterable[Dict[str, Any]]: + def _tokenize_and_chunk( + documents: List[Dict], + tokenizer: PreTrainedTokenizerBase, + max_length: int, + model: PreTrainedModel, + ) -> Iterable[Dict[str, Any]]: for document in documents: encoded = tokenizer(document["text"], truncation=False, return_offsets_mapping=True) - document["annotations"] = sorted(document["annotations"], key=lambda annotation: annotation["start"]) + document["annotations"] = sorted( + document["annotations"], key=lambda annotation: annotation["start"] + ) for i in range(0, len(encoded["input_ids"]), max_length): - chunked_input_ids = encoded["input_ids"][i:i + max_length] - chunked_offsets_mapping = encoded["offset_mapping"][i:i + max_length] + chunked_input_ids = encoded["input_ids"][i : i + max_length] + chunked_offsets_mapping = encoded["offset_mapping"][i : i + max_length] chunked_labels = [0] * len(chunked_input_ids) for annotation in document["annotations"]: start = annotation["start"] end = annotation["end"] - label_id = model.config.label2id.get(annotation["cui"], HuggingFaceNerSupervisedTrainer.DEFAULT_LABEL_ID) + label_id = model.config.label2id.get( + annotation["cui"], HuggingFaceNerSupervisedTrainer.DEFAULT_LABEL_ID + ) for idx, offset_mapping in enumerate(chunked_offsets_mapping): if start <= offset_mapping[0] and offset_mapping[1] <= end: chunked_labels[idx] = label_id - chunked_attention_mask = encoded["attention_mask"][i:i + max_length] + chunked_attention_mask = encoded["attention_mask"][i : i + max_length] yield { "input_ids": chunked_input_ids, @@ -518,15 +703,31 @@ def _tokenize_and_chunk(documents: List[Dict], tokenizer: PreTrainedTokenizerBas } @staticmethod - def _compute_token_level_metrics(eval_pred: EvalPrediction, id2label: Dict[int, str], tracker_client: TrackerClient, model_name: str) -> Dict[str, Any]: + def _compute_token_level_metrics( + eval_pred: EvalPrediction, + id2label: Dict[int, str], + tracker_client: TrackerClient, + model_name: str, + ) -> Dict[str, Any]: predictions = np.argmax(softmax(eval_pred.predictions, axis=2), axis=2) label_ids = eval_pred.label_ids non_padding_indices = np.where(label_ids != HuggingFaceNerSupervisedTrainer.PAD_LABEL_ID) non_padding_predictions = predictions[non_padding_indices].flatten() non_padding_label_ids = label_ids[non_padding_indices].flatten() labels = list(id2label.values()) - precision, recall, f1, support = precision_recall_fscore_support(non_padding_label_ids, non_padding_predictions, labels=list(id2label.keys()), average=None) - filtered_predictions, filtered_label_ids = zip(*[(a, b) for a, b in zip(non_padding_predictions, non_padding_label_ids) if not (a == b == HuggingFaceNerSupervisedTrainer.DEFAULT_LABEL_ID)]) + precision, recall, f1, support = precision_recall_fscore_support( + non_padding_label_ids, + non_padding_predictions, + labels=list(id2label.keys()), + average=None, + ) + filtered_predictions, filtered_label_ids = zip( + *[ + (a, b) + for a, b in zip(non_padding_predictions, non_padding_label_ids) + if not (a == b == HuggingFaceNerSupervisedTrainer.DEFAULT_LABEL_ID) + ] + ) accuracy = accuracy_score(filtered_label_ids, filtered_predictions) metrics = { "accuracy": accuracy, @@ -539,11 +740,15 @@ def _compute_token_level_metrics(eval_pred: EvalPrediction, id2label: Dict[int, aggregated_metrics = [] # limit the number of labels to avoid excessive metrics logging - for idx, (label, precision, recall, f1, support) in enumerate(zip(labels[2:HuggingFaceNerSupervisedTrainer.MAX_CONCEPTS_TO_TRACK+2], - precision[2:HuggingFaceNerSupervisedTrainer.MAX_CONCEPTS_TO_TRACK+2], - recall[2:HuggingFaceNerSupervisedTrainer.MAX_CONCEPTS_TO_TRACK+2], - f1[2:HuggingFaceNerSupervisedTrainer.MAX_CONCEPTS_TO_TRACK+2], - support[2:HuggingFaceNerSupervisedTrainer.MAX_CONCEPTS_TO_TRACK+2])): + for idx, (label, precision, recall, f1, support) in enumerate( + zip( + labels[2 : HuggingFaceNerSupervisedTrainer.MAX_CONCEPTS_TO_TRACK + 2], + precision[2 : HuggingFaceNerSupervisedTrainer.MAX_CONCEPTS_TO_TRACK + 2], + recall[2 : HuggingFaceNerSupervisedTrainer.MAX_CONCEPTS_TO_TRACK + 2], + f1[2 : HuggingFaceNerSupervisedTrainer.MAX_CONCEPTS_TO_TRACK + 2], + support[2 : HuggingFaceNerSupervisedTrainer.MAX_CONCEPTS_TO_TRACK + 2], + ) + ): if support == 0: # the concept has no true labels continue metrics[f"{label}/precision"] = precision if precision is not None else 0.0 @@ -552,21 +757,24 @@ def _compute_token_level_metrics(eval_pred: EvalPrediction, id2label: Dict[int, metrics[f"{label}/support"] = support if support is not None else 0.0 aggregated_labels.append(label) - aggregated_metrics.append({ - "per_concept_p": metrics[f"{label}/precision"], - "per_concept_r": metrics[f"{label}/recall"], - "per_concept_f1": metrics[f"{label}/f1"], - }) + aggregated_metrics.append( + { + "per_concept_p": metrics[f"{label}/precision"], + "per_concept_r": metrics[f"{label}/recall"], + "per_concept_f1": metrics[f"{label}/f1"], + } + ) - HuggingFaceNerSupervisedTrainer._save_metrics_plot(aggregated_metrics, aggregated_labels, tracker_client, model_name) + HuggingFaceNerSupervisedTrainer._save_metrics_plot( + aggregated_metrics, aggregated_labels, tracker_client, model_name + ) logger.debug("Evaluation metrics: %s", metrics) return metrics @staticmethod - def _save_metrics_plot(metrics: List[Dict], - concepts: List[str], - tracker_client: TrackerClient, - model_name: str) -> None: + def _save_metrics_plot( + metrics: List[Dict], concepts: List[str], tracker_client: TrackerClient, model_name: str + ) -> None: try: plot = radar_plot(data=metrics, model_names=concepts) with tempfile.TemporaryDirectory() as d: @@ -578,7 +786,9 @@ def _save_metrics_plot(metrics: List[Dict], logger.error("Error occurred while plotting the metrics") logger.exception(e) - def _get_training_args(self, results_path: str, logs_path: str, training_params: Dict, log_frequency: int) -> TrainingArguments: + def _get_training_args( + self, results_path: str, logs_path: str, training_params: Dict, log_frequency: int + ) -> TrainingArguments: return TrainingArguments( output_dir=results_path, logging_dir=logs_path, @@ -599,25 +809,36 @@ def _get_training_args(self, results_path: str, logs_path: str, training_params: metric_for_best_model="eval_f1_avg", load_best_model_at_end=True, save_total_limit=3, - use_cpu=self._config.DEVICE.lower() == Device.CPU.value if non_default_device_is_available(self._config.DEVICE) else False, + use_cpu=self._config.DEVICE.lower() == Device.CPU.value + if non_default_device_is_available(self._config.DEVICE) + else False, ) - def _save_trained_concepts(self, - training_concepts: Dict, - training_unique_concepts: Dict, - training_ignorance_counts: Dict, - model: PreTrainedModel) -> None: + def _save_trained_concepts( + self, + training_concepts: Dict, + training_unique_concepts: Dict, + training_ignorance_counts: Dict, + model: PreTrainedModel, + ) -> None: if len(training_concepts.keys()) != 0: unknown_concepts = set(training_concepts.keys()) - set(model.config.label2id.keys()) - unknown_concept_pct = round(len(unknown_concepts) / len(training_concepts.keys()) * 100, 2) - self._tracker_client.send_model_stats({ - "unknown_concept_count": len(unknown_concepts), - "unknown_concept_pct": unknown_concept_pct, - }, 0) + unknown_concept_pct = round( + len(unknown_concepts) / len(training_concepts.keys()) * 100, 2 + ) + self._tracker_client.send_model_stats( + { + "unknown_concept_count": len(unknown_concepts), + "unknown_concept_pct": unknown_concept_pct, + }, + 0, + ) if unknown_concepts: - self._tracker_client.save_dataframe_as_csv("unknown_concepts.csv", - pd.DataFrame({"concept": list(unknown_concepts)}), - self._model_name) + self._tracker_client.save_dataframe_as_csv( + "unknown_concepts.csv", + pd.DataFrame({"concept": list(unknown_concepts)}), + self._model_name, + ) annotation_count = [] annotation_unique_count = [] annotation_ignorance_count = [] @@ -626,22 +847,29 @@ def _save_trained_concepts(self, annotation_count.append(training_concepts[c]) annotation_unique_count.append(training_unique_concepts[c]) annotation_ignorance_count.append(training_ignorance_counts[c]) - self._tracker_client.save_dataframe_as_csv("trained_concepts.csv", - pd.DataFrame({ - "concept": concepts, - "anno_count": annotation_count, - "anno_unique_count": annotation_unique_count, - "anno_ignorance_count": annotation_ignorance_count, - }), - self._model_name) - - def _sanity_check_model_and_save_results(self, data_file_path: str, model_service: AbstractModelService) -> None: - self._tracker_client.save_dataframe_as_csv("sanity_check_result.csv", - sanity_check_model_with_trainer_export(data_file_path, - model_service, - return_df=True, - include_anchors=True), - self._model_name) + self._tracker_client.save_dataframe_as_csv( + "trained_concepts.csv", + pd.DataFrame( + { + "concept": concepts, + "anno_count": annotation_count, + "anno_unique_count": annotation_unique_count, + "anno_ignorance_count": annotation_ignorance_count, + } + ), + self._model_name, + ) + + def _sanity_check_model_and_save_results( + self, data_file_path: str, model_service: AbstractModelService + ) -> None: + self._tracker_client.save_dataframe_as_csv( + "sanity_check_result.csv", + sanity_check_model_with_trainer_export( + data_file_path, model_service, return_df=True, include_anchors=True + ), + self._model_name, + ) @final @@ -650,12 +878,14 @@ def __init__(self, tracker_client: TrackerClient) -> None: self.tracker_client = tracker_client self.epoch = 0 - def on_log(self, - args: TrainingArguments, - state: TrainerState, - control: TrainerControl, - logs: Dict[str, float], - **kwargs: Dict[str, Any]) -> None: + def on_log( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + logs: Dict[str, float], + **kwargs: Dict[str, Any], + ) -> None: if logs is not None: self.tracker_client.send_hf_metrics_logs(logs, self.epoch) self.epoch += 1 diff --git a/app/trainers/medcat_deid_trainer.py b/app/trainers/medcat_deid_trainer.py index 0333f7f..7c8a925 100644 --- a/app/trainers/medcat_deid_trainer.py +++ b/app/trainers/medcat_deid_trainer.py @@ -1,41 +1,52 @@ -import os +import gc +import inspect import logging +import os import shutil -import gc -import mlflow import tempfile -import inspect -import torch -import pandas as pd -import numpy as np from collections import defaultdict from functools import partial -from typing import Dict, TextIO, Any, Optional, List, final +from typing import Any, Dict, List, Optional, TextIO, final + +import mlflow +import numpy as np +import pandas as pd +import torch from evaluate.visualization import radar_plot -from transformers import pipeline from medcat import __version__ as medcat_version from medcat.ner.transformers_ner import TransformersNER -from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl, PreTrainedModel, Trainer -from utils import get_settings, non_default_device_is_available, get_hf_pipeline_device_id +from transformers import ( + PreTrainedModel, + Trainer, + TrainerCallback, + TrainerControl, + TrainerState, + TrainingArguments, + pipeline, +) + +from utils import get_hf_pipeline_device_id, get_settings, non_default_device_is_available + from management import tracker_client -from trainers.medcat_trainer import MedcatSupervisedTrainer from processors.metrics_collector import get_stats_from_trainer_export +from trainers.medcat_trainer import MedcatSupervisedTrainer logger = logging.getLogger("cms") class MetricsCallback(TrainerCallback): - def __init__(self, trainer: Trainer) -> None: self._trainer = trainer self._step = 0 self._interval = get_settings().TRAINING_METRICS_LOGGING_INTERVAL - def on_step_end(self, - args: TrainingArguments, - state: TrainerState, - control: TrainerControl, - **kwargs: Dict[str, Any]) -> None: + def on_step_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs: Dict[str, Any], + ) -> None: if self._step == 0: self._step += 1 return @@ -46,18 +57,19 @@ def on_step_end(self, class LabelCountCallback(TrainerCallback): - def __init__(self, trainer: Trainer) -> None: self._trainer = trainer self._label_counts: Dict = defaultdict(int) self._interval = get_settings().TRAINING_METRICS_LOGGING_INTERVAL - def on_step_end(self, - args: TrainingArguments, - state: TrainerState, - control: TrainerControl, - model: Optional[PreTrainedModel] = None, - **kwargs: Dict[str, Any]) -> None: + def on_step_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + model: Optional[PreTrainedModel] = None, + **kwargs: Dict[str, Any], + ) -> None: step = state.global_step train_dataset = self._trainer.train_dataset batch_ids = train_dataset[step]["labels"] @@ -72,14 +84,15 @@ def on_step_end(self, @final class MedcatDeIdentificationSupervisedTrainer(MedcatSupervisedTrainer): - @staticmethod - def run(trainer: "MedcatDeIdentificationSupervisedTrainer", - training_params: Dict, - data_file: TextIO, - log_frequency: int, - run_id: str, - description: Optional[str] = None) -> None: + def run( + trainer: "MedcatDeIdentificationSupervisedTrainer", + training_params: Dict, + data_file: TextIO, + log_frequency: int, + run_id: str, + description: Optional[str] = None, + ) -> None: model_pack_path = None cdb_config_path = None copied_model_pack_path = None @@ -90,17 +103,38 @@ def run(trainer: "MedcatDeIdentificationSupervisedTrainer", if not eval_mode: try: logger.info("Loading a new model copy for training...") - copied_model_pack_path = trainer._make_model_file_copy(trainer._model_pack_path, run_id) + copied_model_pack_path = trainer._make_model_file_copy( + trainer._model_pack_path, run_id + ) model = trainer._model_service.load_model(copied_model_pack_path) ner = model._addl_ner[0] - ner.tokenizer.hf_tokenizer._in_target_context_manager = getattr(ner.tokenizer.hf_tokenizer, "_in_target_context_manager", False) - ner.tokenizer.hf_tokenizer.clean_up_tokenization_spaces = getattr(ner.tokenizer.hf_tokenizer, "clean_up_tokenization_spaces", None) - ner.tokenizer.hf_tokenizer.split_special_tokens = getattr(ner.tokenizer.hf_tokenizer, "split_special_tokens", False) + ner.tokenizer.hf_tokenizer._in_target_context_manager = getattr( + ner.tokenizer.hf_tokenizer, "_in_target_context_manager", False + ) + ner.tokenizer.hf_tokenizer.clean_up_tokenization_spaces = getattr( + ner.tokenizer.hf_tokenizer, "clean_up_tokenization_spaces", None + ) + ner.tokenizer.hf_tokenizer.split_special_tokens = getattr( + ner.tokenizer.hf_tokenizer, "split_special_tokens", False + ) _save_pretrained = ner.model.save_pretrained - if ("safe_serialization" in inspect.signature(_save_pretrained).parameters): - ner.model.save_pretrained = partial(_save_pretrained, safe_serialization=(trainer._config.TRAINING_SAFE_MODEL_SERIALISATION == "true")) - ner_config = {f"transformers.cat_config.{arg}": str(val) for arg, val in ner.config.general.dict().items()} - ner_config.update({f"transformers.training.{arg}": str(val) for arg, val in ner.training_arguments.to_dict().items()}) + if "safe_serialization" in inspect.signature(_save_pretrained).parameters: + ner.model.save_pretrained = partial( + _save_pretrained, + safe_serialization=( + trainer._config.TRAINING_SAFE_MODEL_SERIALISATION == "true" + ), + ) + ner_config = { + f"transformers.cat_config.{arg}": str(val) + for arg, val in ner.config.general.dict().items() + } + ner_config.update( + { + f"transformers.training.{arg}": str(val) + for arg, val in ner.training_arguments.to_dict().items() + } + ) for key, val in ner_config.items(): ner_config[key] = "" if val == "" else val trainer._tracker_client.log_model_config(ner_config) @@ -129,21 +163,34 @@ def run(trainer: "MedcatDeIdentificationSupervisedTrainer", dataset["train"] = dataset["train"].shuffle() dataset["test"] = dataset["test"].shuffle() - ner = MedcatDeIdentificationSupervisedTrainer._customise_training_device(ner, trainer._config.DEVICE) - eval_results, examples, dataset = ner.train(data_file.name, - ignore_extra_labels=True, - dataset=dataset, - # trainer_callbacks=[MetricsCallback, LabelCountCallback] - ) + ner = MedcatDeIdentificationSupervisedTrainer._customise_training_device( + ner, trainer._config.DEVICE + ) + eval_results, examples, dataset = ner.train( + data_file.name, + ignore_extra_labels=True, + dataset=dataset, + # trainer_callbacks=[MetricsCallback, LabelCountCallback] + ) if (training + 1) % log_frequency == 0: for _, row in eval_results.iterrows(): normalised_name = row["name"].replace(" ", "_").lower() grouped_metrics = { - f"{normalised_name}/precision": row["p"] if row["p"] is not None else np.nan, - f"{normalised_name}/recall": row["r"] if row["r"] is not None else np.nan, - f"{normalised_name}/f1": row["f1"] if row["f1"] is not None else np.nan, - f"{normalised_name}/p_merged": row["p_merged"] if row["p_merged"] is not None else np.nan, - f"{normalised_name}/r_merged": row["r_merged"] if row["r_merged"] is not None else np.nan, + f"{normalised_name}/precision": row["p"] + if row["p"] is not None + else np.nan, + f"{normalised_name}/recall": row["r"] + if row["r"] is not None + else np.nan, + f"{normalised_name}/f1": row["f1"] + if row["f1"] is not None + else np.nan, + f"{normalised_name}/p_merged": row["p_merged"] + if row["p_merged"] is not None + else np.nan, + f"{normalised_name}/r_merged": row["r_merged"] + if row["r_merged"] is not None + else np.nan, } trainer._tracker_client.send_model_stats(grouped_metrics, training) @@ -162,34 +209,56 @@ def run(trainer: "MedcatDeIdentificationSupervisedTrainer", for _, row in eval_results.iterrows(): if row["support"] == 0: # the concept has not been used for annotation continue - aggregated_metrics.append({ - "per_concept_p": row["p"] if row["p"] is not None else 0.0, - "per_concept_r": row["r"] if row["r"] is not None else 0.0, - "per_concept_f1": row["f1"] if row["f1"] is not None else 0.0, - "per_concept_support": row["support"] if row["support"] is not None else 0.0, - "per_concept_p_merged": row["p_merged"] if row["p_merged"] is not None else 0.0, - "per_concept_r_merged": row["r_merged"] if row["r_merged"] is not None else 0.0, - }) + aggregated_metrics.append( + { + "per_concept_p": row["p"] if row["p"] is not None else 0.0, + "per_concept_r": row["r"] if row["r"] is not None else 0.0, + "per_concept_f1": row["f1"] if row["f1"] is not None else 0.0, + "per_concept_support": row["support"] + if row["support"] is not None + else 0.0, + "per_concept_p_merged": row["p_merged"] + if row["p_merged"] is not None + else 0.0, + "per_concept_r_merged": row["r_merged"] + if row["r_merged"] is not None + else 0.0, + } + ) cui2names[row["cui"]] = model.cdb.get_name(row["cui"]) - MedcatDeIdentificationSupervisedTrainer._save_metrics_plot(aggregated_metrics, - list(cui2names.values()), - trainer._tracker_client, - trainer._model_name) + MedcatDeIdentificationSupervisedTrainer._save_metrics_plot( + aggregated_metrics, + list(cui2names.values()), + trainer._tracker_client, + trainer._model_name, + ) trainer._tracker_client.send_batched_model_stats(aggregated_metrics, run_id) trainer._save_examples(examples, ["tp", "tn"]) trainer._tracker_client.log_classes_and_names(cui2names) - cui_counts, cui_unique_counts, cui_ignorance_counts, num_of_docs = get_stats_from_trainer_export(data_file.name) + cui_counts, cui_unique_counts, cui_ignorance_counts, num_of_docs = ( + get_stats_from_trainer_export(data_file.name) + ) trainer._tracker_client.log_document_size(num_of_docs) - trainer._save_trained_concepts(cui_counts, cui_unique_counts, cui_ignorance_counts, model) - trainer._sanity_check_model_and_save_results(data_file.name, trainer._model_service.from_model(model)) + trainer._save_trained_concepts( + cui_counts, cui_unique_counts, cui_ignorance_counts, model + ) + trainer._sanity_check_model_and_save_results( + data_file.name, trainer._model_service.from_model(model) + ) if not skip_save_model: - model_pack_path = trainer.save_model_pack(model, trainer._retrained_models_dir, description) + model_pack_path = trainer.save_model_pack( + model, trainer._retrained_models_dir, description + ) cdb_config_path = model_pack_path.replace(".zip", "_config.json") model.cdb.config.save(cdb_config_path) - model_uri = trainer._tracker_client.save_model(model_pack_path, trainer._model_name, trainer._model_manager) + model_uri = trainer._tracker_client.save_model( + model_pack_path, trainer._model_name, trainer._model_manager + ) logger.info("Retrained model saved: %s", model_uri) - trainer._tracker_client.save_model_artifact(cdb_config_path, trainer._model_name) + trainer._tracker_client.save_model_artifact( + cdb_config_path, trainer._model_name + ) else: logger.info("Skipped saving on the retrained model") if redeploy: @@ -202,7 +271,9 @@ def run(trainer: "MedcatDeIdentificationSupervisedTrainer", trainer._tracker_client.end_with_success() # Remove intermediate results folder on successful training - results_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "results")) + results_path = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "results") + ) if results_path and os.path.isdir(results_path): shutil.rmtree(results_path) except Exception as e: @@ -220,12 +291,20 @@ def run(trainer: "MedcatDeIdentificationSupervisedTrainer", else: try: logger.info("Evaluating the running model...") - trainer._tracker_client.log_model_config(trainer.get_flattened_config(trainer._model_service._model)) + trainer._tracker_client.log_model_config( + trainer.get_flattened_config(trainer._model_service._model) + ) trainer._tracker_client.log_trainer_version(medcat_version) ner = trainer._model_service._model._addl_ner[0] - ner.tokenizer.hf_tokenizer._in_target_context_manager = getattr(ner.tokenizer.hf_tokenizer, "_in_target_context_manager", False) - ner.tokenizer.hf_tokenizer.clean_up_tokenization_spaces = getattr(ner.tokenizer.hf_tokenizer, "clean_up_tokenization_spaces", None) - ner.tokenizer.hf_tokenizer.split_special_tokens = getattr(ner.tokenizer.hf_tokenizer, "split_special_tokens", False) + ner.tokenizer.hf_tokenizer._in_target_context_manager = getattr( + ner.tokenizer.hf_tokenizer, "_in_target_context_manager", False + ) + ner.tokenizer.hf_tokenizer.clean_up_tokenization_spaces = getattr( + ner.tokenizer.hf_tokenizer, "clean_up_tokenization_spaces", None + ) + ner.tokenizer.hf_tokenizer.split_special_tokens = getattr( + ner.tokenizer.hf_tokenizer, "split_special_tokens", False + ) eval_results, examples = ner.eval(data_file.name) cui2names = {} eval_results.sort_values(by=["cui"]) @@ -233,19 +312,29 @@ def run(trainer: "MedcatDeIdentificationSupervisedTrainer", for _, row in eval_results.iterrows(): if row["support"] == 0: # the concept has not been used for annotation continue - aggregated_metrics.append({ - "per_concept_p": row["p"] if row["p"] is not None else 0.0, - "per_concept_r": row["r"] if row["r"] is not None else 0.0, - "per_concept_f1": row["f1"] if row["f1"] is not None else 0.0, - "per_concept_support": row["support"] if row["support"] is not None else 0.0, - "per_concept_p_merged": row["p_merged"] if row["p_merged"] is not None else 0.0, - "per_concept_r_merged": row["r_merged"] if row["r_merged"] is not None else 0.0, - }) + aggregated_metrics.append( + { + "per_concept_p": row["p"] if row["p"] is not None else 0.0, + "per_concept_r": row["r"] if row["r"] is not None else 0.0, + "per_concept_f1": row["f1"] if row["f1"] is not None else 0.0, + "per_concept_support": row["support"] + if row["support"] is not None + else 0.0, + "per_concept_p_merged": row["p_merged"] + if row["p_merged"] is not None + else 0.0, + "per_concept_r_merged": row["r_merged"] + if row["r_merged"] is not None + else 0.0, + } + ) cui2names[row["cui"]] = trainer._model_service._model.cdb.get_name(row["cui"]) trainer._tracker_client.send_batched_model_stats(aggregated_metrics, run_id) trainer._save_examples(examples, ["tp", "tn"]) trainer._tracker_client.log_classes_and_names(cui2names) - cui_counts, cui_unique_counts, cui_ignorance_counts, num_of_docs = get_stats_from_trainer_export(data_file.name) + cui_counts, cui_unique_counts, cui_ignorance_counts, num_of_docs = ( + get_stats_from_trainer_export(data_file.name) + ) trainer._tracker_client.log_document_size(num_of_docs) trainer._sanity_check_model_and_save_results(data_file.name, trainer._model_service) logger.info("Model evaluation finished") @@ -260,10 +349,12 @@ def run(trainer: "MedcatDeIdentificationSupervisedTrainer", trainer._training_in_progress = False @staticmethod - def _save_metrics_plot(metrics: List[Dict], - concepts: List[str], - tracker_client: tracker_client.TrackerClient, - model_name: str) -> None: + def _save_metrics_plot( + metrics: List[Dict], + concepts: List[str], + tracker_client: tracker_client.TrackerClient, + model_name: str, + ) -> None: try: plot = radar_plot(data=metrics, model_names=concepts) with tempfile.TemporaryDirectory() as d: @@ -279,12 +370,17 @@ def _save_metrics_plot(metrics: List[Dict], def _customise_training_device(ner: TransformersNER, device_name: str) -> TransformersNER: if non_default_device_is_available(device_name): ner.model.to(torch.device(device_name)) - ner.ner_pipe = pipeline(model=ner.model, - framework="pt", - task="ner", - tokenizer=ner.tokenizer.hf_tokenizer, - device=get_hf_pipeline_device_id(device_name)) + ner.ner_pipe = pipeline( + model=ner.model, + framework="pt", + task="ner", + tokenizer=ner.tokenizer.hf_tokenizer, + device=get_hf_pipeline_device_id(device_name), + ) else: if device_name != "default": - logger.warning("DEVICE is set to '%s' but it is not available. Using 'default' instead.", device_name) + logger.warning( + "DEVICE is set to '%s' but it is not available. Using 'default' instead.", + device_name, + ) return ner diff --git a/app/trainers/medcat_trainer.py b/app/trainers/medcat_trainer.py index f054a57..ee13625 100644 --- a/app/trainers/medcat_trainer.py +++ b/app/trainers/medcat_trainer.py @@ -3,28 +3,32 @@ import os import re import tempfile -import ijson -import datasets from contextlib import redirect_stdout -from typing import TextIO, Dict, Optional, Set, List, Union, final +from typing import Dict, List, Optional, Set, TextIO, Union, final +import datasets +import ijson import pandas as pd from medcat import __version__ as medcat_version from medcat.cat import CAT + +from domain import DatasetSplit +from utils import get_func_params_as_dict, non_default_device_is_available + from management.log_captor import LogCaptor from management.model_manager import ModelManager from model_services.base import AbstractModelService -from trainers.base import SupervisedTrainer, UnsupervisedTrainer from processors.data_batcher import mini_batch -from processors.metrics_collector import sanity_check_model_with_trainer_export, get_stats_from_trainer_export -from utils import get_func_params_as_dict, non_default_device_is_available -from domain import DatasetSplit +from processors.metrics_collector import ( + get_stats_from_trainer_export, + sanity_check_model_with_trainer_export, +) +from trainers.base import SupervisedTrainer, UnsupervisedTrainer logger = logging.getLogger("cms") class _MedcatTrainerCommon(object): - @staticmethod def get_flattened_config(model: CAT, prefix: Optional[str] = None) -> Dict: params = {} @@ -43,16 +47,18 @@ def get_flattened_config(model: CAT, prefix: Optional[str] = None) -> Dict: params[f"{prefix}linking.{key}"] = str(val) params[f"{prefix}word_skipper"] = str(model.cdb.config.word_skipper) params[f"{prefix}punct_checker"] = str(model.cdb.config.punct_checker) - params.pop(f"{prefix}linking.filters", None) # deal with the length value in the older model + params.pop( + f"{prefix}linking.filters", None + ) # deal with the length value in the older model for key, val in params.items(): if val == "": params[key] = "" return params @staticmethod - def deploy_model(model_service: AbstractModelService, - model: CAT, - skip_save_model: bool) -> None: + def deploy_model( + model_service: AbstractModelService, model: CAT, skip_save_model: bool + ) -> None: if skip_save_model: model._versioning() del model_service.model @@ -71,23 +77,26 @@ def save_model_pack(model: CAT, model_dir: str, description: Optional[str] = Non class MedcatSupervisedTrainer(SupervisedTrainer, _MedcatTrainerCommon): - def __init__(self, model_service: AbstractModelService) -> None: SupervisedTrainer.__init__(self, model_service._config, model_service.model_name) self._model_service = model_service self._model_name = model_service.model_name self._model_pack_path = model_service._model_pack_path - self._retrained_models_dir = os.path.join(model_service._model_parent_dir, "retrained", self._model_name.replace(" ", "_")) + self._retrained_models_dir = os.path.join( + model_service._model_parent_dir, "retrained", self._model_name.replace(" ", "_") + ) self._model_manager = ModelManager(type(model_service), model_service._config) os.makedirs(self._retrained_models_dir, exist_ok=True) @staticmethod - def run(trainer: "MedcatSupervisedTrainer", - training_params: Dict, - data_file: TextIO, - log_frequency: int, - run_id: str, - description: Optional[str] = None) -> None: + def run( + trainer: "MedcatSupervisedTrainer", + training_params: Dict, + data_file: TextIO, + log_frequency: int, + run_id: str, + description: Optional[str] = None, + ) -> None: training_params.update({"print_stats": log_frequency}) model_pack_path = None cdb_config_path = None @@ -99,25 +108,38 @@ def run(trainer: "MedcatSupervisedTrainer", if not eval_mode: try: logger.info("Loading a new model copy for training...") - copied_model_pack_path = trainer._make_model_file_copy(trainer._model_pack_path, run_id) + copied_model_pack_path = trainer._make_model_file_copy( + trainer._model_pack_path, run_id + ) if non_default_device_is_available(trainer._config.DEVICE): - model = trainer._model_service.load_model(copied_model_pack_path, - meta_cat_config_dict={"general": {"device": trainer._config.DEVICE}}) + model = trainer._model_service.load_model( + copied_model_pack_path, + meta_cat_config_dict={"general": {"device": trainer._config.DEVICE}}, + ) model.config.general["device"] = trainer._config.DEVICE else: model = trainer._model_service.load_model(copied_model_pack_path) trainer._tracker_client.log_model_config(trainer.get_flattened_config(model)) trainer._tracker_client.log_trainer_version(medcat_version) - cui_counts, cui_unique_counts, cui_ignorance_counts, num_of_docs = get_stats_from_trainer_export(data_file.name) + cui_counts, cui_unique_counts, cui_ignorance_counts, num_of_docs = ( + get_stats_from_trainer_export(data_file.name) + ) trainer._tracker_client.log_document_size(num_of_docs) - training_params.update({"extra_cui_filter": trainer._get_concept_filter(cui_counts, model)}) + training_params.update( + {"extra_cui_filter": trainer._get_concept_filter(cui_counts, model)} + ) logger.info("Performing supervised training...") train_supervised_params = get_func_params_as_dict(model.train_supervised_from_json) - train_supervised_params = {p_key: training_params[p_key] if p_key in training_params else p_val for p_key, p_val in train_supervised_params.items()} + train_supervised_params = { + p_key: training_params[p_key] if p_key in training_params else p_val + for p_key, p_val in train_supervised_params.items() + } model.config.version.description = description or model.config.version.description with redirect_stdout(LogCaptor(trainer._glean_and_log_metrics)): - fps, fns, tps, p, r, f1, cc, examples = model.train_supervised_from_json(data_file.name, **train_supervised_params) + fps, fns, tps, p, r, f1, cc, examples = model.train_supervised_from_json( + data_file.name, **train_supervised_params + ) trainer._save_examples(examples, ["tp", "tn"]) del examples gc.collect() @@ -133,32 +155,44 @@ def run(trainer: "MedcatSupervisedTrainer", fn_accumulated += fns.get(cui, 0) tp_accumulated += tps.get(cui, 0) cc_accumulated += cc.get(cui, 0) - aggregated_metrics.append({ - "per_concept_fp": fps.get(cui, 0), - "per_concept_fn": fns.get(cui, 0), - "per_concept_tp": tps.get(cui, 0), - "per_concept_counts": cc.get(cui, 0), - "per_concept_count_train": model.cdb.cui2count_train.get(cui, 0), - "per_concept_acc_fp": fp_accumulated, - "per_concept_acc_fn": fn_accumulated, - "per_concept_acc_tp": tp_accumulated, - "per_concept_acc_cc": cc_accumulated, - "per_concept_precision": p[cui], - "per_concept_recall": r[cui], - "per_concept_f1": f1_val, - }) + aggregated_metrics.append( + { + "per_concept_fp": fps.get(cui, 0), + "per_concept_fn": fns.get(cui, 0), + "per_concept_tp": tps.get(cui, 0), + "per_concept_counts": cc.get(cui, 0), + "per_concept_count_train": model.cdb.cui2count_train.get(cui, 0), + "per_concept_acc_fp": fp_accumulated, + "per_concept_acc_fn": fn_accumulated, + "per_concept_acc_tp": tp_accumulated, + "per_concept_acc_cc": cc_accumulated, + "per_concept_precision": p[cui], + "per_concept_recall": r[cui], + "per_concept_f1": f1_val, + } + ) cuis.append(cui) trainer._tracker_client.send_batched_model_stats(aggregated_metrics, run_id) - trainer._save_trained_concepts(cui_counts, cui_unique_counts, cui_ignorance_counts, model) + trainer._save_trained_concepts( + cui_counts, cui_unique_counts, cui_ignorance_counts, model + ) trainer._tracker_client.log_classes(cuis) - trainer._sanity_check_model_and_save_results(data_file.name, trainer._model_service.from_model(model)) + trainer._sanity_check_model_and_save_results( + data_file.name, trainer._model_service.from_model(model) + ) if not skip_save_model: - model_pack_path = trainer.save_model_pack(model, trainer._retrained_models_dir, description) + model_pack_path = trainer.save_model_pack( + model, trainer._retrained_models_dir, description + ) cdb_config_path = model_pack_path.replace(".zip", "_config.json") model.cdb.config.save(cdb_config_path) - model_uri = trainer._tracker_client.save_model(model_pack_path, trainer._model_name, trainer._model_manager) + model_uri = trainer._tracker_client.save_model( + model_pack_path, trainer._model_name, trainer._model_manager + ) logger.info("Retrained model saved: %s", model_uri) - trainer._tracker_client.save_model_artifact(cdb_config_path, trainer._model_name) + trainer._tracker_client.save_model_artifact( + cdb_config_path, trainer._model_name + ) else: logger.info("Skipped saving on the retrained model") if redeploy: @@ -184,10 +218,13 @@ def run(trainer: "MedcatSupervisedTrainer", else: try: logger.info("Evaluating the running model...") - trainer._tracker_client.log_model_config(trainer.get_flattened_config(trainer._model_service._model)) + trainer._tracker_client.log_model_config( + trainer.get_flattened_config(trainer._model_service._model) + ) trainer._tracker_client.log_trainer_version(medcat_version) - cui_counts, cui_unique_counts, cui_ignorance_counts, num_of_docs = get_stats_from_trainer_export( - data_file.name) + cui_counts, cui_unique_counts, cui_ignorance_counts, num_of_docs = ( + get_stats_from_trainer_export(data_file.name) + ) trainer._tracker_client.log_document_size(num_of_docs) trainer._sanity_check_model_and_save_results(data_file.name, trainer._model_service) trainer._tracker_client.end_with_success() @@ -206,8 +243,9 @@ def _get_concept_filter(training_concepts: Dict, model: CAT) -> Set[str]: return set(training_concepts.keys()).intersection(set(model.cdb.cui2names.keys())) def _glean_and_log_metrics(self, log: str) -> None: - metric_lines = re.findall(r"Epoch: (\d+), Prec: (\d+\.\d+), Rec: (\d+\.\d+), F1: (\d+\.\d+)", log, - re.IGNORECASE) + metric_lines = re.findall( + r"Epoch: (\d+), Prec: (\d+\.\d+), Rec: (\d+\.\d+), F1: (\d+\.\d+)", log, re.IGNORECASE + ) for step, metric in enumerate(metric_lines): metrics = { "precision": float(metric[1]), @@ -216,22 +254,31 @@ def _glean_and_log_metrics(self, log: str) -> None: } self._tracker_client.send_model_stats(metrics, int(metric[0])) - def _save_trained_concepts(self, - training_concepts: Dict, - training_unique_concepts: Dict, - training_ignorance_counts: Dict, - model: CAT) -> None: + def _save_trained_concepts( + self, + training_concepts: Dict, + training_unique_concepts: Dict, + training_ignorance_counts: Dict, + model: CAT, + ) -> None: if len(training_concepts.keys()) != 0: unknown_concepts = set(training_concepts.keys()) - set(model.cdb.cui2names.keys()) - unknown_concept_pct = round(len(unknown_concepts) / len(training_concepts.keys()) * 100, 2) - self._tracker_client.send_model_stats({ - "unknown_concept_count": len(unknown_concepts), - "unknown_concept_pct": unknown_concept_pct, - }, 0) + unknown_concept_pct = round( + len(unknown_concepts) / len(training_concepts.keys()) * 100, 2 + ) + self._tracker_client.send_model_stats( + { + "unknown_concept_count": len(unknown_concepts), + "unknown_concept_pct": unknown_concept_pct, + }, + 0, + ) if unknown_concepts: - self._tracker_client.save_dataframe_as_csv("unknown_concepts.csv", - pd.DataFrame({"concept": list(unknown_concepts)}), - self._model_name) + self._tracker_client.save_dataframe_as_csv( + "unknown_concepts.csv", + pd.DataFrame({"concept": list(unknown_concepts)}), + self._model_name, + ) train_count = [] concept_names = [] annotation_count = [] @@ -239,29 +286,38 @@ def _save_trained_concepts(self, annotation_ignorance_count = [] concepts = list(training_concepts.keys()) for c in concepts: - train_count.append(model.cdb.cui2count_train[c] if c in model.cdb.cui2count_train else 0) + train_count.append( + model.cdb.cui2count_train[c] if c in model.cdb.cui2count_train else 0 + ) concept_names.append(model.cdb.get_name(c)) annotation_count.append(training_concepts[c]) annotation_unique_count.append(training_unique_concepts[c]) annotation_ignorance_count.append(training_ignorance_counts[c]) - self._tracker_client.save_dataframe_as_csv("trained_concepts.csv", - pd.DataFrame({ - "concept": concepts, - "name": concept_names, - "train_count": train_count, - "anno_count": annotation_count, - "anno_unique_count": annotation_unique_count, - "anno_ignorance_count": annotation_ignorance_count, - }), - self._model_name) + self._tracker_client.save_dataframe_as_csv( + "trained_concepts.csv", + pd.DataFrame( + { + "concept": concepts, + "name": concept_names, + "train_count": train_count, + "anno_count": annotation_count, + "anno_unique_count": annotation_unique_count, + "anno_ignorance_count": annotation_ignorance_count, + } + ), + self._model_name, + ) - def _sanity_check_model_and_save_results(self, data_file_path: str, medcat_model: AbstractModelService) -> None: - self._tracker_client.save_dataframe_as_csv("sanity_check_result.csv", - sanity_check_model_with_trainer_export(data_file_path, - medcat_model, - return_df=True, - include_anchors=True), - self._model_name) + def _sanity_check_model_and_save_results( + self, data_file_path: str, medcat_model: AbstractModelService + ) -> None: + self._tracker_client.save_dataframe_as_csv( + "sanity_check_result.csv", + sanity_check_model_with_trainer_export( + data_file_path, medcat_model, return_df=True, include_anchors=True + ), + self._model_name, + ) def _save_examples(self, examples: Dict, excluded_example_keys: List) -> None: for e_key, e_items in examples.items(): @@ -274,30 +330,35 @@ def _save_examples(self, examples: Dict, excluded_example_keys: List) -> None: # Extract column names from the first row columns = ["concept"] + list(items[0].keys()) for item in items: - rows.append([concept] + list(item.values())[:len(columns)-1]) + rows.append([concept] + list(item.values())[: len(columns) - 1]) if rows: - self._tracker_client.save_dataframe_as_csv(f"{e_key}_examples.csv", pd.DataFrame(rows, columns=columns), self._model_name) + self._tracker_client.save_dataframe_as_csv( + f"{e_key}_examples.csv", pd.DataFrame(rows, columns=columns), self._model_name + ) @final class MedcatUnsupervisedTrainer(UnsupervisedTrainer, _MedcatTrainerCommon): - def __init__(self, model_service: AbstractModelService) -> None: UnsupervisedTrainer.__init__(self, model_service._config, model_service.model_name) self._model_service = model_service self._model_name = model_service.model_name self._model_pack_path = model_service._model_pack_path - self._retrained_models_dir = os.path.join(model_service._model_parent_dir, "retrained", self._model_name.replace(" ", "_")) + self._retrained_models_dir = os.path.join( + model_service._model_parent_dir, "retrained", self._model_name.replace(" ", "_") + ) self._model_manager = ModelManager(type(model_service), model_service._config) os.makedirs(self._retrained_models_dir, exist_ok=True) @staticmethod - def run(trainer: "MedcatUnsupervisedTrainer", - training_params: Dict, - data_file: Union[TextIO, tempfile.TemporaryDirectory], - log_frequency: int, - run_id: str, - description: Optional[str] = None) -> None: + def run( + trainer: "MedcatUnsupervisedTrainer", + training_params: Dict, + data_file: Union[TextIO, tempfile.TemporaryDirectory], + log_frequency: int, + run_id: str, + description: Optional[str] = None, + ) -> None: model_pack_path = None cdb_config_path = None copied_model_pack_path = None @@ -314,8 +375,10 @@ def run(trainer: "MedcatUnsupervisedTrainer", logger.info("Loading a new model copy for training...") copied_model_pack_path = trainer._make_model_file_copy(trainer._model_pack_path, run_id) if non_default_device_is_available(trainer._config.DEVICE): - model = trainer._model_service.load_model(copied_model_pack_path, - meta_cat_config_dict={"general": {"device": trainer._config.DEVICE}}) + model = trainer._model_service.load_model( + copied_model_pack_path, + meta_cat_config_dict={"general": {"device": trainer._config.DEVICE}}, + ) model.config.general["device"] = trainer._config.DEVICE else: model = trainer._model_service.load_model(copied_model_pack_path) @@ -327,7 +390,10 @@ def run(trainer: "MedcatUnsupervisedTrainer", before_cui2count_train = dict(model.cdb.cui2count_train) num_of_docs = 0 train_unsupervised_params = get_func_params_as_dict(model.train) - train_unsupervised_params = {p_key: training_params[p_key] if p_key in training_params else p_val for p_key, p_val in train_unsupervised_params.items()} + train_unsupervised_params = { + p_key: training_params[p_key] if p_key in training_params else p_val + for p_key, p_val in train_unsupervised_params.items() + } for batch in mini_batch(texts, batch_size=log_frequency): step += 1 model.train(batch, **train_unsupervised_params) @@ -335,24 +401,34 @@ def run(trainer: "MedcatUnsupervisedTrainer", trainer._tracker_client.send_model_stats(model.cdb.make_stats(), step) trainer._tracker_client.log_document_size(num_of_docs) - after_cui2count_train = {c: ct for c, ct in - sorted(model.cdb.cui2count_train.items(), key=lambda item: item[1], reverse=True)} + after_cui2count_train = { + c: ct + for c, ct in sorted( + model.cdb.cui2count_train.items(), key=lambda item: item[1], reverse=True + ) + } aggregated_metrics = [] cui_step = 0 for cui, train_count in after_cui2count_train.items(): if cui_step >= 10000: # large numbers will cause the mlflow page to hung on loading break cui_step += 1 - aggregated_metrics.append({ - "per_concept_train_count_before": before_cui2count_train.get(cui, 0), - "per_concept_train_count_after": train_count - }) + aggregated_metrics.append( + { + "per_concept_train_count_before": before_cui2count_train.get(cui, 0), + "per_concept_train_count_after": train_count, + } + ) trainer._tracker_client.send_batched_model_stats(aggregated_metrics, run_id) if not skip_save_model: - model_pack_path = trainer.save_model_pack(model, trainer._retrained_models_dir, description) + model_pack_path = trainer.save_model_pack( + model, trainer._retrained_models_dir, description + ) cdb_config_path = model_pack_path.replace(".zip", "_config.json") model.cdb.config.save(cdb_config_path) - model_uri = trainer._tracker_client.save_model(model_pack_path, trainer._model_name, trainer._model_manager) + model_uri = trainer._tracker_client.save_model( + model_pack_path, trainer._model_name, trainer._model_manager + ) logger.info(f"Retrained model saved: {model_uri}") trainer._tracker_client.save_model_artifact(cdb_config_path, trainer._model_name) else: diff --git a/app/trainers/metacat_trainer.py b/app/trainers/metacat_trainer.py index 917319f..1446c61 100644 --- a/app/trainers/metacat_trainer.py +++ b/app/trainers/metacat_trainer.py @@ -1,21 +1,22 @@ -import os +import gc import logging +import os import shutil -import gc -from typing import Dict, TextIO, Optional, List +from typing import Dict, List, Optional, TextIO import pandas as pd from medcat import __version__ as medcat_version from medcat.meta_cat import MetaCAT -from trainers.medcat_trainer import MedcatSupervisedTrainer + from exception import TrainingFailedException from utils import non_default_device_is_available +from trainers.medcat_trainer import MedcatSupervisedTrainer + logger = logging.getLogger("cms") class MetacatTrainer(MedcatSupervisedTrainer): - @staticmethod def get_flattened_config(model: MetaCAT, prefix: Optional[str] = None) -> Dict: params = {} @@ -32,12 +33,14 @@ def get_flattened_config(model: MetaCAT, prefix: Optional[str] = None) -> Dict: return params @staticmethod - def run(trainer: "MetacatTrainer", - training_params: Dict, - data_file: TextIO, - log_frequency: int, - run_id: str, - description: Optional[str] = None) -> None: + def run( + trainer: "MetacatTrainer", + training_params: Dict, + data_file: TextIO, + log_frequency: int, + run_id: str, + description: Optional[str] = None, + ) -> None: model_pack_path = None cdb_config_path = None copied_model_pack_path = None @@ -48,10 +51,14 @@ def run(trainer: "MetacatTrainer", if not eval_mode: try: logger.info("Loading a new model copy for training...") - copied_model_pack_path = trainer._make_model_file_copy(trainer._model_pack_path, run_id) + copied_model_pack_path = trainer._make_model_file_copy( + trainer._model_pack_path, run_id + ) if non_default_device_is_available(trainer._config.DEVICE): - model = trainer._model_service.load_model(copied_model_pack_path, - meta_cat_config_dict={"general": {"device": trainer._config.DEVICE}}) + model = trainer._model_service.load_model( + copied_model_pack_path, + meta_cat_config_dict={"general": {"device": trainer._config.DEVICE}}, + ) model.config.general["device"] = trainer._config.DEVICE else: model = trainer._model_service.load_model(copied_model_pack_path) @@ -64,42 +71,82 @@ def run(trainer: "MetacatTrainer", if training_params.get("test_size") is not None: meta_cat.config.train.test_size = training_params["test_size"] meta_cat.config.train.nepochs = training_params["nepochs"] - trainer._tracker_client.log_model_config(trainer.get_flattened_config(meta_cat, category_name)) + trainer._tracker_client.log_model_config( + trainer.get_flattened_config(meta_cat, category_name) + ) trainer._tracker_client.log_trainer_version(medcat_version) logger.info('Performing supervised training on category "%s"...', category_name) try: - winner_report = meta_cat.train(data_file.name, os.path.join(copied_model_pack_path.replace(".zip", ""), f"meta_{category_name}")) + winner_report = meta_cat.train( + data_file.name, + os.path.join( + copied_model_pack_path.replace(".zip", ""), f"meta_{category_name}" + ), + ) is_retrained = True report_stats = { - f"{category_name}_macro_avg_precision": winner_report["report"]["macro avg"]["precision"], - f"{category_name}_macro_avg_recall": winner_report["report"]["macro avg"]["recall"], - f"{category_name}_macro_avg_f1": winner_report["report"]["macro avg"]["f1-score"], - f"{category_name}_macro_avg_support": winner_report["report"]["macro avg"]["support"], - f"{category_name}_weighted_avg_precision": winner_report["report"]["weighted avg"]["precision"], - f"{category_name}_weighted_avg_recall": winner_report["report"]["weighted avg"]["recall"], - f"{category_name}_weighted_avg_f1": winner_report["report"]["weighted avg"]["f1-score"], - f"{category_name}_weighted_avg_support": winner_report["report"]["weighted avg"]["support"], + f"{category_name}_macro_avg_precision": winner_report["report"][ + "macro avg" + ]["precision"], + f"{category_name}_macro_avg_recall": winner_report["report"][ + "macro avg" + ]["recall"], + f"{category_name}_macro_avg_f1": winner_report["report"]["macro avg"][ + "f1-score" + ], + f"{category_name}_macro_avg_support": winner_report["report"][ + "macro avg" + ]["support"], + f"{category_name}_weighted_avg_precision": winner_report["report"][ + "weighted avg" + ]["precision"], + f"{category_name}_weighted_avg_recall": winner_report["report"][ + "weighted avg" + ]["recall"], + f"{category_name}_weighted_avg_f1": winner_report["report"][ + "weighted avg" + ]["f1-score"], + f"{category_name}_weighted_avg_support": winner_report["report"][ + "weighted avg" + ]["support"], } - trainer._tracker_client.send_model_stats(report_stats, winner_report["epoch"]) + trainer._tracker_client.send_model_stats( + report_stats, winner_report["epoch"] + ) except Exception as e: - logger.exception("Failed on training meta model: %s. This could be benign if training data has no annotations belonging to this category.", category_name) + logger.exception( + "Failed on training meta model: %s. This could be benign if training" + " data has no annotations belonging to this category.", + category_name, + ) trainer._tracker_client.log_exceptions(e) if not is_retrained: - exception = TrainingFailedException("No metacat model has been retrained. Double-check the presence of metacat models and your annotations.") - logger.error("Error occurred while retraining the model: %s", exception, exc_info=True) + exception = TrainingFailedException( + "No metacat model has been retrained. Double-check the presence of metacat" + " models and your annotations." + ) + logger.error( + "Error occurred while retraining the model: %s", exception, exc_info=True + ) trainer._tracker_client.log_exceptions(exception) trainer._tracker_client.end_with_failure() return if not skip_save_model: - model_pack_path = trainer.save_model_pack(model, trainer._retrained_models_dir, description) + model_pack_path = trainer.save_model_pack( + model, trainer._retrained_models_dir, description + ) cdb_config_path = model_pack_path.replace(".zip", "_config.json") model.cdb.config.save(cdb_config_path) - model_uri = trainer._tracker_client.save_model(model_pack_path, trainer._model_name, trainer._model_manager) + model_uri = trainer._tracker_client.save_model( + model_pack_path, trainer._model_name, trainer._model_manager + ) logger.info("Retrained model saved: %s", model_uri) - trainer._tracker_client.save_model_artifact(cdb_config_path, trainer._model_name) + trainer._tracker_client.save_model_artifact( + cdb_config_path, trainer._model_name + ) else: logger.info("Skipped saving on the retrained model") if redeploy: @@ -124,7 +171,9 @@ def run(trainer: "MetacatTrainer", os.remove(cdb_config_path) # Remove intermediate results folder on successful training - results_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "results")) + results_path = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "results") + ) if results_path and os.path.isdir(results_path): shutil.rmtree(results_path) else: @@ -133,20 +182,35 @@ def run(trainer: "MetacatTrainer", metrics: List[Dict] = [] for meta_cat in trainer._model_service._model._meta_cats: category_name = meta_cat.config.general["category_name"] - trainer._tracker_client.log_model_config(trainer.get_flattened_config(meta_cat, category_name)) + trainer._tracker_client.log_model_config( + trainer.get_flattened_config(meta_cat, category_name) + ) trainer._tracker_client.log_trainer_version(medcat_version) result = meta_cat.eval(data_file.name) - metrics.append({"precision": result.get("precision"), "recall": result.get("recall"), "f1": result.get("f1")}) + metrics.append( + { + "precision": result.get("precision"), + "recall": result.get("recall"), + "f1": result.get("f1"), + } + ) if metrics: - trainer._tracker_client.save_dataframe_as_csv("sanity_check_result.csv", - pd.DataFrame(metrics, columns=["category", "precision", "recall", "f1"]), - trainer._model_service._model_name) + trainer._tracker_client.save_dataframe_as_csv( + "sanity_check_result.csv", + pd.DataFrame(metrics, columns=["category", "precision", "recall", "f1"]), + trainer._model_service._model_name, + ) trainer._tracker_client.end_with_success() logger.info("Model evaluation finished") else: - exception = TrainingFailedException("No metacat model has been evaluated. Double-check the presence of metacat models and your annotations.") - logger.error("Error occurred while evaluating the model: %s", exception, exc_info=True) + exception = TrainingFailedException( + "No metacat model has been evaluated. Double-check the presence of metacat" + " models and your annotations." + ) + logger.error( + "Error occurred while evaluating the model: %s", exception, exc_info=True + ) trainer._tracker_client.log_exceptions(exception) trainer._tracker_client.end_with_failure() return diff --git a/app/utils.py b/app/utils.py index 439166b..f195c0c 100644 --- a/app/utils.py +++ b/app/utils.py @@ -1,23 +1,25 @@ +import copy +import functools +import inspect import json -import socket +import os import random +import socket import struct -import inspect -import os -import copy -import functools import warnings -import torch +from functools import lru_cache +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from urllib.parse import ParseResult + import numpy as np import pandas as pd +import torch +from safetensors.torch import load_file from spacy.lang.en import English from spacy.util import filter_spans -from safetensors.torch import load_file -from urllib.parse import ParseResult -from functools import lru_cache -from typing import List, Optional, Dict, Callable, Any, Union, Tuple, Type -from domain import Annotation, Entity, CodeType, ModelType, Device + from config import Settings +from domain import Annotation, CodeType, Device, Entity, ModelType @lru_cache @@ -43,13 +45,17 @@ def annotations_to_entities(annotations: List[Annotation], model_name: str) -> L entities = [] code_base_uri = get_code_base_uri(model_name) for _, annotation in enumerate(annotations): - entities.append({ - "start": annotation["start"], - "end": annotation["end"], - "label": f"{annotation['label_name']}", - "kb_id": annotation["label_id"], - "kb_url": f"{code_base_uri}/{annotation['label_id']}" if code_base_uri is not None else "#" - }) + entities.append( + { + "start": annotation["start"], + "end": annotation["end"], + "label": f"{annotation['label_name']}", + "kb_id": annotation["label_id"], + "kb_url": f"{code_base_uri}/{annotation['label_id']}" + if code_base_uri is not None + else "#", + } + ) return entities @@ -71,19 +77,30 @@ def send_gelf_message(message: str, gelf_input_uri: ParseResult) -> None: def get_func_params_as_dict(func: Callable) -> Dict: signature = inspect.signature(func) - params = {name: param.default for name, param in signature.parameters.items() if param.default is not inspect.Parameter.empty} + params = { + name: param.default + for name, param in signature.parameters.items() + if param.default is not inspect.Parameter.empty + } return params def json_normalize_trainer_export(trainer_export: Dict) -> pd.DataFrame: - return pd.json_normalize(trainer_export, - record_path=["projects", "documents", "annotations"], - meta=[ - ["projects", "name"], ["projects", "id"], ["projects", "cuis"], ["projects", "tuis"], - ["projects", "documents", "id"], ["projects", "documents", "name"], - ["projects", "documents", "text"], ["projects", "documents", "last_modified"] - ], - sep=".") + return pd.json_normalize( + trainer_export, + record_path=["projects", "documents", "annotations"], + meta=[ + ["projects", "name"], + ["projects", "id"], + ["projects", "cuis"], + ["projects", "tuis"], + ["projects", "documents", "id"], + ["projects", "documents", "name"], + ["projects", "documents", "text"], + ["projects", "documents", "last_modified"], + ], + sep=".", + ) def json_normalize_medcat_entities(medcat_entities: Dict) -> pd.DataFrame: @@ -102,7 +119,7 @@ def json_denormalize(df: pd.DataFrame, sep: str = ".") -> List[Dict]: keys = col.split(sep) current = result_row for i, k in enumerate(keys): - if i == len(keys)-1: + if i == len(keys) - 1: current[k] = cell else: if k not in current: @@ -112,50 +129,82 @@ def json_denormalize(df: pd.DataFrame, sep: str = ".") -> List[Dict]: return result -def filter_by_concept_ids(trainer_export: Dict[str, Any], - model_type: Optional[ModelType] = None, - extra_excluded: Optional[List[str]] = None) -> Dict[str, Any]: +def filter_by_concept_ids( + trainer_export: Dict[str, Any], + model_type: Optional[ModelType] = None, + extra_excluded: Optional[List[str]] = None, +) -> Dict[str, Any]: concept_ids = get_settings().TRAINING_CONCEPT_ID_WHITELIST.split(",") filtered = copy.deepcopy(trainer_export) for project in filtered.get("projects", []): for document in project.get("documents", []): if concept_ids == [""]: - document["annotations"] = [anno for anno in document.get("annotations", []) if anno.get("correct", True) and not anno.get("deleted", False) and not anno.get("killed", False)] + document["annotations"] = [ + anno + for anno in document.get("annotations", []) + if anno.get("correct", True) + and not anno.get("deleted", False) + and not anno.get("killed", False) + ] else: - document["annotations"] = [anno for anno in document.get("annotations", []) if anno.get("cui") in concept_ids and anno.get("correct", True) and not anno.get("deleted", False) and not anno.get("killed", False)] + document["annotations"] = [ + anno + for anno in document.get("annotations", []) + if anno.get("cui") in concept_ids + and anno.get("correct", True) + and not anno.get("deleted", False) + and not anno.get("killed", False) + ] if extra_excluded is not None and len(extra_excluded) > 0: - document["annotations"] = [anno for anno in document.get("annotations", []) if anno.get("cui") not in extra_excluded] + document["annotations"] = [ + anno + for anno in document.get("annotations", []) + if anno.get("cui") not in extra_excluded + ] if model_type in [ModelType.TRANSFORMERS_DEID, ModelType.MEDCAT_DEID, ModelType.ANONCAT]: # special preprocessing for the DeID annotations and consider removing this. for project in filtered["projects"]: for document in project["documents"]: for annotation in document["annotations"]: - if annotation["cui"] == "N1100" or annotation["cui"] == "N1200": # for metric calculation + if ( + annotation["cui"] == "N1100" or annotation["cui"] == "N1200" + ): # for metric calculation annotation["cui"] = "N1000" - if annotation["cui"] == "W5000" and (model_type in [ModelType.MEDCAT_DEID, ModelType.ANONCAT]): # for compatibility + if annotation["cui"] == "W5000" and ( + model_type in [ModelType.MEDCAT_DEID, ModelType.ANONCAT] + ): # for compatibility annotation["cui"] = "C2500" return filtered -def replace_spans_of_concept(trainer_export: Dict[str, Any], concept_id: str, transform: Callable) -> Dict[str, Any]: +def replace_spans_of_concept( + trainer_export: Dict[str, Any], concept_id: str, transform: Callable +) -> Dict[str, Any]: doc_with_initials_ids = set() copied = copy.deepcopy(trainer_export) for project in copied.get("projects", []): for document in project.get("documents", []): text = document.get("text", "") offset = 0 - document["annotations"] = sorted(document.get("annotations", []), key=lambda annotation: annotation["start"]) + document["annotations"] = sorted( + document.get("annotations", []), key=lambda annotation: annotation["start"] + ) for annotation in document.get("annotations", []): annotation["start"] += offset annotation["end"] += offset - if annotation["cui"] == concept_id and annotation.get("correct", True) and not annotation.get("deleted", False) and not annotation.get("killed", False): + if ( + annotation["cui"] == concept_id + and annotation.get("correct", True) + and not annotation.get("deleted", False) + and not annotation.get("killed", False) + ): original = annotation["value"] modified = transform(original) extended = len(modified) - len(original) - text = text[:annotation["start"]] + modified + text[annotation["end"]:] + text = text[: annotation["start"]] + modified + text[annotation["end"] :] annotation["value"] = modified annotation["end"] += extended offset += extended @@ -164,68 +213,102 @@ def replace_spans_of_concept(trainer_export: Dict[str, Any], concept_id: str, tr return copied -def breakdown_annotations(trainer_export: Dict[str, Any], - target_concept_ids: List[str], - primary_delimiter: str, - secondary_delimiter: Optional[str] = None, - *, - include_delimiter: bool = True) -> Dict[str, Any]: +def breakdown_annotations( + trainer_export: Dict[str, Any], + target_concept_ids: List[str], + primary_delimiter: str, + secondary_delimiter: Optional[str] = None, + *, + include_delimiter: bool = True, +) -> Dict[str, Any]: assert isinstance(target_concept_ids, list), "The target_concept_ids is not a list" copied = copy.deepcopy(trainer_export) for project in copied["projects"]: for document in project["documents"]: new_annotations = [] for annotation in document["annotations"]: - if annotation["cui"] in target_concept_ids and primary_delimiter in annotation["value"]: + if ( + annotation["cui"] in target_concept_ids + and primary_delimiter in annotation["value"] + ): start_offset = 0 for sub_text in annotation["value"].split(primary_delimiter): if secondary_delimiter is not None and secondary_delimiter in sub_text: for sub_sub_text in sub_text.split(secondary_delimiter): - if sub_sub_text == "" or all(char.isspace() for char in sub_sub_text): + if sub_sub_text == "" or all( + char.isspace() for char in sub_sub_text + ): start_offset += len(sub_sub_text) + len(secondary_delimiter) continue sub_sub_annotation = copy.deepcopy(annotation) sub_sub_annotation["start"] = annotation["start"] + start_offset - sub_sub_annotation["end"] = sub_sub_annotation["start"] + len(sub_sub_text) + (len(secondary_delimiter) if include_delimiter else 0) - sub_sub_annotation["value"] = sub_sub_text + (secondary_delimiter if include_delimiter else "") + sub_sub_annotation["end"] = ( + sub_sub_annotation["start"] + + len(sub_sub_text) + + (len(secondary_delimiter) if include_delimiter else 0) + ) + sub_sub_annotation["value"] = sub_sub_text + ( + secondary_delimiter if include_delimiter else "" + ) start_offset += len(sub_sub_text) + len(secondary_delimiter) new_annotations.append(sub_sub_annotation) if include_delimiter: - new_annotations[-1]["value"] = new_annotations[-1]["value"][:-len(secondary_delimiter)] + primary_delimiter + new_annotations[-1]["value"] = ( + new_annotations[-1]["value"][: -len(secondary_delimiter)] + + primary_delimiter + ) else: if sub_text == "" or all(char.isspace() for char in sub_text): start_offset += len(sub_text) + len(primary_delimiter) continue sub_annotation = copy.deepcopy(annotation) sub_annotation["start"] = annotation["start"] + start_offset - sub_annotation["end"] = sub_annotation["start"] + len(sub_text) + (len(primary_delimiter) if include_delimiter else 0) - sub_annotation["value"] = sub_text + (primary_delimiter if include_delimiter else "") + sub_annotation["end"] = ( + sub_annotation["start"] + + len(sub_text) + + (len(primary_delimiter) if include_delimiter else 0) + ) + sub_annotation["value"] = sub_text + ( + primary_delimiter if include_delimiter else "" + ) start_offset += len(sub_text) + len(primary_delimiter) new_annotations.append(sub_annotation) if include_delimiter: new_annotations[-1]["end"] -= len(primary_delimiter) - new_annotations[-1]["value"] = new_annotations[-1]["value"][:-len(primary_delimiter)] + new_annotations[-1]["value"] = new_annotations[-1]["value"][ + : -len(primary_delimiter) + ] else: new_annotations.append(annotation) document["annotations"] = new_annotations return copied -def augment_annotations(trainer_export: Dict, cui_regexes_lists: Dict[str, List[List]], *, case_sensitive: bool = True) -> Dict: +def augment_annotations( + trainer_export: Dict, cui_regexes_lists: Dict[str, List[List]], *, case_sensitive: bool = True +) -> Dict: nlp = English() patterns = [] for cui, regexes in cui_regexes_lists.items(): - pts = [{ - "label": cui, - "pattern": [{"TEXT": {"REGEX": part if case_sensitive else r"(?i)" + part}} for part in regex] - } for regex in regexes] + pts = [ + { + "label": cui, + "pattern": [ + {"TEXT": {"REGEX": part if case_sensitive else r"(?i)" + part}} + for part in regex + ], + } + for regex in regexes + ] patterns += pts ruler = nlp.add_pipe("entity_ruler") - ruler.add_patterns(patterns) # type: ignore + ruler.add_patterns(patterns) # type: ignore copied = copy.deepcopy(trainer_export) for project in copied["projects"]: for document in project["documents"]: - document["annotations"] = sorted(document["annotations"], key=lambda anno: anno["start"]) + document["annotations"] = sorted( + document["annotations"], key=lambda anno: anno["start"] + ) gaps = [] gap_start = 0 for annotation in document["annotations"]: @@ -233,7 +316,7 @@ def augment_annotations(trainer_export: Dict, cui_regexes_lists: Dict[str, List[ gaps.append((gap_start, annotation["start"])) gap_start = annotation["end"] if gap_start < len(document["text"]): - gaps.append((gap_start, len(document["text"])+1)) + gaps.append((gap_start, len(document["text"]) + 1)) new_annotations = [] doc = nlp(document["text"]) spans = filter_spans(doc.ents) @@ -253,27 +336,36 @@ def augment_annotations(trainer_export: Dict, cui_regexes_lists: Dict[str, List[ new_annotations.append(annotation) break document["annotations"] += new_annotations - document["annotations"] = sorted(document["annotations"], key=lambda anno: anno["start"]) + document["annotations"] = sorted( + document["annotations"], key=lambda anno: anno["start"] + ) return copied -def safetensors_to_pytorch(safetensors_file_path: Union[str, os.PathLike], - pytorch_file_path: Union[str, os.PathLike]) -> None: +def safetensors_to_pytorch( + safetensors_file_path: Union[str, os.PathLike], pytorch_file_path: Union[str, os.PathLike] +) -> None: state_dict = load_file(safetensors_file_path) torch.save(state_dict, pytorch_file_path) def func_deprecated(message: Optional[str] = None) -> Callable: def decorator(func: Callable) -> Callable: - @functools.wraps(func) def wrapped(*args: Tuple, **kwargs: Dict[str, Any]) -> Callable: warnings.simplefilter("always", DeprecationWarning) - warnings.warn("Function {} has been deprecated.{}".format(func.__name__, " " + message if message else ""), stacklevel=2) + warnings.warn( + "Function {} has been deprecated.{}".format( + func.__name__, " " + message if message else "" + ), + stacklevel=2, + ) warnings.simplefilter("default", DeprecationWarning) return func(*args, **kwargs) + return wrapped + return decorator @@ -284,11 +376,17 @@ def decorator(cls: Type) -> Callable: @functools.wraps(decorated_init) def wrapped(self: "Type", *args: Tuple, **kwargs: Dict[str, Any]) -> Any: warnings.simplefilter("always", DeprecationWarning) - warnings.warn("Class {} has been deprecated.{}".format(cls.__name__, " " + message if message else "")) + warnings.warn( + "Class {} has been deprecated.{}".format( + cls.__name__, " " + message if message else "" + ) + ) warnings.simplefilter("default", DeprecationWarning) decorated_init(self, *args, **kwargs) + cls.__init__ = wrapped return cls + return decorator @@ -301,11 +399,13 @@ def reset_random_seed() -> None: def non_default_device_is_available(device: str) -> bool: - return any([ - device.startswith(Device.GPU.value) and torch.cuda.is_available(), - device.startswith(Device.MPS.value) and torch.backends.mps.is_available(), - device.startswith(Device.CPU.value) - ]) + return any( + [ + device.startswith(Device.GPU.value) and torch.cuda.is_available(), + device.startswith(Device.MPS.value) and torch.backends.mps.is_available(), + device.startswith(Device.CPU.value), + ] + ) def get_hf_pipeline_device_id(device: str) -> int: @@ -374,6 +474,10 @@ def get_hf_pipeline_device_id(device: str) -> int: "92873870": "special concept", "78096516": "environment / location", "72706784": "context-dependent category", - "25624495": '© 2002-2020 International Health Terminology Standards Development Organisation (IHTSDO). All rights reserved. SNOMED CT®, was originally created by The College of American Pathologists. "SNOMED" and "SNOMED CT" are registered trademarks of the IHTSDO.', - "55540447": "linkage concept" + "25624495": ( + "© 2002-2020 International Health Terminology Standards Development Organisation (IHTSDO)." + " All rights reserved. SNOMED CT®, was originally created by The College of American" + ' Pathologists. "SNOMED" and "SNOMED CT" are registered trademarks of the IHTSDO.' + ), + "55540447": "linkage concept", } diff --git a/docker-compose-auth.yml b/docker-compose-auth.yml index cb27dee..b8771be 100644 --- a/docker-compose-auth.yml +++ b/docker-compose-auth.yml @@ -28,4 +28,4 @@ volumes: networks: cogstack-model-serve_cms: - external: true \ No newline at end of file + external: true diff --git a/docker-compose-celery.yml b/docker-compose-celery.yml index bce8ef8..53a7c59 100644 --- a/docker-compose-celery.yml +++ b/docker-compose-celery.yml @@ -54,4 +54,4 @@ services: networks: cms: name: cogstack-model-serve_cms - driver: bridge \ No newline at end of file + driver: bridge diff --git a/docker-compose-dev.yml b/docker-compose-dev.yml index 699e5ac..4e82d0b 100644 --- a/docker-compose-dev.yml +++ b/docker-compose-dev.yml @@ -146,4 +146,4 @@ volumes: networks: cms: - driver: bridge \ No newline at end of file + driver: bridge diff --git a/docker-compose-log.yml b/docker-compose-log.yml index a12c198..4faa88f 100644 --- a/docker-compose-log.yml +++ b/docker-compose-log.yml @@ -86,4 +86,4 @@ volumes: networks: cogstack-model-serve_cms: - external: true \ No newline at end of file + external: true diff --git a/docker-compose-mlflow.yml b/docker-compose-mlflow.yml index af1121c..c389cd5 100644 --- a/docker-compose-mlflow.yml +++ b/docker-compose-mlflow.yml @@ -183,4 +183,4 @@ volumes: networks: cogstack-model-serve_cms: - external: true \ No newline at end of file + external: true diff --git a/docker-compose-mon.yml b/docker-compose-mon.yml index 86dc8fd..a09880f 100644 --- a/docker-compose-mon.yml +++ b/docker-compose-mon.yml @@ -102,4 +102,4 @@ volumes: networks: cogstack-model-serve_cms: - external: true \ No newline at end of file + external: true diff --git a/docker-compose-proxy.yml b/docker-compose-proxy.yml index 4a9bf56..bceaedd 100644 --- a/docker-compose-proxy.yml +++ b/docker-compose-proxy.yml @@ -48,4 +48,4 @@ services: networks: cogstack-model-serve_cms: - external: true \ No newline at end of file + external: true diff --git a/docker-compose.yml b/docker-compose.yml index 4d17877..a148fed 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -310,4 +310,4 @@ volumes: networks: cms: name: cogstack-model-serve_cms - driver: bridge \ No newline at end of file + driver: bridge diff --git a/docker/README.md b/docker/README.md index 157763b..8403f51 100644 --- a/docker/README.md +++ b/docker/README.md @@ -1,3 +1,3 @@ # CogStack ModelServe Docker Images -A group of Docker images serves, trains or tracks models which perform NLP tasks over clinical notes. \ No newline at end of file +A group of Docker images serves, trains or tracks models which perform NLP tasks over clinical notes. diff --git a/docker/celery/.dockerignore b/docker/celery/.dockerignore index 90e2e80..8175385 100644 --- a/docker/celery/.dockerignore +++ b/docker/celery/.dockerignore @@ -1,2 +1,2 @@ app/model/* -app/mlruns/* \ No newline at end of file +app/mlruns/* diff --git a/docker/celery/.env b/docker/celery/.env index d6bd90d..7946561 100644 --- a/docker/celery/.env +++ b/docker/celery/.env @@ -1 +1 @@ -ENABLE_TRAINING_APIS=true \ No newline at end of file +ENABLE_TRAINING_APIS=true diff --git a/docker/celery/Dockerfile-Dashboard b/docker/celery/Dockerfile-Dashboard index ff52301..22a7788 100644 --- a/docker/celery/Dockerfile-Dashboard +++ b/docker/celery/Dockerfile-Dashboard @@ -19,4 +19,4 @@ COPY --chown=$CMS_UID:$CMS_GID docker/celery/requirements_dashboard.txt . RUN pip install --no-cache-dir -U pip &&\ pip install --no-cache-dir -r requirements_dashboard.txt -CMD ["sh", "-c", "/home/cms/.local/bin/celery --broker=redis://redis:6379/0 flower --port=5555"] \ No newline at end of file +CMD ["sh", "-c", "/home/cms/.local/bin/celery --broker=redis://redis:6379/0 flower --port=5555"] diff --git a/docker/celery/Dockerfile-Worker b/docker/celery/Dockerfile-Worker index ceaa24b..f1e4f64 100644 --- a/docker/celery/Dockerfile-Worker +++ b/docker/celery/Dockerfile-Worker @@ -36,4 +36,4 @@ RUN pip install --no-cache-dir -U pip &&\ pip install --no-cache-dir -r requirements.txt && \ python -m spacy download en_core_web_md -CMD ["sh", "-c", "/home/cms/.local/bin/celery -A tasks.tasks.celery worker --loglevel=debug --concurrency=1 --pool threads"] \ No newline at end of file +CMD ["sh", "-c", "/home/cms/.local/bin/celery -A tasks.tasks.celery worker --loglevel=debug --concurrency=1 --pool threads"] diff --git a/docker/celery/requirements-dashboard.txt b/docker/celery/requirements-dashboard.txt index 4dfae21..be1e7fd 100644 --- a/docker/celery/requirements-dashboard.txt +++ b/docker/celery/requirements-dashboard.txt @@ -2,4 +2,4 @@ celery~=5.4.0 redis~=5.0.8 flower~=2.0.1 setuptools -wheel \ No newline at end of file +wheel diff --git a/docker/celery/requirements.txt b/docker/celery/requirements.txt index d645aa7..2b10e7a 100644 --- a/docker/celery/requirements.txt +++ b/docker/celery/requirements.txt @@ -25,4 +25,4 @@ toml~=0.10.2 celery~=5.4.0 redis~=5.0.8 setuptools -wheel \ No newline at end of file +wheel diff --git a/docker/huggingface-ner/.dockerignore b/docker/huggingface-ner/.dockerignore index 90e2e80..8175385 100644 --- a/docker/huggingface-ner/.dockerignore +++ b/docker/huggingface-ner/.dockerignore @@ -1,2 +1,2 @@ app/model/* -app/mlruns/* \ No newline at end of file +app/mlruns/* diff --git a/docker/huggingface-ner/.env b/docker/huggingface-ner/.env index a281300..0a76c29 100644 --- a/docker/huggingface-ner/.env +++ b/docker/huggingface-ner/.env @@ -1,4 +1,4 @@ ENABLE_TRAINING_APIS=true ENABLE_EVALUATION_APIS=true ENABLE_PREVIEWS_APIS=true -LOG_PER_CONCEPT_ACCURACIES=true \ No newline at end of file +LOG_PER_CONCEPT_ACCURACIES=true diff --git a/docker/huggingface-ner/requirements.txt b/docker/huggingface-ner/requirements.txt index 956fb8b..01093ed 100644 --- a/docker/huggingface-ner/requirements.txt +++ b/docker/huggingface-ner/requirements.txt @@ -24,4 +24,4 @@ pynvml~=11.5.3 toml~=0.10.2 peft<0.14.0 setuptools -wheel \ No newline at end of file +wheel diff --git a/docker/medcat-deid/.dockerignore b/docker/medcat-deid/.dockerignore index 90e2e80..8175385 100644 --- a/docker/medcat-deid/.dockerignore +++ b/docker/medcat-deid/.dockerignore @@ -1,2 +1,2 @@ app/model/* -app/mlruns/* \ No newline at end of file +app/mlruns/* diff --git a/docker/medcat-deid/.env b/docker/medcat-deid/.env index 777e3ac..269517d 100644 --- a/docker/medcat-deid/.env +++ b/docker/medcat-deid/.env @@ -3,4 +3,4 @@ DISABLE_UNSUPERVISED_TRAINING=true ENABLE_EVALUATION_APIS=true ENABLE_PREVIEWS_APIS=true LOG_PER_CONCEPT_ACCURACIES=true -TRAINING_CONCEPT_ID_WHITELIST=O,X,N1000,N1100,N1200,N1300,C2110,C2120,C2200,C2300,D4000,H3100,H3200,H3300,H3400,H3500,H4100,W5000 \ No newline at end of file +TRAINING_CONCEPT_ID_WHITELIST=O,X,N1000,N1100,N1200,N1300,C2110,C2120,C2200,C2300,D4000,H3100,H3200,H3300,H3400,H3500,H4100,W5000 diff --git a/docker/medcat-deid/requirements.txt b/docker/medcat-deid/requirements.txt index 956fb8b..01093ed 100644 --- a/docker/medcat-deid/requirements.txt +++ b/docker/medcat-deid/requirements.txt @@ -24,4 +24,4 @@ pynvml~=11.5.3 toml~=0.10.2 peft<0.14.0 setuptools -wheel \ No newline at end of file +wheel diff --git a/docker/medcat-icd10/.dockerignore b/docker/medcat-icd10/.dockerignore index 90e2e80..8175385 100644 --- a/docker/medcat-icd10/.dockerignore +++ b/docker/medcat-icd10/.dockerignore @@ -1,2 +1,2 @@ app/model/* -app/mlruns/* \ No newline at end of file +app/mlruns/* diff --git a/docker/medcat-icd10/requirements.txt b/docker/medcat-icd10/requirements.txt index 6b88a64..fcd73a1 100644 --- a/docker/medcat-icd10/requirements.txt +++ b/docker/medcat-icd10/requirements.txt @@ -23,4 +23,4 @@ pynvml~=11.5.3 toml~=0.10.2 peft<0.14.0 setuptools -wheel \ No newline at end of file +wheel diff --git a/docker/medcat-snomed/.dockerignore b/docker/medcat-snomed/.dockerignore index 90e2e80..8175385 100644 --- a/docker/medcat-snomed/.dockerignore +++ b/docker/medcat-snomed/.dockerignore @@ -1,2 +1,2 @@ app/model/* -app/mlruns/* \ No newline at end of file +app/mlruns/* diff --git a/docker/medcat-snomed/requirements.txt b/docker/medcat-snomed/requirements.txt index 6b88a64..fcd73a1 100644 --- a/docker/medcat-snomed/requirements.txt +++ b/docker/medcat-snomed/requirements.txt @@ -23,4 +23,4 @@ pynvml~=11.5.3 toml~=0.10.2 peft<0.14.0 setuptools -wheel \ No newline at end of file +wheel diff --git a/docker/medcat-umls/.dockerignore b/docker/medcat-umls/.dockerignore index 90e2e80..8175385 100644 --- a/docker/medcat-umls/.dockerignore +++ b/docker/medcat-umls/.dockerignore @@ -1,2 +1,2 @@ app/model/* -app/mlruns/* \ No newline at end of file +app/mlruns/* diff --git a/docker/medcat-umls/requirements.txt b/docker/medcat-umls/requirements.txt index 6b88a64..fcd73a1 100644 --- a/docker/medcat-umls/requirements.txt +++ b/docker/medcat-umls/requirements.txt @@ -23,4 +23,4 @@ pynvml~=11.5.3 toml~=0.10.2 peft<0.14.0 setuptools -wheel \ No newline at end of file +wheel diff --git a/docker/mlflow/deployments/requirements.txt b/docker/mlflow/deployments/requirements.txt index 648fe1a..aec5440 100644 --- a/docker/mlflow/deployments/requirements.txt +++ b/docker/mlflow/deployments/requirements.txt @@ -1,3 +1,3 @@ mlflow[genai]~=2.16.2 setuptools -wheel \ No newline at end of file +wheel diff --git a/docker/mlflow/deployments/scripts/run.sh b/docker/mlflow/deployments/scripts/run.sh index 4a9419f..ca99bd9 100644 --- a/docker/mlflow/deployments/scripts/run.sh +++ b/docker/mlflow/deployments/scripts/run.sh @@ -1,3 +1,3 @@ #!/bin/sh -mlflow deployments start-server --config-path /opt/config/config.yml --host 0.0.0.0 --port 7000 \ No newline at end of file +mlflow deployments start-server --config-path /opt/config/config.yml --host 0.0.0.0 --port 7000 diff --git a/docker/mlflow/models/requirements.txt b/docker/mlflow/models/requirements.txt index d36d973..ad82c4d 100644 --- a/docker/mlflow/models/requirements.txt +++ b/docker/mlflow/models/requirements.txt @@ -2,4 +2,4 @@ mlflow~=2.16.2 psycopg2-binary~=2.9.4 boto3~=1.28.84 setuptools -wheel \ No newline at end of file +wheel diff --git a/docker/mlflow/server/Dockerfile b/docker/mlflow/server/Dockerfile index b4adfdf..bc497ed 100644 --- a/docker/mlflow/server/Dockerfile +++ b/docker/mlflow/server/Dockerfile @@ -33,4 +33,4 @@ COPY docker/mlflow/server/scripts/run.sh ${MLFLOW_HOME}/../scripts/run.sh RUN chmod +x ${MLFLOW_HOME}/../scripts/run.sh WORKDIR ${MLFLOW_HOME} -ENTRYPOINT ["./../scripts/run.sh"] \ No newline at end of file +ENTRYPOINT ["./../scripts/run.sh"] diff --git a/docker/mlflow/server/auth/basic_auth.ini b/docker/mlflow/server/auth/basic_auth.ini index 82e88a9..c87f550 100644 --- a/docker/mlflow/server/auth/basic_auth.ini +++ b/docker/mlflow/server/auth/basic_auth.ini @@ -2,4 +2,4 @@ default_permission = NO_PERMISSIONS database_uri = sqlite:///basic_auth.db admin_username = admin -admin_password = CHANGE_ME \ No newline at end of file +admin_password = CHANGE_ME diff --git a/docker/mlflow/server/requirements.txt b/docker/mlflow/server/requirements.txt index d36d973..ad82c4d 100644 --- a/docker/mlflow/server/requirements.txt +++ b/docker/mlflow/server/requirements.txt @@ -2,4 +2,4 @@ mlflow~=2.16.2 psycopg2-binary~=2.9.4 boto3~=1.28.84 setuptools -wheel \ No newline at end of file +wheel diff --git a/docker/monitoring/grafana/provisioning/dashboards/cms_containers.json b/docker/monitoring/grafana/provisioning/dashboards/cms_containers.json index 20662e4..6f2d558 100644 --- a/docker/monitoring/grafana/provisioning/dashboards/cms_containers.json +++ b/docker/monitoring/grafana/provisioning/dashboards/cms_containers.json @@ -815,4 +815,4 @@ "uid": "pMEd7m0Mz", "version": 3, "weekStart": "" -} \ No newline at end of file +} diff --git a/docker/monitoring/grafana/provisioning/dashboards/cms_deid_medcat.json b/docker/monitoring/grafana/provisioning/dashboards/cms_deid_medcat.json index 4eea61c..3633aed 100644 --- a/docker/monitoring/grafana/provisioning/dashboards/cms_deid_medcat.json +++ b/docker/monitoring/grafana/provisioning/dashboards/cms_deid_medcat.json @@ -3179,4 +3179,4 @@ "uid": "4Rjky-h4k", "version": 7, "weekStart": "" -} \ No newline at end of file +} diff --git a/docker/monitoring/grafana/provisioning/dashboards/cms_deid_trf.json b/docker/monitoring/grafana/provisioning/dashboards/cms_deid_trf.json index b28567f..64cf789 100644 --- a/docker/monitoring/grafana/provisioning/dashboards/cms_deid_trf.json +++ b/docker/monitoring/grafana/provisioning/dashboards/cms_deid_trf.json @@ -2620,4 +2620,4 @@ "uid": "wOX1s-h4k", "version": 7, "weekStart": "" -} \ No newline at end of file +} diff --git a/docker/monitoring/grafana/provisioning/dashboards/cms_huggingface_ner.json b/docker/monitoring/grafana/provisioning/dashboards/cms_huggingface_ner.json index 453aaeb..90b5f69 100644 --- a/docker/monitoring/grafana/provisioning/dashboards/cms_huggingface_ner.json +++ b/docker/monitoring/grafana/provisioning/dashboards/cms_huggingface_ner.json @@ -3179,4 +3179,4 @@ "uid": "uCpEt1QzW", "version": 1, "weekStart": "" -} \ No newline at end of file +} diff --git a/docker/monitoring/grafana/provisioning/dashboards/cms_icd10_medcat.json b/docker/monitoring/grafana/provisioning/dashboards/cms_icd10_medcat.json index 7c4fbbe..3a0cbcd 100644 --- a/docker/monitoring/grafana/provisioning/dashboards/cms_icd10_medcat.json +++ b/docker/monitoring/grafana/provisioning/dashboards/cms_icd10_medcat.json @@ -2834,4 +2834,4 @@ "uid": "CsBePa2Vz", "version": 11, "weekStart": "" -} \ No newline at end of file +} diff --git a/docker/monitoring/grafana/provisioning/dashboards/cms_prometheus.json b/docker/monitoring/grafana/provisioning/dashboards/cms_prometheus.json index a01a8ad..67b9610 100644 --- a/docker/monitoring/grafana/provisioning/dashboards/cms_prometheus.json +++ b/docker/monitoring/grafana/provisioning/dashboards/cms_prometheus.json @@ -3704,4 +3704,4 @@ "uid": "thnO_Gx4k", "version": 4, "weekStart": "" -} \ No newline at end of file +} diff --git a/docker/monitoring/grafana/provisioning/dashboards/cms_snomed_medcat.json b/docker/monitoring/grafana/provisioning/dashboards/cms_snomed_medcat.json index d4e51be..ef8b614 100644 --- a/docker/monitoring/grafana/provisioning/dashboards/cms_snomed_medcat.json +++ b/docker/monitoring/grafana/provisioning/dashboards/cms_snomed_medcat.json @@ -2834,4 +2834,4 @@ "uid": "OSXfX1h4z", "version": 87, "weekStart": "" -} \ No newline at end of file +} diff --git a/docker/monitoring/grafana/provisioning/dashboards/cms_umls_medcat.json b/docker/monitoring/grafana/provisioning/dashboards/cms_umls_medcat.json index 71a6476..25720d0 100644 --- a/docker/monitoring/grafana/provisioning/dashboards/cms_umls_medcat.json +++ b/docker/monitoring/grafana/provisioning/dashboards/cms_umls_medcat.json @@ -2834,4 +2834,4 @@ "uid": "c3Lgs-hVk", "version": 7, "weekStart": "" -} \ No newline at end of file +} diff --git a/docker/monitoring/grafana/provisioning/dashboards/dashboard.yml b/docker/monitoring/grafana/provisioning/dashboards/dashboard.yml index 41b090e..603e3b5 100644 --- a/docker/monitoring/grafana/provisioning/dashboards/dashboard.yml +++ b/docker/monitoring/grafana/provisioning/dashboards/dashboard.yml @@ -10,4 +10,4 @@ providers: editable: true allowUiUpdates: true options: - path: /etc/grafana/provisioning/dashboards \ No newline at end of file + path: /etc/grafana/provisioning/dashboards diff --git a/docker/monitoring/grafana/provisioning/datasources/datasource.yml b/docker/monitoring/grafana/provisioning/datasources/datasource.yml index a5dc55b..ec7a774 100644 --- a/docker/monitoring/grafana/provisioning/datasources/datasource.yml +++ b/docker/monitoring/grafana/provisioning/datasources/datasource.yml @@ -8,4 +8,4 @@ datasources: url: http://prometheus:9090 basicAuth: false isDefault: true - editable: true \ No newline at end of file + editable: true diff --git a/docker/monitoring/prometheus/alertmanager.yml b/docker/monitoring/prometheus/alertmanager.yml index b327fd7..67a1b8f 100644 --- a/docker/monitoring/prometheus/alertmanager.yml +++ b/docker/monitoring/prometheus/alertmanager.yml @@ -36,4 +36,4 @@ receivers: smarthost: smtp.gmail.com:587 auth_username: first.last@gmail.com auth_identity: first.last@gmail.com - auth_password: \ No newline at end of file + auth_password: diff --git a/docker/nginx/Dockerfile b/docker/nginx/Dockerfile index a4d2d0f..da601a4 100644 --- a/docker/nginx/Dockerfile +++ b/docker/nginx/Dockerfile @@ -1,4 +1,4 @@ FROM nginx:1.23.0 RUN rm /etc/nginx/conf.d/default.conf &&\ - rm /etc/nginx/nginx.conf \ No newline at end of file + rm /etc/nginx/nginx.conf diff --git a/docker/nginx/etc/nginx/cors.conf b/docker/nginx/etc/nginx/cors.conf index 1267452..abc9104 100644 --- a/docker/nginx/etc/nginx/cors.conf +++ b/docker/nginx/etc/nginx/cors.conf @@ -52,4 +52,4 @@ if ($cors_method = 'noopt') { add_header 'Access-Control-Allow-Credentials' 'true' always; add_header 'Access-Control-Allow-Methods' 'GET, POST, PUT, DELETE, OPTIONS' always; add_header 'Access-Control-Allow-Headers' 'Accept,Authorization,Cache-Control,Content-Type,DNT,If-Modified-Since,Keep-Alive,Origin,User-Agent,X-Requested-With' always; -} \ No newline at end of file +} diff --git a/docker/nginx/etc/nginx/nginx.conf b/docker/nginx/etc/nginx/nginx.conf index 828597d..cb2e092 100644 --- a/docker/nginx/etc/nginx/nginx.conf +++ b/docker/nginx/etc/nginx/nginx.conf @@ -34,4 +34,4 @@ http { include sites-enabled/prometheus; include sites-enabled/grafana; include sites-enabled/graylog; -} \ No newline at end of file +} diff --git a/docker/nginx/etc/nginx/sites-enabled/de-identification b/docker/nginx/etc/nginx/sites-enabled/de-identification index dfc0a44..439571a 100644 --- a/docker/nginx/etc/nginx/sites-enabled/de-identification +++ b/docker/nginx/etc/nginx/sites-enabled/de-identification @@ -27,4 +27,4 @@ server { proxy_pass http://de-identification:8000/; proxy_set_header Host $host; } -} \ No newline at end of file +} diff --git a/docker/nginx/etc/nginx/sites-enabled/grafana b/docker/nginx/etc/nginx/sites-enabled/grafana index e44003e..da4ccfa 100644 --- a/docker/nginx/etc/nginx/sites-enabled/grafana +++ b/docker/nginx/etc/nginx/sites-enabled/grafana @@ -27,4 +27,4 @@ server { proxy_pass http://grafana:3000/; proxy_set_header Host $host; } -} \ No newline at end of file +} diff --git a/docker/nginx/etc/nginx/sites-enabled/graylog b/docker/nginx/etc/nginx/sites-enabled/graylog index a659f20..88a673e 100644 --- a/docker/nginx/etc/nginx/sites-enabled/graylog +++ b/docker/nginx/etc/nginx/sites-enabled/graylog @@ -27,4 +27,4 @@ server { proxy_pass http://graylog:9000/; proxy_set_header Host $host; } -} \ No newline at end of file +} diff --git a/docker/nginx/etc/nginx/sites-enabled/huggingface-ner b/docker/nginx/etc/nginx/sites-enabled/huggingface-ner index 455c574..6d8200b 100644 --- a/docker/nginx/etc/nginx/sites-enabled/huggingface-ner +++ b/docker/nginx/etc/nginx/sites-enabled/huggingface-ner @@ -27,4 +27,4 @@ server { proxy_pass http://huggingface-ner:8000/; proxy_set_header Host $host; } -} \ No newline at end of file +} diff --git a/docker/nginx/etc/nginx/sites-enabled/medcat-deid b/docker/nginx/etc/nginx/sites-enabled/medcat-deid index b6b1439..5620be3 100644 --- a/docker/nginx/etc/nginx/sites-enabled/medcat-deid +++ b/docker/nginx/etc/nginx/sites-enabled/medcat-deid @@ -27,4 +27,4 @@ server { proxy_pass http://medcat-deid:8000/; proxy_set_header Host $host; } -} \ No newline at end of file +} diff --git a/docker/nginx/etc/nginx/sites-enabled/medcat-icd10 b/docker/nginx/etc/nginx/sites-enabled/medcat-icd10 index 8eb4abb..034c7b8 100644 --- a/docker/nginx/etc/nginx/sites-enabled/medcat-icd10 +++ b/docker/nginx/etc/nginx/sites-enabled/medcat-icd10 @@ -27,4 +27,4 @@ server { proxy_pass http://medcat-icd10:8000/; proxy_set_header Host $host; } -} \ No newline at end of file +} diff --git a/docker/nginx/etc/nginx/sites-enabled/medcat-snomed b/docker/nginx/etc/nginx/sites-enabled/medcat-snomed index 69d3b67..2f528b6 100644 --- a/docker/nginx/etc/nginx/sites-enabled/medcat-snomed +++ b/docker/nginx/etc/nginx/sites-enabled/medcat-snomed @@ -27,4 +27,4 @@ server { proxy_pass http://medcat-snomed:8000/; proxy_set_header Host $host; } -} \ No newline at end of file +} diff --git a/docker/nginx/etc/nginx/sites-enabled/medcat-umls b/docker/nginx/etc/nginx/sites-enabled/medcat-umls index e254e56..a98d329 100644 --- a/docker/nginx/etc/nginx/sites-enabled/medcat-umls +++ b/docker/nginx/etc/nginx/sites-enabled/medcat-umls @@ -27,4 +27,4 @@ server { proxy_pass http://medcat-umls:8000/; proxy_set_header Host $host; } -} \ No newline at end of file +} diff --git a/docker/nginx/etc/nginx/sites-enabled/minio b/docker/nginx/etc/nginx/sites-enabled/minio index f9093df..b49e275 100644 --- a/docker/nginx/etc/nginx/sites-enabled/minio +++ b/docker/nginx/etc/nginx/sites-enabled/minio @@ -27,4 +27,4 @@ server { proxy_pass http://minio:9001/; proxy_set_header Host $host; } -} \ No newline at end of file +} diff --git a/docker/nginx/etc/nginx/sites-enabled/mlflow-ui b/docker/nginx/etc/nginx/sites-enabled/mlflow-ui index 1cafb07..7fb6424 100644 --- a/docker/nginx/etc/nginx/sites-enabled/mlflow-ui +++ b/docker/nginx/etc/nginx/sites-enabled/mlflow-ui @@ -31,4 +31,4 @@ server { proxy_pass http://mlflow-ui:5000/; proxy_set_header Host $host; } -} \ No newline at end of file +} diff --git a/docker/nginx/etc/nginx/sites-enabled/prometheus b/docker/nginx/etc/nginx/sites-enabled/prometheus index b66e2d7..289b4bf 100644 --- a/docker/nginx/etc/nginx/sites-enabled/prometheus +++ b/docker/nginx/etc/nginx/sites-enabled/prometheus @@ -29,4 +29,4 @@ server { proxy_pass http://prometheus:9090/; proxy_set_header Host $host; } -} \ No newline at end of file +} diff --git a/docker/trf-deid/.dockerignore b/docker/trf-deid/.dockerignore index 90e2e80..8175385 100644 --- a/docker/trf-deid/.dockerignore +++ b/docker/trf-deid/.dockerignore @@ -1,2 +1,2 @@ app/model/* -app/mlruns/* \ No newline at end of file +app/mlruns/* diff --git a/docker/trf-deid/requirements.txt b/docker/trf-deid/requirements.txt index 5acf679..0399321 100644 --- a/docker/trf-deid/requirements.txt +++ b/docker/trf-deid/requirements.txt @@ -23,4 +23,4 @@ pynvml~=11.5.3 toml~=0.10.2 peft<0.14.0 setuptools -wheel \ No newline at end of file +wheel diff --git a/pyproject.toml b/pyproject.toml index f14e8a3..f1c5845 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,7 @@ Documentation = "https://cogstack.github.io/CogStack-ModelServe/docs/cogstack_mo [project.optional-dependencies] dev = [ + "pre-commit~=4.0.1", "pytest~=7.1.2", "pytest-mock~=3.7.0", "pytest-timeout~=2.1.0", @@ -77,43 +78,16 @@ version = {attr = "app.__version__"} cms = "app.cli.cli:cmd_app" [tool.ruff] -include = [ - "app/model_services/*.py", - "app/management/*.py", - "app/processors/*.py", - "app/api/*.py", - "app/api/routers/*.py", - "app/api/auth/*.py", - "app/cli/*.py", - "app/trainers/*.py", - "app/*.py", - "scripts/generate_annotations.py", - "scripts/hf_model_packager.py", - "scripts/medcat_concept_diff.py", - "scripts/medcat_config_diff.py", - "scripts/remove_model_version.py", - "scripts/generate_annotations.py", - "tests/", -] - -line-length = 120 -indent-width = 4 +line-length = 100 target-version = "py310" [tool.ruff.lint] -select = ["E", "F", "W", "C90"] -ignore = ["E501", "E226", "C901"] -fixable = ["ALL"] -unfixable = [] -dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" +select = ["E", "F", "I", "W", "C90"] +ignore = ["C901"] -[tool.ruff.format] -quote-style = "double" -indent-style = "space" -skip-magic-trailing-comma = false -line-ending = "auto" -docstring-code-format = false -docstring-code-line-length = "dynamic" +[tool.ruff.lint.isort] +known-first-party = ["config", "domain", "exception", "helper", "registry", "utils"] +known-local-folder = ["api", "cli", "data", "management", "model_services", "processors", "trainers"] [tool.mypy] ignore_missing_imports = true @@ -136,4 +110,4 @@ init_typed = true warn_required_dynamic_aliases = true [tool.pytest.ini_options] -pythonpath = "./app" \ No newline at end of file +pythonpath = "./app" diff --git a/scripts/generate_annotations.py b/scripts/generate_annotations.py index dd3dede..70419d9 100755 --- a/scripts/generate_annotations.py +++ b/scripts/generate_annotations.py @@ -4,13 +4,16 @@ import random import sys from argparse import ArgumentParser -from typing import List, Dict +from typing import Dict, List + from medcat.cat import CAT from spacy.lang.en import English from tqdm.autonotebook import tqdm -def generate_annotations(cuis: List, texts: List, minimum_words: int, cui2original_names: Dict) -> Dict: +def generate_annotations( + cuis: List, texts: List, minimum_words: int, cui2original_names: Dict +) -> Dict: original_names = {cui: cui2original_names[cui] for cui in cuis if cui in cui2original_names} new_snames = {} for cui, names in original_names.items(): @@ -30,10 +33,12 @@ def generate_annotations(cuis: List, texts: List, minimum_words: int, cui2origin } patterns.append(pattern) ruler = nlp.add_pipe("entity_ruler", config={"phrase_matcher_attr": "LOWER"}) - ruler.add_patterns(patterns) # type: ignore + ruler.add_patterns(patterns) # type: ignore documents = [] - for doc_id, text in enumerate(tqdm(texts, desc="Evaluating projects", total=len(texts), leave=False)): + for doc_id, text in enumerate( + tqdm(texts, desc="Evaluating projects", total=len(texts), leave=False) + ): doc = nlp(text) annotations = [] for ent in doc.ents: @@ -65,7 +70,7 @@ def generate_annotations(cuis: List, texts: List, minimum_words: int, cui2origin "--cuis", type=str, default="", - help="The path to the file containing newline-separated CUIs" + help="The path to the file containing newline-separated CUIs", ) parser.add_argument( "-t", @@ -85,13 +90,10 @@ def generate_annotations(cuis: List, texts: List, minimum_words: int, cui2origin "--min-words", type=int, default=1, - help="The lowest number of words each generated annotation will have" + help="The lowest number of words each generated annotation will have", ) parser.add_argument( - "-m", - "--model-pack-path", - type=str, - help="The path to the first model package" + "-m", "--model-pack-path", type=str, help="The path to the first model package" ) parser.add_argument( "-o", @@ -101,17 +103,32 @@ def generate_annotations(cuis: List, texts: List, minimum_words: int, cui2origin ) FLAGS, unparsed = parser.parse_known_args() + error_prefix = "ERROR:" if FLAGS.cuis == "": - print("ERROR: The path to the CUI file is empty. Use '-c' to pass in the file containing newline-separated CUIs.") + print( + error_prefix, + "The path to the CUI file is empty." + "Use '-c' to pass in the file containing newline-separated CUIs.", + ) sys.exit(1) if FLAGS.texts == "": - print("ERROR: The path to the text file is empty. Use '-t' to pass in the file containing texts as a JSON list.") + print( + error_prefix, + "The path to the text file is empty." + "Use '-t' to pass in the file containing texts as a JSON list.", + ) sys.exit(1) if FLAGS.model_pack_path == "": - print("ERROR: The path to the model package is empty. Use '-m' to pass in the model pack path.") + print( + error_prefix, + "The path to the model package is empty. Use '-m' to pass in the model pack path.", + ) sys.exit(1) if FLAGS.output == "": - print("ERROR: The path to the output file is empty. Use '-o' to pass in the file of annotations.") + print( + error_prefix, + "The path to the output file is empty. Use '-o' to pass in the file of annotations.", + ) sys.exit(1) with open(FLAGS.cuis, "r") as f: @@ -123,7 +140,9 @@ def generate_annotations(cuis: List, texts: List, minimum_words: int, cui2origin texts = random.sample(texts, FLAGS.sample_size) cat = CAT.load_model_pack(FLAGS.model_pack_path) - annotations = generate_annotations(cuis, texts, FLAGS.min_words, cat.cdb.addl_info["cui2original_names"]) + annotations = generate_annotations( + cuis, texts, FLAGS.min_words, cat.cdb.addl_info["cui2original_names"] + ) with open(FLAGS.output, "w") as f: json.dump(annotations, f, indent=4) diff --git a/scripts/medcat_concept_diff.py b/scripts/medcat_concept_diff.py index 90df249..302fbd0 100755 --- a/scripts/medcat_concept_diff.py +++ b/scripts/medcat_concept_diff.py @@ -1,10 +1,11 @@ #!/usr/bin/env python +import difflib import os import sys -import difflib -import jsonpickle from argparse import ArgumentParser + +import jsonpickle from medcat.cat import CAT jsonpickle.set_encoder_options("json", sort_keys=True, indent=4) @@ -12,11 +13,7 @@ if __name__ == "__main__": parser = ArgumentParser() parser.add_argument( - "-a", - "--model-pack-path", - type=str, - default="", - help="The path to the first model pack" + "-a", "--model-pack-path", type=str, default="", help="The path to the first model pack" ) parser.add_argument( "-b", @@ -28,7 +25,7 @@ "-p", "--with-preferred-name", action="store_true", - help="Print preferred names of concepts as the second column" + help="Print preferred names of concepts as the second column", ) FLAGS, unparsed = parser.parse_known_args() diff --git a/scripts/medcat_config_diff.py b/scripts/medcat_config_diff.py index bce2ef8..5a17c71 100755 --- a/scripts/medcat_config_diff.py +++ b/scripts/medcat_config_diff.py @@ -1,10 +1,11 @@ #!/usr/bin/env python +import difflib import os import sys -import difflib -import jsonpickle from argparse import ArgumentParser + +import jsonpickle from medcat.cat import CAT jsonpickle.set_encoder_options("json", sort_keys=True, indent=4) @@ -12,11 +13,7 @@ if __name__ == "__main__": parser = ArgumentParser() parser.add_argument( - "-a", - "--model-pack-path", - type=str, - default="", - help="The path to the first model pack" + "-a", "--model-pack-path", type=str, default="", help="The path to the first model pack" ) parser.add_argument( "-b", @@ -38,9 +35,11 @@ cat_a = CAT.load_model_pack(model_pack_path) cat_b = CAT.load_model_pack(another_model_pack_path) json_string_a = jsonpickle.encode( - {field: getattr(cat_a.cdb.config, field) for field in cat_a.cdb.config.fields()}) + {field: getattr(cat_a.cdb.config, field) for field in cat_a.cdb.config.fields()} + ) json_string_b = jsonpickle.encode( - {field: getattr(cat_b.cdb.config, field) for field in cat_b.cdb.config.fields()}) + {field: getattr(cat_b.cdb.config, field) for field in cat_b.cdb.config.fields()} + ) print(f"--- a|{model_pack_path}") print(f"+++ b|{another_model_pack_path}") diff --git a/scripts/remove_model_version.py b/scripts/remove_model_version.py index 70850fc..e9ed85b 100755 --- a/scripts/remove_model_version.py +++ b/scripts/remove_model_version.py @@ -2,14 +2,17 @@ import sys from argparse import ArgumentParser from typing import Union -from mlflow.tracking import MlflowClient + from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository +from mlflow.tracking import MlflowClient ARTIFACTS_DESTINATION = "s3://cms-model-bucket/" DEFAULT_ARTIFACT_ROOT = "mlflow-artifacts:/" -def _remove_model_version(client: MlflowClient, model_name: str, model_version: Union[int, str]) -> None: +def _remove_model_version( + client: MlflowClient, model_name: str, model_version: Union[int, str] +) -> None: versions = client.search_model_versions(f"name='{model_name}'") model_version = str(model_version) if versions is None or len(versions) == 0: @@ -23,10 +26,11 @@ def _remove_model_version(client: MlflowClient, model_name: str, model_version: raise ValueError("You cannot delete models which have not been archived!") if m_version.status != "READY": raise ValueError("You cannot delete models which are not ready yet!") - confirm = input(""" -You cannot undo this action. When you delete a model, all model artifacts stored by the Model Registry -and all the metadata associated with the registered model are deleted. Do you want to proceed? (y/n) - """).lower() == ("y" or "yes") + confirm = input( + "You cannot undo this action. When you delete a model, all model artifacts stored" + " by the Model Registry and all the metadata associated with the registered model" + " are deleted. Do you want to proceed? (y/n)" + ).lower() == ("y" or "yes") if confirm: client.delete_model_version(name=model_name, version=model_version) print(f"Version '{model_version}' of model '{model_name}' was deleted") @@ -40,7 +44,9 @@ def _remove_model_version(client: MlflowClient, model_name: str, model_version: if deleted and run_id: run = client.get_run(run_id) - artifact_repo = get_artifact_repository(run.info.artifact_uri.replace(ARTIFACTS_DESTINATION, DEFAULT_ARTIFACT_ROOT)) + artifact_repo = get_artifact_repository( + run.info.artifact_uri.replace(ARTIFACTS_DESTINATION, DEFAULT_ARTIFACT_ROOT) + ) artifact_repo.delete_artifacts() print(f"Artifacts for version '{model_version}' of model '{model_name}' were deleted") @@ -48,25 +54,21 @@ def _remove_model_version(client: MlflowClient, model_name: str, model_version: if __name__ == "__main__": parser = ArgumentParser() parser.add_argument( - "-u", - "--mlflow-tracking-uri", - type=str, - default="", - help="The MLflow tracking URI" + "-u", "--mlflow-tracking-uri", type=str, default="", help="The MLflow tracking URI" ) parser.add_argument( "-n", "--mlflow-model-name", type=str, default="", - help="The name of the registered MLflow model" + help="The name of the registered MLflow model", ) parser.add_argument( "-v", "--mlflow-model-version", type=str, default="", - help="The version of the registered MLflow model" + help="The version of the registered MLflow model", ) FLAGS, unparsed = parser.parse_known_args() if FLAGS.mlflow_tracking_uri == "": diff --git a/scripts/stop_cms_containers b/scripts/stop_cms_containers index e0bc7c0..d09c51c 100755 --- a/scripts/stop_cms_containers +++ b/scripts/stop_cms_containers @@ -29,4 +29,4 @@ if [ -n "$container_ids" ]; then docker stop $container_ids else echo "No CMS containers are running" -fi \ No newline at end of file +fi diff --git a/tests/app/api/auth/test_db.py b/tests/app/api/auth/test_db.py index 3ce7fa0..efc5e05 100644 --- a/tests/app/api/auth/test_db.py +++ b/tests/app/api/auth/test_db.py @@ -1,6 +1,7 @@ import pytest from fastapi_users.db import SQLAlchemyUserDatabase -from api.auth.db import make_sure_db_and_tables, get_user_db + +from api.auth.db import get_user_db, make_sure_db_and_tables @pytest.mark.asyncio diff --git a/tests/app/api/auth/test_schema.py b/tests/app/api/auth/test_schema.py index 957c281..5bd564b 100644 --- a/tests/app/api/auth/test_schema.py +++ b/tests/app/api/auth/test_schema.py @@ -1,6 +1,7 @@ -from api.auth.schemas import UserRead, UserCreate, UserUpdate from fastapi_users import schemas +from api.auth.schemas import UserCreate, UserRead, UserUpdate + def test_import(): issubclass(UserRead, schemas.BaseUser) diff --git a/tests/app/api/auth/test_users.py b/tests/app/api/auth/test_users.py index 368a4b3..f0b62e5 100644 --- a/tests/app/api/auth/test_users.py +++ b/tests/app/api/auth/test_users.py @@ -1,5 +1,6 @@ -from fastapi_users.authentication.backend import AuthenticationBackend from fastapi_users import FastAPIUsers +from fastapi_users.authentication.backend import AuthenticationBackend + from api.auth.users import Props diff --git a/tests/app/api/test_api.py b/tests/app/api/test_api.py index 5ed361c..a2735f8 100644 --- a/tests/app/api/test_api.py +++ b/tests/app/api/test_api.py @@ -1,6 +1,7 @@ +from utils import get_settings + from api.api import get_model_server, get_stream_server from api.dependencies import ModelServiceDep -from utils import get_settings def test_get_model_server(): @@ -21,11 +22,23 @@ def test_get_model_server(): assert isinstance(info["summary"], str) assert isinstance(info["version"], str) assert {"name": "Metadata", "description": "Get the model card"} in tags - assert {"name": "Annotations", "description": "Retrieve NER entities by running the model"} in tags + assert { + "name": "Annotations", + "description": "Retrieve NER entities by running the model", + } in tags assert {"name": "Redaction", "description": "Redact the extracted NER entities"} in tags - assert {"name": "Rendering", "description": "Preview embeddable annotation snippet in HTML"} in tags - assert {"name": "Training", "description": "Trigger model training on input annotations"} in tags - assert {"name": "Evaluating", "description": "Evaluate the deployed model with trainer export"} in tags + assert { + "name": "Rendering", + "description": "Preview embeddable annotation snippet in HTML", + } in tags + assert { + "name": "Training", + "description": "Trigger model training on input annotations", + } in tags + assert { + "name": "Evaluating", + "description": "Evaluate the deployed model with trainer export", + } in tags assert {"name": "Authentication", "description": "Authenticate registered users"} in tags assert "/info" in paths assert "/process" in paths @@ -64,7 +77,10 @@ def test_get_stream_server(): assert isinstance(info["title"], str) assert isinstance(info["summary"], str) assert isinstance(info["version"], str) - assert {"name": "Streaming", "description": "Retrieve NER entities as a stream by running the model"} in tags + assert { + "name": "Streaming", + "description": "Retrieve NER entities as a stream by running the model", + } in tags assert "/stream/process" in paths assert "/stream/ws" in paths assert "/auth/jwt/login" in paths diff --git a/tests/app/api/test_dependencies.py b/tests/app/api/test_dependencies.py index 87be16d..0d0f02c 100644 --- a/tests/app/api/test_dependencies.py +++ b/tests/app/api/test_dependencies.py @@ -1,14 +1,15 @@ import pytest from fastapi import HTTPException -from api.dependencies import ModelServiceDep, validate_tracking_id from config import Settings + +from api.dependencies import ModelServiceDep, validate_tracking_id +from model_services.huggingface_ner_model import HuggingFaceNerModel from model_services.medcat_model import MedCATModel +from model_services.medcat_model_deid import MedCATModelDeIdentification from model_services.medcat_model_icd10 import MedCATModelIcd10 from model_services.medcat_model_umls import MedCATModelUmls -from model_services.medcat_model_deid import MedCATModelDeIdentification from model_services.trf_model_deid import TransformersModelDeIdentification -from model_services.huggingface_ner_model import HuggingFaceNerModel def test_medcat_snomed_dep(): diff --git a/tests/app/api/test_serving_common.py b/tests/app/api/test_serving_common.py index b0be036..7b02ea5 100644 --- a/tests/app/api/test_serving_common.py +++ b/tests/app/api/test_serving_common.py @@ -1,17 +1,19 @@ +import json import os import tempfile +from unittest.mock import create_autospec import httpx -import json import pytest -import api.globals as cms_globals from fastapi.testclient import TestClient -from api.api import get_model_server + from domain import ModelCard, ModelType from utils import get_settings -from model_services.medcat_model import MedCATModel + +import api.globals as cms_globals +from api.api import get_model_server from management.model_manager import ModelManager -from unittest.mock import create_autospec +from model_services.medcat_model import MedCATModel config = get_settings() config.ENABLE_TRAINING_APIS = "true" @@ -21,11 +23,22 @@ config.AUTH_USER_ENABLED = "true" TRACKING_ID = "123e4567-e89b-12d3-a456-426614174000" -TRAINER_EXPORT_PATH = os.path.join(os.path.dirname(__file__), "..", "..", "resources", "fixture", "trainer_export.json") +TRAINER_EXPORT_PATH = os.path.join( + os.path.dirname(__file__), "..", "..", "resources", "fixture", "trainer_export.json" +) NOTE_PATH = os.path.join(os.path.dirname(__file__), "..", "..", "resources", "fixture", "note.txt") -ANOTHER_TRAINER_EXPORT_PATH = os.path.join(os.path.dirname(__file__), "..", "..", "resources", "fixture", "another_trainer_export.json") -TRAINER_EXPORT_MULTI_PROJS_PATH = os.path.join(os.path.dirname(__file__), "..", "..", "resources", "fixture", "trainer_export_multi_projs.json") -MULTI_TEXTS_FILE_PATH = os.path.join(os.path.dirname(__file__), "..", "..", "resources", "fixture", "sample_texts.json") +ANOTHER_TRAINER_EXPORT_PATH = os.path.join( + os.path.dirname(__file__), "..", "..", "resources", "fixture", "another_trainer_export.json" +) +TRAINER_EXPORT_MULTI_PROJS_PATH = os.path.join( + os.path.dirname(__file__), "..", "..", "resources", "fixture", "trainer_export_multi_projs.json" +) +MULTI_TEXTS_FILE_PATH = os.path.join( + os.path.dirname(__file__), "..", "..", "resources", "fixture", "sample_texts.json" +) +PUBLIC_KEY_PEM_PATH = os.path.join( + os.path.dirname(__file__), "..", "..", "resources", "fixture", "public_key.pem" +) @pytest.fixture(scope="function") @@ -43,87 +56,85 @@ def client(model_service): def test_process_invalid_jsonl(model_service, client): - annotations = [{ - "label_name": "Spinal stenosis", - "label_id": "76107001", - "start": 0, - "end": 15, - "accuracy": 1.0, - "meta_anns": { - "Status": { - "value": "Affirmed", - "confidence": 0.9999833106994629, - "name": "Status" - } - }, - }] + annotations = [ + { + "label_name": "Spinal stenosis", + "label_id": "76107001", + "start": 0, + "end": 15, + "accuracy": 1.0, + "meta_anns": { + "Status": {"value": "Affirmed", "confidence": 0.9999833106994629, "name": "Status"} + }, + } + ] model_service.annotate.return_value = annotations model_manager = ModelManager(None, None) model_manager.model_service = model_service cms_globals.model_manager_dep = lambda: model_manager - response = client.post("/process_jsonl", - data="invalid json lines", - headers={"Content-Type": "application/x-ndjson"}) + response = client.post( + "/process_jsonl", + data="invalid json lines", + headers={"Content-Type": "application/x-ndjson"}, + ) assert response.status_code == 400 assert response.json() == {"message": "Invalid JSON Lines."} def test_process_unknown_jsonl_properties(model_service, client): - annotations = [{ - "label_name": "Spinal stenosis", - "label_id": "76107001", - "start": 0, - "end": 15, - "accuracy": 1.0, - "meta_anns": { - "Status": { - "value": "Affirmed", - "confidence": 0.9999833106994629, - "name": "Status" - } - }, - }] + annotations = [ + { + "label_name": "Spinal stenosis", + "label_id": "76107001", + "start": 0, + "end": 15, + "accuracy": 1.0, + "meta_anns": { + "Status": {"value": "Affirmed", "confidence": 0.9999833106994629, "name": "Status"} + }, + } + ] model_service.annotate.return_value = annotations model_manager = ModelManager(None, None) model_manager.model_service = model_service cms_globals.model_manager_dep = lambda: model_manager - response = client.post("/process_jsonl", - data='{"unknown": "doc1", "text": "Spinal stenosis"}\n{"unknown": "doc2", "text": "Spinal stenosis"}', - headers={"Content-Type": "application/x-ndjson"}) + response = client.post( + "/process_jsonl", + data=( + '{"unknown": "doc1", "text": "Spinal stenosis"}\n' + '{"unknown": "doc2", "text": "Spinal stenosis"}' + ), + headers={"Content-Type": "application/x-ndjson"}, + ) assert response.status_code == 400 assert "Invalid JSON properties found." in response.json()["message"] def test_redact_with_white_list(model_service, client): - annotations = [{ - "label_name": "Spinal stenosis", - "label_id": "76107001", - "start": 0, - "end": 15, - "accuracy": 1.0, - "meta_anns": { - "Status": { - "value": "Affirmed", - "confidence": 0.9999833106994629, - "name": "Status" - } - }, - }] + annotations = [ + { + "label_name": "Spinal stenosis", + "label_id": "76107001", + "start": 0, + "end": 15, + "accuracy": 1.0, + "meta_anns": { + "Status": {"value": "Affirmed", "confidence": 0.9999833106994629, "name": "Status"} + }, + } + ] concepts_to_keep = ["76107001"] url = f"/redact?concepts_to_keep={','.join(concepts_to_keep)}" - + model_service.annotate.return_value = annotations + response = client.post(url, data="Spinal stenosis", headers={"Content-Type": "text/plain"}) - response = client.post(url, - data="Spinal stenosis", - headers={"Content-Type": "text/plain"}) - assert response.text == "Spinal stenosis" @@ -131,37 +142,41 @@ def test_warning_on_no_redaction(model_service, client): annotations = [] model_service.annotate.return_value = annotations - response = client.post("/redact?warn_on_no_redaction=true", - data="Spinal stenosis", - headers={"Content-Type": "text/plain"}) + response = client.post( + "/redact?warn_on_no_redaction=true", + data="Spinal stenosis", + headers={"Content-Type": "text/plain"}, + ) assert response.text == "WARNING: No entities were detected for redaction." def test_redact_with_encryption(model_service, client): - annotations = [{ - "label_name": "Spinal stenosis", - "label_id": "76107001", - "start": 0, - "end": 15, - "accuracy": 1.0, - "meta_anns": { - "Status": { - "value": "Affirmed", - "confidence": 0.9999833106994629, - "name": "Status" - } - }, - }] + annotations = [ + { + "label_name": "Spinal stenosis", + "label_id": "76107001", + "start": 0, + "end": 15, + "accuracy": 1.0, + "meta_anns": { + "Status": {"value": "Affirmed", "confidence": 0.9999833106994629, "name": "Status"} + }, + } + ] + + with open(PUBLIC_KEY_PEM_PATH, "r") as f: + public_key_pem = f.read() + body = { "text": "Spinal stenosis", - "public_key_pem": "-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA3ITkTP8Tm/5FygcwY2EQ7LgVsuCF0OH7psUqvlXnOPNCfX86CobHBiSFjG9o5ZeajPtTXaf1thUodgpJZVZSqpVTXwGKo8r0COMO87IcwYigkZZgG/WmZgoZART+AA0+JvjFGxflJAxSv7puGlf82E+u5Wz2psLBSDO5qrnmaDZTvPh5eX84cocahVVI7X09/kI+sZiKauM69yoy1bdx16YIIeNm0M9qqS3tTrjouQiJfZ8jUKSZ44Na/81LMVw5O46+5GvwD+OsR43kQ0TexMwgtHxQQsiXLWHCDNy2ZzkzukDYRwA3V2lwVjtQN0WjxHg24BTBDBM+v7iQ7cbweQIDAQAB\n-----END PUBLIC KEY-----" + "public_key_pem": public_key_pem, } model_service.annotate.return_value = annotations - response = client.post("/redact_with_encryption", - json=body, - headers={"Content-Type": "application/json"}) + response = client.post( + "/redact_with_encryption", json=body, headers={"Content-Type": "application/json"} + ) assert response.json()["redacted_text"] == "[REDACTED_0]" assert len(response.json()["encryptions"]) == 1 @@ -172,15 +187,19 @@ def test_redact_with_encryption(model_service, client): def test_warning_on_no_encrypted_redaction(model_service, client): annotations = [] + with open(PUBLIC_KEY_PEM_PATH, "r") as f: + public_key_pem = f.read() body = { "text": "Spinal stenosis", - "public_key_pem": "-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA3ITkTP8Tm/5FygcwY2EQ7LgVsuCF0OH7psUqvlXnOPNCfX86CobHBiSFjG9o5ZeajPtTXaf1thUodgpJZVZSqpVTXwGKo8r0COMO87IcwYigkZZgG/WmZgoZART+AA0+JvjFGxflJAxSv7puGlf82E+u5Wz2psLBSDO5qrnmaDZTvPh5eX84cocahVVI7X09/kI+sZiKauM69yoy1bdx16YIIeNm0M9qqS3tTrjouQiJfZ8jUKSZ44Na/81LMVw5O46+5GvwD+OsR43kQ0TexMwgtHxQQsiXLWHCDNy2ZzkzukDYRwA3V2lwVjtQN0WjxHg24BTBDBM+v7iQ7cbweQIDAQAB\n-----END PUBLIC KEY-----" + "public_key_pem": public_key_pem, } model_service.annotate.return_value = annotations - response = client.post("/redact_with_encryption?warn_on_no_redaction=true", - json=body, - headers={"Content-Type": "application/json"}) + response = client.post( + "/redact_with_encryption?warn_on_no_redaction=true", + json=body, + headers={"Content-Type": "application/json"}, + ) assert response.json()["message"] == "WARNING: No entities were detected for redaction." @@ -188,10 +207,13 @@ def test_warning_on_no_encrypted_redaction(model_service, client): def test_preview_trainer_export(client): with open(TRAINER_EXPORT_PATH, "rb") as f1: with open(ANOTHER_TRAINER_EXPORT_PATH, "rb") as f2: - response = client.post("/preview_trainer_export", files=[ - ("trainer_export", f1), - ("trainer_export", f2), - ]) + response = client.post( + "/preview_trainer_export", + files=[ + ("trainer_export", f1), + ("trainer_export", f2), + ], + ) assert response.status_code == 200 assert response.headers["Content-Type"] == "application/octet-stream" @@ -200,10 +222,13 @@ def test_preview_trainer_export(client): # test with provided tracking ID with open(TRAINER_EXPORT_PATH, "rb") as f1: with open(ANOTHER_TRAINER_EXPORT_PATH, "rb") as f2: - response = client.post(f"/preview_trainer_export?tracking_id={TRACKING_ID}", files=[ - ("trainer_export", f1), - ("trainer_export", f2), - ]) + response = client.post( + f"/preview_trainer_export?tracking_id={TRACKING_ID}", + files=[ + ("trainer_export", f1), + ("trainer_export", f2), + ], + ) assert response.status_code == 200 assert response.headers["Content-Type"] == "application/octet-stream" @@ -214,7 +239,11 @@ def test_preview_trainer_export(client): def test_preview_trainer_export_str(client): with open(TRAINER_EXPORT_PATH, "r") as f: trainer_export_str = f.read() - response = client.post("/preview_trainer_export", data={"trainer_export_str": trainer_export_str}, headers={"Content-Type": "application/x-www-form-urlencoded"}) + response = client.post( + "/preview_trainer_export", + data={"trainer_export_str": trainer_export_str}, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) assert response.status_code == 200 assert response.headers["Content-Type"] == "application/octet-stream" @@ -223,7 +252,10 @@ def test_preview_trainer_export_str(client): def test_preview_trainer_export_with_project_id(client): with open(TRAINER_EXPORT_PATH, "rb") as f: - response = client.post("/preview_trainer_export?project_id=14", files={"trainer_export": ("trainer_export.json", f, "multipart/form-data")}) + response = client.post( + "/preview_trainer_export?project_id=14", + files={"trainer_export": ("trainer_export.json", f, "multipart/form-data")}, + ) assert response.status_code == 200 assert response.headers["Content-Type"] == "application/octet-stream" @@ -232,7 +264,10 @@ def test_preview_trainer_export_with_project_id(client): def test_preview_trainer_export_with_document_id(client): with open(TRAINER_EXPORT_PATH, "rb") as f: - response = client.post("/preview_trainer_export?document_id=3205", files={"trainer_export": ("trainer_export.json", f, "multipart/form-data")}) + response = client.post( + "/preview_trainer_export?document_id=3205", + files={"trainer_export": ("trainer_export.json", f, "multipart/form-data")}, + ) assert response.status_code == 200 assert response.headers["Content-Type"] == "application/octet-stream" @@ -241,7 +276,10 @@ def test_preview_trainer_export_with_document_id(client): def test_preview_trainer_export_with_project_and_document_ids(client): with open(TRAINER_EXPORT_PATH, "rb") as f: - response = client.post("/preview_trainer_export?project_id=14&document_id=3205", files={"trainer_export": ("trainer_export.json", f, "multipart/form-data")}) + response = client.post( + "/preview_trainer_export?project_id=14&document_id=3205", + files={"trainer_export": ("trainer_export.json", f, "multipart/form-data")}, + ) assert response.status_code == 200 assert response.headers["Content-Type"] == "application/octet-stream" @@ -251,7 +289,10 @@ def test_preview_trainer_export_with_project_and_document_ids(client): @pytest.mark.parametrize("pid,did", [(14, 1), (1, 3205)]) def test_preview_trainer_export_on_missing_project_or_document(pid, did, client): with open(TRAINER_EXPORT_PATH, "rb") as f: - response = client.post(f"/preview_trainer_export?project_id={pid}&document_id={did}", files={"trainer_export": ("trainer_export.json", f, "multipart/form-data")}) + response = client.post( + f"/preview_trainer_export?project_id={pid}&document_id={did}", + files={"trainer_export": ("trainer_export.json", f, "multipart/form-data")}, + ) assert response.status_code == 404 assert response.json() == {"message": "Cannot find any matching documents to preview"} @@ -268,7 +309,9 @@ def test_train_supervised(model_service, client): # test with provided tracking ID with open(TRAINER_EXPORT_PATH, "rb") as f: - response = client.post(f"/train_supervised?tracking_id={TRACKING_ID}", files=[("trainer_export", f)]) + response = client.post( + f"/train_supervised?tracking_id={TRACKING_ID}", files=[("trainer_export", f)] + ) model_service.train_supervised.assert_called() assert response.status_code == 202 @@ -279,7 +322,7 @@ def test_train_supervised(model_service, client): def test_train_unsupervised(model_service, client): with tempfile.TemporaryFile("r+b") as f: - f.write(str.encode("[\"Spinal stenosis\"]")) + f.write(str.encode('["Spinal stenosis"]')) response = client.post("/train_unsupervised", files=[("training_data", f)]) model_service.train_unsupervised.assert_called() @@ -288,8 +331,10 @@ def test_train_unsupervised(model_service, client): # test with provided tracking ID with tempfile.TemporaryFile("r+b") as f: - f.write(str.encode("[\"Spinal stenosis\"]")) - response = client.post(f"/train_unsupervised?tracking_id={TRACKING_ID}", files=[("training_data", f)]) + f.write(str.encode('["Spinal stenosis"]')) + response = client.post( + f"/train_unsupervised?tracking_id={TRACKING_ID}", files=[("training_data", f)] + ) model_service.train_unsupervised.assert_called() assert response.json()["message"] == "Your training started successfully." @@ -298,12 +343,14 @@ def test_train_unsupervised(model_service, client): def test_train_unsupervised_with_hf_hub_dataset(model_service, client): - model_card = ModelCard.parse_obj({ - "api_version": "0.0.1", - "model_description": "huggingface_ner_model_description", - "model_type": ModelType.MEDCAT_SNOMED, - "model_card": None, - }) + model_card = ModelCard.parse_obj( + { + "api_version": "0.0.1", + "model_description": "huggingface_ner_model_description", + "model_type": ModelType.MEDCAT_SNOMED, + "model_card": None, + } + ) model_service.info.return_value = model_card response = client.post("/train_unsupervised_with_hf_hub_dataset?hf_dataset_repo_id=imdb") @@ -313,7 +360,9 @@ def test_train_unsupervised_with_hf_hub_dataset(model_service, client): assert "training_id" in response.json() # test with provided tracking ID - response = client.post(f"/train_unsupervised_with_hf_hub_dataset?hf_dataset_repo_id=imdb&tracking_id={TRACKING_ID}") + response = client.post( + f"/train_unsupervised_with_hf_hub_dataset?hf_dataset_repo_id=imdb&tracking_id={TRACKING_ID}" + ) model_service.train_unsupervised.assert_called() assert response.json()["message"] == "Your training started successfully." @@ -332,7 +381,9 @@ def test_train_metacat(model_service, client): # test with provided tracking ID with open(TRAINER_EXPORT_PATH, "rb") as f: - response = client.post(f"/train_metacat?tracking_id={TRACKING_ID}", files=[("trainer_export", f)]) + response = client.post( + f"/train_metacat?tracking_id={TRACKING_ID}", files=[("trainer_export", f)] + ) model_service.train_metacat.assert_called() assert response.status_code == 202 @@ -351,7 +402,9 @@ def test_evaluate_with_trainer_export(client): # test with provided tracking ID with open(TRAINER_EXPORT_PATH, "rb") as f: - response = client.post(f"/evaluate?tracking_id={TRACKING_ID}", files=[("trainer_export", f)]) + response = client.post( + f"/evaluate?tracking_id={TRACKING_ID}", files=[("trainer_export", f)] + ) assert response.status_code == 202 assert response.json()["message"] == "Your evaluation started successfully." @@ -369,7 +422,9 @@ def test_sanity_check_with_trainer_export(client): # test with provided tracking ID with open(TRAINER_EXPORT_PATH, "rb") as f: - response = client.post(f"/sanity-check?tracking_id={TRACKING_ID}", files=[("trainer_export", f)]) + response = client.post( + f"/sanity-check?tracking_id={TRACKING_ID}", files=[("trainer_export", f)] + ) assert response.status_code == 200 assert response.headers["Content-Type"] == "text/csv; charset=utf-8" @@ -380,33 +435,51 @@ def test_sanity_check_with_trainer_export(client): def test_inter_annotator_agreement_scores_per_concept(client): with open(TRAINER_EXPORT_PATH, "rb") as f1: with open(ANOTHER_TRAINER_EXPORT_PATH, "rb") as f2: - response = client.post("/iaa-scores?annotator_a_project_id=14&annotator_b_project_id=15&scope=per_concept", files=[ - ("trainer_export", f1), - ("trainer_export", f2), - ]) + response = client.post( + "/iaa-scores?annotator_a_project_id=14&annotator_b_project_id=15&scope=per_concept", + files=[ + ("trainer_export", f1), + ("trainer_export", f2), + ], + ) assert response.status_code == 200 assert response.headers["Content-Type"] == "text/csv; charset=utf-8" - assert response.text.split("\n")[0] == "concept,iaa_percentage,cohens_kappa,iaa_percentage_meta,cohens_kappa_meta" + assert ( + response.text.split("\n")[0] + == "concept,iaa_percentage,cohens_kappa,iaa_percentage_meta,cohens_kappa_meta" + ) # test with provided tracking ID with open(TRAINER_EXPORT_PATH, "rb") as f1: with open(ANOTHER_TRAINER_EXPORT_PATH, "rb") as f2: - response = client.post(f"/iaa-scores?annotator_a_project_id=14&annotator_b_project_id=15&scope=per_concept&tracking_id={TRACKING_ID}", files=[ - ("trainer_export", f1), - ("trainer_export", f2), - ]) + response = client.post( + f"/iaa-scores?annotator_a_project_id=14&annotator_b_project_id=15&scope=per_concept&tracking_id={TRACKING_ID}", + files=[ + ("trainer_export", f1), + ("trainer_export", f2), + ], + ) assert response.status_code == 200 assert response.headers["Content-Type"] == "text/csv; charset=utf-8" - assert response.text.split("\n")[0] == "concept,iaa_percentage,cohens_kappa,iaa_percentage_meta,cohens_kappa_meta" + assert ( + response.text.split("\n")[0] + == "concept,iaa_percentage,cohens_kappa,iaa_percentage_meta,cohens_kappa_meta" + ) assert TRACKING_ID in response.headers["Content-Disposition"] -@pytest.mark.parametrize("pid_a,pid_b,error_message", [(0, 2, "Cannot find the project with ID: 0"), (1, 3, "Cannot find the project with ID: 3")]) +@pytest.mark.parametrize( + "pid_a,pid_b,error_message", + [(0, 2, "Cannot find the project with ID: 0"), (1, 3, "Cannot find the project with ID: 3")], +) def test_project_not_found_on_getting_iaa_scores(pid_a, pid_b, error_message, client): with open(TRAINER_EXPORT_MULTI_PROJS_PATH, "rb") as f: - response = client.post(f"/iaa-scores?annotator_a_project_id={pid_a}&annotator_b_project_id={pid_b}&scope=per_concept", files={"trainer_export": ("trainer_export.json", f, "multipart/form-data")}) + response = client.post( + f"/iaa-scores?annotator_a_project_id={pid_a}&annotator_b_project_id={pid_b}&scope=per_concept", + files={"trainer_export": ("trainer_export.json", f, "multipart/form-data")}, + ) assert response.status_code == 400 assert response.headers["content-type"] == "application/json" @@ -416,49 +489,67 @@ def test_project_not_found_on_getting_iaa_scores(pid_a, pid_b, error_message, cl def test_unknown_scope_on_getting_iaa_scores(client): with open(TRAINER_EXPORT_PATH, "rb") as f1: with open(ANOTHER_TRAINER_EXPORT_PATH, "rb") as f2: - response = client.post("/iaa-scores?annotator_a_project_id=14&annotator_b_project_id=15&scope=unknown", files=[ - ("trainer_export", f1), - ("trainer_export", f2), - ]) + response = client.post( + "/iaa-scores?annotator_a_project_id=14&annotator_b_project_id=15&scope=unknown", + files=[ + ("trainer_export", f1), + ("trainer_export", f2), + ], + ) assert response.status_code == 400 assert response.headers["content-type"] == "application/json" - assert response.json() == {"message": "Unknown scope: \"unknown\""} + assert response.json() == {"message": 'Unknown scope: "unknown"'} def test_inter_annotator_agreement_scores_per_doc(client): with open(TRAINER_EXPORT_PATH, "rb") as f1: with open(ANOTHER_TRAINER_EXPORT_PATH, "rb") as f2: - response = client.post("/iaa-scores?annotator_a_project_id=14&annotator_b_project_id=15&scope=per_document", files=[ - ("trainer_export", f1), - ("trainer_export", f2), - ]) + response = client.post( + "/iaa-scores?annotator_a_project_id=14&annotator_b_project_id=15&scope=per_document", + files=[ + ("trainer_export", f1), + ("trainer_export", f2), + ], + ) assert response.status_code == 200 assert response.headers["Content-Type"] == "text/csv; charset=utf-8" - assert response.text.split("\n")[0] == "doc_id,iaa_percentage,cohens_kappa,iaa_percentage_meta,cohens_kappa_meta" + assert ( + response.text.split("\n")[0] + == "doc_id,iaa_percentage,cohens_kappa,iaa_percentage_meta,cohens_kappa_meta" + ) def test_inter_annotator_agreement_scores_per_span(client): with open(TRAINER_EXPORT_PATH, "rb") as f1: with open(ANOTHER_TRAINER_EXPORT_PATH, "rb") as f2: - response = client.post("/iaa-scores?annotator_a_project_id=14&annotator_b_project_id=15&scope=per_span", files=[ - ("trainer_export", f1), - ("trainer_export", f2), - ]) + response = client.post( + "/iaa-scores?annotator_a_project_id=14&annotator_b_project_id=15&scope=per_span", + files=[ + ("trainer_export", f1), + ("trainer_export", f2), + ], + ) assert response.status_code == 200 assert response.headers["Content-Type"] == "text/csv; charset=utf-8" - assert response.text.split("\n")[0] == "doc_id,span_start,span_end,iaa_percentage,cohens_kappa,iaa_percentage_meta,cohens_kappa_meta" + assert response.text.split("\n")[0] == ( + "doc_id,span_start,span_end,iaa_percentage,cohens_kappa," + "iaa_percentage_meta,cohens_kappa_meta" + ) def test_concat_trainer_exports(client): with open(TRAINER_EXPORT_PATH, "rb") as f1: with open(ANOTHER_TRAINER_EXPORT_PATH, "rb") as f2: - response = client.post("/concat_trainer_exports", files=[ - ("trainer_export", f1), - ("trainer_export", f2), - ]) + response = client.post( + "/concat_trainer_exports", + files=[ + ("trainer_export", f1), + ("trainer_export", f2), + ], + ) assert response.status_code == 200 assert response.headers["Content-Type"] == "application/json; charset=utf-8" @@ -467,10 +558,13 @@ def test_concat_trainer_exports(client): # test with provided tracking ID with open(TRAINER_EXPORT_PATH, "rb") as f1: with open(ANOTHER_TRAINER_EXPORT_PATH, "rb") as f2: - response = client.post(f"/concat_trainer_exports?tracking_id={TRACKING_ID}", files=[ - ("trainer_export", f1), - ("trainer_export", f2), - ]) + response = client.post( + f"/concat_trainer_exports?tracking_id={TRACKING_ID}", + files=[ + ("trainer_export", f1), + ("trainer_export", f2), + ], + ) assert response.status_code == 200 assert response.headers["Content-Type"] == "application/json; charset=utf-8" @@ -481,45 +575,59 @@ def test_concat_trainer_exports(client): def test_get_annotation_stats(client): with open(TRAINER_EXPORT_PATH, "rb") as f1: with open(ANOTHER_TRAINER_EXPORT_PATH, "rb") as f2: - response = client.post("/annotation-stats", files=[ - ("trainer_export", f1), - ("trainer_export", f2), - ]) + response = client.post( + "/annotation-stats", + files=[ + ("trainer_export", f1), + ("trainer_export", f2), + ], + ) assert response.status_code == 200 assert response.headers["Content-Type"] == "text/csv; charset=utf-8" - assert response.text.split("\n")[0] == "concept,anno_count,anno_unique_counts,anno_ignorance_counts" + assert ( + response.text.split("\n")[0] + == "concept,anno_count,anno_unique_counts,anno_ignorance_counts" + ) # test with provided tracking ID with open(TRAINER_EXPORT_PATH, "rb") as f1: with open(ANOTHER_TRAINER_EXPORT_PATH, "rb") as f2: - response = client.post(f"/annotation-stats?tracking_id={TRACKING_ID}", files=[ - ("trainer_export", f1), - ("trainer_export", f2), - ]) + response = client.post( + f"/annotation-stats?tracking_id={TRACKING_ID}", + files=[ + ("trainer_export", f1), + ("trainer_export", f2), + ], + ) assert response.status_code == 200 assert response.headers["Content-Type"] == "text/csv; charset=utf-8" - assert response.text.split("\n")[0] == "concept,anno_count,anno_unique_counts,anno_ignorance_counts" + assert ( + response.text.split("\n")[0] + == "concept,anno_count,anno_unique_counts,anno_ignorance_counts" + ) assert TRACKING_ID in response.headers["Content-Disposition"] def test_extract_entities_from_text_list_file_as_json_file(model_service, client): annotations_list = [ - [{ - "label_name": "Spinal stenosis", - "label_id": "76107001", - "start": 0, - "end": 15, - "accuracy": 1.0, - "meta_anns": { - "Status": { - "value": "Affirmed", - "confidence": 0.9999833106994629, - "name": "Status" - } - }, - }] + [ + { + "label_name": "Spinal stenosis", + "label_id": "76107001", + "start": 0, + "end": 15, + "accuracy": 1.0, + "meta_anns": { + "Status": { + "value": "Affirmed", + "confidence": 0.9999833106994629, + "name": "Status", + } + }, + } + ] ] * 15 model_service.batch_annotate.return_value = annotations_list @@ -527,44 +635,102 @@ def test_extract_entities_from_text_list_file_as_json_file(model_service, client response = client.post("/process_bulk_file", files=[("multi_text_file", f)]) assert isinstance(response, httpx.Response) - assert json.loads(response.content) == [{ - "text": "Description: Intracerebral hemorrhage (very acute clinical changes occurred immediately).\nCC: Left hand numbness on presentation; then developed lethargy later that day.\nHX: On the day of presentation, this 72 y/o RHM suddenly developed generalized weakness and lightheadedness, and could not rise from a chair. Four hours later he experienced sudden left hand numbness lasting two hours. There were no other associated symptoms except for the generalized weakness and lightheadedness. He denied vertigo.\nHe had been experiencing falling spells without associated LOC up to several times a month for the past year.\nMEDS: procardia SR, Lasix, Ecotrin, KCL, Digoxin, Colace, Coumadin.\nPMH: 1)8/92 evaluation for presyncope (Echocardiogram showed: AV fibrosis/calcification, AV stenosis/insufficiency, MV stenosis with annular calcification and regurgitation, moderate TR, Decreased LV systolic function, severe LAE. MRI brain: focal areas of increased T2 signal in the left cerebellum and in the brainstem probably representing microvascular ischemic disease. IVG (MUGA scan)revealed: global hypokinesis of the LV and biventricular dysfunction, RV ejection Fx 45% and LV ejection Fx 39%. He was subsequently placed on coumadin severe valvular heart disease), 2)HTN, 3)Rheumatic fever and heart disease, 4)COPD, 5)ETOH abuse, 6)colonic polyps, 7)CAD, 8)CHF, 9)Appendectomy, 10)Junctional tachycardia.", - "annotations": [{ - "label_name": "Spinal stenosis", - "label_id": "76107001", - "start": 0, - "end": 15, - "accuracy": 1.0, - "meta_anns": { - "Status": { - "value": "Affirmed", - "confidence": 0.9999833106994629, - "name": "Status" - } - }, - }] - }] * 15 + assert ( + json.loads(response.content) + == [ + { + "text": ( + "Description: Intracerebral hemorrhage (very acute clinical changes occurred" + " immediately).\nCC: Left hand numbness on presentation; then developed" + " lethargy later that day.\nHX: On the day of presentation, this 72 y/o RHM" + " suddenly developed generalized weakness and lightheadedness, and could not" + " rise from a chair. Four hours later he experienced sudden left hand numbness" + " lasting two hours. There were no other associated symptoms except for the" + " generalized weakness and lightheadedness. He denied vertigo.\nHe had been" + " experiencing falling spells without associated LOC up to several times a" + " month for the past year.\nMEDS: procardia SR, Lasix, Ecotrin, KCL, Digoxin," + " Colace, Coumadin.\nPMH: 1)8/92 evaluation for presyncope (Echocardiogram" + " showed: AV fibrosis/calcification, AV stenosis/insufficiency, MV stenosis" + " with annular calcification and regurgitation, moderate TR, Decreased LV" + " systolic function, severe LAE. MRI brain: focal areas of increased T2 signal" + " in the left cerebellum and in the brainstem probably representing" + " microvascular ischemic disease. IVG (MUGA scan)revealed: global hypokinesis" + " of the LV and biventricular dysfunction, RV ejection Fx 45% and LV ejection" + " Fx 39%. He was subsequently placed on coumadin severe valvular heart" + " disease), 2)HTN, 3)Rheumatic fever and heart disease, 4)COPD, 5)ETOH abuse," + " 6)colonic polyps, 7)CAD, 8)CHF, 9)Appendectomy, 10)Junctional tachycardia." + ), + "annotations": [ + { + "label_name": "Spinal stenosis", + "label_id": "76107001", + "start": 0, + "end": 15, + "accuracy": 1.0, + "meta_anns": { + "Status": { + "value": "Affirmed", + "confidence": 0.9999833106994629, + "name": "Status", + } + }, + } + ], + } + ] + * 15 + ) # test with provided tracking ID with open(MULTI_TEXTS_FILE_PATH, "rb") as f: - response = client.post(f"/process_bulk_file?tracking_id={TRACKING_ID}", files=[("multi_text_file", f)]) + response = client.post( + f"/process_bulk_file?tracking_id={TRACKING_ID}", files=[("multi_text_file", f)] + ) assert isinstance(response, httpx.Response) - assert json.loads(response.content) == [{ - "text": "Description: Intracerebral hemorrhage (very acute clinical changes occurred immediately).\nCC: Left hand numbness on presentation; then developed lethargy later that day.\nHX: On the day of presentation, this 72 y/o RHM suddenly developed generalized weakness and lightheadedness, and could not rise from a chair. Four hours later he experienced sudden left hand numbness lasting two hours. There were no other associated symptoms except for the generalized weakness and lightheadedness. He denied vertigo.\nHe had been experiencing falling spells without associated LOC up to several times a month for the past year.\nMEDS: procardia SR, Lasix, Ecotrin, KCL, Digoxin, Colace, Coumadin.\nPMH: 1)8/92 evaluation for presyncope (Echocardiogram showed: AV fibrosis/calcification, AV stenosis/insufficiency, MV stenosis with annular calcification and regurgitation, moderate TR, Decreased LV systolic function, severe LAE. MRI brain: focal areas of increased T2 signal in the left cerebellum and in the brainstem probably representing microvascular ischemic disease. IVG (MUGA scan)revealed: global hypokinesis of the LV and biventricular dysfunction, RV ejection Fx 45% and LV ejection Fx 39%. He was subsequently placed on coumadin severe valvular heart disease), 2)HTN, 3)Rheumatic fever and heart disease, 4)COPD, 5)ETOH abuse, 6)colonic polyps, 7)CAD, 8)CHF, 9)Appendectomy, 10)Junctional tachycardia.", - "annotations": [{ - "label_name": "Spinal stenosis", - "label_id": "76107001", - "start": 0, - "end": 15, - "accuracy": 1.0, - "meta_anns": { - "Status": { - "value": "Affirmed", - "confidence": 0.9999833106994629, - "name": "Status" - } - }, - }] - }] * 15 + assert ( + json.loads(response.content) + == [ + { + "text": ( + "Description: Intracerebral hemorrhage (very acute clinical changes occurred" + " immediately).\nCC: Left hand numbness on presentation; then developed" + " lethargy later that day.\nHX: On the day of presentation, this 72 y/o RHM" + " suddenly developed generalized weakness and lightheadedness, and could not" + " rise from a chair. Four hours later he experienced sudden left hand numbness" + " lasting two hours. There were no other associated symptoms except for the" + " generalized weakness and lightheadedness. He denied vertigo.\nHe had been" + " experiencing falling spells without associated LOC up to several times a" + " month for the past year.\nMEDS: procardia SR, Lasix, Ecotrin, KCL, Digoxin," + " Colace, Coumadin.\nPMH: 1)8/92 evaluation for presyncope (Echocardiogram" + " showed: AV fibrosis/calcification, AV stenosis/insufficiency, MV stenosis" + " with annular calcification and regurgitation, moderate TR, Decreased LV" + " systolic function, severe LAE. MRI brain: focal areas of increased T2 signal" + " in the left cerebellum and in the brainstem probably representing" + " microvascular ischemic disease. IVG (MUGA scan)revealed: global hypokinesis" + " of the LV and biventricular dysfunction, RV ejection Fx 45% and LV ejection" + " Fx 39%. He was subsequently placed on coumadin severe valvular heart" + " disease), 2)HTN, 3)Rheumatic fever and heart disease, 4)COPD, 5)ETOH abuse," + " 6)colonic polyps, 7)CAD, 8)CHF, 9)Appendectomy, 10)Junctional tachycardia." + ), + "annotations": [ + { + "label_name": "Spinal stenosis", + "label_id": "76107001", + "start": 0, + "end": 15, + "accuracy": 1.0, + "meta_anns": { + "Status": { + "value": "Affirmed", + "confidence": 0.9999833106994629, + "name": "Status", + } + }, + } + ], + } + ] + * 15 + ) assert TRACKING_ID in response.headers["Content-Disposition"] diff --git a/tests/app/api/test_serving_hf_ner.py b/tests/app/api/test_serving_hf_ner.py index c6ac825..275fded 100644 --- a/tests/app/api/test_serving_hf_ner.py +++ b/tests/app/api/test_serving_hf_ner.py @@ -1,12 +1,15 @@ import os +from unittest.mock import create_autospec + import pytest -import api.globals as cms_globals from fastapi.testclient import TestClient -from api.api import get_model_server + +from domain import ModelCard, ModelType from utils import get_settings + +import api.globals as cms_globals +from api.api import get_model_server from model_services.huggingface_ner_model import HuggingFaceNerModel -from domain import ModelCard, ModelType -from unittest.mock import create_autospec config = get_settings() config.ENABLE_TRAINING_APIS = "true" @@ -15,11 +18,19 @@ config.ENABLE_PREVIEWS_APIS = "true" config.AUTH_USER_ENABLED = "true" -TRAINER_EXPORT_PATH = os.path.join(os.path.dirname(__file__), "..", "..", "resources", "fixture", "trainer_export.json") +TRAINER_EXPORT_PATH = os.path.join( + os.path.dirname(__file__), "..", "..", "resources", "fixture", "trainer_export.json" +) NOTE_PATH = os.path.join(os.path.dirname(__file__), "..", "..", "resources", "fixture", "note.txt") -ANOTHER_TRAINER_EXPORT_PATH = os.path.join(os.path.dirname(__file__), "..", "..", "resources", "fixture", "another_trainer_export.json") -TRAINER_EXPORT_MULTI_PROJS_PATH = os.path.join(os.path.dirname(__file__), "..", "..", "resources", "fixture", "trainer_export_multi_projs.json") -MULTI_TEXTS_FILE_PATH = os.path.join(os.path.dirname(__file__), "..", "..", "resources", "fixture", "sample_texts.json") +ANOTHER_TRAINER_EXPORT_PATH = os.path.join( + os.path.dirname(__file__), "..", "..", "resources", "fixture", "another_trainer_export.json" +) +TRAINER_EXPORT_MULTI_PROJS_PATH = os.path.join( + os.path.dirname(__file__), "..", "..", "resources", "fixture", "trainer_export_multi_projs.json" +) +MULTI_TEXTS_FILE_PATH = os.path.join( + os.path.dirname(__file__), "..", "..", "resources", "fixture", "sample_texts.json" +) @pytest.fixture(scope="function") @@ -37,12 +48,14 @@ def client(model_service): def test_train_unsupervised_with_hf_hub_dataset(model_service, client): - model_card = ModelCard.parse_obj({ - "api_version": "0.0.1", - "model_description": "huggingface_ner_model_description", - "model_type": ModelType.HUGGINGFACE_NER, - "model_card": None, - }) + model_card = ModelCard.parse_obj( + { + "api_version": "0.0.1", + "model_description": "huggingface_ner_model_description", + "model_type": ModelType.HUGGINGFACE_NER, + "model_card": None, + } + ) model_service.info.return_value = model_card response = client.post("/train_unsupervised_with_hf_hub_dataset?hf_dataset_repo_id=imdb") diff --git a/tests/app/api/test_serving_stream.py b/tests/app/api/test_serving_stream.py index 594336e..a389fc0 100644 --- a/tests/app/api/test_serving_stream.py +++ b/tests/app/api/test_serving_stream.py @@ -1,17 +1,17 @@ -import httpx import json -import pytest - -import api.globals as cms_globals +from unittest.mock import create_autospec +import httpx +import pytest from fastapi.testclient import TestClient from starlette.websockets import WebSocketDisconnect -from api.api import get_stream_server + from utils import get_settings -from model_services.medcat_model import MedCATModel -from management.model_manager import ModelManager -from unittest.mock import create_autospec +import api.globals as cms_globals +from api.api import get_stream_server +from management.model_manager import ModelManager +from model_services.medcat_model import MedCATModel config = get_settings() config.ENABLE_TRAINING_APIS = "true" @@ -37,7 +37,9 @@ def app(model_service): @pytest.mark.asyncio async def test_stream_process_empty_stream(model_service, app): async with httpx.AsyncClient(app=app, base_url="http://test") as ac: - response = await ac.post("/stream/process", data="", headers={"Content-Type": "application/x-ndjson"}) + response = await ac.post( + "/stream/process", data="", headers={"Content-Type": "application/x-ndjson"} + ) assert response.status_code == 200 jsonlines = b"" @@ -49,9 +51,11 @@ async def test_stream_process_empty_stream(model_service, app): @pytest.mark.asyncio async def test_stream_process_invalidate_jsonl(model_service, app): async with httpx.AsyncClient(app=app, base_url="http://test") as ac: - response = await ac.post("/stream/process", - data='{"name": "doc1", "text": Spinal stenosis}\n'.encode("utf-8"), - headers={"Content-Type": "application/x-ndjson"}) + response = await ac.post( + "/stream/process", + data='{"name": "doc1", "text": Spinal stenosis}\n'.encode("utf-8"), + headers={"Content-Type": "application/x-ndjson"}, + ) assert response.status_code == 200 jsonlines = b"" @@ -63,15 +67,23 @@ async def test_stream_process_invalidate_jsonl(model_service, app): @pytest.mark.asyncio async def test_stream_process_unknown_jsonl_property(model_service, app): async with httpx.AsyncClient(app=app, base_url="http://test") as ac: - response = await ac.post("/stream/process", - data='{"unknown": "doc1", "text": "Spinal stenosis"}\n{"unknown": "doc2", "text": "Spinal stenosis"}', - headers={"Content-Type": "application/x-ndjson"}) + response = await ac.post( + "/stream/process", + data=( + '{"unknown": "doc1", "text": "Spinal stenosis"}\n' + '{"unknown": "doc2", "text": "Spinal stenosis"}' + ), + headers={"Content-Type": "application/x-ndjson"}, + ) assert response.status_code == 200 jsonlines = b"" async for chunk in response.aiter_bytes(): jsonlines += chunk - assert "Invalid JSON properties found" in json.loads(jsonlines.decode("utf-8").splitlines()[-1])["error"] + assert ( + "Invalid JSON properties found" + in json.loads(jsonlines.decode("utf-8").splitlines()[-1])["error"] + ) def test_websocket_process_on_annotation_error(model_service, app): diff --git a/tests/app/api/test_serving_trf.py b/tests/app/api/test_serving_trf.py index c802130..75a4397 100644 --- a/tests/app/api/test_serving_trf.py +++ b/tests/app/api/test_serving_trf.py @@ -1,12 +1,14 @@ +from unittest.mock import create_autospec + import pytest -import api.globals as cms_globals from fastapi.testclient import TestClient -from api.api import get_model_server -from utils import get_settings -from model_services.trf_model_deid import TransformersModelDeIdentification -from unittest.mock import create_autospec + from domain import ModelCard, ModelType +from utils import get_settings +import api.globals as cms_globals +from api.api import get_model_server +from model_services.trf_model_deid import TransformersModelDeIdentification config = get_settings() config.AUTH_USER_ENABLED = "true" @@ -31,24 +33,28 @@ def test_healthz(client): def test_readyz(model_service, client): - model_card = ModelCard.parse_obj({ - "api_version": "0.0.1", - "model_description": "deid_model_description", - "model_type": ModelType.TRANSFORMERS_DEID, - "model_card": None, - }) + model_card = ModelCard.parse_obj( + { + "api_version": "0.0.1", + "model_description": "deid_model_description", + "model_type": ModelType.TRANSFORMERS_DEID, + "model_card": None, + } + ) model_service.info.return_value = model_card assert client.get("/readyz").content.decode("utf-8") == ModelType.TRANSFORMERS_DEID def test_info(model_service, client): - model_card = ModelCard.parse_obj({ - "api_version": "0.0.1", - "model_description": "deid_model_description", - "model_type": ModelType.TRANSFORMERS_DEID, - "model_card": None, - }) + model_card = ModelCard.parse_obj( + { + "api_version": "0.0.1", + "model_description": "deid_model_description", + "model_type": ModelType.TRANSFORMERS_DEID, + "model_card": None, + } + ) model_service.info.return_value = model_card response = client.get("/info") @@ -57,38 +63,39 @@ def test_info(model_service, client): def test_process(model_service, client): - annotations = [{ - "label_name": "NW1 2BU", - "label_id": "C2120", - "start": 0, - "end": 6, - }] + annotations = [ + { + "label_name": "NW1 2BU", + "label_id": "C2120", + "start": 0, + "end": 6, + } + ] model_service.annotate.return_value = annotations - response = client.post("/process", - data="NW1 2BU", - headers={"Content-Type": "text/plain"}) + response = client.post("/process", data="NW1 2BU", headers={"Content-Type": "text/plain"}) - assert response.json() == { - "text": "NW1 2BU", - "annotations": annotations - } + assert response.json() == {"text": "NW1 2BU", "annotations": annotations} def test_process_bulk(model_service, client): annotations_list = [ - [{ - "label_name": "NW1 2BU", - "label_id": "C2120", - "start": 0, - "end": 6, - }], - [{ - "label_name": "NW1 2DA", - "label_id": "C2120", - "start": 0, - "end": 6, - }] + [ + { + "label_name": "NW1 2BU", + "label_id": "C2120", + "start": 0, + "end": 6, + } + ], + [ + { + "label_name": "NW1 2DA", + "label_id": "C2120", + "start": 0, + "end": 6, + } + ], ] model_service.batch_annotate.return_value = annotations_list @@ -97,38 +104,42 @@ def test_process_bulk(model_service, client): assert response.json() == [ { "text": "NW1 2BU", - "annotations": [{ - "label_name": "NW1 2BU", - "label_id": "C2120", - "start": 0, - "end": 6, - }] + "annotations": [ + { + "label_name": "NW1 2BU", + "label_id": "C2120", + "start": 0, + "end": 6, + } + ], }, { "text": "NW1 2DA", - "annotations": [{ - "label_name": "NW1 2DA", - "label_id": "C2120", - "start": 0, - "end": 6, - }] - } + "annotations": [ + { + "label_name": "NW1 2DA", + "label_id": "C2120", + "start": 0, + "end": 6, + } + ], + }, ] def test_preview(model_service, client): - annotations = [{ - "label_name": "NW1 2BU", - "label_id": "C2120", - "start": 0, - "end": 6, - }] + annotations = [ + { + "label_name": "NW1 2BU", + "label_id": "C2120", + "start": 0, + "end": 6, + } + ] model_service.annotate.return_value = annotations model_service.model_name = "De-Identification Model" - response = client.post("/preview", - data="NW1 2BU", - headers={"Content-Type": "text/plain"}) + response = client.post("/preview", data="NW1 2BU", headers={"Content-Type": "text/plain"}) assert response.status_code == 200 assert response.headers["Content-Type"] == "application/octet-stream" diff --git a/tests/app/api/test_utils.py b/tests/app/api/test_utils.py index 96be7c5..622d625 100644 --- a/tests/app/api/test_utils.py +++ b/tests/app/api/test_utils.py @@ -1,11 +1,13 @@ from fastapi import FastAPI + from utils import get_settings + from api.utils import ( add_exception_handlers, add_rate_limiter, - get_rate_limiter, - encrypt, decrypt, + encrypt, + get_rate_limiter, ) @@ -88,6 +90,12 @@ def test_decrypt(): cExDsxcGU7ZcTO9WVwDhqF/9ofkXfLOFKxugLNEA5RA3gRcpCxMRLS4k6dfN9N9o 3RQZkF/usTTvyvFQR96frZb2FQ== -----END PRIVATE KEY-----""" - encrypted = "TLlMBh4GDf3BSsO/RKlqG5H7Sxv7OXGbl8qE/6YLQPm3coBbnrRRReX7pLamnjLPUU0PtIRIg2H/hWBWE/3cRtXDPT7jMtmGHMIPO/95A0DkrndIkOeQ29J6TBPBBG6YqBNRb2dyhDBwDIEDjPTiRe68sYz4KkxzSOkcz31314kSkZvdIDtQOgeRDa0/7U0VrJePL2N7SJvEiHf4Xa3vW3/20S3O8s/Yp0Azb/kS9dFa54VO1fNNhJ46OtPpdekiFDR5yvQfHwFVeSDdY+eAuYLTWa6bz/LrQkRAdRi9EW5Iz/q8WgKhZXQJfcXtiKfVuFar2N2KodY7C/45vMOfvw==" + encrypted = """ +TLlMBh4GDf3BSsO/RKlqG5H7Sxv7OXGbl8qE/6YLQPm3coBbnrRRReX7pLamnjLP +UU0PtIRIg2H/hWBWE/3cRtXDPT7jMtmGHMIPO/95A0DkrndIkOeQ29J6TBPBBG6Y +qBNRb2dyhDBwDIEDjPTiRe68sYz4KkxzSOkcz31314kSkZvdIDtQOgeRDa0/7U0V +rJePL2N7SJvEiHf4Xa3vW3/20S3O8s/Yp0Azb/kS9dFa54VO1fNNhJ46OtPpdeki +FDR5yvQfHwFVeSDdY+eAuYLTWa6bz/LrQkRAdRi9EW5Iz/q8WgKhZXQJfcXtiKfV +uFar2N2KodY7C/45vMOfvw==""" decrypted = decrypt(encrypted, fake_private_key_pem) assert decrypted == "test" diff --git a/tests/app/cli/test_cli.py b/tests/app/cli/test_cli.py index b910817..c261e36 100644 --- a/tests/app/cli/test_cli.py +++ b/tests/app/cli/test_cli.py @@ -1,9 +1,11 @@ import os -import pytest from unittest.mock import patch -from cli.cli import cmd_app + +import pytest from typer.testing import CliRunner +from cli.cli import cmd_app + MODEL_PARENT_DIR = os.path.join(os.path.dirname(__file__), "..", "..", "resources", "model") runner = CliRunner() @@ -15,11 +17,24 @@ def test_serve_help(): assert "This serves various CogStack NLP models" in result.output -@pytest.mark.skipif(not os.path.exists(os.path.join(MODEL_PARENT_DIR, "deid_model.zip")), - reason="requires the model file to be present in the resources folder") +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "deid_model.zip")), + reason="requires the model file to be present in the resources folder", +) def test_serve_model(): with patch("uvicorn.run", side_effect=KeyboardInterrupt): - result = runner.invoke(cmd_app, ["serve", "--model-type", "medcat_deid", "--model-name", "deid model", "--model-path", os.path.join(MODEL_PARENT_DIR, "deid_model.zip")]) + result = runner.invoke( + cmd_app, + [ + "serve", + "--model-type", + "medcat_deid", + "--model-name", + "deid model", + "--model-path", + os.path.join(MODEL_PARENT_DIR, "deid_model.zip"), + ], + ) assert result.exit_code == 1 assert "\nAborted.\n" in result.output @@ -30,10 +45,23 @@ def test_register_help(): assert "This pushes a pretrained NLP model to the CogStack ModelServe registry" in result.output -@pytest.mark.skipif(not os.path.exists(os.path.join(MODEL_PARENT_DIR, "deid_model.zip")), - reason="requires the model file to be present in the resources folder") +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "deid_model.zip")), + reason="requires the model file to be present in the resources folder", +) def test_register_nodel(): - result = runner.invoke(cmd_app, ["register", "--model-type", "medcat_deid", "--model-name", "deid model", "--model-path", os.path.join(MODEL_PARENT_DIR, "deid_model.zip")]) + result = runner.invoke( + cmd_app, + [ + "register", + "--model-type", + "medcat_deid", + "--model-name", + "deid model", + "--model-path", + os.path.join(MODEL_PARENT_DIR, "deid_model.zip"), + ], + ) assert result.exit_code == 0 assert "as a new model version" in result.output diff --git a/tests/app/conftest.py b/tests/app/conftest.py index 34a1b92..9705922 100644 --- a/tests/app/conftest.py +++ b/tests/app/conftest.py @@ -1,13 +1,16 @@ import os -import pytest from unittest.mock import Mock + +import pytest + from config import Settings -from model_services.medcat_model_snomed import MedCATModelSnomed + +from model_services.huggingface_ner_model import HuggingFaceNerModel +from model_services.medcat_model_deid import MedCATModelDeIdentification from model_services.medcat_model_icd10 import MedCATModelIcd10 +from model_services.medcat_model_snomed import MedCATModelSnomed from model_services.medcat_model_umls import MedCATModelUmls -from model_services.medcat_model_deid import MedCATModelDeIdentification from model_services.trf_model_deid import TransformersModelDeIdentification -from model_services.huggingface_ner_model import HuggingFaceNerModel MODEL_PARENT_DIR = os.path.join(os.path.dirname(__file__), "..", "resources", "model") diff --git a/tests/app/data/test_anno_dataset.py b/tests/app/data/test_anno_dataset.py index 47a4dee..4eaf6bf 100644 --- a/tests/app/data/test_anno_dataset.py +++ b/tests/app/data/test_anno_dataset.py @@ -1,11 +1,26 @@ import os + import datasets + from app.data import anno_dataset def test_load_dataset(): - trainer_export = os.path.join(os.path.dirname(__file__), "..", "..", "resources", "fixture", "trainer_export_multi_projs.json") - dataset = datasets.load_dataset(anno_dataset.__file__, data_files={"annotations": trainer_export}, split="train", cache_dir="/tmp", trust_remote_code=True) + trainer_export = os.path.join( + os.path.dirname(__file__), + "..", + "..", + "resources", + "fixture", + "trainer_export_multi_projs.json", + ) + dataset = datasets.load_dataset( + anno_dataset.__file__, + data_files={"annotations": trainer_export}, + split="train", + cache_dir="/tmp", + trust_remote_code=True, + ) assert dataset.features.to_dict() == { "project": {"dtype": "string", "_type": "Value"}, "name": {"dtype": "string", "_type": "Value"}, @@ -23,7 +38,13 @@ def test_load_dataset(): def test_generate_examples(): - example_gen = anno_dataset.generate_examples([os.path.join(os.path.dirname(__file__), "..", "..", "resources", "fixture", "trainer_export.json")]) + example_gen = anno_dataset.generate_examples( + [ + os.path.join( + os.path.dirname(__file__), "..", "..", "resources", "fixture", "trainer_export.json" + ) + ] + ) example = next(example_gen) assert example[0] == "1" assert "project" in example[1] diff --git a/tests/app/data/test_doc_dataset.py b/tests/app/data/test_doc_dataset.py index e39e257..80542ba 100644 --- a/tests/app/data/test_doc_dataset.py +++ b/tests/app/data/test_doc_dataset.py @@ -1,18 +1,37 @@ import os + import datasets + from app.data import doc_dataset def test_load_dataset(): - sample_texts = os.path.join(os.path.dirname(__file__), "..", "..", "resources", "fixture", "sample_texts.json") - dataset = datasets.load_dataset(doc_dataset.__file__, data_files={"documents": sample_texts}, split="train", cache_dir="/tmp", trust_remote_code=True) - assert dataset.features.to_dict() == {"name": {"dtype": "string", "_type": "Value"}, "text": {"dtype": "string", "_type": "Value"}} + sample_texts = os.path.join( + os.path.dirname(__file__), "..", "..", "resources", "fixture", "sample_texts.json" + ) + dataset = datasets.load_dataset( + doc_dataset.__file__, + data_files={"documents": sample_texts}, + split="train", + cache_dir="/tmp", + trust_remote_code=True, + ) + assert dataset.features.to_dict() == { + "name": {"dtype": "string", "_type": "Value"}, + "text": {"dtype": "string", "_type": "Value"}, + } assert len(dataset.to_list()) == 15 assert dataset.to_list()[0]["name"] == "1" def test_generate_examples(): - example_gen = doc_dataset.generate_examples([os.path.join(os.path.dirname(__file__), "..", "..", "resources", "fixture", "sample_texts.json")]) + example_gen = doc_dataset.generate_examples( + [ + os.path.join( + os.path.dirname(__file__), "..", "..", "resources", "fixture", "sample_texts.json" + ) + ] + ) example = next(example_gen) assert example[0] == "1" assert "name" in example[1] diff --git a/tests/app/helper.py b/tests/app/helper.py index e419f07..20d4a2c 100644 --- a/tests/app/helper.py +++ b/tests/app/helper.py @@ -1,4 +1,5 @@ import time + import mlflow diff --git a/tests/app/model_services/test_huggingface_ner_model.py b/tests/app/model_services/test_huggingface_ner_model.py index 2340490..7f09288 100644 --- a/tests/app/model_services/test_huggingface_ner_model.py +++ b/tests/app/model_services/test_huggingface_ner_model.py @@ -1,10 +1,13 @@ import os import tempfile -import pytest from unittest.mock import Mock -from tests.app.conftest import MODEL_PARENT_DIR + +import pytest from transformers import PreTrainedModel, PreTrainedTokenizerBase + from domain import ModelType +from tests.app.conftest import MODEL_PARENT_DIR + from model_services.huggingface_ner_model import HuggingFaceNerModel @@ -17,30 +20,40 @@ def test_api_version(huggingface_ner_model): def test_from_model(huggingface_ner_model): - new_model_service = huggingface_ner_model.from_model(huggingface_ner_model.model, huggingface_ner_model.tokenizer) + new_model_service = huggingface_ner_model.from_model( + huggingface_ner_model.model, huggingface_ner_model.tokenizer + ) assert isinstance(new_model_service, HuggingFaceNerModel) assert new_model_service.model == huggingface_ner_model.model assert new_model_service.tokenizer == huggingface_ner_model.tokenizer -@pytest.mark.skipif(not os.path.exists(os.path.join(MODEL_PARENT_DIR, "huggingface_ner_model.zip")), - reason="requires the model file to be present in the resources folder") +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "huggingface_ner_model.zip")), + reason="requires the model file to be present in the resources folder", +) def test_init_model(huggingface_ner_model): huggingface_ner_model.init_model() assert huggingface_ner_model.model is not None assert huggingface_ner_model.tokenizer is not None -@pytest.mark.skipif(not os.path.exists(os.path.join(MODEL_PARENT_DIR, "huggingface_ner_model.zip")), - reason="requires the model file to be present in the resources folder") +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "huggingface_ner_model.zip")), + reason="requires the model file to be present in the resources folder", +) def test_load_model(huggingface_ner_model): - model, tokenizer = HuggingFaceNerModel.load_model(os.path.join(MODEL_PARENT_DIR, "huggingface_ner_model.zip")) + model, tokenizer = HuggingFaceNerModel.load_model( + os.path.join(MODEL_PARENT_DIR, "huggingface_ner_model.zip") + ) assert isinstance(model, PreTrainedModel) assert isinstance(tokenizer, PreTrainedTokenizerBase) -@pytest.mark.skipif(not os.path.exists(os.path.join(MODEL_PARENT_DIR, "huggingface_ner_model.zip")), - reason="requires the model file to be present in the resources folder") +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "huggingface_ner_model.zip")), + reason="requires the model file to be present in the resources folder", +) def test_info(huggingface_ner_model): huggingface_ner_model.init_model() model_card = huggingface_ner_model.info() @@ -49,41 +62,55 @@ def test_info(huggingface_ner_model): assert model_card.model_type == ModelType.HUGGINGFACE_NER -@pytest.mark.skipif(not os.path.exists(os.path.join(MODEL_PARENT_DIR, "huggingface_ner_model.zip")), - reason="requires the model file to be present in the resources folder") +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "huggingface_ner_model.zip")), + reason="requires the model file to be present in the resources folder", +) def test_annotate(huggingface_ner_model): huggingface_ner_model.init_model() annotations = huggingface_ner_model.annotate( - """The patient is a 60-year-old female, who complained of coughing during meals. """ - """ Her outpatient evaluation revealed a mild-to-moderate cognitive linguistic deficit, which was completed approximately""" - """ 2 months ago. The patient had a history of hypertension and TIA/stroke. The patient denied history of heartburn""" - """ and/or gastroesophageal reflux disorder. A modified barium swallow study was ordered to objectively evaluate the""" - """ patient's swallowing function and safety and to rule out aspiration.,OBJECTIVE: , Modified barium swallow study""" - """ was performed in the Radiology Suite in cooperation with Dr. ABC. The patient was seated upright in a video imaging""" - """ chair throughout this assessment. To evaluate the patient's swallowing function and safety, she was administered""" - """ graduated amounts of liquid and food mixed with barium in the form of thin liquid (teaspoon x2, cup sip x2); nectar-thick""" - """ liquid (teaspoon x2, cup sip x2); puree consistency (teaspoon x2); and solid food consistency (1/4 cracker x1).,ASSESSMENT,""" - """ ORAL STAGE:, Premature spillage to the level of the valleculae and pyriform sinuses with thin liquid. Decreased""" - """ tongue base retraction, which contributed to vallecular pooling after the swallow.,PHARYNGEAL STAGE: , No aspiration""" - """ was observed during this evaluation. Penetration was noted with cup sips of thin liquid only. Trace residual on""" - """ the valleculae and on tongue base with nectar-thick puree and solid consistencies. The patient's hyolaryngeal""" - """ elevation and anterior movement are within functional limits. Epiglottic inversion is within functional limits.,""" - """ CERVICAL ESOPHAGEAL STAGE: ,The patient's upper esophageal sphincter opening is well coordinated with swallow and""" - """ readily accepted the bolus. Radiologist noted reduced peristaltic action of the constricted muscles in the esophagus,""" - """ which may be contributing to the patient's complaint of globus sensation.,DIAGNOSTIC IMPRESSION:, No aspiration was""" - """ noted during this evaluation. Penetration with cup sips of thin liquid. The patient did cough during this evaluation,""" - """ but that was noted related to aspiration or penetration.,PROGNOSTIC IMPRESSION: ,Based on this evaluation, the prognosis""" - """ for swallowing and safety is good.,PLAN: , Based on this evaluation and following recommendations are being made:,1. """ - """ The patient to take small bite and small sips to help decrease the risk of aspiration and penetration.,2. The patient""" - """ should remain upright at a 90-degree angle for at least 45 minutes after meals to decrease the risk of aspiration and""" - """ penetration as well as to reduce her globus sensation.,3. The patient should be referred to a gastroenterologist for""" - """ further evaluation of her esophageal function.,The patient does not need any skilled speech therapy for her swallowing""" - """ abilities at this time, and she is discharged from my services.). Dr. ABC""") + "The patient is a 60-year-old female, who complained of coughing during meals. Her" + " outpatient evaluation revealed a mild-to-moderate cognitive linguistic deficit, which was" + " completed approximately 2 months ago. The patient had a history of hypertension and" + " TIA/stroke. The patient denied history of heartburn and/or gastroesophageal reflux" + " disorder. A modified barium swallow study was ordered to objectively evaluate the" + " patient's swallowing function and safety and to rule out aspiration.,OBJECTIVE: ," + " Modified barium swallow study was performed in the Radiology Suite in cooperation with" + " Dr. ABC. The patient was seated upright in a video imaging chair throughout this" + " assessment. To evaluate the patient's swallowing function and safety, she was" + " administered graduated amounts of liquid and food mixed with barium in the form of thin" + " liquid (teaspoon x2, cup sip x2); nectar-thick liquid (teaspoon x2, cup sip x2); puree" + " consistency (teaspoon x2); and solid food consistency (1/4 cracker x1).,ASSESSMENT, ORAL" + " STAGE:, Premature spillage to the level of the valleculae and pyriform sinuses with thin" + " liquid. Decreased tongue base retraction, which contributed to vallecular pooling after" + " the swallow.,PHARYNGEAL STAGE: , No aspiration was observed during this evaluation. " + " Penetration was noted with cup sips of thin liquid only. Trace residual on the" + " valleculae and on tongue base with nectar-thick puree and solid consistencies. The" + " patient's hyolaryngeal elevation and anterior movement are within functional limits. " + " Epiglottic inversion is within functional limits., CERVICAL ESOPHAGEAL STAGE: ,The" + " patient's upper esophageal sphincter opening is well coordinated with swallow and readily" + " accepted the bolus. Radiologist noted reduced peristaltic action of the constricted" + " muscles in the esophagus, which may be contributing to the patient's complaint of globus" + " sensation.,DIAGNOSTIC IMPRESSION:, No aspiration was noted during this evaluation. " + " Penetration with cup sips of thin liquid. The patient did cough during this evaluation," + " but that was noted related to aspiration or penetration.,PROGNOSTIC IMPRESSION: ,Based on" + " this evaluation, the prognosis for swallowing and safety is good.,PLAN: , Based on this" + " evaluation and following recommendations are being made:,1. The patient to take small" + " bite and small sips to help decrease the risk of aspiration and penetration.,2. The" + " patient should remain upright at a 90-degree angle for at least 45 minutes after meals to" + " decrease the risk of aspiration and penetration as well as to reduce her globus" + " sensation.,3. The patient should be referred to a gastroenterologist for further" + " evaluation of her esophageal function.,The patient does not need any skilled speech" + " therapy for her swallowing abilities at this time, and she is discharged from my" + " services.). Dr. ABC" + ) assert isinstance(annotations, list) -@pytest.mark.skipif(not os.path.exists(os.path.join(MODEL_PARENT_DIR, "huggingface_ner_model.zip")), - reason="requires the model file to be present in the resources folder") +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "huggingface_ner_model.zip")), + reason="requires the model file to be present in the resources folder", +) def test_train_unsupervised(huggingface_ner_model): huggingface_ner_model.init_model() huggingface_ner_model._config.REDEPLOY_TRAINED_MODEL = "false" @@ -94,8 +121,10 @@ def test_train_unsupervised(huggingface_ner_model): huggingface_ner_model._unsupervised_trainer.train.assert_called() -@pytest.mark.skipif(not os.path.exists(os.path.join(MODEL_PARENT_DIR, "huggingface_ner_model.zip")), - reason="requires the model file to be present in the resources folder") +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "huggingface_ner_model.zip")), + reason="requires the model file to be present in the resources folder", +) def test_train_supervised(huggingface_ner_model): huggingface_ner_model.init_model() huggingface_ner_model._config.REDEPLOY_TRAINED_MODEL = "false" diff --git a/tests/app/model_services/test_medcat_model_deid.py b/tests/app/model_services/test_medcat_model_deid.py index 00735b8..acfc22b 100644 --- a/tests/app/model_services/test_medcat_model_deid.py +++ b/tests/app/model_services/test_medcat_model_deid.py @@ -1,10 +1,13 @@ import os import tempfile -import pytest from unittest.mock import Mock -from tests.app.conftest import MODEL_PARENT_DIR + +import pytest from medcat.cat import CAT + from domain import ModelType +from tests.app.conftest import MODEL_PARENT_DIR + from model_services.medcat_model_deid import MedCATModelDeIdentification @@ -22,22 +25,28 @@ def test_from_model(medcat_deid_model): assert new_model_service.model == medcat_deid_model.model -@pytest.mark.skipif(not os.path.exists(os.path.join(MODEL_PARENT_DIR, "deid_model.zip")), - reason="requires the model file to be present in the resources folder") +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "deid_model.zip")), + reason="requires the model file to be present in the resources folder", +) def test_init_model(medcat_deid_model): medcat_deid_model.init_model() assert medcat_deid_model.model is not None -@pytest.mark.skipif(not os.path.exists(os.path.join(MODEL_PARENT_DIR, "deid_model.zip")), - reason="requires the model file to be present in the resources folder") +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "deid_model.zip")), + reason="requires the model file to be present in the resources folder", +) def test_load_model(medcat_deid_model): cat = MedCATModelDeIdentification.load_model(os.path.join(MODEL_PARENT_DIR, "deid_model.zip")) assert type(cat) is CAT -@pytest.mark.skipif(not os.path.exists(os.path.join(MODEL_PARENT_DIR, "deid_model.zip")), - reason="requires the model file to be present in the resources folder") +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "deid_model.zip")), + reason="requires the model file to be present in the resources folder", +) def test_info(medcat_deid_model): medcat_deid_model.init_model() model_card = medcat_deid_model.info() @@ -46,36 +55,48 @@ def test_info(medcat_deid_model): assert model_card.model_type == ModelType.ANONCAT -@pytest.mark.skipif(not os.path.exists(os.path.join(MODEL_PARENT_DIR, "deid_model.zip")), - reason="requires the model file to be present in the resources folder") +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "deid_model.zip")), + reason="requires the model file to be present in the resources folder", +) def test_annotate(medcat_deid_model): medcat_deid_model.init_model() annotations = medcat_deid_model.annotate( - """The patient is a 60-year-old female, who complained of coughing during meals. """ - """ Her outpatient evaluation revealed a mild-to-moderate cognitive linguistic deficit, which was completed approximately""" - """ 2 months ago. The patient had a history of hypertension and TIA/stroke. The patient denied history of heartburn""" - """ and/or gastroesophageal reflux disorder. A modified barium swallow study was ordered to objectively evaluate the""" - """ patient's swallowing function and safety and to rule out aspiration.,OBJECTIVE: , Modified barium swallow study""" - """ was performed in the Radiology Suite in cooperation with Dr. ABC. The patient was seated upright in a video imaging""" - """ chair throughout this assessment. To evaluate the patient's swallowing function and safety, she was administered""" - """ graduated amounts of liquid and food mixed with barium in the form of thin liquid (teaspoon x2, cup sip x2); nectar-thick""" - """ liquid (teaspoon x2, cup sip x2); puree consistency (teaspoon x2); and solid food consistency (1/4 cracker x1).,ASSESSMENT,""" - """ ORAL STAGE:, Premature spillage to the level of the valleculae and pyriform sinuses with thin liquid. Decreased""" - """ tongue base retraction, which contributed to vallecular pooling after the swallow.,PHARYNGEAL STAGE: , No aspiration""" - """ was observed during this evaluation. Penetration was noted with cup sips of thin liquid only. Trace residual on""" - """ the valleculae and on tongue base with nectar-thick puree and solid consistencies. The patient's hyolaryngeal""" - """ elevation and anterior movement are within functional limits. Epiglottic inversion is within functional limits.,""" - """ CERVICAL ESOPHAGEAL STAGE: ,The patient's upper esophageal sphincter opening is well coordinated with swallow and""" - """ readily accepted the bolus. Radiologist noted reduced peristaltic action of the constricted muscles in the esophagus,""" - """ which may be contributing to the patient's complaint of globus sensation.,DIAGNOSTIC IMPRESSION:, No aspiration was""" - """ noted during this evaluation. Penetration with cup sips of thin liquid. The patient did cough during this evaluation,""" - """ but that was noted related to aspiration or penetration.,PROGNOSTIC IMPRESSION: ,Based on this evaluation, the prognosis""" - """ for swallowing and safety is good.,PLAN: , Based on this evaluation and following recommendations are being made:,1. """ - """ The patient to take small bite and small sips to help decrease the risk of aspiration and penetration.,2. The patient""" - """ should remain upright at a 90-degree angle for at least 45 minutes after meals to decrease the risk of aspiration and""" - """ penetration as well as to reduce her globus sensation.,3. The patient should be referred to a gastroenterologist for""" - """ further evaluation of her esophageal function.,The patient does not need any skilled speech therapy for her swallowing""" - """ abilities at this time, and she is discharged from my services.). Dr. ABC""") + "The patient is a 60-year-old female, who complained of coughing during meals. Her" + " outpatient evaluation revealed a mild-to-moderate cognitive linguistic deficit, which was" + " completed approximately 2 months ago. The patient had a history of hypertension and" + " TIA/stroke. The patient denied history of heartburn and/or gastroesophageal reflux" + " disorder. A modified barium swallow study was ordered to objectively evaluate the" + " patient's swallowing function and safety and to rule out aspiration.,OBJECTIVE: ," + " Modified barium swallow study was performed in the Radiology Suite in cooperation with" + " Dr. ABC. The patient was seated upright in a video imaging chair throughout this" + " assessment. To evaluate the patient's swallowing function and safety, she was" + " administered graduated amounts of liquid and food mixed with barium in the form of thin" + " liquid (teaspoon x2, cup sip x2); nectar-thick liquid (teaspoon x2, cup sip x2); puree" + " consistency (teaspoon x2); and solid food consistency (1/4 cracker x1).,ASSESSMENT, ORAL" + " STAGE:, Premature spillage to the level of the valleculae and pyriform sinuses with thin" + " liquid. Decreased tongue base retraction, which contributed to vallecular pooling after" + " the swallow.,PHARYNGEAL STAGE: , No aspiration was observed during this evaluation. " + " Penetration was noted with cup sips of thin liquid only. Trace residual on the" + " valleculae and on tongue base with nectar-thick puree and solid consistencies. The" + " patient's hyolaryngeal elevation and anterior movement are within functional limits. " + " Epiglottic inversion is within functional limits., CERVICAL ESOPHAGEAL STAGE: ,The" + " patient's upper esophageal sphincter opening is well coordinated with swallow and readily" + " accepted the bolus. Radiologist noted reduced peristaltic action of the constricted" + " muscles in the esophagus, which may be contributing to the patient's complaint of globus" + " sensation.,DIAGNOSTIC IMPRESSION:, No aspiration was noted during this evaluation. " + " Penetration with cup sips of thin liquid. The patient did cough during this evaluation," + " but that was noted related to aspiration or penetration.,PROGNOSTIC IMPRESSION: ,Based on" + " this evaluation, the prognosis for swallowing and safety is good.,PLAN: , Based on this" + " evaluation and following recommendations are being made:,1. The patient to take small" + " bite and small sips to help decrease the risk of aspiration and penetration.,2. The" + " patient should remain upright at a 90-degree angle for at least 45 minutes after meals to" + " decrease the risk of aspiration and penetration as well as to reduce her globus" + " sensation.,3. The patient should be referred to a gastroenterologist for further" + " evaluation of her esophageal function.,The patient does not need any skilled speech" + " therapy for her swallowing abilities at this time, and she is discharged from my" + " services.). Dr. ABC" + ) assert len(annotations) == 2 assert type(annotations[0]["label_name"]) is str assert type(annotations[1]["label_name"]) is str @@ -91,36 +112,48 @@ def test_annotate(medcat_deid_model): assert annotations[1]["categories"] == ["PII"] -@pytest.mark.skipif(not os.path.exists(os.path.join(MODEL_PARENT_DIR, "deid_model.zip")), - reason="requires the model file to be present in the resources folder") +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "deid_model.zip")), + reason="requires the model file to be present in the resources folder", +) def test_annotate_with_local_chunking(medcat_deid_model): medcat_deid_model.init_model() annotations = medcat_deid_model.annotate_with_local_chunking( - """The patient is a 60-year-old female, who complained of coughing during meals. """ - """ Her outpatient evaluation revealed a mild-to-moderate cognitive linguistic deficit, which was completed approximately""" - """ 2 months ago. The patient had a history of hypertension and TIA/stroke. The patient denied history of heartburn""" - """ and/or gastroesophageal reflux disorder. A modified barium swallow study was ordered to objectively evaluate the""" - """ patient's swallowing function and safety and to rule out aspiration.,OBJECTIVE: , Modified barium swallow study""" - """ was performed in the Radiology Suite in cooperation with Dr. ABC. The patient was seated upright in a video imaging""" - """ chair throughout this assessment. To evaluate the patient's swallowing function and safety, she was administered""" - """ graduated amounts of liquid and food mixed with barium in the form of thin liquid (teaspoon x2, cup sip x2); nectar-thick""" - """ liquid (teaspoon x2, cup sip x2); puree consistency (teaspoon x2); and solid food consistency (1/4 cracker x1).,ASSESSMENT,""" - """ ORAL STAGE:, Premature spillage to the level of the valleculae and pyriform sinuses with thin liquid. Decreased""" - """ tongue base retraction, which contributed to vallecular pooling after the swallow.,PHARYNGEAL STAGE: , No aspiration""" - """ was observed during this evaluation. Penetration was noted with cup sips of thin liquid only. Trace residual on""" - """ the valleculae and on tongue base with nectar-thick puree and solid consistencies. The patient's hyolaryngeal""" - """ elevation and anterior movement are within functional limits. Epiglottic inversion is within functional limits.,""" - """ CERVICAL ESOPHAGEAL STAGE: ,The patient's upper esophageal sphincter opening is well coordinated with swallow and""" - """ readily accepted the bolus. Radiologist noted reduced peristaltic action of the constricted muscles in the esophagus,""" - """ which may be contributing to the patient's complaint of globus sensation.,DIAGNOSTIC IMPRESSION:, No aspiration was""" - """ noted during this evaluation. Penetration with cup sips of thin liquid. The patient did cough during this evaluation,""" - """ but that was noted related to aspiration or penetration.,PROGNOSTIC IMPRESSION: ,Based on this evaluation, the prognosis""" - """ for swallowing and safety is good.,PLAN: , Based on this evaluation and following recommendations are being made:,1. """ - """ The patient to take small bite and small sips to help decrease the risk of aspiration and penetration.,2. The patient""" - """ should remain upright at a 90-degree angle for at least 45 minutes after meals to decrease the risk of aspiration and""" - """ penetration as well as to reduce her globus sensation.,3. The patient should be referred to a gastroenterologist for""" - """ further evaluation of her esophageal function.,The patient does not need any skilled speech therapy for her swallowing""" - """ abilities at this time, and she is discharged from my services.). Dr. ABC""") + "The patient is a 60-year-old female, who complained of coughing during meals. Her" + " outpatient evaluation revealed a mild-to-moderate cognitive linguistic deficit, which was" + " completed approximately 2 months ago. The patient had a history of hypertension and" + " TIA/stroke. The patient denied history of heartburn and/or gastroesophageal reflux" + " disorder. A modified barium swallow study was ordered to objectively evaluate the" + " patient's swallowing function and safety and to rule out aspiration.,OBJECTIVE: ," + " Modified barium swallow study was performed in the Radiology Suite in cooperation with" + " Dr. ABC. The patient was seated upright in a video imaging chair throughout this" + " assessment. To evaluate the patient's swallowing function and safety, she was" + " administered graduated amounts of liquid and food mixed with barium in the form of thin" + " liquid (teaspoon x2, cup sip x2); nectar-thick liquid (teaspoon x2, cup sip x2); puree" + " consistency (teaspoon x2); and solid food consistency (1/4 cracker x1).,ASSESSMENT, ORAL" + " STAGE:, Premature spillage to the level of the valleculae and pyriform sinuses with thin" + " liquid. Decreased tongue base retraction, which contributed to vallecular pooling after" + " the swallow.,PHARYNGEAL STAGE: , No aspiration was observed during this evaluation. " + " Penetration was noted with cup sips of thin liquid only. Trace residual on the" + " valleculae and on tongue base with nectar-thick puree and solid consistencies. The" + " patient's hyolaryngeal elevation and anterior movement are within functional limits. " + " Epiglottic inversion is within functional limits., CERVICAL ESOPHAGEAL STAGE: ,The" + " patient's upper esophageal sphincter opening is well coordinated with swallow and readily" + " accepted the bolus. Radiologist noted reduced peristaltic action of the constricted" + " muscles in the esophagus, which may be contributing to the patient's complaint of globus" + " sensation.,DIAGNOSTIC IMPRESSION:, No aspiration was noted during this evaluation. " + " Penetration with cup sips of thin liquid. The patient did cough during this evaluation," + " but that was noted related to aspiration or penetration.,PROGNOSTIC IMPRESSION: ,Based on" + " this evaluation, the prognosis for swallowing and safety is good.,PLAN: , Based on this" + " evaluation and following recommendations are being made:,1. The patient to take small" + " bite and small sips to help decrease the risk of aspiration and penetration.,2. The" + " patient should remain upright at a 90-degree angle for at least 45 minutes after meals to" + " decrease the risk of aspiration and penetration as well as to reduce her globus" + " sensation.,3. The patient should be referred to a gastroenterologist for further" + " evaluation of her esophageal function.,The patient does not need any skilled speech" + " therapy for her swallowing abilities at this time, and she is discharged from my" + " services.). Dr. ABC" + ) assert len(annotations) == 2 assert type(annotations[0]["label_name"]) is str assert type(annotations[1]["label_name"]) is str @@ -136,11 +169,15 @@ def test_annotate_with_local_chunking(medcat_deid_model): assert annotations[1]["categories"] == ["PII"] -@pytest.mark.skipif(not os.path.exists(os.path.join(MODEL_PARENT_DIR, "deid_model.zip")), - reason="requires the model file to be present in the resources folder") +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "deid_model.zip")), + reason="requires the model file to be present in the resources folder", +) def test_batch_annotate(medcat_deid_model): medcat_deid_model.init_model() - annotation_list = medcat_deid_model.batch_annotate(["This is a post code NW1 2DA", "This is a post code NW1 2DA"]) + annotation_list = medcat_deid_model.batch_annotate( + ["This is a post code NW1 2DA", "This is a post code NW1 2DA"] + ) assert len(annotation_list) == 2 assert type(annotation_list[0][0]["label_name"]) is str assert type(annotation_list[1][0]["label_name"]) is str @@ -150,8 +187,10 @@ def test_batch_annotate(medcat_deid_model): assert annotation_list[1][0]["accuracy"] > 0 -@pytest.mark.skipif(not os.path.exists(os.path.join(MODEL_PARENT_DIR, "deid_model.zip")), - reason="requires the model file to be present in the resources folder") +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "deid_model.zip")), + reason="requires the model file to be present in the resources folder", +) def test_train_supervised(medcat_deid_model): medcat_deid_model.init_model() medcat_deid_model._config.REDEPLOY_TRAINED_MODEL = "false" diff --git a/tests/app/model_services/test_medcat_model_icd10.py b/tests/app/model_services/test_medcat_model_icd10.py index 454a42b..c02ea4e 100644 --- a/tests/app/model_services/test_medcat_model_icd10.py +++ b/tests/app/model_services/test_medcat_model_icd10.py @@ -1,10 +1,13 @@ import os import tempfile -import pytest from unittest.mock import Mock -from tests.app.conftest import MODEL_PARENT_DIR + +import pytest from medcat.cat import CAT + from domain import ModelType +from tests.app.conftest import MODEL_PARENT_DIR + from model_services.medcat_model_icd10 import MedCATModelIcd10 @@ -23,20 +26,24 @@ def test_from_model(medcat_icd10_model): def test_get_records_from_doc(medcat_icd10_model): - records = medcat_icd10_model.get_records_from_doc({ - "entities": - { + records = medcat_icd10_model.get_records_from_doc( + { + "entities": { "0": { "pretty_name": "pretty_name", "cui": "cui", "types": ["type"], "icd10": [{"code": "code", "name": "name"}], - "athena_ids": [{"name": "name_1", "code": "code_1"}, {"name": "name_2", "code": "code_2"}], + "athena_ids": [ + {"name": "name_1", "code": "code_1"}, + {"name": "name_2", "code": "code_2"}, + ], "acc": 1.0, - "meta_anns": {} + "meta_anns": {}, } } - }) + } + ) assert len(records) == 1 assert records[0]["label_name"] == "name" assert records[0]["cui"] == "cui" @@ -47,35 +54,49 @@ def test_get_records_from_doc(medcat_icd10_model): assert records[0]["meta_anns"] == {} -@pytest.mark.skipif(not os.path.exists(os.path.join(MODEL_PARENT_DIR, "icd10_model.zip")), - reason="requires the model file to be present in the resources folder") +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "icd10_model.zip")), + reason="requires the model file to be present in the resources folder", +) def test_init_model_with_no_tui_filter(medcat_icd10_model): original = MedCATModelIcd10.load_model(os.path.join(MODEL_PARENT_DIR, "icd10_model.zip")) medcat_icd10_model._whitelisted_tuis = set([""]) medcat_icd10_model.init_model() assert medcat_icd10_model.model is not None - assert medcat_icd10_model.model.cdb.config.linking.filters.get("cuis") == original.cdb.config.linking.filters.get("cuis") + assert medcat_icd10_model.model.cdb.config.linking.filters.get( + "cuis" + ) == original.cdb.config.linking.filters.get("cuis") -@pytest.mark.skipif(not os.path.exists(os.path.join(MODEL_PARENT_DIR, "icd10_model.zip")), - reason="requires the model file to be present in the resources folder") +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "icd10_model.zip")), + reason="requires the model file to be present in the resources folder", +) def test_init_model(medcat_icd10_model): medcat_icd10_model.init_model() target_tuis = medcat_icd10_model._config.TYPE_UNIQUE_ID_WHITELIST.split(",") - target_cuis = {cui for tui in target_tuis for cui in medcat_icd10_model.model.cdb.addl_info.get("type_id2cuis").get(tui, {})} + target_cuis = { + cui + for tui in target_tuis + for cui in medcat_icd10_model.model.cdb.addl_info.get("type_id2cuis").get(tui, {}) + } assert medcat_icd10_model.model is not None assert medcat_icd10_model.model.cdb.config.linking.filters.get("cuis") == target_cuis -@pytest.mark.skipif(not os.path.exists(os.path.join(MODEL_PARENT_DIR, "icd10_model.zip")), - reason="requires the model file to be present in the resources folder") +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "icd10_model.zip")), + reason="requires the model file to be present in the resources folder", +) def test_load_model(medcat_icd10_model): cat = MedCATModelIcd10.load_model(os.path.join(MODEL_PARENT_DIR, "icd10_model.zip")) assert type(cat) is CAT -@pytest.mark.skipif(not os.path.exists(os.path.join(MODEL_PARENT_DIR, "icd10_model.zip")), - reason="requires the model file to be present in the resources folder") +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "icd10_model.zip")), + reason="requires the model file to be present in the resources folder", +) def test_info(medcat_icd10_model): medcat_icd10_model.init_model() model_card = medcat_icd10_model.info() @@ -84,8 +105,10 @@ def test_info(medcat_icd10_model): assert model_card.model_type == ModelType.MEDCAT_ICD10 -@pytest.mark.skipif(not os.path.exists(os.path.join(MODEL_PARENT_DIR, "icd10_model.zip")), - reason="requires the model file to be present in the resources folder") +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "icd10_model.zip")), + reason="requires the model file to be present in the resources folder", +) def test_annotate(medcat_icd10_model): medcat_icd10_model.init_model() annotations = medcat_icd10_model.annotate("Spinal stenosis") @@ -96,8 +119,10 @@ def test_annotate(medcat_icd10_model): assert annotations[0]["accuracy"] > 0 -@pytest.mark.skipif(not os.path.exists(os.path.join(MODEL_PARENT_DIR, "icd10_model.zip")), - reason="requires the model file to be present in the resources folder") +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "icd10_model.zip")), + reason="requires the model file to be present in the resources folder", +) def test_train_supervised(medcat_icd10_model): medcat_icd10_model.init_model() medcat_icd10_model._config.REDEPLOY_TRAINED_MODEL = "false" @@ -108,8 +133,10 @@ def test_train_supervised(medcat_icd10_model): medcat_icd10_model._supervised_trainer.train.assert_called() -@pytest.mark.skipif(not os.path.exists(os.path.join(MODEL_PARENT_DIR, "icd10_model.zip")), - reason="requires the model file to be present in the resources folder") +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "icd10_model.zip")), + reason="requires the model file to be present in the resources folder", +) def test_train_unsupervised(medcat_icd10_model): medcat_icd10_model.init_model() medcat_icd10_model._config.REDEPLOY_TRAINED_MODEL = "false" diff --git a/tests/app/model_services/test_medcat_model_snomed.py b/tests/app/model_services/test_medcat_model_snomed.py index 861befc..5bb7772 100644 --- a/tests/app/model_services/test_medcat_model_snomed.py +++ b/tests/app/model_services/test_medcat_model_snomed.py @@ -1,10 +1,13 @@ import os import tempfile -import pytest from unittest.mock import Mock -from tests.app.conftest import MODEL_PARENT_DIR + +import pytest from medcat.cat import CAT + from domain import ModelType +from tests.app.conftest import MODEL_PARENT_DIR + from model_services.medcat_model_snomed import MedCATModelSnomed @@ -23,17 +26,22 @@ def test_from_model(medcat_snomed_model): def test_get_records_from_doc(medcat_snomed_model): - records = medcat_snomed_model.get_records_from_doc({ - "entities": { - "0": { - "pretty_name": "pretty_name", - "cui": "cui", - "types": ["type"], - "athena_ids": [{"name": "name_1", "code": "code_1"}, {"name": "name_2", "code": "code_2"}], - "meta_anns": {} + records = medcat_snomed_model.get_records_from_doc( + { + "entities": { + "0": { + "pretty_name": "pretty_name", + "cui": "cui", + "types": ["type"], + "athena_ids": [ + {"name": "name_1", "code": "code_1"}, + {"name": "name_2", "code": "code_2"}, + ], + "meta_anns": {}, + } } } - }) + ) assert len(records) == 1 assert records[0]["label_name"] == "pretty_name" assert records[0]["label_id"] == "cui" @@ -42,35 +50,49 @@ def test_get_records_from_doc(medcat_snomed_model): assert records[0]["meta_anns"] == {} -@pytest.mark.skipif(not os.path.exists(os.path.join(MODEL_PARENT_DIR, "snomed_model.zip")), - reason="requires the model file to be present in the resources folder") +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "snomed_model.zip")), + reason="requires the model file to be present in the resources folder", +) def test_init_model(medcat_snomed_model): medcat_snomed_model.init_model() target_tuis = medcat_snomed_model._config.TYPE_UNIQUE_ID_WHITELIST.split(",") - target_cuis = {cui for tui in target_tuis for cui in medcat_snomed_model.model.cdb.addl_info.get("type_id2cuis").get(tui, {})} + target_cuis = { + cui + for tui in target_tuis + for cui in medcat_snomed_model.model.cdb.addl_info.get("type_id2cuis").get(tui, {}) + } assert medcat_snomed_model.model is not None assert medcat_snomed_model.model.cdb.config.linking.filters.get("cuis") == target_cuis -@pytest.mark.skipif(not os.path.exists(os.path.join(MODEL_PARENT_DIR, "snomed_model.zip")), - reason="requires the model file to be present in the resources folder") +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "snomed_model.zip")), + reason="requires the model file to be present in the resources folder", +) def test_init_model_with_no_tui_filter(medcat_snomed_model): original = MedCATModelSnomed.load_model(os.path.join(MODEL_PARENT_DIR, "snomed_model.zip")) medcat_snomed_model._whitelisted_tuis = set([""]) medcat_snomed_model.init_model() assert medcat_snomed_model.model is not None - assert medcat_snomed_model.model.cdb.config.linking.filters.get("cuis") == original.cdb.config.linking.filters.get("cuis") + assert medcat_snomed_model.model.cdb.config.linking.filters.get( + "cuis" + ) == original.cdb.config.linking.filters.get("cuis") -@pytest.mark.skipif(not os.path.exists(os.path.join(MODEL_PARENT_DIR, "snomed_model.zip")), - reason="requires the model file to be present in the resources folder") +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "snomed_model.zip")), + reason="requires the model file to be present in the resources folder", +) def test_load_model(medcat_snomed_model): cat = MedCATModelSnomed.load_model(os.path.join(MODEL_PARENT_DIR, "snomed_model.zip")) assert type(cat) is CAT -@pytest.mark.skipif(not os.path.exists(os.path.join(MODEL_PARENT_DIR, "snomed_model.zip")), - reason="requires the model file to be present in the resources folder") +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "snomed_model.zip")), + reason="requires the model file to be present in the resources folder", +) def test_info(medcat_snomed_model): medcat_snomed_model.init_model() model_card = medcat_snomed_model.info() @@ -79,8 +101,10 @@ def test_info(medcat_snomed_model): assert model_card.model_type == ModelType.MEDCAT_SNOMED -@pytest.mark.skipif(not os.path.exists(os.path.join(MODEL_PARENT_DIR, "snomed_model.zip")), - reason="requires the model file to be present in the resources folder") +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "snomed_model.zip")), + reason="requires the model file to be present in the resources folder", +) def test_annotate(medcat_snomed_model): medcat_snomed_model.init_model() annotations = medcat_snomed_model.annotate("Spinal stenosis") @@ -91,8 +115,10 @@ def test_annotate(medcat_snomed_model): assert annotations[0]["accuracy"] > 0 -@pytest.mark.skipif(not os.path.exists(os.path.join(MODEL_PARENT_DIR, "snomed_model.zip")), - reason="requires the model file to be present in the resources folder") +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "snomed_model.zip")), + reason="requires the model file to be present in the resources folder", +) def test_train_supervised(medcat_snomed_model): medcat_snomed_model.init_model() medcat_snomed_model._config.REDEPLOY_TRAINED_MODEL = "false" @@ -103,8 +129,10 @@ def test_train_supervised(medcat_snomed_model): medcat_snomed_model._supervised_trainer.train.assert_called() -@pytest.mark.skipif(not os.path.exists(os.path.join(MODEL_PARENT_DIR, "snomed_model.zip")), - reason="requires the model file to be present in the resources folder") +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "snomed_model.zip")), + reason="requires the model file to be present in the resources folder", +) def test_train_unsupervised(medcat_snomed_model): medcat_snomed_model.init_model() medcat_snomed_model._config.REDEPLOY_TRAINED_MODEL = "false" diff --git a/tests/app/model_services/test_medcat_model_umls.py b/tests/app/model_services/test_medcat_model_umls.py index d9863e7..f9d6604 100644 --- a/tests/app/model_services/test_medcat_model_umls.py +++ b/tests/app/model_services/test_medcat_model_umls.py @@ -1,10 +1,13 @@ import os import tempfile -import pytest from unittest.mock import Mock -from tests.app.conftest import MODEL_PARENT_DIR + +import pytest from medcat.cat import CAT + from domain import ModelType +from tests.app.conftest import MODEL_PARENT_DIR + from model_services.medcat_model_umls import MedCATModelUmls @@ -22,22 +25,28 @@ def test_from_model(medcat_umls_model): assert new_model_service.model == medcat_umls_model.model -@pytest.mark.skipif(not os.path.exists(os.path.join(MODEL_PARENT_DIR, "umls_model.zip")), - reason="requires the model file to be present in the resources folder") +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "umls_model.zip")), + reason="requires the model file to be present in the resources folder", +) def test_init_model(medcat_umls_model): medcat_umls_model.init_model() assert medcat_umls_model.model is not None -@pytest.mark.skipif(not os.path.exists(os.path.join(MODEL_PARENT_DIR, "umls_model.zip")), - reason="requires the model file to be present in the resources folder") +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "umls_model.zip")), + reason="requires the model file to be present in the resources folder", +) def test_load_model(medcat_umls_model): cat = MedCATModelUmls.load_model(os.path.join(MODEL_PARENT_DIR, "umls_model.zip")) assert type(cat) is CAT -@pytest.mark.skipif(not os.path.exists(os.path.join(MODEL_PARENT_DIR, "umls_model.zip")), - reason="requires the model file to be present in the resources folder") +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "umls_model.zip")), + reason="requires the model file to be present in the resources folder", +) def test_info(medcat_umls_model): medcat_umls_model.init_model() model_card = medcat_umls_model.info() @@ -46,8 +55,10 @@ def test_info(medcat_umls_model): assert model_card.model_type == ModelType.MEDCAT_UMLS -@pytest.mark.skipif(not os.path.exists(os.path.join(MODEL_PARENT_DIR, "umls_model.zip")), - reason="requires the model file to be present in the resources folder") +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "umls_model.zip")), + reason="requires the model file to be present in the resources folder", +) def test_annotate(medcat_umls_model): medcat_umls_model.init_model() annotations = medcat_umls_model.annotate("Spinal stenosis") @@ -58,8 +69,10 @@ def test_annotate(medcat_umls_model): assert annotations[0]["accuracy"] > 0 -@pytest.mark.skipif(not os.path.exists(os.path.join(MODEL_PARENT_DIR, "umls_model.zip")), - reason="requires the model file to be present in the resources folder") +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "umls_model.zip")), + reason="requires the model file to be present in the resources folder", +) def test_train_supervised(medcat_umls_model): medcat_umls_model.init_model() medcat_umls_model._config.REDEPLOY_TRAINED_MODEL = "false" @@ -70,8 +83,10 @@ def test_train_supervised(medcat_umls_model): medcat_umls_model._supervised_trainer.train.assert_called() -@pytest.mark.skipif(not os.path.exists(os.path.join(MODEL_PARENT_DIR, "umls_model.zip")), - reason="requires the model file to be present in the resources folder") +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "umls_model.zip")), + reason="requires the model file to be present in the resources folder", +) def test_train_unsupervised(medcat_umls_model): medcat_umls_model.init_model() medcat_umls_model._config.REDEPLOY_TRAINED_MODEL = "false" diff --git a/tests/app/model_services/test_trf_model_deid.py b/tests/app/model_services/test_trf_model_deid.py index 2a2e43d..34f790f 100644 --- a/tests/app/model_services/test_trf_model_deid.py +++ b/tests/app/model_services/test_trf_model_deid.py @@ -1,9 +1,12 @@ import os + import pytest -from tests.app.conftest import MODEL_PARENT_DIR -from transformers.models.bert.modeling_bert import BertForTokenClassification from medcat.tokenizers.transformers_ner import TransformersTokenizerNER +from transformers.models.bert.modeling_bert import BertForTokenClassification + from domain import ModelType +from tests.app.conftest import MODEL_PARENT_DIR + from model_services.trf_model_deid import TransformersModelDeIdentification @@ -15,23 +18,31 @@ def test_api_version(trf_model): assert trf_model.api_version == "0.0.1" -@pytest.mark.skipif(not os.path.exists(os.path.join(MODEL_PARENT_DIR, "trf_deid_model.zip")), - reason="requires the model file to be present in the resources folder") +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "trf_deid_model.zip")), + reason="requires the model file to be present in the resources folder", +) def test_init_model(trf_model): trf_model.init_model() assert trf_model.model is not None -@pytest.mark.skipif(not os.path.exists(os.path.join(MODEL_PARENT_DIR, "trf_deid_model.zip")), - reason="requires the model file to be present in the resources folder") +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "trf_deid_model.zip")), + reason="requires the model file to be present in the resources folder", +) def test_load_model(trf_model): - tokenizer, model = TransformersModelDeIdentification.load_model(os.path.join(MODEL_PARENT_DIR, "trf_deid_model.zip")) + tokenizer, model = TransformersModelDeIdentification.load_model( + os.path.join(MODEL_PARENT_DIR, "trf_deid_model.zip") + ) assert type(tokenizer) is TransformersTokenizerNER assert type(model) is BertForTokenClassification -@pytest.mark.skipif(not os.path.exists(os.path.join(MODEL_PARENT_DIR, "trf_deid_model.zip")), - reason="requires the model file to be present in the resources folder") +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "trf_deid_model.zip")), + reason="requires the model file to be present in the resources folder", +) def test_info(trf_model): trf_model.init_model() model_card = trf_model.info() @@ -40,8 +51,10 @@ def test_info(trf_model): assert model_card.model_type == ModelType.TRANSFORMERS_DEID -@pytest.mark.skipif(not os.path.exists(os.path.join(MODEL_PARENT_DIR, "trf_deid_model.zip")), - reason="requires the model file to be present in the resources folder") +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "trf_deid_model.zip")), + reason="requires the model file to be present in the resources folder", +) def test_annotate(trf_model): trf_model.init_model() annotations = trf_model.annotate("NW1 2DA") @@ -51,8 +64,10 @@ def test_annotate(trf_model): assert annotations[0]["end"] == 7 -@pytest.mark.skipif(not os.path.exists(os.path.join(MODEL_PARENT_DIR, "trf_deid_model.zip")), - reason="requires the model file to be present in the resources folder") +@pytest.mark.skipif( + not os.path.exists(os.path.join(MODEL_PARENT_DIR, "trf_deid_model.zip")), + reason="requires the model file to be present in the resources folder", +) def test_batch_annotate(trf_model): trf_model.init_model() annotation_list = trf_model.batch_annotate(["NW1 2DA", "NW1 2DA"]) diff --git a/tests/app/monitoring/test_log_captor.py b/tests/app/monitoring/test_log_captor.py index 3663743..e63e414 100644 --- a/tests/app/monitoring/test_log_captor.py +++ b/tests/app/monitoring/test_log_captor.py @@ -1,4 +1,5 @@ from contextlib import redirect_stdout + from management.log_captor import LogCaptor diff --git a/tests/app/monitoring/test_model_manager.py b/tests/app/monitoring/test_model_manager.py index 707a391..2ceb290 100644 --- a/tests/app/monitoring/test_model_manager.py +++ b/tests/app/monitoring/test_model_manager.py @@ -1,26 +1,33 @@ -import mlflow import tempfile -import pandas as pd from typing import Generator from unittest.mock import Mock, call + +import mlflow +import pandas as pd from mlflow.pyfunc import PythonModelContext -from model_services.base import AbstractModelService -from management.model_manager import ModelManager + from config import Settings from exception import ManagedModelException +from management.model_manager import ModelManager +from model_services.base import AbstractModelService + def test_retrieve_python_model_from_uri(mlflow_fixture): config = Settings() ModelManager.retrieve_python_model_from_uri("model_uri", config) - mlflow.set_tracking_uri.assert_has_calls([call(config.MLFLOW_TRACKING_URI), call(config.MLFLOW_TRACKING_URI)]) + mlflow.set_tracking_uri.assert_has_calls( + [call(config.MLFLOW_TRACKING_URI), call(config.MLFLOW_TRACKING_URI)] + ) mlflow.pyfunc.load_model.assert_called_once_with(model_uri="model_uri") def test_retrieve_model_service_from_uri(mlflow_fixture): config = Settings() model_service = ModelManager.retrieve_model_service_from_uri("model_uri", config) - mlflow.set_tracking_uri.assert_has_calls([call(config.MLFLOW_TRACKING_URI), call(config.MLFLOW_TRACKING_URI)]) + mlflow.set_tracking_uri.assert_has_calls( + [call(config.MLFLOW_TRACKING_URI), call(config.MLFLOW_TRACKING_URI)] + ) mlflow.pyfunc.load_model.assert_called_once_with(model_uri="model_uri") assert model_service._config.BASE_MODEL_FULL_PATH == "model_uri" assert model_service._config == config @@ -30,45 +37,54 @@ def test_download_model_package(mlflow_fixture): try: ModelManager.download_model_package("mlflow_tracking_uri", "/tmp") except ManagedModelException as e: - assert "Cannot find the model .zip file inside artifacts downloaded from mlflow_tracking_uri" == str(e) + assert ( + "Cannot find the model .zip file inside artifacts downloaded from mlflow_tracking_uri" + == str(e) + ) def test_log_model_with_registration(mlflow_fixture): model_manager = ModelManager(_MockedModelService, Settings()) model_info = model_manager.log_model("model_name", "filepath", "model_name") assert model_info is not None - mlflow.pyfunc.log_model.assert_called_once_with(artifact_path="model_name", - python_model=model_manager, - signature=model_manager.model_signature, - code_path=model_manager._get_code_path_list(), - pip_requirements=model_manager._get_pip_requirements_from_file(), - artifacts={"model_path": "filepath"}, - registered_model_name="model_name") + mlflow.pyfunc.log_model.assert_called_once_with( + artifact_path="model_name", + python_model=model_manager, + signature=model_manager.model_signature, + code_path=model_manager._get_code_path_list(), + pip_requirements=model_manager._get_pip_requirements_from_file(), + artifacts={"model_path": "filepath"}, + registered_model_name="model_name", + ) def test_log_model_without_registration(mlflow_fixture): model_manager = ModelManager(_MockedModelService, Settings()) model_info = model_manager.log_model("model_name", "filepath") assert model_info is not None - mlflow.pyfunc.log_model.assert_called_once_with(artifact_path="model_name", - python_model=model_manager, - signature=model_manager.model_signature, - code_path=model_manager._get_code_path_list(), - pip_requirements=model_manager._get_pip_requirements_from_file(), - artifacts={"model_path": "filepath"}, - registered_model_name=None) + mlflow.pyfunc.log_model.assert_called_once_with( + artifact_path="model_name", + python_model=model_manager, + signature=model_manager.model_signature, + code_path=model_manager._get_code_path_list(), + pip_requirements=model_manager._get_pip_requirements_from_file(), + artifacts={"model_path": "filepath"}, + registered_model_name=None, + ) def test_save_model(mlflow_fixture): model_manager = ModelManager(_MockedModelService, Settings()) with tempfile.TemporaryDirectory() as local_dir: model_manager.save_model(local_dir, ".") - mlflow.pyfunc.save_model.assert_called_once_with(path=local_dir, - python_model=model_manager, - signature=model_manager.model_signature, - code_path=model_manager._get_code_path_list(), - pip_requirements=model_manager._get_pip_requirements_from_file(), - artifacts={"model_path": "."}) + mlflow.pyfunc.save_model.assert_called_once_with( + path=local_dir, + python_model=model_manager, + signature=model_manager.model_signature, + code_path=model_manager._get_code_path_list(), + pip_requirements=model_manager._get_pip_requirements_from_file(), + artifacts={"model_path": "."}, + ) def test_load_context(mlflow_fixture): @@ -81,7 +97,7 @@ def test_get_model_signature(): model_manager = ModelManager(_MockedModelService, Settings()) assert model_manager.model_signature.inputs.to_dict() == [ {"type": "string", "name": "name", "required": False}, - {"type": "string", "name": "text", "required": True} + {"type": "string", "name": "text", "required": True}, ] assert model_manager.model_signature.outputs.to_dict() == [ {"type": "string", "name": "doc_name", "required": True}, @@ -100,58 +116,88 @@ def test_predict(mlflow_fixture): model_manager = ModelManager(_MockedModelService, Settings()) model_manager._model_service = Mock() model_manager._model_service.annotate = Mock() - model_manager._model_service.annotate.return_value = [{ - "label_name": "Spinal stenosis", - "label_id": "76107001", - "start": 0, - "end": 15, - "accuracy": 1.0, - "meta_anns": { - "Status": { - "value": "Affirmed", - "confidence": 0.9999833106994629, - "name": "Status" - } - }, - }] - output = model_manager.predict(None, pd.DataFrame([{"name": "doc_1", "text": "text_1"}, {"name": "doc_2", "text": "text_2"}])) + model_manager._model_service.annotate.return_value = [ + { + "label_name": "Spinal stenosis", + "label_id": "76107001", + "start": 0, + "end": 15, + "accuracy": 1.0, + "meta_anns": { + "Status": {"value": "Affirmed", "confidence": 0.9999833106994629, "name": "Status"} + }, + } + ] + output = model_manager.predict( + None, + pd.DataFrame([{"name": "doc_1", "text": "text_1"}, {"name": "doc_2", "text": "text_2"}]), + ) assert output.to_dict() == { "doc_name": {0: "doc_1", 1: "doc_2"}, "label_name": {0: "Spinal stenosis", 1: "Spinal stenosis"}, "label_id": {0: "76107001", 1: "76107001"}, - "start": {0: 0, 1: 0}, "end": {0: 15, 1: 15}, + "start": {0: 0, 1: 0}, + "end": {0: 15, 1: 15}, "accuracy": {0: 1.0, 1: 1.0}, - "meta_anns": {0: {"Status": {"value": "Affirmed", "confidence": 0.9999833106994629, "name": "Status"}}, 1: {"Status": {"value": "Affirmed", "confidence": 0.9999833106994629, "name": "Status"}}}} + "meta_anns": { + 0: { + "Status": {"value": "Affirmed", "confidence": 0.9999833106994629, "name": "Status"} + }, + 1: { + "Status": {"value": "Affirmed", "confidence": 0.9999833106994629, "name": "Status"} + }, + }, + } def test_predict_stream(mlflow_fixture): model_manager = ModelManager(_MockedModelService, Settings()) model_manager._model_service = Mock() model_manager._model_service.annotate = Mock() - model_manager._model_service.annotate.return_value = [{ - "label_name": "Spinal stenosis", - "label_id": "76107001", - "start": 0, - "end": 15, - "accuracy": 1.0, - "meta_anns": { - "Status": { - "value": "Affirmed", - "confidence": 0.9999833106994629, - "name": "Status" - } - }, - }] - output = model_manager.predict_stream(None, pd.DataFrame([{"name": "doc_1", "text": "text_1"}, {"name": "doc_2", "text": "text_2"}])) + model_manager._model_service.annotate.return_value = [ + { + "label_name": "Spinal stenosis", + "label_id": "76107001", + "start": 0, + "end": 15, + "accuracy": 1.0, + "meta_anns": { + "Status": {"value": "Affirmed", "confidence": 0.9999833106994629, "name": "Status"} + }, + } + ] + output = model_manager.predict_stream( + None, + pd.DataFrame([{"name": "doc_1", "text": "text_1"}, {"name": "doc_2", "text": "text_2"}]), + ) assert isinstance(output, Generator) assert list(output) == [ - {"doc_name": "doc_1", "label_name": "Spinal stenosis", "label_id": "76107001", "start": 0, "end": 15, "accuracy": 1.0, "meta_anns": {"Status": {"value": "Affirmed", "confidence": 0.9999833106994629, "name": "Status"}}}, - {"doc_name": "doc_2", "label_name": "Spinal stenosis", "label_id": "76107001", "start": 0, "end": 15, "accuracy": 1.0, "meta_anns": {"Status": {"value": "Affirmed", "confidence": 0.9999833106994629, "name": "Status"}}}, + { + "doc_name": "doc_1", + "label_name": "Spinal stenosis", + "label_id": "76107001", + "start": 0, + "end": 15, + "accuracy": 1.0, + "meta_anns": { + "Status": {"value": "Affirmed", "confidence": 0.9999833106994629, "name": "Status"} + }, + }, + { + "doc_name": "doc_2", + "label_name": "Spinal stenosis", + "label_id": "76107001", + "start": 0, + "end": 15, + "accuracy": 1.0, + "meta_anns": { + "Status": {"value": "Affirmed", "confidence": 0.9999833106994629, "name": "Status"} + }, + }, ] class _MockedModelService(AbstractModelService): - def __init__(self, config: Settings, *args, **kwargs) -> None: self._config = config self.model_name = "Mocked Model" diff --git a/tests/app/monitoring/test_tracker_client.py b/tests/app/monitoring/test_tracker_client.py index 18b45c4..da823e5 100644 --- a/tests/app/monitoring/test_tracker_client.py +++ b/tests/app/monitoring/test_tracker_client.py @@ -1,19 +1,29 @@ import os -import mlflow +from unittest.mock import Mock, call + import datasets -import pytest +import mlflow import pandas as pd -from management.tracker_client import TrackerClient -from data import doc_dataset +import pytest + from tests.app.helper import StringContains -from unittest.mock import Mock, call + +from data import doc_dataset +from management.tracker_client import TrackerClient def test_start_new(mlflow_fixture): tracker_client = TrackerClient("") - experiment_id, run_id = tracker_client.start_tracking("model_name", "input_file_name", "base_model_origin", - "training_type", {"param": "param"}, "run_name", 10) + experiment_id, run_id = tracker_client.start_tracking( + "model_name", + "input_file_name", + "base_model_origin", + "training_type", + {"param": "param"}, + "run_name", + 10, + ) mlflow.get_experiment_by_name.assert_called_once_with("model_name_training_type") mlflow.create_experiment.assert_called_once_with(name="model_name_training_type") @@ -74,7 +84,9 @@ def test_save_model_artifact(mlflow_fixture): tracker_client.save_model_artifact("filepath", "model name") - mlflow.log_artifact.assert_called_once_with("filepath", artifact_path=os.path.join("model_name", "artifacts")) + mlflow.log_artifact.assert_called_once_with( + "filepath", artifact_path=os.path.join("model_name", "artifacts") + ) def test_save_raw_artifact(mlflow_fixture): @@ -82,7 +94,9 @@ def test_save_raw_artifact(mlflow_fixture): tracker_client.save_raw_artifact("filepath", "model name") - mlflow.log_artifact.assert_called_once_with("filepath", artifact_path=os.path.join("model_name", "artifacts", "raw")) + mlflow.log_artifact.assert_called_once_with( + "filepath", artifact_path=os.path.join("model_name", "artifacts", "raw") + ) def test_save_processed_artifact(mlflow_fixture): @@ -90,15 +104,21 @@ def test_save_processed_artifact(mlflow_fixture): tracker_client.save_processed_artifact("filepath", "model name") - mlflow.log_artifact.assert_called_once_with("filepath", artifact_path=os.path.join("model_name", "artifacts", "processed")) + mlflow.log_artifact.assert_called_once_with( + "filepath", artifact_path=os.path.join("model_name", "artifacts", "processed") + ) def test_save_dataframe_as_csv(mlflow_fixture): tracker_client = TrackerClient("") - tracker_client.save_dataframe_as_csv("test.csv", pd.DataFrame({"x": ["x1", "x2"], "y": ["y1", "y2"]}), "model_name") + tracker_client.save_dataframe_as_csv( + "test.csv", pd.DataFrame({"x": ["x1", "x2"], "y": ["y1", "y2"]}), "model_name" + ) - mlflow.log_artifact.assert_called_once_with(StringContains("test.csv"), artifact_path=os.path.join("model_name", "stats")) + mlflow.log_artifact.assert_called_once_with( + StringContains("test.csv"), artifact_path=os.path.join("model_name", "stats") + ) def test_save_dict_as_json(mlflow_fixture): @@ -106,7 +126,9 @@ def test_save_dict_as_json(mlflow_fixture): tracker_client.save_dict_as_json("test.json", {"key": {"value": ["v1", "v2"]}}, "model_name") - mlflow.log_artifact.assert_called_once_with(StringContains("test.json"), artifact_path=os.path.join("model_name", "stats")) + mlflow.log_artifact.assert_called_once_with( + StringContains("test.json"), artifact_path=os.path.join("model_name", "stats") + ) def test_save_plot(mlflow_fixture): @@ -114,26 +136,43 @@ def test_save_plot(mlflow_fixture): tracker_client.save_plot("test.png", "model_name") - mlflow.log_artifact.assert_called_once_with(StringContains("test.png"), artifact_path=os.path.join("model_name", "stats")) + mlflow.log_artifact.assert_called_once_with( + StringContains("test.png"), artifact_path=os.path.join("model_name", "stats") + ) def test_save_table_dict(mlflow_fixture): tracker_client = TrackerClient("") - tracker_client.save_table_dict({"col1": ["cell1", "cell2"], "col2": ["cell3", "cell4"]}, "model_name", "table.json") + tracker_client.save_table_dict( + {"col1": ["cell1", "cell2"], "col2": ["cell3", "cell4"]}, "model_name", "table.json" + ) - mlflow.log_table.assert_called_once_with(data={"col1": ["cell1", "cell2"], "col2": ["cell3", "cell4"]}, artifact_file=os.path.join("model_name", "tables", "table.json")) + mlflow.log_table.assert_called_once_with( + data={"col1": ["cell1", "cell2"], "col2": ["cell3", "cell4"]}, + artifact_file=os.path.join("model_name", "tables", "table.json"), + ) def test_save_train_dataset(mlflow_fixture): tracker_client = TrackerClient("") - sample_texts = os.path.join(os.path.dirname(__file__), "..", "..", "resources", "fixture", "sample_texts.json") - dataset = datasets.load_dataset(doc_dataset.__file__, data_files={"documents": sample_texts}, split="train", cache_dir="/tmp", trust_remote_code=True) + sample_texts = os.path.join( + os.path.dirname(__file__), "..", "..", "resources", "fixture", "sample_texts.json" + ) + dataset = datasets.load_dataset( + doc_dataset.__file__, + data_files={"documents": sample_texts}, + split="train", + cache_dir="/tmp", + trust_remote_code=True, + ) tracker_client.save_train_dataset(dataset) assert mlflow.log_input.call_count == 1 - assert isinstance(mlflow.log_input.call_args[0][0], mlflow.data.huggingface_dataset.HuggingFaceDataset) + assert isinstance( + mlflow.log_input.call_args[0][0], mlflow.data.huggingface_dataset.HuggingFaceDataset + ) assert mlflow.log_input.call_args[1]["context"] == "train" @@ -151,11 +190,15 @@ def test_save_model(mlflow_fixture): mlflow_client.search_model_versions.return_value = [version] tracker_client.mlflow_client = mlflow_client - artifact_uri = tracker_client.save_model("path/to/file.zip", "model_name", model_manager, "validation_status") + artifact_uri = tracker_client.save_model( + "path/to/file.zip", "model_name", model_manager, "validation_status" + ) assert "artifacts/model_name" in artifact_uri model_manager.log_model.assert_called_once_with("model_name", "path/to/file.zip", "model_name") - mlflow_client.set_model_version_tag.assert_called_once_with(name="model_name", version="1", key="validation_status", value="validation_status") + mlflow_client.set_model_version_tag.assert_called_once_with( + name="model_name", version="1", key="validation_status", value="validation_status" + ) mlflow.set_tag.assert_called_once_with("training.output.package", "file.zip") @@ -172,19 +215,23 @@ def test_save_pretrained_model(mlflow_fixture): tracker_client = TrackerClient("") model_manager = Mock() - tracker_client.save_pretrained_model("model_name", - "model_path", - model_manager, - "training_type", - "run_name", - {"param": "value"}, - [{"p": 0.8, "r": 0.8}, {"p": 0.9, "r": 0.9}], - {"tag_name": "tag_value"}) + tracker_client.save_pretrained_model( + "model_name", + "model_path", + model_manager, + "training_type", + "run_name", + {"param": "value"}, + [{"p": 0.8, "r": 0.8}, {"p": 0.9, "r": 0.9}], + {"tag_name": "tag_value"}, + ) mlflow.get_experiment_by_name.assert_called_once_with("model_name_training_type") mlflow.start_run.assert_called_once_with(experiment_id="experiment_id") mlflow.log_params.assert_called_once_with({"param": "value"}) - mlflow.log_metrics.assert_has_calls([call({"p": 0.8, "r": 0.8}, 0), call({"p": 0.9, "r": 0.9}, 1)]) + mlflow.log_metrics.assert_has_calls( + [call({"p": 0.8, "r": 0.8}, 0), call({"p": 0.9, "r": 0.9}, 1)] + ) mlflow.set_tags.assert_called() assert mlflow.set_tags.call_args.args[0]["mlflow.runName"] == "run_name" assert mlflow.set_tags.call_args.args[0]["training.base_model.origin"] == "model_path" @@ -209,7 +256,9 @@ def test_log_multiple_exceptions(mlflow_fixture): tracker_client.log_exceptions([Exception("exception_0"), Exception("exception_1")]) - mlflow.set_tag.assert_has_calls([call("exception_0", "exception_0"), call("exception_1", "exception_1")]) + mlflow.set_tag.assert_has_calls( + [call("exception_0", "exception_0"), call("exception_1", "exception_1")] + ) def test_log_classes(mlflow_fixture): @@ -225,7 +274,9 @@ def test_log_classes_and_names(mlflow_fixture): tracker_client.log_classes_and_names({"class_1": "class_1_name", "class_2": "class_2_name"}) - mlflow.set_tag.assert_called_once_with("training.entity.class2names", "{'class_1': 'class_1_name', 'class_2': 'class_2_name'}") + mlflow.set_tag.assert_called_once_with( + "training.entity.class2names", "{'class_1': 'class_1_name', 'class_2': 'class_2_name'}" + ) def test_log_trainer_version(mlflow_fixture): @@ -266,12 +317,17 @@ def test_send_batched_model_stats(mlflow_fixture): tracker_client.mlflow_client = mlflow_client tracker_client.send_batched_model_stats( - [{"m1": "v1", "m2": "v1"}, {"m1": "v2", "m2": "v2"}, {"m1": "v3", "m2": "v3"}], - "run_id", 3) + [{"m1": "v1", "m2": "v1"}, {"m1": "v2", "m2": "v2"}, {"m1": "v3", "m2": "v3"}], "run_id", 3 + ) - mlflow_client.log_batch.assert_has_calls([call(run_id='run_id', metrics=[]), call(run_id='run_id', metrics=[])]) + mlflow_client.log_batch.assert_has_calls( + [call(run_id="run_id", metrics=[]), call(run_id="run_id", metrics=[])] + ) def test_get_experiment_name(): assert TrackerClient.get_experiment_name("SNOMED model") == "SNOMED_model" - assert TrackerClient.get_experiment_name("SNOMED model", "unsupervised") == "SNOMED_model_unsupervised" + assert ( + TrackerClient.get_experiment_name("SNOMED model", "unsupervised") + == "SNOMED_model_unsupervised" + ) diff --git a/tests/app/processors/test_metrics_collector.py b/tests/app/processors/test_metrics_collector.py index eda87f0..1dcec5c 100644 --- a/tests/app/processors/test_metrics_collector.py +++ b/tests/app/processors/test_metrics_collector.py @@ -1,18 +1,21 @@ +import json import os import tempfile -import json -import pytest from unittest.mock import create_autospec + +import pytest + +from exception import AnnotationException + from model_services.base import AbstractModelService from processors.metrics_collector import ( - sanity_check_model_with_trainer_export, concat_trainer_exports, - get_stats_from_trainer_export, get_iaa_scores_per_concept, get_iaa_scores_per_doc, get_iaa_scores_per_span, + get_stats_from_trainer_export, + sanity_check_model_with_trainer_export, ) -from exception import AnnotationException @pytest.fixture @@ -33,12 +36,18 @@ def test_sanity_check_model_with_trainer_export_path(model_service): "label_id": "C0020538", "start": 255, "end": 267, - } + }, ] model_service.annotate.return_value = annotations - path = os.path.join(os.path.join(os.path.dirname(__file__), "..", "..", "resources"), "fixture", "trainer_export.json") - - precision, recall, f1, per_cui_prec, per_cui_rec, per_cui_f1, per_cui_name, per_cui_anchors = sanity_check_model_with_trainer_export(path, model_service) + path = os.path.join( + os.path.join(os.path.dirname(__file__), "..", "..", "resources"), + "fixture", + "trainer_export.json", + ) + + precision, recall, f1, per_cui_prec, per_cui_rec, per_cui_f1, per_cui_name, per_cui_anchors = ( + sanity_check_model_with_trainer_export(path, model_service) + ) assert precision == 0.5 assert recall == 0.07142857142857142 assert f1 == 0.125 @@ -62,10 +71,14 @@ def test_evaluate_model_and_return_dataframe(model_service): "label_id": "C0020538", "start": 255, "end": 267, - } + }, ] model_service.annotate.return_value = annotations - path = os.path.join(os.path.join(os.path.dirname(__file__), "..", "..", "resources"), "fixture", "trainer_export.json") + path = os.path.join( + os.path.join(os.path.dirname(__file__), "..", "..", "resources"), + "fixture", + "trainer_export.json", + ) result = sanity_check_model_with_trainer_export(path, model_service, return_df=True) @@ -90,10 +103,14 @@ def test_sanity_check_model_with_trainer_export_file(model_service): "label_id": "C0020538", "start": 255, "end": 267, - } + }, ] model_service.annotate.return_value = annotations - path = os.path.join(os.path.join(os.path.dirname(__file__), "..", "..", "resources"), "fixture", "trainer_export.json") + path = os.path.join( + os.path.join(os.path.dirname(__file__), "..", "..", "resources"), + "fixture", + "trainer_export.json", + ) with open(path, "r") as file: result = sanity_check_model_with_trainer_export(file, model_service, return_df=True) @@ -119,13 +136,19 @@ def test_sanity_check_model_with_trainer_export_dict(model_service): "label_id": "C0020538", "start": 255, "end": 267, - } + }, ] model_service.annotate.return_value = annotations - path = os.path.join(os.path.join(os.path.dirname(__file__), "..", "..", "resources"), "fixture", "trainer_export.json") + path = os.path.join( + os.path.join(os.path.dirname(__file__), "..", "..", "resources"), + "fixture", + "trainer_export.json", + ) with open(path, "r") as file: - result = sanity_check_model_with_trainer_export(json.load(file), model_service, return_df=True) + result = sanity_check_model_with_trainer_export( + json.load(file), model_service, return_df=True + ) assert set(result["concept"].to_list()) == {"C0020538", "C0017168"} assert set(result["name"].to_list()) == {"gastroesophageal reflux", "hypertension"} @@ -148,24 +171,41 @@ def test_evaluate_model_and_include_anchors(model_service): "label_id": "C0020538", "start": 255, "end": 267, - } + }, ] model_service.annotate.return_value = annotations - path = os.path.join(os.path.join(os.path.dirname(__file__), "..", "..", "resources"), "fixture", "trainer_export.json") + path = os.path.join( + os.path.join(os.path.dirname(__file__), "..", "..", "resources"), + "fixture", + "trainer_export.json", + ) - result = sanity_check_model_with_trainer_export(path, model_service, return_df=True, include_anchors=True) + result = sanity_check_model_with_trainer_export( + path, model_service, return_df=True, include_anchors=True + ) assert set(result["concept"].to_list()) == {"C0020538", "C0017168"} assert set(result["name"].to_list()) == {"gastroesophageal reflux", "hypertension"} assert set(result["precision"].to_list()) == {0.5, 0.5} assert set(result["recall"].to_list()) == {0.25, 1.0} assert set(result["f1"].to_list()) == {0.3333333333333333, 0.6666666666666666} - assert set(result["anchors"].to_list()) == {"P14/D3204/S255/E267;P14/D3205/S255/E267", "P14/D3204/S332/E355;P14/D3205/S332/E355"} + assert set(result["anchors"].to_list()) == { + "P14/D3204/S255/E267;P14/D3205/S255/E267", + "P14/D3204/S332/E355;P14/D3205/S332/E355", + } def test_concat_trainer_exports(): - path_1 = os.path.join(os.path.join(os.path.dirname(__file__), "..", "..", "resources"), "fixture", "trainer_export.json") - path_2 = os.path.join(os.path.join(os.path.dirname(__file__), "..", "..", "resources"), "fixture", "trainer_export_multi_projs.json") + path_1 = os.path.join( + os.path.join(os.path.dirname(__file__), "..", "..", "resources"), + "fixture", + "trainer_export.json", + ) + path_2 = os.path.join( + os.path.join(os.path.dirname(__file__), "..", "..", "resources"), + "fixture", + "trainer_export_multi_projs.json", + ) with tempfile.NamedTemporaryFile() as f: concat_trainer_exports([path_1, path_2], f.name, True) new_export = json.load(f) @@ -173,23 +213,41 @@ def test_concat_trainer_exports(): def test_concat_trainer_exports_with_duplicated_project_ids(): - path = os.path.join(os.path.join(os.path.dirname(__file__), "..", "..", "resources"), "fixture", "trainer_export.json") + path = os.path.join( + os.path.join(os.path.dirname(__file__), "..", "..", "resources"), + "fixture", + "trainer_export.json", + ) with pytest.raises(AnnotationException) as e: concat_trainer_exports([path, path, path]) assert "Found multiple projects share the same ID:" in str(e.value) def test_concat_trainer_exports_with_recurring_document_ids(): - path = os.path.join(os.path.join(os.path.dirname(__file__), "..", "..", "resources"), "fixture", "trainer_export.json") - another_path = os.path.join(os.path.join(os.path.dirname(__file__), "..", "..", "resources"), "fixture", "trainer_export_multi_projs.json") + path = os.path.join( + os.path.join(os.path.dirname(__file__), "..", "..", "resources"), + "fixture", + "trainer_export.json", + ) + another_path = os.path.join( + os.path.join(os.path.dirname(__file__), "..", "..", "resources"), + "fixture", + "trainer_export_multi_projs.json", + ) with pytest.raises(AnnotationException) as e: concat_trainer_exports([path, another_path], allow_recurring_doc_ids=False) assert str(e.value) == "Found multiple documents share the same ID(s): [3204, 3205]" def test_get_stats_from_trainer_export(): - path = os.path.join(os.path.join(os.path.dirname(__file__), "..", "..", "resources"), "fixture", "trainer_export.json") - cui_counts, cui_unique_counts, cui_ignorance_counts, num_of_docs = get_stats_from_trainer_export(path) + path = os.path.join( + os.path.join(os.path.dirname(__file__), "..", "..", "resources"), + "fixture", + "trainer_export.json", + ) + cui_counts, cui_unique_counts, cui_ignorance_counts, num_of_docs = ( + get_stats_from_trainer_export(path) + ) assert cui_counts == { "C0003864": 2, "C0007222": 1, @@ -207,7 +265,7 @@ def test_get_stats_from_trainer_export(): "C0042029": 4, "C0155626": 2, "C0338614": 1, - "C0878544": 1 + "C0878544": 1, } assert cui_unique_counts == { "C0017168": 1, @@ -226,37 +284,195 @@ def test_get_stats_from_trainer_export(): "C0037284": 2, "C0003864": 1, "C0011849": 1, - "C0338614": 1 + "C0338614": 1, } assert cui_ignorance_counts == {"C0012634": 1, "C0338614": 1} assert num_of_docs == 2 def test_get_stats_from_trainer_export_as_dataframe(): - path = os.path.join(os.path.join(os.path.dirname(__file__), "..", "..", "resources"), "fixture", "trainer_export.json") + path = os.path.join( + os.path.join(os.path.dirname(__file__), "..", "..", "resources"), + "fixture", + "trainer_export.json", + ) result = get_stats_from_trainer_export(path, return_df=True) - assert result["concept"].tolist() == ["C0017168", "C0020538", "C0012634", "C0038454", "C0007787", "C0155626", "C0011860", "C0042029", "C0010068", "C0007222", "C0027051", "C0878544", "C0020473", "C0037284", "C0003864", "C0011849", "C0338614"] + assert result["concept"].tolist() == [ + "C0017168", + "C0020538", + "C0012634", + "C0038454", + "C0007787", + "C0155626", + "C0011860", + "C0042029", + "C0010068", + "C0007222", + "C0027051", + "C0878544", + "C0020473", + "C0037284", + "C0003864", + "C0011849", + "C0338614", + ] assert result["anno_count"].tolist() == [1, 4, 1, 1, 1, 2, 3, 4, 1, 1, 1, 1, 3, 2, 2, 1, 1] - assert result["anno_unique_counts"].tolist() == [1, 1, 1, 1, 1, 1, 3, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1] - assert result["anno_ignorance_counts"].tolist() == [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1] + assert result["anno_unique_counts"].tolist() == [ + 1, + 1, + 1, + 1, + 1, + 1, + 3, + 2, + 1, + 1, + 1, + 1, + 1, + 2, + 1, + 1, + 1, + ] + assert result["anno_ignorance_counts"].tolist() == [ + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + ] def test_get_iaa_scores_per_concept(): - path = os.path.join(os.path.join(os.path.dirname(__file__), "..", "..", "resources"), "fixture", "trainer_export_multi_projs.json") - per_cui_anno_iia_pct, per_cui_anno_cohens_kappa, per_cui_metaanno_iia_pct, per_cui_metaanno_cohens_kappa = get_iaa_scores_per_concept(path, 1, 2) - assert set(per_cui_anno_iia_pct.keys()) == {"C0003864", "C0007222", "C0007787", "C0010068", "C0011849", "C0011860", "C0012634", "C0017168", "C0020473", "C0020538", "C0027051", "C0037284", "C0038454", "C0042029", "C0155626", "C0338614", "C0878544"} - assert set(per_cui_anno_cohens_kappa.keys()) == {"C0003864", "C0007222", "C0007787", "C0010068", "C0011849", "C0011860", "C0012634", "C0017168", "C0020473", "C0020538", "C0027051", "C0037284", "C0038454", "C0042029", "C0155626", "C0338614", "C0878544"} - assert set(per_cui_metaanno_iia_pct.keys()) == {"C0003864", "C0007222", "C0007787", "C0010068", "C0011849", "C0011860", "C0012634", "C0017168", "C0020473", "C0020538", "C0027051", "C0037284", "C0038454", "C0042029", "C0155626", "C0338614", "C0878544"} - assert set(per_cui_metaanno_cohens_kappa.keys()) == {"C0003864", "C0007222", "C0007787", "C0010068", "C0011849", "C0011860", "C0012634", "C0017168", "C0020473", "C0020538", "C0027051", "C0037284", "C0038454", "C0042029", "C0155626", "C0338614", "C0878544"} + path = os.path.join( + os.path.join(os.path.dirname(__file__), "..", "..", "resources"), + "fixture", + "trainer_export_multi_projs.json", + ) + ( + per_cui_anno_iia_pct, + per_cui_anno_cohens_kappa, + per_cui_metaanno_iia_pct, + per_cui_metaanno_cohens_kappa, + ) = get_iaa_scores_per_concept(path, 1, 2) + assert set(per_cui_anno_iia_pct.keys()) == { + "C0003864", + "C0007222", + "C0007787", + "C0010068", + "C0011849", + "C0011860", + "C0012634", + "C0017168", + "C0020473", + "C0020538", + "C0027051", + "C0037284", + "C0038454", + "C0042029", + "C0155626", + "C0338614", + "C0878544", + } + assert set(per_cui_anno_cohens_kappa.keys()) == { + "C0003864", + "C0007222", + "C0007787", + "C0010068", + "C0011849", + "C0011860", + "C0012634", + "C0017168", + "C0020473", + "C0020538", + "C0027051", + "C0037284", + "C0038454", + "C0042029", + "C0155626", + "C0338614", + "C0878544", + } + assert set(per_cui_metaanno_iia_pct.keys()) == { + "C0003864", + "C0007222", + "C0007787", + "C0010068", + "C0011849", + "C0011860", + "C0012634", + "C0017168", + "C0020473", + "C0020538", + "C0027051", + "C0037284", + "C0038454", + "C0042029", + "C0155626", + "C0338614", + "C0878544", + } + assert set(per_cui_metaanno_cohens_kappa.keys()) == { + "C0003864", + "C0007222", + "C0007787", + "C0010068", + "C0011849", + "C0011860", + "C0012634", + "C0017168", + "C0020473", + "C0020538", + "C0027051", + "C0037284", + "C0038454", + "C0042029", + "C0155626", + "C0338614", + "C0878544", + } def test_get_iaa_scores_per_concept_and_return_dataframe(): - path = os.path.join(os.path.join(os.path.dirname(__file__), "..", "..", "resources"), "fixture", - "trainer_export_multi_projs.json") + path = os.path.join( + os.path.join(os.path.dirname(__file__), "..", "..", "resources"), + "fixture", + "trainer_export_multi_projs.json", + ) result = get_iaa_scores_per_concept(path, 1, 2, return_df=True) - assert set(result["concept"]) == {"C0003864", "C0007222", "C0007787", "C0010068", "C0011849", "C0011860", - "C0012634", "C0017168", "C0020473", "C0020538", "C0027051", "C0037284", - "C0038454", "C0042029", "C0155626", "C0338614", "C0878544"} + assert set(result["concept"]) == { + "C0003864", + "C0007222", + "C0007787", + "C0010068", + "C0011849", + "C0011860", + "C0012634", + "C0017168", + "C0020473", + "C0020538", + "C0027051", + "C0037284", + "C0038454", + "C0042029", + "C0155626", + "C0338614", + "C0878544", + } assert len(result["iaa_percentage"]) == 17 assert len(result["cohens_kappa"]) == 17 assert len(result["iaa_percentage_meta"]) == 17 @@ -264,8 +480,17 @@ def test_get_iaa_scores_per_concept_and_return_dataframe(): def test_get_iaa_scores_per_doc(): - path = os.path.join(os.path.join(os.path.dirname(__file__), "..", "..", "resources"), "fixture", "trainer_export_multi_projs.json") - per_doc_anno_iia_pct, per_doc_anno_cohens_kappa, per_doc_metaanno_iia_pct, per_doc_metaanno_cohens_kappa = get_iaa_scores_per_doc(path, 1, 2) + path = os.path.join( + os.path.join(os.path.dirname(__file__), "..", "..", "resources"), + "fixture", + "trainer_export_multi_projs.json", + ) + ( + per_doc_anno_iia_pct, + per_doc_anno_cohens_kappa, + per_doc_metaanno_iia_pct, + per_doc_metaanno_cohens_kappa, + ) = get_iaa_scores_per_doc(path, 1, 2) assert set(per_doc_anno_iia_pct.keys()) == {"3204", "3205"} assert set(per_doc_anno_cohens_kappa.keys()) == {"3204", "3205"} assert set(per_doc_metaanno_iia_pct.keys()) == {"3204", "3205"} @@ -273,7 +498,11 @@ def test_get_iaa_scores_per_doc(): def test_get_iaa_scores_per_doc_and_return_dataframe(): - path = os.path.join(os.path.join(os.path.dirname(__file__), "..", "..", "resources"), "fixture", "trainer_export_multi_projs.json") + path = os.path.join( + os.path.join(os.path.dirname(__file__), "..", "..", "resources"), + "fixture", + "trainer_export_multi_projs.json", + ) result = get_iaa_scores_per_doc(path, 1, 2, return_df=True) assert len(result["doc_id"]) == 2 assert len(result["iaa_percentage"]) == 2 @@ -283,8 +512,17 @@ def test_get_iaa_scores_per_doc_and_return_dataframe(): def test_get_iaa_scores_per_span(): - path = os.path.join(os.path.join(os.path.dirname(__file__), "..", "..", "resources"), "fixture", "trainer_export_multi_projs.json") - per_doc_anno_iia_pct, per_doc_anno_cohens_kappa, per_doc_metaanno_iia_pct, per_doc_metaanno_cohens_kappa = get_iaa_scores_per_span(path, 1, 2) + path = os.path.join( + os.path.join(os.path.dirname(__file__), "..", "..", "resources"), + "fixture", + "trainer_export_multi_projs.json", + ) + ( + per_doc_anno_iia_pct, + per_doc_anno_cohens_kappa, + per_doc_metaanno_iia_pct, + per_doc_metaanno_cohens_kappa, + ) = get_iaa_scores_per_span(path, 1, 2) assert len(per_doc_anno_iia_pct.keys()) == 30 assert len(per_doc_anno_cohens_kappa.keys()) == 30 assert len(per_doc_metaanno_iia_pct.keys()) == 30 @@ -292,7 +530,11 @@ def test_get_iaa_scores_per_span(): def test_get_iaa_scores_per_span_and_return_dataframe(): - path = os.path.join(os.path.join(os.path.dirname(__file__), "..", "..", "resources"), "fixture", "trainer_export_multi_projs.json") + path = os.path.join( + os.path.join(os.path.dirname(__file__), "..", "..", "resources"), + "fixture", + "trainer_export_multi_projs.json", + ) result = get_iaa_scores_per_span(path, 1, 2, return_df=True) assert len(result["doc_id"]) == 30 assert len(result["span_start"]) == 30 diff --git a/tests/app/test_registry.py b/tests/app/test_registry.py index 50aed72..059a414 100644 --- a/tests/app/test_registry.py +++ b/tests/app/test_registry.py @@ -1,10 +1,11 @@ from domain import ModelType from registry import model_service_registry -from model_services.trf_model_deid import TransformersModelDeIdentification + +from model_services.medcat_model_deid import MedCATModelDeIdentification +from model_services.medcat_model_icd10 import MedCATModelIcd10 from model_services.medcat_model_snomed import MedCATModelSnomed from model_services.medcat_model_umls import MedCATModelUmls -from model_services.medcat_model_icd10 import MedCATModelIcd10 -from model_services.medcat_model_deid import MedCATModelDeIdentification +from model_services.trf_model_deid import TransformersModelDeIdentification def test_model_registry(): @@ -13,4 +14,7 @@ def test_model_registry(): assert model_service_registry[ModelType.MEDCAT_ICD10.value] == MedCATModelIcd10 assert model_service_registry[ModelType.MEDCAT_DEID.value] == MedCATModelDeIdentification assert model_service_registry[ModelType.ANONCAT.value] == MedCATModelDeIdentification - assert model_service_registry[ModelType.TRANSFORMERS_DEID.value] == TransformersModelDeIdentification + assert ( + model_service_registry[ModelType.TRANSFORMERS_DEID.value] + == TransformersModelDeIdentification + ) diff --git a/tests/app/test_utils.py b/tests/app/test_utils.py index 9188edd..cc66d33 100644 --- a/tests/app/test_utils.py +++ b/tests/app/test_utils.py @@ -1,26 +1,27 @@ -import os import json +import os import tempfile +from urllib.parse import urlparse + import torch from safetensors.torch import save_file -from urllib.parse import urlparse from utils import ( - get_settings, - get_code_base_uri, annotations_to_entities, - send_gelf_message, + augment_annotations, + breakdown_annotations, + filter_by_concept_ids, + get_code_base_uri, get_func_params_as_dict, + get_hf_pipeline_device_id, + get_settings, + json_denormalize, json_normalize_medcat_entities, json_normalize_trainer_export, - json_denormalize, - filter_by_concept_ids, + non_default_device_is_available, replace_spans_of_concept, - breakdown_annotations, - augment_annotations, safetensors_to_pytorch, - non_default_device_is_available, - get_hf_pipeline_device_id, + send_gelf_message, ) @@ -31,19 +32,23 @@ def test_get_code_base_uri(): def test_annotations_to_entities(): - annotations = [{ - "label_name": "Spinal stenosis", - "label_id": "76107001", - "start": 1, - "end": 15, - }] - expected = [{ - "start": 1, - "end": 15, - "label": "Spinal stenosis", - "kb_id": "76107001", - "kb_url": "http://snomed.info/id/76107001", - }] + annotations = [ + { + "label_name": "Spinal stenosis", + "label_id": "76107001", + "start": 1, + "end": 15, + } + ] + expected = [ + { + "start": 1, + "end": 15, + "label": "Spinal stenosis", + "kb_id": "76107001", + "kb_url": "http://snomed.info/id/76107001", + } + ] assert annotations_to_entities(annotations, "SNOMED model") == expected @@ -62,30 +67,88 @@ def test_send_gelf_message(mocker): def test_get_func_params_as_dict(): def func(arg1, arg2=None, arg3="arg3"): pass + params = get_func_params_as_dict(func) assert params == {"arg2": None, "arg3": "arg3"} def test_json_normalize_medcat_entities(): - medcat_entities_path = os.path.join(os.path.dirname(__file__), "..", "resources", "fixture", "medcat_entities.json") + medcat_entities_path = os.path.join( + os.path.dirname(__file__), "..", "resources", "fixture", "medcat_entities.json" + ) with open(medcat_entities_path, "r") as f: medcat_entities = json.load(f) df = json_normalize_medcat_entities(medcat_entities) assert len(df) == 25 - assert df.columns.tolist() == ["pretty_name", "cui", "type_ids", "types", "source_value", "detected_name", "acc", "context_similarity", "start", "end", "icd10", "ontologies", "snomed", "id", "meta_anns.Presence.value", "meta_anns.Presence.confidence", "meta_anns.Presence.name", "meta_anns.Subject.value", "meta_anns.Subject.confidence", "meta_anns.Subject.name", "meta_anns.Time.value", "meta_anns.Time.confidence", "meta_anns.Time.name"] + assert df.columns.tolist() == [ + "pretty_name", + "cui", + "type_ids", + "types", + "source_value", + "detected_name", + "acc", + "context_similarity", + "start", + "end", + "icd10", + "ontologies", + "snomed", + "id", + "meta_anns.Presence.value", + "meta_anns.Presence.confidence", + "meta_anns.Presence.name", + "meta_anns.Subject.value", + "meta_anns.Subject.confidence", + "meta_anns.Subject.name", + "meta_anns.Time.value", + "meta_anns.Time.confidence", + "meta_anns.Time.name", + ] def test_json_normalize_trainer_export(): - trainer_export_path = os.path.join(os.path.dirname(__file__), "..", "resources", "fixture", "trainer_export.json") + trainer_export_path = os.path.join( + os.path.dirname(__file__), "..", "resources", "fixture", "trainer_export.json" + ) with open(trainer_export_path, "r") as f: trainer_export = json.load(f) df = json_normalize_trainer_export(trainer_export) assert len(df) == 30 - assert df.columns.tolist() == ["id", "user", "cui", "value", "start", "end", "validated", "correct", "deleted", "alternative", "killed", "last_modified", "manually_created", "acc", "meta_anns.Status.name", "meta_anns.Status.value", "meta_anns.Status.acc", "meta_anns.Status.validated", "projects.name", "projects.id", "projects.cuis", "projects.tuis", "projects.documents.id", "projects.documents.name", "projects.documents.text", "projects.documents.last_modified"] + assert df.columns.tolist() == [ + "id", + "user", + "cui", + "value", + "start", + "end", + "validated", + "correct", + "deleted", + "alternative", + "killed", + "last_modified", + "manually_created", + "acc", + "meta_anns.Status.name", + "meta_anns.Status.value", + "meta_anns.Status.acc", + "meta_anns.Status.validated", + "projects.name", + "projects.id", + "projects.cuis", + "projects.tuis", + "projects.documents.id", + "projects.documents.name", + "projects.documents.text", + "projects.documents.last_modified", + ] def test_json_denormalize(): - trainer_export_path = os.path.join(os.path.dirname(__file__), "..", "resources", "fixture", "trainer_export.json") + trainer_export_path = os.path.join( + os.path.dirname(__file__), "..", "resources", "fixture", "trainer_export.json" + ) with open(trainer_export_path, "r") as f: trainer_export = json.load(f) df = json_normalize_trainer_export(trainer_export) @@ -97,7 +160,9 @@ def test_filter_by_concept_ids(): config = get_settings() backup = config.TRAINING_CONCEPT_ID_WHITELIST config.TRAINING_CONCEPT_ID_WHITELIST = "C0017168, C0020538" - trainer_export_path = os.path.join(os.path.dirname(__file__), "..", "resources", "fixture", "trainer_export.json") + trainer_export_path = os.path.join( + os.path.dirname(__file__), "..", "resources", "fixture", "trainer_export.json" + ) with open(trainer_export_path, "r") as f: trainer_export = json.load(f) filtered = filter_by_concept_ids(trainer_export, extra_excluded=["C0020538"]) @@ -111,18 +176,27 @@ def test_filter_by_concept_ids(): def test_replace_spans_of_concept(): def transform(source: str) -> str: return source.upper()[:-7] - trainer_export_path = os.path.join(os.path.dirname(__file__), "..", "resources", "fixture", "trainer_export.json") + + trainer_export_path = os.path.join( + os.path.dirname(__file__), "..", "resources", "fixture", "trainer_export.json" + ) with open(trainer_export_path, "r") as f: trainer_export = json.load(f) result = replace_spans_of_concept(trainer_export, "C0017168", transform) - updated = [(anno["value"], anno["start"], anno["end"]) for anno in result["projects"][0]["documents"][0]["annotations"] if anno["cui"] == "C0017168"] + updated = [ + (anno["value"], anno["start"], anno["end"]) + for anno in result["projects"][0]["documents"][0]["annotations"] + if anno["cui"] == "C0017168" + ] assert updated[0][0] == "GASTROESOPHAGEAL" assert updated[0][1] == 332 assert updated[0][2] == 348 def test_breakdown_annotations(): - trainer_export_path = os.path.join(os.path.dirname(__file__), "..", "resources", "fixture", "trainer_export.json") + trainer_export_path = os.path.join( + os.path.dirname(__file__), "..", "resources", "fixture", "trainer_export.json" + ) with open(trainer_export_path, "r") as f: trainer_export = json.load(f) result = breakdown_annotations(trainer_export, ["C0017168"], " ", "e") @@ -134,7 +208,9 @@ def test_breakdown_annotations(): def test_breakdown_annotations_without_including_delimiter(): - trainer_export_path = os.path.join(os.path.dirname(__file__), "..", "resources", "fixture", "trainer_export.json") + trainer_export_path = os.path.join( + os.path.dirname(__file__), "..", "resources", "fixture", "trainer_export.json" + ) with open(trainer_export_path, "r") as f: trainer_export = json.load(f) result = breakdown_annotations(trainer_export, ["C0017168"], " ", "e", include_delimiter=False) @@ -146,7 +222,9 @@ def test_breakdown_annotations_without_including_delimiter(): def test_augment_annotations(): - trainer_export_path = os.path.join(os.path.dirname(__file__), "..", "resources", "fixture", "trainer_export.json") + trainer_export_path = os.path.join( + os.path.dirname(__file__), "..", "resources", "fixture", "trainer_export.json" + ) with open(trainer_export_path, "r") as f: trainer_export = json.load(f) result = augment_annotations(trainer_export, {"00001": [["HISTORY"]], "00002": [["DISCHARGE"]]}) @@ -164,30 +242,58 @@ def test_augment_annotations(): def test_augment_annotations_case_insensitive(): - trainer_export_path = os.path.join(os.path.dirname(__file__), "..", "resources", "fixture", "trainer_export.json") + trainer_export_path = os.path.join( + os.path.dirname(__file__), "..", "resources", "fixture", "trainer_export.json" + ) with open(trainer_export_path, "r") as f: trainer_export = json.load(f) - result = augment_annotations(trainer_export, { - "00001": [["HiSToRy"]], - "00002": [ - [r"^\d{1,2}\s*$", r"-", r"^\s*\d{1,2}\s*$", r"-", r"^\s*\d{2,4}$"], - [r"^\d{1,2}\s*[.\/]\s*\d{1,2}\s*[.\/]\s*\d{2,4}$"], - [r"^\d{2,4}\s*$", r"-", r"^\s*\d{1,2}\s*$", r"-", r"^\s*\d{1,2}$"], - [r"^\d{2,4}\s*[.\/]\s*\d{1,2}\s*[.\/]\s*\d{1,2}$"], - [r"^\d{1,2}$", r"^[-.\/]$", r"^(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec|January|February|March|April|June|July|August|September|October|November|December)\s*[-.\/]\s*\d{2,4}$"], - [r"^\d{2,4}$", r"^[-.\/]$", r"^(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec|January|February|March|April|June|July|August|September|October|November|December)\s*[-.\/]\s*\d{1,2}$"], - [r"^\d{1,2}\s*$", r"-", r"^\s*\d{4}$"], - [r"^\d{1,2}\s*[\/]\s*\d{4}$"], - [r"^\d{4}\s*$", r"-", r"^\s*\d{1,2}$"], - [r"^\d{4}\s*[\/]\s*\d{1,2}$"], - [r"^(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec|January|February|March|April|June|July|August|September|October|November|December)\s*[-.\/]\s*\d{4}$"], - [r"^(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec|January|February|March|April|June|July|August|September|October|November|December)(\s+\d{1,2})*$", r",", r"^\d{4}$"], - [r"^\d{4}\s*[-.\/]\s*(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec|January|February|March|April|June|July|August|September|October|November|December)$"], - [r"^\d{4}$", r"^(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec|January|February|March|April|June|July|August|September|October|November|December)$"], - [r"^(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec|January|February|March|April|June|July|August|September|October|November|December)$", r"^\d{4}$"], - [r"^(?:19\d\d|20\d\d)$"], - ] - }, case_sensitive=False) + result = augment_annotations( + trainer_export, + { + "00001": [["HiSToRy"]], + "00002": [ + [r"^\d{1,2}\s*$", r"-", r"^\s*\d{1,2}\s*$", r"-", r"^\s*\d{2,4}$"], + [r"^\d{1,2}\s*[.\/]\s*\d{1,2}\s*[.\/]\s*\d{2,4}$"], + [r"^\d{2,4}\s*$", r"-", r"^\s*\d{1,2}\s*$", r"-", r"^\s*\d{1,2}$"], + [r"^\d{2,4}\s*[.\/]\s*\d{1,2}\s*[.\/]\s*\d{1,2}$"], + [ + r"^\d{1,2}$", + r"^[-.\/]$", + r"^(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec|January|February|March|April|June|July|August|September|October|November|December)\s*[-.\/]\s*\d{2,4}$", + ], + [ + r"^\d{2,4}$", + r"^[-.\/]$", + r"^(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec|January|February|March|April|June|July|August|September|October|November|December)\s*[-.\/]\s*\d{1,2}$", + ], + [r"^\d{1,2}\s*$", r"-", r"^\s*\d{4}$"], + [r"^\d{1,2}\s*[\/]\s*\d{4}$"], + [r"^\d{4}\s*$", r"-", r"^\s*\d{1,2}$"], + [r"^\d{4}\s*[\/]\s*\d{1,2}$"], + [ + r"^(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec|January|February|March|April|June|July|August|September|October|November|December)\s*[-.\/]\s*\d{4}$" + ], + [ + r"^(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec|January|February|March|April|June|July|August|September|October|November|December)(\s+\d{1,2})*$", + r",", + r"^\d{4}$", + ], + [ + r"^\d{4}\s*[-.\/]\s*(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec|January|February|March|April|June|July|August|September|October|November|December)$" + ], + [ + r"^\d{4}$", + r"^(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec|January|February|March|April|June|July|August|September|October|November|December)$", + ], + [ + r"^(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec|January|February|March|April|June|July|August|September|October|November|December)$", + r"^\d{4}$", + ], + [r"^(?:19\d\d|20\d\d)$"], + ], + }, + case_sensitive=False, + ) match_count_00001 = 0 match_count_00002 = 0 diff --git a/tests/app/trainers/test_hf_transformer_trainer.py b/tests/app/trainers/test_hf_transformer_trainer.py index 39547e7..10e88fc 100644 --- a/tests/app/trainers/test_hf_transformer_trainer.py +++ b/tests/app/trainers/test_hf_transformer_trainer.py @@ -1,16 +1,22 @@ import os -from unittest.mock import create_autospec, patch, Mock +from unittest.mock import Mock, create_autospec, patch + from config import Settings -from model_services.huggingface_ner_model import HuggingFaceNerModel -from trainers.huggingface_ner_trainer import HuggingFaceNerUnsupervisedTrainer, HuggingFaceNerSupervisedTrainer +from model_services.huggingface_ner_model import HuggingFaceNerModel +from trainers.huggingface_ner_trainer import ( + HuggingFaceNerSupervisedTrainer, + HuggingFaceNerUnsupervisedTrainer, +) model_parent_dir = os.path.join(os.path.dirname(__file__), "..", "..", "resources", "fixture") -model_service = create_autospec(HuggingFaceNerModel, - _config=Settings(), - _model_parent_dir=model_parent_dir, - _enable_trainer=True, - _model_pack_path=os.path.join(model_parent_dir, "model.zip")) +model_service = create_autospec( + HuggingFaceNerModel, + _config=Settings(), + _model_parent_dir=model_parent_dir, + _enable_trainer=True, + _model_pack_path=os.path.join(model_parent_dir, "model.zip"), +) unsupervised_trainer = HuggingFaceNerUnsupervisedTrainer(model_service) unsupervised_trainer.model_name = "unsupervised_trainer" supervised_trainer = HuggingFaceNerSupervisedTrainer(model_service) @@ -30,7 +36,9 @@ def test_deploy_model(): def test_huggingface_ner_unsupervised_trainer(mlflow_fixture): with patch.object(unsupervised_trainer, "run", wraps=unsupervised_trainer.run) as run: unsupervised_trainer._tracker_client = Mock() - unsupervised_trainer._tracker_client.start_tracking = Mock(return_value=("experiment_id", "run_id")) + unsupervised_trainer._tracker_client.start_tracking = Mock( + return_value=("experiment_id", "run_id") + ) with open(os.path.join(data_dir, "sample_texts.json"), "r") as f: unsupervised_trainer.train(f, 1, 1, "training_id", "input_file_name") unsupervised_trainer._tracker_client.start_tracking.assert_called_once() @@ -40,7 +48,9 @@ def test_huggingface_ner_unsupervised_trainer(mlflow_fixture): def test_huggingface_ner_supervised_trainer(mlflow_fixture): with patch.object(supervised_trainer, "run", wraps=supervised_trainer.run) as run: supervised_trainer._tracker_client = Mock() - supervised_trainer._tracker_client.start_tracking = Mock(return_value=("experiment_id", "run_id")) + supervised_trainer._tracker_client.start_tracking = Mock( + return_value=("experiment_id", "run_id") + ) with open(os.path.join(data_dir, "trainer_export.json"), "r") as f: supervised_trainer.train(f, 1, 1, "training_id", "input_file_name") supervised_trainer._tracker_client.end_with_success() @@ -50,9 +60,13 @@ def test_huggingface_ner_supervised_trainer(mlflow_fixture): def test_huggingface_ner_unsupervised_run(mlflow_fixture): with open(os.path.join(data_dir, "sample_texts.json"), "r") as data_file: - HuggingFaceNerUnsupervisedTrainer.run(unsupervised_trainer, {"nepochs": 1}, data_file, 1, "run_id") + HuggingFaceNerUnsupervisedTrainer.run( + unsupervised_trainer, {"nepochs": 1}, data_file, 1, "run_id" + ) def test_huggingface_ner_supervised_run(mlflow_fixture): with open(os.path.join(data_dir, "trainer_export.json"), "r") as data_file: - HuggingFaceNerSupervisedTrainer.run(supervised_trainer, {"nepochs": 1, "print_stats": 1}, data_file, 1, "run_id") + HuggingFaceNerSupervisedTrainer.run( + supervised_trainer, {"nepochs": 1, "print_stats": 1}, data_file, 1, "run_id" + ) diff --git a/tests/app/trainers/test_medcat_deid_trainer.py b/tests/app/trainers/test_medcat_deid_trainer.py index df89207..c82152d 100644 --- a/tests/app/trainers/test_medcat_deid_trainer.py +++ b/tests/app/trainers/test_medcat_deid_trainer.py @@ -1,18 +1,26 @@ import os +from unittest.mock import Mock, create_autospec, patch + import mlflow -from unittest.mock import create_autospec, patch, Mock -from transformers import TrainingArguments, TrainerState, TrainerControl +from transformers import TrainerControl, TrainerState, TrainingArguments + from config import Settings + from model_services.medcat_model_deid import MedCATModelDeIdentification -from trainers.medcat_deid_trainer import MedcatDeIdentificationSupervisedTrainer -from trainers.medcat_deid_trainer import MetricsCallback, LabelCountCallback +from trainers.medcat_deid_trainer import ( + LabelCountCallback, + MedcatDeIdentificationSupervisedTrainer, + MetricsCallback, +) model_parent_dir = os.path.join(os.path.dirname(__file__), "..", "..", "resources") -model_service = create_autospec(MedCATModelDeIdentification, - _config=Settings(), - _model_parent_dir=model_parent_dir, - _enable_trainer=True, - _model_pack_path=os.path.join(model_parent_dir, "model.zip")) +model_service = create_autospec( + MedCATModelDeIdentification, + _config=Settings(), + _model_parent_dir=model_parent_dir, + _enable_trainer=True, + _model_pack_path=os.path.join(model_parent_dir, "model.zip"), +) deid_trainer = MedcatDeIdentificationSupervisedTrainer(model_service) deid_trainer.model_name = "deid_trainer" data_dir = os.path.join(os.path.dirname(__file__), "..", "..", "resources", "fixture") @@ -30,7 +38,9 @@ def test_medcat_deid_supervised_trainer(mlflow_fixture): def test_medcat_deid_supervised_run(mlflow_fixture): with open(os.path.join(data_dir, "trainer_export.json"), "r") as data_file: - MedcatDeIdentificationSupervisedTrainer.run(deid_trainer, {"nepochs": 1, "print_stats": 1}, data_file, 1, "run_id") + MedcatDeIdentificationSupervisedTrainer.run( + deid_trainer, {"nepochs": 1, "print_stats": 1}, data_file, 1, "run_id" + ) def test_trainer_callbacks(mlflow_fixture): diff --git a/tests/app/trainers/test_medcat_trainer.py b/tests/app/trainers/test_medcat_trainer.py index 0ed302b..988f072 100644 --- a/tests/app/trainers/test_medcat_trainer.py +++ b/tests/app/trainers/test_medcat_trainer.py @@ -1,17 +1,21 @@ import os -from unittest.mock import create_autospec, patch, Mock +from unittest.mock import Mock, create_autospec, patch + from medcat.config import General + from config import Settings + from model_services.medcat_model import MedCATModel from trainers.medcat_trainer import MedcatSupervisedTrainer, MedcatUnsupervisedTrainer - model_parent_dir = os.path.join(os.path.dirname(__file__), "..", "..", "resources", "fixture") -model_service = create_autospec(MedCATModel, - _config=Settings(), - _model_parent_dir=model_parent_dir, - _enable_trainer=True, - _model_pack_path=os.path.join(model_parent_dir, "model.zip")) +model_service = create_autospec( + MedCATModel, + _config=Settings(), + _model_parent_dir=model_parent_dir, + _enable_trainer=True, + _model_pack_path=os.path.join(model_parent_dir, "model.zip"), +) supervised_trainer = MedcatSupervisedTrainer(model_service) supervised_trainer.model_name = "supervised_trainer" unsupervised_trainer = MedcatUnsupervisedTrainer(model_service) @@ -47,7 +51,9 @@ def test_save_model_pack(): def test_medcat_supervised_trainer(mlflow_fixture): with patch.object(supervised_trainer, "run", wraps=supervised_trainer.run) as run: supervised_trainer._tracker_client = Mock() - supervised_trainer._tracker_client.start_tracking = Mock(return_value=("experiment_id", "run_id")) + supervised_trainer._tracker_client.start_tracking = Mock( + return_value=("experiment_id", "run_id") + ) with open(os.path.join(data_dir, "trainer_export.json"), "r") as f: supervised_trainer.train(f, 1, 1, "training_id", "input_file_name") supervised_trainer._tracker_client.end_with_success() @@ -58,7 +64,9 @@ def test_medcat_supervised_trainer(mlflow_fixture): def test_medcat_unsupervised_trainer(mlflow_fixture): with patch.object(unsupervised_trainer, "run", wraps=unsupervised_trainer.run) as run: unsupervised_trainer._tracker_client = Mock() - unsupervised_trainer._tracker_client.start_tracking = Mock(return_value=("experiment_id", "run_id")) + unsupervised_trainer._tracker_client.start_tracking = Mock( + return_value=("experiment_id", "run_id") + ) with open(os.path.join(data_dir, "sample_texts.json"), "r") as f: unsupervised_trainer.train(f, 1, 1, "training_id", "input_file_name") unsupervised_trainer._tracker_client.start_tracking.assert_called_once() @@ -67,9 +75,13 @@ def test_medcat_unsupervised_trainer(mlflow_fixture): def test_medcat_supervised_run(mlflow_fixture): with open(os.path.join(data_dir, "trainer_export.json"), "r") as data_file: - MedcatSupervisedTrainer.run(supervised_trainer, {"nepochs": 1, "print_stats": 1}, data_file, 1, "run_id") + MedcatSupervisedTrainer.run( + supervised_trainer, {"nepochs": 1, "print_stats": 1}, data_file, 1, "run_id" + ) def test_medcat_unsupervised_run(mlflow_fixture): with open(os.path.join(data_dir, "sample_texts.json"), "r") as data_file: - MedcatUnsupervisedTrainer.run(unsupervised_trainer, {"nepochs": 1, "print_stats": 1}, data_file, 1, "run_id") + MedcatUnsupervisedTrainer.run( + unsupervised_trainer, {"nepochs": 1, "print_stats": 1}, data_file, 1, "run_id" + ) diff --git a/tests/app/trainers/test_metacat_trainer.py b/tests/app/trainers/test_metacat_trainer.py index fa645b6..6767000 100644 --- a/tests/app/trainers/test_metacat_trainer.py +++ b/tests/app/trainers/test_metacat_trainer.py @@ -1,16 +1,21 @@ import os -from unittest.mock import create_autospec, patch, Mock +from unittest.mock import Mock, create_autospec, patch + from medcat.config_meta_cat import General, Model, Train + from config import Settings + from model_services.medcat_model import MedCATModel from trainers.metacat_trainer import MetacatTrainer model_parent_dir = os.path.join(os.path.dirname(__file__), "..", "..", "resources", "fixture") -model_service = create_autospec(MedCATModel, - _config=Settings(), - _model_parent_dir=model_parent_dir, - _enable_trainer=True, - _model_pack_path=os.path.join(model_parent_dir, "model.zip")) +model_service = create_autospec( + MedCATModel, + _config=Settings(), + _model_parent_dir=model_parent_dir, + _enable_trainer=True, + _model_pack_path=os.path.join(model_parent_dir, "model.zip"), +) metacat_trainer = MetacatTrainer(model_service) metacat_trainer.model_name = "metacat_trainer" @@ -45,7 +50,9 @@ def test_save_model_pack(): def test_metacat_trainer(mlflow_fixture): with patch.object(metacat_trainer, "run", wraps=metacat_trainer.run) as run: metacat_trainer._tracker_client = Mock() - metacat_trainer._tracker_client.start_tracking = Mock(return_value=("experiment_id", "run_id")) + metacat_trainer._tracker_client.start_tracking = Mock( + return_value=("experiment_id", "run_id") + ) with open(os.path.join(data_dir, "trainer_export.json"), "r") as f: metacat_trainer.train(f, 1, 1, "training_id", "input_file_name") metacat_trainer._tracker_client.start_tracking.assert_called_once() @@ -54,4 +61,6 @@ def test_metacat_trainer(mlflow_fixture): def test_metacat_supervised_run(mlflow_fixture): with open(os.path.join(data_dir, "trainer_export.json"), "r") as data_file: - MetacatTrainer.run(metacat_trainer, {"nepochs": 1, "print_stats": 1}, data_file, 1, "run_id") + MetacatTrainer.run( + metacat_trainer, {"nepochs": 1, "print_stats": 1}, data_file, 1, "run_id" + ) diff --git a/tests/integration/features/serving.feature b/tests/integration/features/serving.feature index 6b40de2..8649ab5 100644 --- a/tests/integration/features/serving.feature +++ b/tests/integration/features/serving.feature @@ -59,4 +59,4 @@ Feature: When I send a POST request with the following content | endpoint | data | content_type | | /preview | Spinal stenosis | text/plain | - Then the response should contain a preview page \ No newline at end of file + Then the response should contain a preview page diff --git a/tests/integration/features/serving_stream.feature b/tests/integration/features/serving_stream.feature index a67eac7..fbc6dda 100644 --- a/tests/integration/features/serving_stream.feature +++ b/tests/integration/features/serving_stream.feature @@ -11,4 +11,4 @@ Scenario: Stream entities extracted from free texts Scenario: Interactively extract entities from free texts Given CMS stream app is up and running When I send a piece of text to the WS endpoint - Then the response should contain annotated spans \ No newline at end of file + Then the response should contain annotated spans diff --git a/tests/integration/helper.py b/tests/integration/helper.py index cbb5015..4d09101 100644 --- a/tests/integration/helper.py +++ b/tests/integration/helper.py @@ -1,27 +1,21 @@ import asyncio from functools import partial, wraps + from pytest_bdd import parsers def parse_data_table(text, orient="dict"): parsed_text = [ - [x.strip() for x in line.split("|")] - for line in [x.strip("|") for x in text.splitlines()] + [x.strip() for x in line.split("|")] for line in [x.strip("|") for x in text.splitlines()] ] header, *data = parsed_text if orient == "dict": - return [ - dict(zip(header, line)) - for line in data - ] + return [dict(zip(header, line)) for line in data] else: if orient == "columns": - data = [ - [line[i] for line in data] - for i in range(len(header)) - ] + data = [[line[i] for line in data] for i in range(len(header))] return header, data diff --git a/tests/integration/test_steps.py b/tests/integration/test_steps.py index 0d1b176..6a54066 100644 --- a/tests/integration/test_steps.py +++ b/tests/integration/test_steps.py @@ -1,15 +1,18 @@ import json -import httpx -import api.globals as cms_globals -from pytest_bdd import scenarios, given, when, then, parsers from unittest.mock import create_autospec + +import httpx from fastapi.testclient import TestClient -from management.model_manager import ModelManager -from model_services.medcat_model import MedCATModel +from pytest_bdd import given, parsers, scenarios, then, when + from domain import ModelCard, ModelType -from api.api import get_model_server, get_stream_server +from helper import async_to_sync, data_table from utils import get_settings -from helper import data_table, async_to_sync + +import api.globals as cms_globals +from api.api import get_model_server, get_stream_server +from management.model_manager import ModelManager +from model_services.medcat_model import MedCATModel scenarios("features/serving.feature") scenarios("features/serving_stream.feature") @@ -27,11 +30,7 @@ def cms_is_running(): "end": 15, "accuracy": 1.0, "meta_anns": { - "Status": { - "value": "Affirmed", - "confidence": 0.9999833106994629, - "name": "Status" - } + "Status": {"value": "Affirmed", "confidence": 0.9999833106994629, "name": "Status"} }, } model_service.annotate.return_value = [single_annotation] @@ -40,12 +39,14 @@ def cms_is_running(): [single_annotation], ] model_service.batch_annotate.return_value = annotations_list - model_card = ModelCard.parse_obj({ - "api_version": "0.0.1", - "model_description": "medcat_model_description", - "model_type": ModelType.MEDCAT_SNOMED, - "model_card": None, - }) + model_card = ModelCard.parse_obj( + { + "api_version": "0.0.1", + "model_description": "medcat_model_description", + "model_type": ModelType.MEDCAT_SNOMED, + "model_card": None, + } + ) model_service.info.return_value = model_card model_manager = ModelManager(None, None) model_manager.model_service = model_service @@ -72,11 +73,7 @@ def cms_stream_is_running(): "end": 15, "accuracy": 1.0, "meta_anns": { - "Status": { - "value": "Affirmed", - "confidence": 0.9999833106994629, - "name": "Status" - } + "Status": {"value": "Affirmed", "confidence": 0.9999833106994629, "name": "Status"} }, } model_service.async_annotate.return_value = [single_annotation] @@ -109,9 +106,15 @@ def check_status_code(context, body, status_code): assert context["response"].status_code == status_code -@when(data_table("I send a POST request with the following content", fixture="request", orient="dict")) +@when( + data_table("I send a POST request with the following content", fixture="request", orient="dict") +) def send_post_request(context, request): - context["response"] = context["client"].post(request[0]["endpoint"], data=request[0]["data"].replace("\\n", "\n"), headers={"Content-Type": request[0]["content_type"]}) + context["response"] = context["client"].post( + request[0]["endpoint"], + data=request[0]["data"].replace("\\n", "\n"), + headers={"Content-Type": request[0]["content_type"]}, + ) @then("the response should contain json lines") @@ -125,14 +128,8 @@ def check_response_jsonl(context): @then("the response should contain bulk annotations") def check_response_bulk(context): assert context["response"].json() == [ - { - "text": "Spinal stenosis", - "annotations": [context["single_annotation"]] - }, - { - "text": "Spinal stenosis", - "annotations": [context["single_annotation"]] - }, + {"text": "Spinal stenosis", "annotations": [context["single_annotation"]]}, + {"text": "Spinal stenosis", "annotations": [context["single_annotation"]]}, ] @@ -147,13 +144,19 @@ def check_response_previewed(context): assert context["response"].headers["Content-Type"] == "application/octet-stream" -@when(data_table("I send an async POST request with the following content", fixture="request", orient="dict")) +@when( + data_table( + "I send an async POST request with the following content", fixture="request", orient="dict" + ) +) @async_to_sync async def send_async_post_request(context_stream, request): async with httpx.AsyncClient(app=context_stream["app"], base_url="http://test") as ac: - context_stream["response"] = await ac.post(request[0]["endpoint"], - data=request[0]["data"].replace("\\n", "\n").encode("utf-8"), - headers={"Content-Type": request[0]["content_type"]}) + context_stream["response"] = await ac.post( + request[0]["endpoint"], + data=request[0]["data"].replace("\\n", "\n").encode("utf-8"), + headers={"Content-Type": request[0]["content_type"]}, + ) @then("the response should contain annotation stream") @@ -163,7 +166,10 @@ async def check_response_stream(context_stream): jsonlines = b"" async for chunk in context_stream["response"].aiter_bytes(): jsonlines += chunk - assert json.loads(jsonlines.decode("utf-8").splitlines()[-1]) == {"doc_name": "doc2", **context_stream["single_annotation"]} + assert json.loads(jsonlines.decode("utf-8").splitlines()[-1]) == { + "doc_name": "doc2", + **context_stream["single_annotation"], + } @when("I send a piece of text to the WS endpoint") diff --git a/tests/load/data/sample_texts.json b/tests/load/data/sample_texts.json index 9dd48af..bf2a34f 100644 --- a/tests/load/data/sample_texts.json +++ b/tests/load/data/sample_texts.json @@ -9,4 +9,4 @@ "Description: Intracerebral hemorrhage (very acute clinical changes occurred immediately).\nCC: Left hand numbness on presentation; then developed lethargy later that day.\nHX: On the day of presentation, this 72 y/o RHM suddenly developed generalized weakness and lightheadedness, and could not rise from a chair. Four hours later he experienced sudden left hand numbness lasting two hours. There were no other associated symptoms except for the generalized weakness and lightheadedness. He denied vertigo.\nHe had been experiencing falling spells without associated LOC up to several times a month for the past year.\nMEDS: procardia SR, Lasix, Ecotrin, KCL, Digoxin, Colace, Coumadin.\nPMH: 1)8/92 evaluation for presyncope (Echocardiogram showed: AV fibrosis/calcification, AV stenosis/insufficiency, MV stenosis with annular calcification and regurgitation, moderate TR, Decreased LV systolic function, severe LAE. MRI brain: focal areas of increased T2 signal in the left cerebellum and in the brainstem probably representing microvascular ischemic disease. IVG (MUGA scan)revealed: global hypokinesis of the LV and biventricular dysfunction, RV ejection Fx 45% and LV ejection Fx 39%. He was subsequently placed on coumadin severe valvular heart disease), 2)HTN, 3)Rheumatic fever and heart disease, 4)COPD, 5)ETOH abuse, 6)colonic polyps, 7)CAD, 8)CHF, 9)Appendectomy, 10)Junctional tachycardia.", "Description: Intracerebral hemorrhage (very acute clinical changes occurred immediately).\nCC: Left hand numbness on presentation; then developed lethargy later that day.\nHX: On the day of presentation, this 72 y/o RHM suddenly developed generalized weakness and lightheadedness, and could not rise from a chair. Four hours later he experienced sudden left hand numbness lasting two hours. There were no other associated symptoms except for the generalized weakness and lightheadedness. He denied vertigo.\nHe had been experiencing falling spells without associated LOC up to several times a month for the past year.\nMEDS: procardia SR, Lasix, Ecotrin, KCL, Digoxin, Colace, Coumadin.\nPMH: 1)8/92 evaluation for presyncope (Echocardiogram showed: AV fibrosis/calcification, AV stenosis/insufficiency, MV stenosis with annular calcification and regurgitation, moderate TR, Decreased LV systolic function, severe LAE. MRI brain: focal areas of increased T2 signal in the left cerebellum and in the brainstem probably representing microvascular ischemic disease. IVG (MUGA scan)revealed: global hypokinesis of the LV and biventricular dysfunction, RV ejection Fx 45% and LV ejection Fx 39%. He was subsequently placed on coumadin severe valvular heart disease), 2)HTN, 3)Rheumatic fever and heart disease, 4)COPD, 5)ETOH abuse, 6)colonic polyps, 7)CAD, 8)CHF, 9)Appendectomy, 10)Junctional tachycardia.", "Description: Intracerebral hemorrhage (very acute clinical changes occurred immediately).\nCC: Left hand numbness on presentation; then developed lethargy later that day.\nHX: On the day of presentation, this 72 y/o RHM suddenly developed generalized weakness and lightheadedness, and could not rise from a chair. Four hours later he experienced sudden left hand numbness lasting two hours. There were no other associated symptoms except for the generalized weakness and lightheadedness. He denied vertigo.\nHe had been experiencing falling spells without associated LOC up to several times a month for the past year.\nMEDS: procardia SR, Lasix, Ecotrin, KCL, Digoxin, Colace, Coumadin.\nPMH: 1)8/92 evaluation for presyncope (Echocardiogram showed: AV fibrosis/calcification, AV stenosis/insufficiency, MV stenosis with annular calcification and regurgitation, moderate TR, Decreased LV systolic function, severe LAE. MRI brain: focal areas of increased T2 signal in the left cerebellum and in the brainstem probably representing microvascular ischemic disease. IVG (MUGA scan)revealed: global hypokinesis of the LV and biventricular dysfunction, RV ejection Fx 45% and LV ejection Fx 39%. He was subsequently placed on coumadin severe valvular heart disease), 2)HTN, 3)Rheumatic fever and heart disease, 4)COPD, 5)ETOH abuse, 6)colonic polyps, 7)CAD, 8)CHF, 9)Appendectomy, 10)Junctional tachycardia." -] \ No newline at end of file +] diff --git a/tests/load/data/trainer_export.json b/tests/load/data/trainer_export.json index b3933b5..e5329a5 100644 --- a/tests/load/data/trainer_export.json +++ b/tests/load/data/trainer_export.json @@ -745,4 +745,4 @@ ] } ] -} \ No newline at end of file +} diff --git a/tests/load/docker-compose.yml b/tests/load/docker-compose.yml index 9198471..a7b2ca9 100644 --- a/tests/load/docker-compose.yml +++ b/tests/load/docker-compose.yml @@ -44,4 +44,4 @@ services: networks: cogstack-model-serve_cms: - external: true \ No newline at end of file + external: true diff --git a/tests/load/main/locustfile.py b/tests/load/main/locustfile.py index 9f43617..0133dfc 100644 --- a/tests/load/main/locustfile.py +++ b/tests/load/main/locustfile.py @@ -1,3 +1,3 @@ -from process import Process # noqa -from process_bulk import ProcessBulk # noqa -from other import Other # noqa \ No newline at end of file +from process import Process # noqa +from process_bulk import ProcessBulk # noqa +from other import Other # noqa diff --git a/tests/load/main/other.py b/tests/load/main/other.py index b9f14cd..39726f9 100644 --- a/tests/load/main/other.py +++ b/tests/load/main/other.py @@ -1,15 +1,14 @@ import os -from locust import HttpUser, task, constant_throughput + +from locust import HttpUser, constant_throughput, task CMS_BASE_URL = os.environ["CMS_BASE_URL"] class Other(HttpUser): - wait_time = constant_throughput(1) - def on_start(self): - ... + def on_start(self): ... @task def info(self): @@ -17,10 +16,20 @@ def info(self): @task def train_unsupervised(self): - with open(os.path.join(os.path.dirname(__file__), "..", "data", "sample_texts.json"), "r") as file: - self.client.post(f"{CMS_BASE_URL}/train_unsupervised?log_frequency=1000", files={"training_data": file}) + with open( + os.path.join(os.path.dirname(__file__), "..", "data", "sample_texts.json"), "r" + ) as file: + self.client.post( + f"{CMS_BASE_URL}/train_unsupervised?log_frequency=1000", + files={"training_data": file}, + ) @task def train_supervised(self): - with open(os.path.join(os.path.dirname(__file__), "..", "data", "trainer_export.json"), "r") as file: - self.client.post(f"{CMS_BASE_URL}/train_supervised?epochs=1&log_frequency=1", files={"trainer_export": file}) + with open( + os.path.join(os.path.dirname(__file__), "..", "data", "trainer_export.json"), "r" + ) as file: + self.client.post( + f"{CMS_BASE_URL}/train_supervised?epochs=1&log_frequency=1", + files={"trainer_export": file}, + ) diff --git a/tests/load/main/process.py b/tests/load/main/process.py index 3891144..b6ac27f 100644 --- a/tests/load/main/process.py +++ b/tests/load/main/process.py @@ -1,23 +1,25 @@ import os + import ijson -from locust import HttpUser, task, constant_throughput +from locust import HttpUser, constant_throughput, task CMS_BASE_URL = os.environ["CMS_BASE_URL"] class Process(HttpUser): - wait_time = constant_throughput(1) - def on_start(self): - ... + def on_start(self): ... - def on_stop(self): - ... + def on_stop(self): ... @task def process(self): - with open(os.path.join(os.path.dirname(__file__), "..", "data", "sample_texts.json"), "r") as file: + with open( + os.path.join(os.path.dirname(__file__), "..", "data", "sample_texts.json"), "r" + ) as file: texts = ijson.items(file, "item") for text in texts: - self.client.post(f"{CMS_BASE_URL}/process", headers={"Content-Type": "text/plain"}, data=text) + self.client.post( + f"{CMS_BASE_URL}/process", headers={"Content-Type": "text/plain"}, data=text + ) diff --git a/tests/load/main/process_bulk.py b/tests/load/main/process_bulk.py index 6d08345..8df5edf 100644 --- a/tests/load/main/process_bulk.py +++ b/tests/load/main/process_bulk.py @@ -1,35 +1,42 @@ -import os import json +import os + import ijson -from locust import HttpUser, task, constant_throughput +from locust import HttpUser, constant_throughput, task CMS_BASE_URL = os.environ["CMS_BASE_URL"] class ProcessBulk(HttpUser): - num_of_doc_per_call = 10 - wait_time = constant_throughput(num_of_doc_per_call*1.5) + wait_time = constant_throughput(num_of_doc_per_call * 1.5) - def on_start(self): - ... + def on_start(self): ... - def on_stop(self): - ... + def on_stop(self): ... @task def process_bulk(self): - - with open(os.path.join(os.path.dirname(__file__), "..", "data", "sample_texts.json"), "r") as file: + with open( + os.path.join(os.path.dirname(__file__), "..", "data", "sample_texts.json"), "r" + ) as file: batch = [] texts = ijson.items(file, "item") for text in texts: if len(batch) < ProcessBulk.num_of_doc_per_call: batch.append(text) else: - self.client.post(f"{CMS_BASE_URL}/process_bulk", headers={"Content-Type": "application/json"}, data=json.dumps(batch)) + self.client.post( + f"{CMS_BASE_URL}/process_bulk", + headers={"Content-Type": "application/json"}, + data=json.dumps(batch), + ) batch.clear() batch.append(text) if batch: - self.client.post(f"{CMS_BASE_URL}/process_bulk", headers={"Content-Type": "application/json"}, data=json.dumps(batch)) + self.client.post( + f"{CMS_BASE_URL}/process_bulk", + headers={"Content-Type": "application/json"}, + data=json.dumps(batch), + ) batch.clear() diff --git a/tests/load/requirements.txt b/tests/load/requirements.txt index d2e32f4..4852a48 100644 --- a/tests/load/requirements.txt +++ b/tests/load/requirements.txt @@ -1 +1 @@ -ijson~=3.1.4 \ No newline at end of file +ijson~=3.1.4 diff --git a/tests/resources/fixture/another_trainer_export.json b/tests/resources/fixture/another_trainer_export.json index 8e92289..20b8981 100644 --- a/tests/resources/fixture/another_trainer_export.json +++ b/tests/resources/fixture/another_trainer_export.json @@ -745,4 +745,4 @@ ] } ] -} \ No newline at end of file +} diff --git a/tests/resources/fixture/medcat_entities.json b/tests/resources/fixture/medcat_entities.json index 977e6c4..bad0f39 100644 --- a/tests/resources/fixture/medcat_entities.json +++ b/tests/resources/fixture/medcat_entities.json @@ -893,4 +893,4 @@ } }, "tokens": [] -} \ No newline at end of file +} diff --git a/tests/resources/fixture/public_key.pem b/tests/resources/fixture/public_key.pem new file mode 100644 index 0000000..e7ab0d6 --- /dev/null +++ b/tests/resources/fixture/public_key.pem @@ -0,0 +1,3 @@ +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA3ITkTP8Tm/5FygcwY2EQ7LgVsuCF0OH7psUqvlXnOPNCfX86CobHBiSFjG9o5ZeajPtTXaf1thUodgpJZVZSqpVTXwGKo8r0COMO87IcwYigkZZgG/WmZgoZART+AA0+JvjFGxflJAxSv7puGlf82E+u5Wz2psLBSDO5qrnmaDZTvPh5eX84cocahVVI7X09/kI+sZiKauM69yoy1bdx16YIIeNm0M9qqS3tTrjouQiJfZ8jUKSZ44Na/81LMVw5O46+5GvwD+OsR43kQ0TexMwgtHxQQsiXLWHCDNy2ZzkzukDYRwA3V2lwVjtQN0WjxHg24BTBDBM+v7iQ7cbweQIDAQAB +-----END PUBLIC KEY----- diff --git a/tests/resources/fixture/sample_text.txt b/tests/resources/fixture/sample_text.txt index 80cf4af..5ec4d03 100644 --- a/tests/resources/fixture/sample_text.txt +++ b/tests/resources/fixture/sample_text.txt @@ -7,4 +7,4 @@ He had been experiencing falling spells without associated LOC up to several tim MEDS: procardia SR, Lasix, Ecotrin, KCL, Digoxin, Colace, Coumadin. -PMH: 1)8/92 evaluation for presyncope (Echocardiogram showed: AV fibrosis/calcification, AV stenosis/insufficiency, MV stenosis with annular calcification and regurgitation, moderate TR, Decreased LV systolic function, severe LAE. MRI brain: focal areas of increased T2 signal in the left cerebellum and in the brainstem probably representing microvascular ischemic disease. IVG (MUGA scan)revealed: global hypokinesis of the LV and biventricular dysfunction, RV ejection Fx 45% and LV ejection Fx 39%. He was subsequently placed on coumadin severe valvular heart disease), 2)HTN, 3)Rheumatic fever and heart disease, 4)COPD, 5)ETOH abuse, 6)colonic polyps, 7)CAD, 8)CHF, 9)Appendectomy, 10)Junctional tachycardia. \ No newline at end of file +PMH: 1)8/92 evaluation for presyncope (Echocardiogram showed: AV fibrosis/calcification, AV stenosis/insufficiency, MV stenosis with annular calcification and regurgitation, moderate TR, Decreased LV systolic function, severe LAE. MRI brain: focal areas of increased T2 signal in the left cerebellum and in the brainstem probably representing microvascular ischemic disease. IVG (MUGA scan)revealed: global hypokinesis of the LV and biventricular dysfunction, RV ejection Fx 45% and LV ejection Fx 39%. He was subsequently placed on coumadin severe valvular heart disease), 2)HTN, 3)Rheumatic fever and heart disease, 4)COPD, 5)ETOH abuse, 6)colonic polyps, 7)CAD, 8)CHF, 9)Appendectomy, 10)Junctional tachycardia. diff --git a/tests/resources/fixture/sample_texts.json b/tests/resources/fixture/sample_texts.json index 0430ed1..865b33c 100644 --- a/tests/resources/fixture/sample_texts.json +++ b/tests/resources/fixture/sample_texts.json @@ -14,4 +14,4 @@ "Description: Intracerebral hemorrhage (very acute clinical changes occurred immediately).\nCC: Left hand numbness on presentation; then developed lethargy later that day.\nHX: On the day of presentation, this 72 y/o RHM suddenly developed generalized weakness and lightheadedness, and could not rise from a chair. Four hours later he experienced sudden left hand numbness lasting two hours. There were no other associated symptoms except for the generalized weakness and lightheadedness. He denied vertigo.\nHe had been experiencing falling spells without associated LOC up to several times a month for the past year.\nMEDS: procardia SR, Lasix, Ecotrin, KCL, Digoxin, Colace, Coumadin.\nPMH: 1)8/92 evaluation for presyncope (Echocardiogram showed: AV fibrosis/calcification, AV stenosis/insufficiency, MV stenosis with annular calcification and regurgitation, moderate TR, Decreased LV systolic function, severe LAE. MRI brain: focal areas of increased T2 signal in the left cerebellum and in the brainstem probably representing microvascular ischemic disease. IVG (MUGA scan)revealed: global hypokinesis of the LV and biventricular dysfunction, RV ejection Fx 45% and LV ejection Fx 39%. He was subsequently placed on coumadin severe valvular heart disease), 2)HTN, 3)Rheumatic fever and heart disease, 4)COPD, 5)ETOH abuse, 6)colonic polyps, 7)CAD, 8)CHF, 9)Appendectomy, 10)Junctional tachycardia.", "Description: Intracerebral hemorrhage (very acute clinical changes occurred immediately).\nCC: Left hand numbness on presentation; then developed lethargy later that day.\nHX: On the day of presentation, this 72 y/o RHM suddenly developed generalized weakness and lightheadedness, and could not rise from a chair. Four hours later he experienced sudden left hand numbness lasting two hours. There were no other associated symptoms except for the generalized weakness and lightheadedness. He denied vertigo.\nHe had been experiencing falling spells without associated LOC up to several times a month for the past year.\nMEDS: procardia SR, Lasix, Ecotrin, KCL, Digoxin, Colace, Coumadin.\nPMH: 1)8/92 evaluation for presyncope (Echocardiogram showed: AV fibrosis/calcification, AV stenosis/insufficiency, MV stenosis with annular calcification and regurgitation, moderate TR, Decreased LV systolic function, severe LAE. MRI brain: focal areas of increased T2 signal in the left cerebellum and in the brainstem probably representing microvascular ischemic disease. IVG (MUGA scan)revealed: global hypokinesis of the LV and biventricular dysfunction, RV ejection Fx 45% and LV ejection Fx 39%. He was subsequently placed on coumadin severe valvular heart disease), 2)HTN, 3)Rheumatic fever and heart disease, 4)COPD, 5)ETOH abuse, 6)colonic polyps, 7)CAD, 8)CHF, 9)Appendectomy, 10)Junctional tachycardia.", "Description: Intracerebral hemorrhage (very acute clinical changes occurred immediately).\nCC: Left hand numbness on presentation; then developed lethargy later that day.\nHX: On the day of presentation, this 72 y/o RHM suddenly developed generalized weakness and lightheadedness, and could not rise from a chair. Four hours later he experienced sudden left hand numbness lasting two hours. There were no other associated symptoms except for the generalized weakness and lightheadedness. He denied vertigo.\nHe had been experiencing falling spells without associated LOC up to several times a month for the past year.\nMEDS: procardia SR, Lasix, Ecotrin, KCL, Digoxin, Colace, Coumadin.\nPMH: 1)8/92 evaluation for presyncope (Echocardiogram showed: AV fibrosis/calcification, AV stenosis/insufficiency, MV stenosis with annular calcification and regurgitation, moderate TR, Decreased LV systolic function, severe LAE. MRI brain: focal areas of increased T2 signal in the left cerebellum and in the brainstem probably representing microvascular ischemic disease. IVG (MUGA scan)revealed: global hypokinesis of the LV and biventricular dysfunction, RV ejection Fx 45% and LV ejection Fx 39%. He was subsequently placed on coumadin severe valvular heart disease), 2)HTN, 3)Rheumatic fever and heart disease, 4)COPD, 5)ETOH abuse, 6)colonic polyps, 7)CAD, 8)CHF, 9)Appendectomy, 10)Junctional tachycardia." -] \ No newline at end of file +] diff --git a/tests/resources/fixture/trainer_export.json b/tests/resources/fixture/trainer_export.json index bf3fa2d..bcea1fb 100644 --- a/tests/resources/fixture/trainer_export.json +++ b/tests/resources/fixture/trainer_export.json @@ -745,4 +745,4 @@ ] } ] -} \ No newline at end of file +} diff --git a/tests/resources/fixture/trainer_export_multi_projs.json b/tests/resources/fixture/trainer_export_multi_projs.json index e5eae7e..ff1adb9 100644 --- a/tests/resources/fixture/trainer_export_multi_projs.json +++ b/tests/resources/fixture/trainer_export_multi_projs.json @@ -1487,4 +1487,4 @@ ] } ] -} \ No newline at end of file +}