Skip to content

Commit

Permalink
feat: Allow passing a tracking ID for API requests with side-effects (#2
Browse files Browse the repository at this point in the history
)

Extend API to accept a tracking ID as an optional query parameter,
allowing upstream systems to track training requests. Validate
received IDs to ensure they're alphanumeric strings of length 1-256,
following MLflow's internal run ID validation model. Extend serving
tests to check that the ID (if provided) is included in the API's response.

Signed-off-by: Phoevos Kalemkeris <phoevos.kalemkeris@ucl.ac.uk>
  • Loading branch information
phoevos authored Dec 18, 2024
1 parent 6d89586 commit 5d94c50
Show file tree
Hide file tree
Showing 9 changed files with 223 additions and 18 deletions.
19 changes: 19 additions & 0 deletions app/api/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
import logging
import re
from typing import Union
from typing_extensions import Annotated

from fastapi import HTTPException, Query
from starlette.status import HTTP_400_BAD_REQUEST

from typing import Optional
from config import Settings
from registry import model_service_registry
from model_services.base import AbstractModelService
from management.model_manager import ModelManager

TRACKING_ID_REGEX = re.compile(r"^[a-zA-Z0-9][\w\-]{0,255}$")

logger = logging.getLogger("cms")


Expand Down Expand Up @@ -45,3 +53,14 @@ def __init__(self, model_service: AbstractModelService) -> None:

def __call__(self) -> ModelManager:
return self._model_manager


def validate_tracking_id(
tracking_id: Annotated[Union[str, None], Query(description="The tracking ID of the requested task")] = None,
) -> Union[str, None]:
if tracking_id is not None and TRACKING_ID_REGEX.match(tracking_id) is None:
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
detail=f"Invalid tracking ID '{tracking_id}', must be an alphanumeric string of length 1 to 256",
)
return tracking_id
28 changes: 19 additions & 9 deletions app/api/routers/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
import uuid
import tempfile

from typing import List
from typing import List, Union
from starlette.status import HTTP_202_ACCEPTED, HTTP_503_SERVICE_UNAVAILABLE
from typing_extensions import Annotated
from fastapi import APIRouter, Query, Depends, UploadFile, Request, File
from fastapi.responses import StreamingResponse, JSONResponse

