diff --git a/app/api/routers/preview.py b/app/api/routers/preview.py index 1ca298c..6ed1ffd 100644 --- a/app/api/routers/preview.py +++ b/app/api/routers/preview.py @@ -74,7 +74,7 @@ def get_rendered_entities_from_trainer_export(request: Request, entities.append({ "start": annotation["start"], "end": annotation["end"], - "label": f"{annotation['cui']} ({'correct' if annotation['correct'] else 'incorrect'}{'; terminated' if annotation['killed'] else ''})", + "label": f"{annotation['cui']} ({'correct' if annotation.get('correct', True) else 'incorrect'}{'; terminated' if annotation.get('deleted', False) and annotation.get('killed', False) else ''})", "kb_id": annotation["cui"], "kb_url": "#", }) diff --git a/app/api/routers/supervised_training.py b/app/api/routers/supervised_training.py index 6cf1471..021cfa3 100644 --- a/app/api/routers/supervised_training.py +++ b/app/api/routers/supervised_training.py @@ -5,7 +5,7 @@ from typing import List, Union from typing_extensions import Annotated -from fastapi import APIRouter, Depends, UploadFile, Query, Request, File +from fastapi import APIRouter, Depends, UploadFile, Query, Request, File, Form from fastapi.responses import JSONResponse from starlette.status import HTTP_202_ACCEPTED, HTTP_503_SERVICE_UNAVAILABLE @@ -30,7 +30,7 @@ async def train_supervised(request: Request, epochs: Annotated[int, Query(description="The number of training epochs", ge=0)] = 1, lr_override: Annotated[Union[float, None], Query(description="The override of the initial learning rate", gt=0.0)] = None, log_frequency: Annotated[int, Query(description="The number of processed documents after which training metrics will be logged", ge=1)] = 1, - description: Annotated[Union[str, None], Query(description="The description of the training or change logs")] = None, + description: Annotated[Union[str, None], Form(description="The description of the training or change logs")] = None, model_service: AbstractModelService = Depends(cms_globals.model_service_dep)) -> JSONResponse: files = [] file_names = [] diff --git a/app/data/anno_dataset.py b/app/data/anno_dataset.py index 8b28d7e..ead456c 100644 --- a/app/data/anno_dataset.py +++ b/app/data/anno_dataset.py @@ -1,7 +1,7 @@ import datasets import json from pathlib import Path -from typing import List, Iterable, Tuple, Dict +from typing import List, Iterable, Tuple, Dict, Optional from utils import filter_by_concept_ids @@ -24,7 +24,8 @@ def _info(self) -> datasets.DatasetInfo: description="Annotation Dataset. This is a dataset containing flattened MedCAT Trainer export", features=datasets.Features( { - "name": datasets.Value("string"), + "project": datasets.Value("string"), + "name":datasets.Value("string"), "text": datasets.Value("string"), "starts": datasets.Value("string"), # Mlflow ColSpec schema does not support HF Dataset Sequence "ends": datasets.Value("string"), # Mlflow ColSpec schema does not support HF Dataset Sequence @@ -57,8 +58,9 @@ def generate_examples(filepaths: List[Path]) -> Iterable[Tuple[str, Dict]]: ends.append(str(annotation["end"])) labels.append(annotation["cui"]) yield str(id_), { - "name": document["name"], - "text": document["text"], + "project": project.get("name"), + "name": document.get("name"), + "text": document.get("text"), "starts": ",".join(starts), "ends": ",".join(ends), "labels": ",".join(labels), diff --git a/app/data/doc_dataset.py b/app/data/doc_dataset.py index 173c9ef..cc48842 100644 --- a/app/data/doc_dataset.py +++ b/app/data/doc_dataset.py @@ -44,5 +44,5 @@ def generate_examples(filepaths: List[Path]) -> Iterable[Tuple[str, Dict]]: with open(str(filepath), "r") as f: texts = ijson.items(f, "item") for text in texts: - yield str(id_), {"name": f"doc_{str(id_)}", "text": text} + yield str(id_), {"name": f"{str(id_)}", "text": text} id_ += 1 diff --git a/app/processors/metrics_collector.py b/app/processors/metrics_collector.py index 645b844..4e816a8 100644 --- a/app/processors/metrics_collector.py +++ b/app/processors/metrics_collector.py @@ -225,7 +225,7 @@ def get_iaa_scores_per_concept(export_file: Union[str, TextIO], per_cui_metaanno_iia_pct = {} per_cui_metaanno_cohens_kappa = {} for cui, cui_metastate_pairs in cui_metastates.items(): - per_cui_metaanno_iia_pct[cui] = len([1 for csp in cui_metastate_pairs if csp[0] == csp[1]]) / len(cui_metastate_pairs) * 100 + per_cui_metaanno_iia_pct[cui] = len([1 for cmp in cui_metastate_pairs if cmp[0] == cmp[1]]) / len(cui_metastate_pairs) * 100 per_cui_metaanno_cohens_kappa[cui] = _get_cohens_kappa_coefficient(*map(list, zip(*cui_metastate_pairs))) if return_df: @@ -286,7 +286,7 @@ def get_iaa_scores_per_doc(export_file: Union[str, TextIO], per_doc_metaanno_iia_pct = {} per_doc_metaanno_cohens_kappa = {} for doc_id, doc_metastate_pairs in doc_metastates.items(): - per_doc_metaanno_iia_pct[str(doc_id)] = len([1 for dsp in doc_metastate_pairs if dsp[0] == dsp[1]]) / len(doc_metastate_pairs) * 100 + per_doc_metaanno_iia_pct[str(doc_id)] = len([1 for dmp in doc_metastate_pairs if dmp[0] == dmp[1]]) / len(doc_metastate_pairs) * 100 per_doc_metaanno_cohens_kappa[str(doc_id)] = _get_cohens_kappa_coefficient(*map(list, zip(*doc_metastate_pairs))) if return_df: diff --git a/tests/app/data/test_anno_dataset.py b/tests/app/data/test_anno_dataset.py index 0c12dbd..25a25bf 100644 --- a/tests/app/data/test_anno_dataset.py +++ b/tests/app/data/test_anno_dataset.py @@ -7,6 +7,7 @@ def test_load_dataset(): trainer_export = os.path.join(os.path.dirname(__file__), "..", "..", "resources", "fixture", "trainer_export_multi_projs.json") dataset = datasets.load_dataset(anno_dataset.__file__, data_files={"annotations": trainer_export}, split="train", cache_dir="/tmp") assert dataset.features.to_dict() == { + "project": {"dtype": "string", "_type": "Value"}, "name": {"dtype": "string", "_type": "Value"}, "text": {"dtype": "string", "_type": "Value"}, "starts": {"dtype": "string", "_type": "Value"}, @@ -14,6 +15,7 @@ def test_load_dataset(): "labels": {"dtype": "string", "_type": "Value"}, } assert len(dataset.to_list()) == 4 + assert dataset.to_list()[0]["project"] == "MT Samples (Clone)" assert dataset.to_list()[0]["name"] == "1687" assert dataset.to_list()[0]["starts"] == "332,255,276,272" assert dataset.to_list()[0]["ends"] == "355,267,282,275" @@ -24,6 +26,7 @@ def test_generate_examples(): example_gen = anno_dataset.generate_examples([os.path.join(os.path.dirname(__file__), "..", "..", "resources", "fixture", "trainer_export.json")]) example = next(example_gen) assert example[0] == "1" + assert "project" in example[1] assert "name" in example[1] assert "text" in example[1] assert "starts" in example[1] diff --git a/tests/app/data/test_doc_dataset.py b/tests/app/data/test_doc_dataset.py index c446e4d..391ed47 100644 --- a/tests/app/data/test_doc_dataset.py +++ b/tests/app/data/test_doc_dataset.py @@ -8,7 +8,7 @@ def test_load_dataset(): dataset = datasets.load_dataset(doc_dataset.__file__, data_files={"documents": sample_texts}, split="train", cache_dir="/tmp") assert dataset.features.to_dict() == {"name": {"dtype": "string", "_type": "Value"}, "text": {"dtype": "string", "_type": "Value"}} assert len(dataset.to_list()) == 15 - assert dataset.to_list()[0]["name"] == "doc_1" + assert dataset.to_list()[0]["name"] == "1" def test_generate_examples(): diff --git a/tests/app/test_utils.py b/tests/app/test_utils.py index 4dc4dd3..8843919 100644 --- a/tests/app/test_utils.py +++ b/tests/app/test_utils.py @@ -170,12 +170,16 @@ def test_augment_annotations_case_insensitive(): [r"^\d{2,4}\s*[.\/]\s*\d{1,2}\s*[.\/]\s*\d{1,2}$"], [r"^\d{1,2}$", r"^[-.\/]$", r"^(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec|January|February|March|April|May|June|July|August|September|October|November|December)\s*[-.\/]\s*\d{2,4}$"], [r"^\d{2,4}$", r"^[-.\/]$", r"^(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec|January|February|March|April|May|June|July|August|September|October|November|December)\s*[-.\/]\s*\d{1,2}$"], - [r"^(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec|January|February|March|April|May|June|July|August|September|October|November|December)\s*[-.\/]\s*\d{4}$"], - [r"^\d{4}$", r"^(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec|January|February|March|April|May|June|July|August|September|October|November|December)$"], [r"^\d{1,2}\s*$", r"-", r"^\s*\d{4}$"], [r"^\d{1,2}\s*[\/]\s*\d{4}$"], [r"^\d{4}\s*$", r"-", r"^\s*\d{1,2}$"], [r"^\d{4}\s*[\/]\s*\d{1,2}$"], + [r"^(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec|January|February|March|April|May|June|July|August|September|October|November|December)\s*[-.\/]\s*\d{4}$"], + [r"^(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec|January|February|March|April|May|June|July|August|September|October|November|December)(\s+\d{1,2})*$", r",", r"^\d{4}$"], + [r"^\d{4}\s*[-.\/]\s*(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec|January|February|March|April|May|June|July|August|September|October|November|December)$"], + [r"^\d{4}$", r"^(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec|January|February|March|April|May|June|July|August|September|October|November|December)$"], + [r"^(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec|January|February|March|April|May|June|July|August|September|October|November|December)$", r"^\d{4}$"], + [r"^(?:19\d\d|20\d\d)$"], ] }, False)