Skip to content

Commit

Permalink
Classification finetuning v2 (#55)
Browse files Browse the repository at this point in the history
* adapt to new classification format

* add properties

* update version

* address comments
  • Loading branch information
alex-matton authored Aug 23, 2023
1 parent e68feb1 commit 875fb7d
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 15 deletions.
40 changes: 35 additions & 5 deletions cohere_sagemaker/classification.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,45 @@
from cohere_sagemaker.response import CohereObject
from typing import Iterator, List, Union
from typing import Any, Dict, Iterator, List, Literal, Union

Prediction = Union[str, int, List[str], List[int]]
ClassificationDict = Dict[Literal["prediction", "confidence", "text"], Any]


class Classification(CohereObject):
def __init__(self, classification: Union[str, int, List[str], List[int]]) -> None:
# A classification can be either a label (int or string) for single-label classification,
# or a list of labels (int or string) for multi-label classification.
def __init__(self, classification: Union[Prediction, ClassificationDict]) -> None:
# Prediction is the old format (version 1 of classification-finetuning)
# ClassificationDict is the new format (version 2 of classification-finetuning).
# It also contains the original text and the labels' confidence scores of the prediction
self.classification = classification

def is_multilabel(self) -> bool:
return not isinstance(self.classification, (int, str))
if isinstance(self.classification, list):
return True
elif isinstance(self.classification, (int, str)):
return False
return isinstance(self.classification["prediction"], list)

@property
def prediction(self) -> Prediction:
if isinstance(self.classification, (list, int, str)):
return self.classification
return self.classification["prediction"]

@property
def confidence(self) -> List[float]:
if isinstance(self.classification, (list, int, str)):
raise ValueError(
"Confidence scores are not available for version prior to 2.0 of Cohere Classification Finetuning AWS package"
)
return self.classification["confidence"]

@property
def text(self) -> str:
if isinstance(self.classification, (list, int, str)):
raise ValueError(
"Original text is not available for version prior to 2.0 of Cohere Classification Finetuning AWS package"
)
return self.classification["text"]


class Classifications(CohereObject):
Expand Down
4 changes: 2 additions & 2 deletions cohere_sagemaker/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def _s3_models_dir_to_tarfile(self, s3_models_dir: str) -> str:
str: S3 URI pointing to the `models.tar.gz` file
"""

s3_models_dir = s3_models_dir + ("/" if not s3_models_dir.endswith("/") else "")
s3_models_dir = s3_models_dir.rstrip("/") + "/"

# Links of all fine-tuned models in s3_models_dir. Their format should be .tar.gz
s3_tar_models = [
Expand Down Expand Up @@ -410,7 +410,7 @@ def create_finetune(
out, the default role "ServiceRoleSagemaker" will be used, which generally works outside of SageMaker.
"""
assert name != "model", "name cannot be 'model'"
s3_models_dir = s3_models_dir + ("/" if not s3_models_dir.endswith("/") else "")
s3_models_dir = s3_models_dir.rstrip("/") + "/"

if role is None:
try:
Expand Down
14 changes: 7 additions & 7 deletions notebooks/Deploy classification model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
"source": [
"region = boto3.Session().region_name\n",
"\n",
"cohere_package = \"classification-finetuning-91717d633bb2357ba721ffc4ba2fe75c\"\n",
"cohere_package = \"classification-finetuning-v2-67831404bb66304b87a66a3a27c979ac\"\n",
"algorithm_map = {\n",
" \"us-east-1\": f\"arn:aws:sagemaker:us-east-1:865070037744:algorithm/{cohere_package}\",\n",
" \"us-east-2\": f\"arn:aws:sagemaker:us-east-2:057799348421:algorithm/{cohere_package}\",\n",
Expand Down Expand Up @@ -130,7 +130,7 @@
"metadata": {},
"outputs": [],
"source": [
"s3_data_dir = \"s3://...\" # Do not add a trailing slash otherwise the upload will not work"
"s3_data_dir = \"s3/...\" # where to upload the data"
]
},
{
Expand All @@ -148,8 +148,8 @@
"outputs": [],
"source": [
"sess = sage.Session()\n",
"train_dataset1 = S3Uploader.upload(\"../examples/sample_sentiment_classification_data.jsonl\", s3_data_dir, sagemaker_session=sess)\n",
"train_dataset2 = S3Uploader.upload(\"../examples/sample_multilabel_classification_data.jsonl\", s3_data_dir, sagemaker_session=sess)"
"train_dataset1 = S3Uploader.upload(\"../examples/sample_sentiment_classification_data.jsonl\", s3_data_dir.rstrip(\"/\"), sagemaker_session=sess)\n",
"train_dataset2 = S3Uploader.upload(\"../examples/sample_multilabel_classification_data.jsonl\", s3_data_dir.rstrip(\"/\"), sagemaker_session=sess)"
]
},
{
Expand All @@ -168,7 +168,7 @@
"metadata": {},
"outputs": [],
"source": [
"s3_models_dir = \"s3://...\" # Do not add a trailing slash otherwise it will not work"
"s3_models_dir = \"s3/...\" # where the models will be saved "
]
},
{
Expand Down Expand Up @@ -300,8 +300,8 @@
"metadata": {},
"outputs": [],
"source": [
"co.delete_endpoint()\n",
"co.close()"
"# co.delete_endpoint()\n",
"# co.close()"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def has_ext_modules(foo) -> bool:


setuptools.setup(name='cohere-sagemaker',
version='0.7.2',
version='0.8.0',
author='Cohere',
author_email='support@cohere.ai',
description='A Python library for the Cohere endpoints in AWS Sagemaker',
Expand Down

0 comments on commit 875fb7d

Please sign in to comment.