From 875fb7d7360bf31334189e2ad4da30076d7affd1 Mon Sep 17 00:00:00 2001 From: Alexandre Matton <69212936+alex-matton@users.noreply.github.com> Date: Tue, 22 Aug 2023 18:01:43 -0700 Subject: [PATCH] Classification finetuning v2 (#55) * adapt to new classification format * add properties * update version * address comments --- cohere_sagemaker/classification.py | 40 ++++++++++++++++++--- cohere_sagemaker/client.py | 4 +-- notebooks/Deploy classification model.ipynb | 14 ++++---- setup.py | 2 +- 4 files changed, 45 insertions(+), 15 deletions(-) diff --git a/cohere_sagemaker/classification.py b/cohere_sagemaker/classification.py index eb94d6fa6..39acbbd5a 100644 --- a/cohere_sagemaker/classification.py +++ b/cohere_sagemaker/classification.py @@ -1,15 +1,45 @@ from cohere_sagemaker.response import CohereObject -from typing import Iterator, List, Union +from typing import Any, Dict, Iterator, List, Literal, Union + +Prediction = Union[str, int, List[str], List[int]] +ClassificationDict = Dict[Literal["prediction", "confidence", "text"], Any] class Classification(CohereObject): - def __init__(self, classification: Union[str, int, List[str], List[int]]) -> None: - # A classification can be either a label (int or string) for single-label classification, - # or a list of labels (int or string) for multi-label classification. + def __init__(self, classification: Union[Prediction, ClassificationDict]) -> None: + # Prediction is the old format (version 1 of classification-finetuning) + # ClassificationDict is the new format (version 2 of classification-finetuning). + # It also contains the original text and the labels' confidence scores of the prediction self.classification = classification def is_multilabel(self) -> bool: - return not isinstance(self.classification, (int, str)) + if isinstance(self.classification, list): + return True + elif isinstance(self.classification, (int, str)): + return False + return isinstance(self.classification["prediction"], list) + + @property + def prediction(self) -> Prediction: + if isinstance(self.classification, (list, int, str)): + return self.classification + return self.classification["prediction"] + + @property + def confidence(self) -> List[float]: + if isinstance(self.classification, (list, int, str)): + raise ValueError( + "Confidence scores are not available for version prior to 2.0 of Cohere Classification Finetuning AWS package" + ) + return self.classification["confidence"] + + @property + def text(self) -> str: + if isinstance(self.classification, (list, int, str)): + raise ValueError( + "Original text is not available for version prior to 2.0 of Cohere Classification Finetuning AWS package" + ) + return self.classification["text"] class Classifications(CohereObject): diff --git a/cohere_sagemaker/client.py b/cohere_sagemaker/client.py index 36a9b514b..a6d3a80cd 100644 --- a/cohere_sagemaker/client.py +++ b/cohere_sagemaker/client.py @@ -64,7 +64,7 @@ def _s3_models_dir_to_tarfile(self, s3_models_dir: str) -> str: str: S3 URI pointing to the `models.tar.gz` file """ - s3_models_dir = s3_models_dir + ("/" if not s3_models_dir.endswith("/") else "") + s3_models_dir = s3_models_dir.rstrip("/") + "/" # Links of all fine-tuned models in s3_models_dir. Their format should be .tar.gz s3_tar_models = [ @@ -410,7 +410,7 @@ def create_finetune( out, the default role "ServiceRoleSagemaker" will be used, which generally works outside of SageMaker. """ assert name != "model", "name cannot be 'model'" - s3_models_dir = s3_models_dir + ("/" if not s3_models_dir.endswith("/") else "") + s3_models_dir = s3_models_dir.rstrip("/") + "/" if role is None: try: diff --git a/notebooks/Deploy classification model.ipynb b/notebooks/Deploy classification model.ipynb index c88ac1cbc..bb478489a 100644 --- a/notebooks/Deploy classification model.ipynb +++ b/notebooks/Deploy classification model.ipynb @@ -81,7 +81,7 @@ "source": [ "region = boto3.Session().region_name\n", "\n", - "cohere_package = \"classification-finetuning-91717d633bb2357ba721ffc4ba2fe75c\"\n", + "cohere_package = \"classification-finetuning-v2-67831404bb66304b87a66a3a27c979ac\"\n", "algorithm_map = {\n", " \"us-east-1\": f\"arn:aws:sagemaker:us-east-1:865070037744:algorithm/{cohere_package}\",\n", " \"us-east-2\": f\"arn:aws:sagemaker:us-east-2:057799348421:algorithm/{cohere_package}\",\n", @@ -130,7 +130,7 @@ "metadata": {}, "outputs": [], "source": [ - "s3_data_dir = \"s3://...\" # Do not add a trailing slash otherwise the upload will not work" + "s3_data_dir = \"s3/...\" # where to upload the data" ] }, { @@ -148,8 +148,8 @@ "outputs": [], "source": [ "sess = sage.Session()\n", - "train_dataset1 = S3Uploader.upload(\"../examples/sample_sentiment_classification_data.jsonl\", s3_data_dir, sagemaker_session=sess)\n", - "train_dataset2 = S3Uploader.upload(\"../examples/sample_multilabel_classification_data.jsonl\", s3_data_dir, sagemaker_session=sess)" + "train_dataset1 = S3Uploader.upload(\"../examples/sample_sentiment_classification_data.jsonl\", s3_data_dir.rstrip(\"/\"), sagemaker_session=sess)\n", + "train_dataset2 = S3Uploader.upload(\"../examples/sample_multilabel_classification_data.jsonl\", s3_data_dir.rstrip(\"/\"), sagemaker_session=sess)" ] }, { @@ -168,7 +168,7 @@ "metadata": {}, "outputs": [], "source": [ - "s3_models_dir = \"s3://...\" # Do not add a trailing slash otherwise it will not work" + "s3_models_dir = \"s3/...\" # where the models will be saved " ] }, { @@ -300,8 +300,8 @@ "metadata": {}, "outputs": [], "source": [ - "co.delete_endpoint()\n", - "co.close()" + "# co.delete_endpoint()\n", + "# co.close()" ] }, { diff --git a/setup.py b/setup.py index 539ee155a..537075e94 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,7 @@ def has_ext_modules(foo) -> bool: setuptools.setup(name='cohere-sagemaker', - version='0.7.2', + version='0.8.0', author='Cohere', author_email='support@cohere.ai', description='A Python library for the Cohere endpoints in AWS Sagemaker',