From f1a8c8a2a43fda88163908f98a0453a70817a505 Mon Sep 17 00:00:00 2001 From: rickard Date: Tue, 10 Sep 2024 21:28:24 +0200 Subject: [PATCH] add some checks for empty responses and tests --- tale/llm/responses/WorldCreaturesResponse.py | 2 +- tale/llm/responses/WorldItemsResponse.py | 2 +- tale/llm/world_building.py | 4 +++ tests/test_llm_utils.py | 29 ++++++++++++++++++++ 4 files changed, 35 insertions(+), 2 deletions(-) diff --git a/tale/llm/responses/WorldCreaturesResponse.py b/tale/llm/responses/WorldCreaturesResponse.py index e45cc092..838b7e78 100644 --- a/tale/llm/responses/WorldCreaturesResponse.py +++ b/tale/llm/responses/WorldCreaturesResponse.py @@ -3,5 +3,5 @@ class WorldCreaturesResponse(): def __init__(self, response: dict = {}): - self.creatures = response["creatures"] + self.creatures = response.get("creatures", []) self.valid = len(self.creatures) > 0 \ No newline at end of file diff --git a/tale/llm/responses/WorldItemsResponse.py b/tale/llm/responses/WorldItemsResponse.py index 2ad04735..c32a4e9e 100644 --- a/tale/llm/responses/WorldItemsResponse.py +++ b/tale/llm/responses/WorldItemsResponse.py @@ -3,5 +3,5 @@ class WorldItemsResponse(): def __init__(self, response: dict = {}): - self.items = response["items"] + self.items = response.get("items", []) self.valid = len(self.items) > 0 \ No newline at end of file diff --git a/tale/llm/world_building.py b/tale/llm/world_building.py index 5efcef95..002ff728 100644 --- a/tale/llm/world_building.py +++ b/tale/llm/world_building.py @@ -186,8 +186,12 @@ def generate_start_location(self, location: Location, zone_info: dict, story_typ if self.json_grammar_key: request_body[self.json_grammar_key] = self.json_grammar result = self.io_util.synchronous_request(request_body, prompt=prompt) + if not result: + return LocationResponse.empty() try: json_result = json.loads(parse_utils.sanitize_json(result)) + if not json_result.get('name', None): + return LocationResponse.empty() location.name=json_result['name'] return LocationResponse(json_result=json_result, location=location, exit_location_name='', item_types=self.item_types) except Exception as exc: diff --git a/tests/test_llm_utils.py b/tests/test_llm_utils.py index 16b36c46..6eac792b 100644 --- a/tests/test_llm_utils.py +++ b/tests/test_llm_utils.py @@ -251,6 +251,12 @@ def test_generate_world_items(self): shield = result.items[1] assert(shield['name'] == 'shield') + self.llm_util._world_building.io_util.response = '' + result = self.llm_util._world_building.generate_world_items(world_generation_context=WorldGenerationContext(story_context='',story_type='',world_info='',world_mood=0)) + assert not result.valid + assert(len(result.items) == 0) + + def test_generate_world_creatures(self): # mostly for coverage self.llm_util._world_building.io_util.response = '{"creatures":[{"name": "dragon", "body": "Creature", "unarmed_attack": "BITE", "hp":100, "level":10}]}' @@ -263,6 +269,10 @@ def test_generate_world_creatures(self): assert(dragon["level"] == 10) assert(dragon["unarmed_attack"] == UnarmedAttack.BITE.name) + self.llm_util._world_building.io_util.response = '' + result = self.llm_util._world_building.generate_world_creatures(world_generation_context=WorldGenerationContext(story_context='',story_type='',world_info='',world_mood=0)) + assert not result.valid + assert(len(result.creatures) == 0) def test_get_neighbor_or_generate_zone(self): """ Tests the get_neighbor_or_generate_zone method of llm_utils. @@ -472,6 +482,25 @@ def test_issue_overwriting_exits(self): assert((len(rocky_cliffs.exits) == 6)) assert((len(location_response.new_locations) == 2)) + def test_generate_location_empty_response(self): + self.llm_util._world_building.io_util.response='' + location = Location(name='Outside') + result = self.llm_util.generate_start_location(location, + story_type='', + story_context='', + zone_info={}, + world_info='') + assert result.empty() + + self.llm_util._world_building.io_util.response='{}' + location = Location(name='Outside') + result = self.llm_util.generate_start_location(location, + story_type='', + story_context='', + zone_info={}, + world_info='') + assert result.empty() + def test_generate_note_lore(self): self.llm_util._quest_building.io_util.response = 'A long lost tale of a hero who saved the world from a great evil.' world_generation_context = WorldGenerationContext(story_context=self.story.config.context, story_type=self.story.config.type, world_info='', world_mood=0)