diff --git a/cohere_sagemaker/client.py b/cohere_sagemaker/client.py index 542702ccb..165ac2b78 100644 --- a/cohere_sagemaker/client.py +++ b/cohere_sagemaker/client.py @@ -218,10 +218,10 @@ def generate( 'temperature': temperature, 'k': k, 'p': p, + 'num_generations': num_generations, 'stop_sequences': stop_sequences, 'return_likelihoods': return_likelihoods, 'truncate': truncate, - 'num_generations': num_generations, } for key, value in list(json_params.items()): if value is None: diff --git a/tests/client_test.py b/tests/client_test.py index a2aaff130..1c82880f6 100644 --- a/tests/client_test.py +++ b/tests/client_test.py @@ -1,63 +1,148 @@ import unittest +import json from cohere_sagemaker import Client, CohereError +from botocore.stub import Stubber +from botocore.response import StreamingBody +from io import BytesIO +from typing import Dict, Optional, Any + class TestClient(unittest.TestCase): + ENDPOINT_NAME = 'cohere-gpt-medium' + PROMPT = "Hello world!" + TEXT = "Mock generation #1" + TEXT2 = "Mock generation #2" def setUp(self): - self.client = Client(endpoint_name='cohere-gpt-medium', region_name='us-east-1') + self.client = Client(endpoint_name=self.ENDPOINT_NAME, + region_name='us-west-2') + self.default_request_params = {"prompt": self.PROMPT, + "max_tokens": 20, + "temperature": 1.0, + "k": 0, + "p": 0.75, + "num_generations": 1} super().setUp() def tearDown(self): self.client.close() + def stub(self, expected_params, generations_text): + stubber = Stubber(self.client._client) + generations = [] + for text in generations_text: + generations.append(f'{{"text": "{text}"}}') + generations = ', '.join(generations) + b = f'{{"generations": [{generations}]}}'.encode() + mock_response = {"Body": StreamingBody(BytesIO(b), len(b))} + stubber.add_response('invoke_endpoint', mock_response, expected_params) + stubber.activate() + + def stub_err(self, service_message, http_status_code=400): + stubber = Stubber(self.client._client) + stubber.add_client_error('invoke_endpoint', + service_message=service_message, + http_status_code=http_status_code) + stubber.activate() + + def expected_params(self, + custom_request_params: Optional[Dict[str, Any]] = {}, + custom_http_params: Optional[Dict[str, Any]] = {}) -> Dict: + # Optionally override the default parameters with custom ones + for k, v in custom_request_params.items(): + self.default_request_params[k] = v + default_http_params = { + 'Body': f'{json.dumps(self.default_request_params)}', + 'ContentType': 'application/json', + 'EndpointName': self.ENDPOINT_NAME} + for k, v in custom_http_params.items(): + default_http_params[k] = v + return default_http_params + # TODO unauthorized - def test_generate(self): - # temperature = 0.0 makes output deterministic - response = self.client.generate("Hello world!", temperature=0.0) - self.assertEqual(response.generations[0].text, " I'm a newbie to the forum and I'm looking for some help. I have a new") + def test_generate_defaults(self): + self.stub(self.expected_params(), [self.TEXT]) + response = self.client.generate(self.PROMPT) + self.assertEqual(len(response.generations), 1) + self.assertEqual(response.generations[0].text, self.TEXT) def test_variant(self): - response = self.client.generate("Hello world!", temperature=0.0, variant="AllTraffic") - self.assertEqual(response.generations[0].text, " I'm a newbie to the forum and I'm looking for some help. I have a new") - - def test_max_tokens(self): - response = self.client.generate("Hello world!", temperature=0.0, max_tokens=40) - self.assertEqual(response.generations[0].text, " I'm a newbie to the forum and I'm looking for some help. I have a new build PC and I'm trying to get it to work with my TV. I have a Samsung UE") + self.stub(self.expected_params( + custom_http_params={"TargetVariant": "AllTraffic"}), [self.TEXT]) + response = self.client.generate(self.PROMPT, variant="AllTraffic") + self.assertEqual(len(response.generations), 1) + self.assertEqual(response.generations[0].text, self.TEXT) + + def test_override_defaults(self): + self.stub(self.expected_params( + custom_request_params={"temperature": 0.0, + "max_tokens": 40, + "k": 1, + "p": 0.5, + "stop_sequences": ["."], + "return_likelihoods": "likelihood", + "truncate": "LEFT"}), + [self.TEXT]) + response = self.client.generate(self.PROMPT, + temperature=0.0, + max_tokens=40, + k=1, + p=0.5, + stop_sequences=["."], + return_likelihoods="likelihood", + truncate="LEFT") + self.assertEqual(len(response.generations), 1) + self.assertEqual(response.generations[0].text, self.TEXT) + + def test_two_generations(self): + num_generations = 2 + self.stub(self.expected_params(custom_request_params={ + "num_generations": num_generations}), [self.TEXT, self.TEXT2]) + response = self.client.generate(self.PROMPT, + num_generations=num_generations) + self.assertEqual(len(response.generations), num_generations) + self.assertEqual(response.generations[0].text, self.TEXT) + self.assertEqual(response.generations[1].text, self.TEXT2) def test_bad_region(self): - client = Client(endpoint_name='cohere-gpt-medium', region_name='invalid-region') + expected_err = "Could not connect to the endpoint URL" + self.stub_err(expected_err) try: - client.generate("Hello world!") + self.client.generate(self.PROMPT) self.fail("expected error") except CohereError as e: - self.assertIn("Could not connect to the endpoint URL", str(e.message)) - client.close() + self.assertIn(expected_err, + str(e.message)) def test_wrong_region(self): - client = Client(endpoint_name='cohere-gpt-medium', region_name='us-east-2') + expected_err = ("Endpoint cohere-gpt-medium of account 455073351313 " + "not found.") + self.stub_err(expected_err) try: - client.generate("Hello world!") + self.client.generate(self.PROMPT) self.fail("expected error") except CohereError as e: - self.assertIn("Endpoint cohere-gpt-medium of account 455073351313 not found.", str(e.message)) - client.close() + self.assertIn(expected_err, str(e.message)) def test_bad_variant(self): + expected_err = "Variant invalid-variant not found for Request" + self.stub_err(expected_err) try: - self.client.generate("Hello world!", variant="invalid-variant") + self.client.generate(self.PROMPT) self.fail("expected error") except CohereError as e: - self.assertIn("Variant invalid-variant not found for Request", str(e.message)) + self.assertIn(expected_err, str(e.message)) def test_client_not_connected(self): - client = Client(region_name='us-east-1') + client = Client() try: - client.generate("Hello world!") + client.generate(self.PROMPT) self.fail("expected error") except CohereError as e: self.assertIn("No endpoint connected", str(e.message)) + if __name__ == '__main__': unittest.main()