Skip to content

Commit

Permalink
Fix tests (#63)
Browse files Browse the repository at this point in the history
* fix tests & stub out service calls

* tidying

* fix to match latest version, mostly

* stub out remaining errors
  • Loading branch information
leilacc authored Sep 26, 2023
1 parent fa32411 commit eefda69
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 24 deletions.
2 changes: 1 addition & 1 deletion cohere_sagemaker/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,10 @@ def generate(
'temperature': temperature,
'k': k,
'p': p,
'num_generations': num_generations,
'stop_sequences': stop_sequences,
'return_likelihoods': return_likelihoods,
'truncate': truncate,
'num_generations': num_generations,
}
for key, value in list(json_params.items()):
if value is None:
Expand Down
131 changes: 108 additions & 23 deletions tests/client_test.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,148 @@
import unittest
import json

from cohere_sagemaker import Client, CohereError
from botocore.stub import Stubber
from botocore.response import StreamingBody
from io import BytesIO
from typing import Dict, Optional, Any


class TestClient(unittest.TestCase):
ENDPOINT_NAME = 'cohere-gpt-medium'
PROMPT = "Hello world!"
TEXT = "Mock generation #1"
TEXT2 = "Mock generation #2"

def setUp(self):
self.client = Client(endpoint_name='cohere-gpt-medium', region_name='us-east-1')
self.client = Client(endpoint_name=self.ENDPOINT_NAME,
region_name='us-west-2')
self.default_request_params = {"prompt": self.PROMPT,
"max_tokens": 20,
"temperature": 1.0,
"k": 0,
"p": 0.75,
"num_generations": 1}
super().setUp()

def tearDown(self):
self.client.close()

def stub(self, expected_params, generations_text):
stubber = Stubber(self.client._client)
generations = []
for text in generations_text:
generations.append(f'{{"text": "{text}"}}')
generations = ', '.join(generations)
b = f'{{"generations": [{generations}]}}'.encode()
mock_response = {"Body": StreamingBody(BytesIO(b), len(b))}
stubber.add_response('invoke_endpoint', mock_response, expected_params)
stubber.activate()

def stub_err(self, service_message, http_status_code=400):
stubber = Stubber(self.client._client)
stubber.add_client_error('invoke_endpoint',
service_message=service_message,
http_status_code=http_status_code)
stubber.activate()

def expected_params(self,
custom_request_params: Optional[Dict[str, Any]] = {},
custom_http_params: Optional[Dict[str, Any]] = {}) -> Dict:
# Optionally override the default parameters with custom ones
for k, v in custom_request_params.items():
self.default_request_params[k] = v
default_http_params = {
'Body': f'{json.dumps(self.default_request_params)}',
'ContentType': 'application/json',
'EndpointName': self.ENDPOINT_NAME}
for k, v in custom_http_params.items():
default_http_params[k] = v
return default_http_params

# TODO unauthorized

def test_generate(self):
# temperature = 0.0 makes output deterministic
response = self.client.generate("Hello world!", temperature=0.0)
self.assertEqual(response.generations[0].text, " I'm a newbie to the forum and I'm looking for some help. I have a new")
def test_generate_defaults(self):
self.stub(self.expected_params(), [self.TEXT])
response = self.client.generate(self.PROMPT)
self.assertEqual(len(response.generations), 1)
self.assertEqual(response.generations[0].text, self.TEXT)

def test_variant(self):
response = self.client.generate("Hello world!", temperature=0.0, variant="AllTraffic")
self.assertEqual(response.generations[0].text, " I'm a newbie to the forum and I'm looking for some help. I have a new")

def test_max_tokens(self):
response = self.client.generate("Hello world!", temperature=0.0, max_tokens=40)
self.assertEqual(response.generations[0].text, " I'm a newbie to the forum and I'm looking for some help. I have a new build PC and I'm trying to get it to work with my TV. I have a Samsung UE")
self.stub(self.expected_params(
custom_http_params={"TargetVariant": "AllTraffic"}), [self.TEXT])
response = self.client.generate(self.PROMPT, variant="AllTraffic")
self.assertEqual(len(response.generations), 1)
self.assertEqual(response.generations[0].text, self.TEXT)

def test_override_defaults(self):
self.stub(self.expected_params(
custom_request_params={"temperature": 0.0,
"max_tokens": 40,
"k": 1,
"p": 0.5,
"stop_sequences": ["."],
"return_likelihoods": "likelihood",
"truncate": "LEFT"}),
[self.TEXT])
response = self.client.generate(self.PROMPT,
temperature=0.0,
max_tokens=40,
k=1,
p=0.5,
stop_sequences=["."],
return_likelihoods="likelihood",
truncate="LEFT")
self.assertEqual(len(response.generations), 1)
self.assertEqual(response.generations[0].text, self.TEXT)

def test_two_generations(self):
num_generations = 2
self.stub(self.expected_params(custom_request_params={
"num_generations": num_generations}), [self.TEXT, self.TEXT2])
response = self.client.generate(self.PROMPT,
num_generations=num_generations)
self.assertEqual(len(response.generations), num_generations)
self.assertEqual(response.generations[0].text, self.TEXT)
self.assertEqual(response.generations[1].text, self.TEXT2)

def test_bad_region(self):
client = Client(endpoint_name='cohere-gpt-medium', region_name='invalid-region')
expected_err = "Could not connect to the endpoint URL"
self.stub_err(expected_err)
try:
client.generate("Hello world!")
self.client.generate(self.PROMPT)
self.fail("expected error")
except CohereError as e:
self.assertIn("Could not connect to the endpoint URL", str(e.message))
client.close()
self.assertIn(expected_err,
str(e.message))

def test_wrong_region(self):
client = Client(endpoint_name='cohere-gpt-medium', region_name='us-east-2')
expected_err = ("Endpoint cohere-gpt-medium of account 455073351313 "
"not found.")
self.stub_err(expected_err)
try:
client.generate("Hello world!")
self.client.generate(self.PROMPT)
self.fail("expected error")
except CohereError as e:
self.assertIn("Endpoint cohere-gpt-medium of account 455073351313 not found.", str(e.message))
client.close()
self.assertIn(expected_err, str(e.message))

def test_bad_variant(self):
expected_err = "Variant invalid-variant not found for Request"
self.stub_err(expected_err)
try:
self.client.generate("Hello world!", variant="invalid-variant")
self.client.generate(self.PROMPT)
self.fail("expected error")
except CohereError as e:
self.assertIn("Variant invalid-variant not found for Request", str(e.message))
self.assertIn(expected_err, str(e.message))

def test_client_not_connected(self):
client = Client(region_name='us-east-1')
client = Client()
try:
client.generate("Hello world!")
client.generate(self.PROMPT)
self.fail("expected error")
except CohereError as e:
self.assertIn("No endpoint connected", str(e.message))


if __name__ == '__main__':
unittest.main()

0 comments on commit eefda69

Please sign in to comment.