Skip to content

Commit af85350

Browse files
committed
fix tests and layout_generator after merge
1 parent 2d130e9 commit af85350

File tree

2 files changed

+67
-39
lines changed

2 files changed

+67
-39
lines changed

src/overcooked_ai_py/mdp/layout_generator.py

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -128,47 +128,49 @@ def generate_padded_mdp(self, outside_information={}):
128128
Return a PADDED MDP with mdp params specified in self.mdp_params
129129
"""
130130
mdp_gen_params = self.mdp_params_generator.generate(outside_information)
131+
131132
outer_shape = self.outer_shape
132133
if "layout_name" in mdp_gen_params.keys() and mdp_gen_params["layout_name"] is not None:
133134
mdp = OvercookedGridworld.from_layout_name(**mdp_gen_params)
134135
mdp_generator_fn = lambda: self.padded_mdp(mdp)
135136
else:
136-
137137
required_keys = ["inner_shape", "prop_empty", "prop_feats", "display"]
138+
# with generate_all_orders key start_all_orders will be generated inside make_new_layout method
139+
if not mdp_gen_params.get("generate_all_orders"):
140+
required_keys.append("start_all_orders")
138141
missing_keys = [k for k in required_keys if k not in mdp_gen_params.keys()]
142+
if len(missing_keys) != 0:
143+
print("missing keys dict", mdp_gen_params)
139144
assert len(missing_keys) == 0, "These keys were missing from the mdp_params: {}".format(missing_keys)
140145
inner_shape = mdp_gen_params["inner_shape"]
141146
assert inner_shape[0] <= outer_shape[0] and inner_shape[1] <= outer_shape[1], \
142147
"inner_shape cannot fit into the outershap"
143148
layout_generator = LayoutGenerator(self.mdp_params_generator, outer_shape=self.outer_shape)
144149

145-
if "start_all_orders" in mdp_gen_params:
146-
recipe_params = {"start_all_orders": mdp_gen_params["start_all_orders"]}
147-
if "recipe_values" in mdp_gen_params:
148-
recipe_params["recipe_values"] = mdp_gen_params["recipe_values"]
149-
if "recipe_times" in mdp_gen_params:
150-
recipe_params["recipe_times"] = mdp_gen_params["recipe_times"]
151-
else:
152-
recipe_params = LayoutGenerator.add_generated_mdp_params_orders(self.mdp_params)
153-
154150
if "feature_types" not in mdp_gen_params:
155151
mdp_gen_params["feature_types"] = DEFAULT_FEATURE_TYPES
156152

157-
mdp_generator_fn = lambda: layout_generator.make_disjoint_sets_layout(
158-
inner_shape=mdp_gen_params["inner_shape"],
159-
prop_empty=mdp_gen_params["prop_empty"],
160-
prop_features=mdp_gen_params["prop_feats"],
161-
base_param=recipe_params,
162-
feature_types=mdp_gen_params["feature_types"],
163-
display=mdp_gen_params["display"]
164-
)
165-
153+
mdp_generator_fn = lambda: layout_generator.make_new_layout(mdp_gen_params)
166154
return mdp_generator_fn()
167-
155+
156+
@staticmethod
157+
def create_base_params(mdp_gen_params):
158+
assert mdp_gen_params.get("start_all_orders") or mdp_gen_params.get("generate_all_orders")
159+
mdp_gen_params = LayoutGenerator.add_generated_mdp_params_orders(mdp_gen_params)
160+
recipe_params = {"start_all_orders": mdp_gen_params["start_all_orders"]}
161+
if mdp_gen_params.get("start_bonus_orders"):
162+
recipe_params["start_bonus_orders"] = mdp_gen_params["start_bonus_orders"]
163+
if "recipe_values" in mdp_gen_params:
164+
recipe_params["recipe_values"] = mdp_gen_params["recipe_values"]
165+
if "recipe_times" in mdp_gen_params:
166+
recipe_params["recipe_times"] = mdp_gen_params["recipe_times"]
167+
return recipe_params
168+
168169
@staticmethod
169170
def add_generated_mdp_params_orders(mdp_params):
170171
"""
171-
adds generated parameters (i.e. generated orders) to mdp_params
172+
adds generated parameters (i.e. generated orders) to mdp_params,
173+
returns onchanged copy of mdp_params when there is no "generate_all_orders" and "generate_bonus_orders" keys inside mdp_params
172174
"""
173175
mdp_params = copy.deepcopy(mdp_params)
174176
if mdp_params.get("generate_all_orders"):
@@ -199,10 +201,18 @@ def padded_mdp(self, mdp, display=False):
199201

200202
start_positions = self.get_random_starting_positions(padded_grid)
201203
mdp_grid = self.padded_grid_to_layout_grid(padded_grid, start_positions, display=display)
204+
return OvercookedGridworld.from_grid(mdp_grid)
205+
206+
def make_new_layout(self, mdp_gen_params):
207+
return self.make_disjoint_sets_layout(
208+
inner_shape=mdp_gen_params["inner_shape"],
209+
prop_empty=mdp_gen_params["prop_empty"],
210+
prop_features=mdp_gen_params["prop_feats"],
211+
base_param=LayoutGenerator.create_base_params(mdp_gen_params),
212+
feature_types=mdp_gen_params["feature_types"],
213+
display=mdp_gen_params["display"]
214+
)
202215

203-
mdp_params = LayoutGenerator.add_generated_mdp_params_orders(self.mdp_params)
204-
return OvercookedGridworld.from_grid(mdp_grid, base_layout_params=mdp_params)
205-
206216
def make_disjoint_sets_layout(self, inner_shape, prop_empty, prop_features, base_param, feature_types=DEFAULT_FEATURE_TYPES, display=True):
207217
grid = Grid(inner_shape)
208218
self.dig_space_with_disjoint_sets(grid, prop_empty)

testing/overcooked_test.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -899,9 +899,15 @@ def test_random_layout_feature_types(self):
899899
for optional_features_combo in optional_features_combinations:
900900
left_out_optional_features = optional_features - optional_features_combo
901901
used_features = list(optional_features_combo | mandatory_features)
902-
mdp_gen_params = {"prop_feats": (1, 1),
903-
"feature_types": used_features}
904-
mdp_fn = LayoutGenerator.mdp_gen_fn_from_dict(**mdp_gen_params)
902+
mdp_gen_params = {"prop_feats": 0.9,
903+
"feature_types": used_features,
904+
"prop_empty": 0.1,
905+
"inner_shape": (6, 5),
906+
"display": False,
907+
"start_all_orders" : [
908+
{ "ingredients" : ["onion", "onion", "onion"]}
909+
]}
910+
mdp_fn = LayoutGenerator.mdp_gen_fn_from_dict(mdp_gen_params, outer_shape=(6, 5))
905911
env = OvercookedEnv(mdp_fn, **DEFAULT_ENV_PARAMS)
906912
for _ in range(10):
907913
env.reset()
@@ -916,31 +922,43 @@ def test_random_layout_generated_recipes(self):
916922
only_onions_dict_recipes = [r.to_dict() for r in only_onions_recipes]
917923

918924
# checking if recipes are generated from mdp_params
919-
mdp_params = {"generate_all_orders": {"n":2, "ingredients": ["onion"], "min_size":2, "max_size":3}}
920-
mdp_gen_params = {"mdp_params": mdp_params}
921-
mdp_fn = LayoutGenerator.mdp_gen_fn_from_dict(**mdp_gen_params)
925+
mdp_gen_params = {"generate_all_orders": {"n":2, "ingredients": ["onion"], "min_size":2, "max_size":3},
926+
"prop_feats": 0.9,
927+
"prop_empty": 0.1,
928+
"inner_shape": (6, 5),
929+
"display": False}
930+
mdp_fn = LayoutGenerator.mdp_gen_fn_from_dict(mdp_gen_params, outer_shape=(6, 5))
922931
env = OvercookedEnv(mdp_fn, **DEFAULT_ENV_PARAMS)
923932
for _ in range(10):
924933
env.reset()
925934
self.assertCountEqual(env.mdp.start_all_orders, only_onions_dict_recipes)
926-
self.assertTrue(len(env.mdp.start_bonus_orders) == 0)
935+
self.assertEqual(len(env.mdp.start_bonus_orders), 0)
927936

928937
# checking if bonus_orders is subset of all_orders even if not specified
929-
mdp_params = {"generate_all_orders": {"n":2, "ingredients": ["onion"], "min_size":2, "max_size":3},
930-
"generate_bonus_orders": {"n":1, "min_size":2, "max_size":3}}
931-
mdp_gen_params = {"mdp_params": mdp_params}
932-
mdp_fn = LayoutGenerator.mdp_gen_fn_from_dict(**mdp_gen_params)
938+
939+
mdp_gen_params = {"generate_all_orders": {"n":2, "ingredients": ["onion"], "min_size":2, "max_size":3},
940+
"generate_bonus_orders": {"n":1, "min_size":2, "max_size":3},
941+
"prop_feats": 0.9,
942+
"prop_empty": 0.1,
943+
"inner_shape": (6, 5),
944+
"display": False}
945+
mdp_fn = LayoutGenerator.mdp_gen_fn_from_dict(mdp_gen_params, outer_shape=(6,5))
933946
env = OvercookedEnv(mdp_fn, **DEFAULT_ENV_PARAMS)
934947
for _ in range(10):
935948
env.reset()
936949
self.assertCountEqual(env.mdp.start_all_orders, only_onions_dict_recipes)
937-
self.assertTrue(len(env.mdp.start_bonus_orders) == 1)
950+
self.assertEqual(len(env.mdp.start_bonus_orders), 1)
938951
self.assertTrue(env.mdp.start_bonus_orders[0] in only_onions_dict_recipes)
939952

940953
# checking if after reset there are new recipes generated
941-
mdp_params = {"generate_all_orders": {"n":3, "min_size":2, "max_size":3}}
942-
mdp_gen_params = {"mdp_params": mdp_params}
943-
mdp_fn = LayoutGenerator.mdp_gen_fn_from_dict(**mdp_gen_params)
954+
mdp_gen_params = {"generate_all_orders": {"n":3, "min_size":2, "max_size":3},
955+
"prop_feats": 0.9,
956+
"prop_empty": 0.1,
957+
"inner_shape": (6, 5),
958+
"display": False,
959+
"feature_types": [POT, DISH_DISPENSER, SERVING_LOC, ONION_DISPENSER, TOMATO_DISPENSER]
960+
}
961+
mdp_fn = LayoutGenerator.mdp_gen_fn_from_dict(mdp_gen_params, outer_shape=(6,5))
944962
env = OvercookedEnv(mdp_fn, **DEFAULT_ENV_PARAMS)
945963
generated_recipes_strings = set()
946964
for _ in range(20):

0 commit comments

Comments
 (0)