import api.globals as cms_globals
from api.dependencies import validate_tracking_id
from domain import Tags, Scope
from model_services.base import AbstractModelService
from processors.metrics_collector import (
Expand All @@ -34,6 +35,7 @@
description="Evaluate the model being served with a trainer export")
async def get_evaluation_with_trainer_export(request: Request,
trainer_export: Annotated[List[UploadFile], File(description="One or more trainer export files to be uploaded")],
tracking_id: Union[str, None] = Depends(validate_tracking_id),
model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> JSONResponse:
files = []
file_names = []
Expand All @@ -54,7 +56,7 @@ async def get_evaluation_with_trainer_export(request: Request,
json.dump(concatenated, data_file)
data_file.flush()
data_file.seek(0)
evaluation_id = str(uuid.uuid4())
evaluation_id = tracking_id or str(uuid.uuid4())
evaluation_accepted = model_service.train_supervised(data_file, 0, sys.maxsize, evaluation_id, ",".join(file_names))
if evaluation_accepted:
return JSONResponse(content={"message": "Your evaluation started successfully.", "evaluation_id": evaluation_id}, status_code=HTTP_202_ACCEPTED)
Expand All @@ -69,6 +71,7 @@ async def get_evaluation_with_trainer_export(request: Request,
description="Sanity check the model being served with a trainer export")
def get_sanity_check_with_trainer_export(request: Request,
trainer_export: Annotated[List[UploadFile], File(description="One or more trainer export files to be uploaded")],
tracking_id: Union[str, None] = Depends(validate_tracking_id),
model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> StreamingResponse:
files = []
file_names = []
Expand All @@ -88,8 +91,9 @@ def get_sanity_check_with_trainer_export(request: Request,
metrics = sanity_check_model_with_trainer_export(concatenated, model_service, return_df=True, include_anchors=False)
stream = io.StringIO()
metrics.to_csv(stream, index=False)
tracking_id = tracking_id or str(uuid.uuid4())
response = StreamingResponse(iter([stream.getvalue()]), media_type="text/csv")
response.headers["Content-Disposition"] = f'attachment ; filename="sanity_check_{str(uuid.uuid4())}.csv"'
response.headers["Content-Disposition"] = f'attachment ; filename="sanity_check_{tracking_id}.csv"'
return response


Expand All @@ -102,7 +106,8 @@ def get_inter_annotator_agreement_scores(request: Request,
trainer_export: Annotated[List[UploadFile], File(description="A list of trainer export files to be uploaded")],
annotator_a_project_id: Annotated[int, Query(description="The project ID from one annotator")],
annotator_b_project_id: Annotated[int, Query(description="The project ID from another annotator")],
scope: Annotated[str, Query(enum=[s.value for s in Scope], description="The scope for which the score will be calculated, e.g., per_concept, per_document or per_span")]) -> StreamingResponse:
scope: Annotated[str, Query(enum=[s.value for s in Scope], description="The scope for which the score will be calculated, e.g., per_concept, per_document or per_span")],
tracking_id: Union[str, None] = Depends(validate_tracking_id)) -> StreamingResponse:
files = []
for te in trainer_export:
temp_te = tempfile.NamedTemporaryFile()
Expand All @@ -126,8 +131,9 @@ def get_inter_annotator_agreement_scores(request: Request,
raise AnnotationException(f'Unknown scope: "{scope}"')
stream = io.StringIO()
iaa_scores.to_csv(stream, index=False)
tracking_id = tracking_id or str(uuid.uuid4())
response = StreamingResponse(iter([stream.getvalue()]), media_type="text/csv")
response.headers["Content-Disposition"] = f'attachment ; filename="iaa_{str(uuid.uuid4())}.csv"'
response.headers["Content-Disposition"] = f'attachment ; filename="iaa_{tracking_id}.csv"'
return response


Expand All @@ -137,7 +143,8 @@ def get_inter_annotator_agreement_scores(request: Request,
dependencies=[Depends(cms_globals.props.current_active_user)],
description="Concatenate multiple trainer export files into a single file for download")
def get_concatenated_trainer_exports(request: Request,
trainer_export: Annotated[List[UploadFile], File(description="A list of trainer export files to be uploaded")]) -> JSONResponse:
trainer_export: Annotated[List[UploadFile], File(description="A list of trainer export files to be uploaded")],
tracking_id: Union[str, None] = Depends(validate_tracking_id)) -> JSONResponse:
files = []
for te in trainer_export:
temp_te = tempfile.NamedTemporaryFile()
Expand All @@ -148,8 +155,9 @@ def get_concatenated_trainer_exports(request: Request,
concatenated = concat_trainer_exports([file.name for file in files], allow_recurring_doc_ids=False)
for file in files:
file.close()
tracking_id = tracking_id or str(uuid.uuid4())
response = JSONResponse(concatenated, media_type="application/json; charset=utf-8")
response.headers["Content-Disposition"] = f'attachment ; filename="concatenated_{str(uuid.uuid4())}.json"'
response.headers["Content-Disposition"] = f'attachment ; filename="concatenated_{tracking_id}.json"'
return response


Expand All @@ -159,7 +167,8 @@ def get_concatenated_trainer_exports(request: Request,
dependencies=[Depends(cms_globals.props.current_active_user)],
description="Get annotation stats of trainer export files")
def get_annotation_stats(request: Request,
trainer_export: Annotated[List[UploadFile], File(description="One or more trainer export files to be uploaded")]) -> StreamingResponse:
trainer_export: Annotated[List[UploadFile], File(description="One or more trainer export files to be uploaded")],
tracking_id: Union[str, None] = Depends(validate_tracking_id)) -> StreamingResponse:
files = []
file_names = []
for te in trainer_export:
Expand All @@ -177,6 +186,7 @@ def get_annotation_stats(request: Request,
stats = get_stats_from_trainer_export(concatenated, return_df=True)
stream = io.StringIO()
stats.to_csv(stream, index=False)
tracking_id = tracking_id or str(uuid.uuid4())
response = StreamingResponse(iter([stream.getvalue()]), media_type="text/csv")
response.headers["Content-Disposition"] = f'attachment ; filename="stats_{str(uuid.uuid4())}.csv"'
response.headers["Content-Disposition"] = f'attachment ; filename="stats_{tracking_id}.csv"'
return response
5 changes: 4 additions & 1 deletion app/api/routers/invocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from domain import TextWithAnnotations, TextWithPublicKey, TextStreamItem, ModelCard, Tags
from model_services.base import AbstractModelService
from utils import get_settings
from api.dependencies import validate_tracking_id
from api.utils import get_rate_limiter, encrypt
from management.prometheus_metrics import (
cms_doc_annotations,
Expand Down Expand Up @@ -132,6 +133,7 @@ def get_entities_from_multiple_texts(request: Request,
description="Upload a file containing a list of plain text and extract the NER entities in JSON")
def extract_entities_from_multi_text_file(request: Request,
multi_text_file: Annotated[UploadFile, File(description="A file containing a list of plain texts, in the format of [\"text_1\", \"text_2\", ..., \"text_n\"]")],
tracking_id: Union[str, None] = Depends(validate_tracking_id),
model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> StreamingResponse:
with tempfile.NamedTemporaryFile() as data_file:
for line in multi_text_file.file:
Expand Down Expand Up @@ -160,8 +162,9 @@ def extract_entities_from_multi_text_file(request: Request,
output = json.dumps(body)
logger.debug(output)
json_file = BytesIO(output.encode())
tracking_id = tracking_id or str(uuid.uuid4())
response = StreamingResponse(json_file, media_type="application/json")
response.headers["Content-Disposition"] = f'attachment ; filename="concatenated_{str(uuid.uuid4())}.json"'
response.headers["Content-Disposition"] = f'attachment ; filename="concatenated_{tracking_id}.json"'
return response


Expand Down
4 changes: 3 additions & 1 deletion app/api/routers/metacat_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from starlette.status import HTTP_202_ACCEPTED, HTTP_503_SERVICE_UNAVAILABLE

import api.globals as cms_globals
from api.dependencies import validate_tracking_id
from domain import Tags
from model_services.base import AbstractModelService
from processors.metrics_collector import concat_trainer_exports
Expand All @@ -29,6 +30,7 @@ async def train_metacat(request: Request,
epochs: Annotated[int, Query(description="The number of training epochs", ge=0)] = 1,
log_frequency: Annotated[int, Query(description="The number of processed documents after which training metrics will be logged", ge=1)] = 1,
description: Annotated[Union[str, None], Query(description="The description on the training or change logs")] = None,
tracking_id: Union[str, None] = Depends(validate_tracking_id),
model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> JSONResponse:
files = []
file_names = []
Expand All @@ -49,7 +51,7 @@ async def train_metacat(request: Request,
json.dump(concatenated, data_file)
data_file.flush()
data_file.seek(0)
training_id = str(uuid.uuid4())
training_id = tracking_id or str(uuid.uuid4())
try:
training_accepted = model_service.train_metacat(data_file,
epochs,
Expand Down
11 changes: 8 additions & 3 deletions app/api/routers/preview.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from starlette.status import HTTP_404_NOT_FOUND

import api.globals as cms_globals
from api.dependencies import validate_tracking_id
from domain import Doc, Tags
from model_services.base import AbstractModelService
from processors.metrics_collector import concat_trainer_exports
Expand All @@ -27,14 +28,16 @@
description="Extract the NER entities in HTML for preview")
async def get_rendered_entities_from_text(request: Request,
text: Annotated[str, Body(description="The text to be sent to the model for NER", media_type="text/plain")],
tracking_id: Union[str, None] = Depends(validate_tracking_id),
model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> StreamingResponse:
annotations = model_service.annotate(text)
entities = annotations_to_entities(annotations, model_service.model_name)
logger.debug("Entities extracted for previewing %s", entities)
ent_input = Doc(text=text, ents=entities)
data = displacy.render(ent_input.dict(), style="ent", manual=True)
tracking_id = tracking_id or str(uuid.uuid4())
response = StreamingResponse(BytesIO(data.encode()), media_type="application/octet-stream")
response.headers["Content-Disposition"] = f'attachment ; filename="preview_{str(uuid.uuid4())}.html"'
response.headers["Content-Disposition"] = f'attachment ; filename="preview_{tracking_id}.html"'
return response


Expand All @@ -47,7 +50,8 @@ def get_rendered_entities_from_trainer_export(request: Request,
trainer_export: Annotated[List[UploadFile], File(description="One or more trainer export files to be uploaded")] = [],
trainer_export_str: Annotated[str, Form(description="The trainer export raw JSON string")] = "{\"projects\": []}",
project_id: Annotated[Union[int, None], Query(description="The target project ID, and if not provided, all projects will be included")] = None,
document_id: Annotated[Union[int, None], Query(description="The target document ID, and if not provided, all documents of the target project(s) will be included")] = None) -> Response:
document_id: Annotated[Union[int, None], Query(description="The target document ID, and if not provided, all documents of the target project(s) will be included")] = None,
tracking_id: Union[str, None] = Depends(validate_tracking_id)) -> Response:
data: Dict = {"projects": []}
if trainer_export is not None:
files = []
Expand Down Expand Up @@ -88,8 +92,9 @@ def get_rendered_entities_from_trainer_export(request: Request,
doc = Doc(text=document["text"], ents=entities, title=f"P{project['id']}/D{document['id']}")
htmls.append(displacy.render(doc.dict(), style="ent", manual=True))
if htmls:
tracking_id = tracking_id or str(uuid.uuid4())
response = StreamingResponse(BytesIO("<br/>".join(htmls).encode()), media_type="application/octet-stream")
response.headers["Content-Disposition"] = f'attachment ; filename="preview_{str(uuid.uuid4())}.html"'
response.headers["Content-Disposition"] = f'attachment ; filename="preview_{tracking_id}.html"'
else:
logger.debug("Cannot find any matching documents to preview")
return JSONResponse(content={"message": "Cannot find any matching documents to preview"}, status_code=HTTP_404_NOT_FOUND)
Expand Down
4 changes: 3 additions & 1 deletion app/api/routers/supervised_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from starlette.status import HTTP_202_ACCEPTED, HTTP_503_SERVICE_UNAVAILABLE

import api.globals as cms_globals
from api.dependencies import validate_tracking_id
from domain import Tags
from model_services.base import AbstractModelService
from processors.metrics_collector import concat_trainer_exports
Expand All @@ -32,6 +33,7 @@ async def train_supervised(request: Request,
test_size: Annotated[Union[float, None], Query(description="The override of the test size in percentage. (For a 'huggingface-ner' model, a negative value can be used to apply the train-validation-test split if implicitly defined in trainer export: 'projects[0]' is used for training, 'projects[1]' for validation, and 'projects[2]' for testing)")] = 0.2,
log_frequency: Annotated[int, Query(description="The number of processed documents after which training metrics will be logged", ge=1)] = 1,
description: Annotated[Union[str, None], Form(description="The description of the training or change logs")] = None,
tracking_id: Union[str, None] = Depends(validate_tracking_id),
model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> JSONResponse:
files = []
file_names = []
Expand All @@ -51,7 +53,7 @@ async def train_supervised(request: Request,
json.dump(concatenated, data_file)
data_file.flush()
data_file.seek(0)
training_id = str(uuid.uuid4())
training_id = tracking_id or str(uuid.uuid4())
try:
training_accepted = model_service.train_supervised(data_file,
epochs,
Expand Down
Loading

0 comments on commit 5d94c50

Please sign in to comment.