Skip to content

Commit

Permalink
train: Return MLflow tracking information (#4)
Browse files Browse the repository at this point in the history
Extend the API to return MLflow tracking information as part of a
training or evaluation response, including the experiment and run IDs,
and update the tests accordingly. If training is already in progress,
the API returns the experiment and run IDs of the current training run.
This affects the following routes:
* POST /train_supervised
* POST /train_unsupervised
* POST /train_unsupervised_with_hf_hub_dataset
* POST /train_metacat
* POST /evaluate

Signed-off-by: Phoevos Kalemkeris <phoevos.kalemkeris@ucl.ac.uk>
  • Loading branch information
phoevos authored Jan 6, 2025
1 parent 72bbe83 commit ac0a8c1
Show file tree
Hide file tree
Showing 11 changed files with 122 additions and 49 deletions.
21 changes: 18 additions & 3 deletions app/api/routers/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,26 @@ async def get_evaluation_with_trainer_export(request: Request,
data_file.flush()
data_file.seek(0)
evaluation_id = tracking_id or str(uuid.uuid4())
evaluation_accepted = model_service.train_supervised(data_file, 0, sys.maxsize, evaluation_id, ",".join(file_names))
evaluation_accepted, experiment_id, run_id = model_service.train_supervised(
data_file, 0, sys.maxsize, evaluation_id, ",".join(file_names)
)
if evaluation_accepted:
return JSONResponse(content={"message": "Your evaluation started successfully.", "evaluation_id": evaluation_id}, status_code=HTTP_202_ACCEPTED)
return JSONResponse(
content={
"message": "Your evaluation started successfully.",
"evaluation_id": evaluation_id,
"experiment_id": experiment_id,
"run_id": run_id,
}, status_code=HTTP_202_ACCEPTED
)
else:
return JSONResponse(content={"message": "Another training or evaluation on this model is still active. Please retry later."}, status_code=HTTP_503_SERVICE_UNAVAILABLE)
return JSONResponse(
content={
"message": "Another training or evaluation on this model is still active. Please retry later.",
"experiment_id": experiment_id,
"run_id": run_id,
}, status_code=HTTP_503_SERVICE_UNAVAILABLE
)


@router.post("/sanity-check",
Expand Down
26 changes: 20 additions & 6 deletions app/api/routers/metacat_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import uuid
import json
import logging
from typing import List, Union
from typing import List, Tuple, Union
from typing_extensions import Annotated

from fastapi import APIRouter, Depends, UploadFile, Query, Request, File
Expand Down Expand Up @@ -53,7 +53,7 @@ async def train_metacat(request: Request,
data_file.seek(0)
training_id = tracking_id or str(uuid.uuid4())
try:
training_accepted = model_service.train_metacat(data_file,
training_response = model_service.train_metacat(data_file,
epochs,
log_frequency,
training_id,
Expand All @@ -65,13 +65,27 @@ async def train_metacat(request: Request,
for file in files:
file.close()

return _get_training_response(training_accepted, training_id)
return _get_training_response(training_response, training_id)


def _get_training_response(training_accepted: bool, training_id: str) -> JSONResponse:
def _get_training_response(training_response: Tuple[bool, str, str], training_id: str) -> JSONResponse:
training_accepted, experiment_id, run_id = training_response
if training_accepted:
logger.debug("Training accepted with ID: %s", training_id)
return JSONResponse(content={"message": "Your training started successfully.", "training_id": training_id}, status_code=HTTP_202_ACCEPTED)
return JSONResponse(
content={
"message": "Your training started successfully.",
"training_id": training_id,
"experiment_id": experiment_id,
"run_id": run_id,
}, status_code=HTTP_202_ACCEPTED
)
else:
logger.debug("Training refused due to another active training or evaluation on this model")
return JSONResponse(content={"message": "Another training or evaluation on this model is still active. Please retry your training later."}, status_code=HTTP_503_SERVICE_UNAVAILABLE)
return JSONResponse(
content={
"message": "Another training or evaluation on this model is still active. Please retry your training later.",
"experiment_id": experiment_id,
"run_id": run_id,
}, status_code=HTTP_503_SERVICE_UNAVAILABLE
)
26 changes: 20 additions & 6 deletions app/api/routers/supervised_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import uuid
import json
import logging
from typing import List, Union
from typing import List, Tuple, Union
from typing_extensions import Annotated

from fastapi import APIRouter, Depends, UploadFile, Query, Request, File, Form
Expand Down Expand Up @@ -55,7 +55,7 @@ async def train_supervised(request: Request,
data_file.seek(0)
training_id = tracking_id or str(uuid.uuid4())
try:
training_accepted = model_service.train_supervised(data_file,
training_response = model_service.train_supervised(data_file,
epochs,
log_frequency,
training_id,
Expand All @@ -69,13 +69,27 @@ async def train_supervised(request: Request,
for file in files:
file.close()

return _get_training_response(training_accepted, training_id)
return _get_training_response(training_response, training_id)


def _get_training_response(training_accepted: bool, training_id: str) -> JSONResponse:
def _get_training_response(training_response: Tuple[bool, str, str], training_id: str) -> JSONResponse:
training_accepted, experiment_id, run_id = training_response
if training_accepted:
logger.debug("Training accepted with ID: %s", training_id)
return JSONResponse(content={"message": "Your training started successfully.", "training_id": training_id}, status_code=HTTP_202_ACCEPTED)
return JSONResponse(
content={
"message": "Your training started successfully.",
"training_id": training_id,
"experiment_id": experiment_id,
"run_id": run_id,
}, status_code=HTTP_202_ACCEPTED
)
else:
logger.debug("Training refused due to another active training or evaluation on this model")
return JSONResponse(content={"message": "Another training or evaluation on this model is still active. Please retry your training later."}, status_code=HTTP_503_SERVICE_UNAVAILABLE)
return JSONResponse(
content={
"message": "Another training or evaluation on this model is still active. Please retry your training later.",
"experiment_id": experiment_id,
"run_id": run_id,
}, status_code=HTTP_503_SERVICE_UNAVAILABLE
)
30 changes: 22 additions & 8 deletions app/api/routers/unsupervised_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import datasets
import zipfile
from typing import List, Union
from typing import List, Tuple, Union
from typing_extensions import Annotated

from fastapi import APIRouter, Depends, UploadFile, Query, Request, File
Expand Down Expand Up @@ -65,7 +65,7 @@ async def train_unsupervised(request: Request,
data_file.seek(0)
training_id = tracking_id or str(uuid.uuid4())
try:
training_accepted = model_service.train_unsupervised(data_file,
training_response = model_service.train_unsupervised(data_file,
epochs,
log_frequency,
training_id,
Expand All @@ -79,7 +79,7 @@ async def train_unsupervised(request: Request,
for file in files:
file.close()

return _get_training_response(training_accepted, training_id)
return _get_training_response(training_response, training_id)


@router.post("/train_unsupervised_with_hf_hub_dataset",
Expand Down Expand Up @@ -133,7 +133,7 @@ async def train_unsupervised_with_hf_dataset(request: Request,
hf_dataset.save_to_disk(data_dir.name)

training_id = tracking_id or str(uuid.uuid4())
training_accepted = model_service.train_unsupervised(data_dir,
training_response = model_service.train_unsupervised(data_dir,
epochs,
log_frequency,
training_id,
Expand All @@ -143,13 +143,27 @@ async def train_unsupervised_with_hf_dataset(request: Request,
lr_override=lr_override,
test_size=test_size,
description=description)
return _get_training_response(training_accepted, training_id)
return _get_training_response(training_response, training_id)


def _get_training_response(training_accepted: bool, training_id: str) -> JSONResponse:
def _get_training_response(training_response: Tuple[bool, str, str], training_id: str) -> JSONResponse:
training_accepted, experiment_id, run_id = training_response
if training_accepted:
logger.debug("Training accepted with ID: %s", training_id)
return JSONResponse(content={"message": "Your training started successfully.", "training_id": training_id}, status_code=HTTP_202_ACCEPTED)
return JSONResponse(
content={
"message": "Your training started successfully.",
"training_id": training_id,
"experiment_id": experiment_id,
"run_id": run_id,
}, status_code=HTTP_202_ACCEPTED
)
else:
logger.debug("Training refused due to another active training or evaluation on this model")
return JSONResponse(content={"message": "Another training or evaluation on this model is still active. Please retry later."}, status_code=HTTP_503_SERVICE_UNAVAILABLE)
return JSONResponse(
content={
"message": "Another training or evaluation on this model is still active. Please retry later.",
"experiment_id": experiment_id,
"run_id": run_id,
}, status_code=HTTP_503_SERVICE_UNAVAILABLE
)
6 changes: 3 additions & 3 deletions app/model_services/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ def batch_annotate(self, texts: List[str]) -> List[List[Dict[str, Any]]]:
def init_model(self) -> None:
raise NotImplementedError

def train_supervised(self, *args: Tuple, **kwargs: Dict[str, Any]) -> bool:
def train_supervised(self, *args: Tuple, **kwargs: Dict[str, Any]) -> Tuple[bool, str, str]:
raise NotImplementedError

def train_unsupervised(self, *args: Tuple, **kwargs: Dict[str, Any]) -> bool:
def train_unsupervised(self, *args: Tuple, **kwargs: Dict[str, Any]) -> Tuple[bool, str, str]:
raise NotImplementedError

def train_metacat(self, *args: Tuple, **kwargs: Dict[str, Any]) -> bool:
def train_metacat(self, *args: Tuple, **kwargs: Dict[str, Any]) -> Tuple[bool, str, str]:
raise NotImplementedError
4 changes: 2 additions & 2 deletions app/model_services/huggingface_ner_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def train_supervised(self,
raw_data_files: Optional[List[TextIO]] = None,
description: Optional[str] = None,
synchronised: bool = False,
**hyperparams: Dict[str, Any]) -> bool:
**hyperparams: Dict[str, Any]) -> Tuple[bool, str, str]:
if self._supervised_trainer is None:
raise ConfigurationException("The supervised trainer is not enabled")
return self._supervised_trainer.train(data_file, epochs, log_frequency, training_id, input_file_name, raw_data_files, description, synchronised, **hyperparams)
Expand All @@ -170,7 +170,7 @@ def train_unsupervised(self,
raw_data_files: Optional[List[TextIO]] = None,
description: Optional[str] = None,
synchronised: bool = False,
**hyperparams: Dict[str, Any]) -> bool:
**hyperparams: Dict[str, Any]) -> Tuple[bool, str, str]:
if self._unsupervised_trainer is None:
raise ConfigurationException("The unsupervised trainer is not enabled")
return self._unsupervised_trainer.train(data_file, epochs, log_frequency, training_id, input_file_name, raw_data_files, description, synchronised, **hyperparams)
6 changes: 3 additions & 3 deletions app/model_services/medcat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def train_supervised(self,
raw_data_files: Optional[List[TextIO]] = None,
description: Optional[str] = None,
synchronised: bool = False,
**hyperparams: Dict[str, Any]) -> bool:
**hyperparams: Dict[str, Any]) -> Tuple[bool, str, str]:
if self._supervised_trainer is None:
raise ConfigurationException("The supervised trainer is not enabled")
return self._supervised_trainer.train(data_file, epochs, log_frequency, training_id, input_file_name, raw_data_files, description, synchronised, **hyperparams)
Expand All @@ -135,7 +135,7 @@ def train_unsupervised(self,
raw_data_files: Optional[List[TextIO]] = None,
description: Optional[str] = None,
synchronised: bool = False,
**hyperparams: Dict[str, Any]) -> bool:
**hyperparams: Dict[str, Any]) -> Tuple[bool, str, str]:
if self._unsupervised_trainer is None:
raise ConfigurationException("The unsupervised trainer is not enabled")
return self._unsupervised_trainer.train(data_file, epochs, log_frequency, training_id, input_file_name, raw_data_files, description, synchronised, **hyperparams)
Expand All @@ -149,7 +149,7 @@ def train_metacat(self,
raw_data_files: Optional[List[TextIO]] = None,
description: Optional[str] = None,
synchronised: bool = False,
**hyperparams: Dict[str, Any]) -> bool:
**hyperparams: Dict[str, Any]) -> Tuple[bool, str, str]:
if self._metacat_trainer is None:
raise ConfigurationException("The metacat trainer is not enabled")
return self._metacat_trainer.train(data_file, epochs, log_frequency, training_id, input_file_name, raw_data_files, description, synchronised, **hyperparams)
Expand Down
4 changes: 2 additions & 2 deletions app/model_services/medcat_model_deid.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import inspect
import threading
import torch
from typing import Dict, List, TextIO, Optional, Any, final, Callable
from typing import Dict, List, TextIO, Tuple, Optional, Any, final, Callable
from functools import partial
from transformers import pipeline
from medcat.cat import CAT
Expand Down Expand Up @@ -147,7 +147,7 @@ def train_supervised(self,
raw_data_files: Optional[List[TextIO]] = None,
description: Optional[str] = None,
synchronised: bool = False,
**hyperparams: Dict[str, Any]) -> bool:
**hyperparams: Dict[str, Any]) -> Tuple[bool, str, str]:
if self._supervised_trainer is None:
raise ConfigurationException("Trainers are not enabled")
return self._supervised_trainer.train(data_file, epochs, log_frequency, training_id, input_file_name, raw_data_files, description, synchronised, **hyperparams)
Expand Down
28 changes: 19 additions & 9 deletions app/trainers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from typing import TextIO, Callable, Dict, Optional, Any, List, Union, final
from typing import TextIO, Callable, Dict, Tuple, Optional, Any, List, Union, final
from config import Settings
from management.tracker_client import TrackerClient
from data import doc_dataset, anno_dataset
Expand All @@ -26,6 +26,8 @@ def __init__(self, config: Settings, model_name: str) -> None:
self._model_name = model_name
self._training_lock = threading.Lock()
self._training_in_progress = False
self._experiment_id = None
self._run_id = None
self._tracker_client = TrackerClient(self._config.MLFLOW_TRACKING_URI)
self._executor: Optional[ThreadPoolExecutor] = ThreadPoolExecutor(max_workers=1)

Expand All @@ -37,6 +39,14 @@ def model_name(self) -> str:
def model_name(self, model_name: str) -> None:
self._model_name = model_name

@property
def experiment_id(self) -> str:
return self._experiment_id or ""

@property
def run_id(self) -> str:
return self._run_id or ""

@final
def start_training(self,
run: Callable,
Expand All @@ -48,13 +58,13 @@ def start_training(self,
input_file_name: str,
raw_data_files: Optional[List[TextIO]] = None,
description: Optional[str] = None,
synchronised: bool = False) -> bool:
synchronised: bool = False) -> Tuple[bool, str, str]:
with self._training_lock:
if self._training_in_progress:
return False
return False, self.experiment_id, self.run_id
else:
loop = asyncio.get_event_loop()
experiment_id, run_id = self._tracker_client.start_tracking(
self._experiment_id, self._run_id = self._tracker_client.start_tracking(
model_name=self._model_name,
input_file_name=input_file_name,
base_model_original=self._config.BASE_MODEL_FULL_PATH,
Expand Down Expand Up @@ -101,15 +111,15 @@ def start_training(self,
else:
raise ValueError(f"Unknown training type: {training_type}")

logger.info("Starting training job: %s with experiment ID: %s", training_id, experiment_id)
logger.info("Starting training job: %s with experiment ID: %s", training_id, self.experiment_id)
self._training_in_progress = True
training_task = asyncio.ensure_future(loop.run_in_executor(self._executor,
partial(run, self, training_params, data_file, log_frequency, run_id, description)))
partial(run, self, training_params, data_file, log_frequency, self.run_id, description)))

if synchronised:
loop.run_until_complete(training_task)

return True
return True, self.experiment_id, self.run_id

@staticmethod
def _make_model_file_copy(model_file_path: str, run_id: str) -> str:
Expand Down Expand Up @@ -161,7 +171,7 @@ def train(self,
raw_data_files: Optional[List[TextIO]] = None,
description: Optional[str] = None,
synchronised: bool = False,
**hyperparams: Dict[str, Any]) -> bool:
**hyperparams: Dict[str, Any]) -> Tuple[bool, str, str]:
training_type = TrainingType.SUPERVISED.value
training_params = {
"data_path": data_file.name,
Expand Down Expand Up @@ -204,7 +214,7 @@ def train(self,
raw_data_files: Optional[List[TextIO]] = None,
description: Optional[str] = None,
synchronised: bool = False,
**hyperparams: Dict[str, Any]) -> bool:
**hyperparams: Dict[str, Any]) -> Tuple[bool, str, str]:
training_type = TrainingType.UNSUPERVISED.value
training_params = {
"nepochs": epochs,
Expand Down
Loading

0 comments on commit ac0a8c1

Please sign in to comment.