@@ -128,47 +128,49 @@ def generate_padded_mdp(self, outside_information={}):
128
128
Return a PADDED MDP with mdp params specified in self.mdp_params
129
129
"""
130
130
mdp_gen_params = self .mdp_params_generator .generate (outside_information )
131
+
131
132
outer_shape = self .outer_shape
132
133
if "layout_name" in mdp_gen_params .keys () and mdp_gen_params ["layout_name" ] is not None :
133
134
mdp = OvercookedGridworld .from_layout_name (** mdp_gen_params )
134
135
mdp_generator_fn = lambda : self .padded_mdp (mdp )
135
136
else :
136
-
137
137
required_keys = ["inner_shape" , "prop_empty" , "prop_feats" , "display" ]
138
+ # with generate_all_orders key start_all_orders will be generated inside make_new_layout method
139
+ if not mdp_gen_params .get ("generate_all_orders" ):
140
+ required_keys .append ("start_all_orders" )
138
141
missing_keys = [k for k in required_keys if k not in mdp_gen_params .keys ()]
142
+ if len (missing_keys ) != 0 :
143
+ print ("missing keys dict" , mdp_gen_params )
139
144
assert len (missing_keys ) == 0 , "These keys were missing from the mdp_params: {}" .format (missing_keys )
140
145
inner_shape = mdp_gen_params ["inner_shape" ]
141
146
assert inner_shape [0 ] <= outer_shape [0 ] and inner_shape [1 ] <= outer_shape [1 ], \
142
147
"inner_shape cannot fit into the outershap"
143
148
layout_generator = LayoutGenerator (self .mdp_params_generator , outer_shape = self .outer_shape )
144
149
145
- if "start_all_orders" in mdp_gen_params :
146
- recipe_params = {"start_all_orders" : mdp_gen_params ["start_all_orders" ]}
147
- if "recipe_values" in mdp_gen_params :
148
- recipe_params ["recipe_values" ] = mdp_gen_params ["recipe_values" ]
149
- if "recipe_times" in mdp_gen_params :
150
- recipe_params ["recipe_times" ] = mdp_gen_params ["recipe_times" ]
151
- else :
152
- recipe_params = LayoutGenerator .add_generated_mdp_params_orders (self .mdp_params )
153
-
154
150
if "feature_types" not in mdp_gen_params :
155
151
mdp_gen_params ["feature_types" ] = DEFAULT_FEATURE_TYPES
156
152
157
- mdp_generator_fn = lambda : layout_generator .make_disjoint_sets_layout (
158
- inner_shape = mdp_gen_params ["inner_shape" ],
159
- prop_empty = mdp_gen_params ["prop_empty" ],
160
- prop_features = mdp_gen_params ["prop_feats" ],
161
- base_param = recipe_params ,
162
- feature_types = mdp_gen_params ["feature_types" ],
163
- display = mdp_gen_params ["display" ]
164
- )
165
-
153
+ mdp_generator_fn = lambda : layout_generator .make_new_layout (mdp_gen_params )
166
154
return mdp_generator_fn ()
167
-
155
+
156
+ @staticmethod
157
+ def create_base_params (mdp_gen_params ):
158
+ assert mdp_gen_params .get ("start_all_orders" ) or mdp_gen_params .get ("generate_all_orders" )
159
+ mdp_gen_params = LayoutGenerator .add_generated_mdp_params_orders (mdp_gen_params )
160
+ recipe_params = {"start_all_orders" : mdp_gen_params ["start_all_orders" ]}
161
+ if mdp_gen_params .get ("start_bonus_orders" ):
162
+ recipe_params ["start_bonus_orders" ] = mdp_gen_params ["start_bonus_orders" ]
163
+ if "recipe_values" in mdp_gen_params :
164
+ recipe_params ["recipe_values" ] = mdp_gen_params ["recipe_values" ]
165
+ if "recipe_times" in mdp_gen_params :
166
+ recipe_params ["recipe_times" ] = mdp_gen_params ["recipe_times" ]
167
+ return recipe_params
168
+
168
169
@staticmethod
169
170
def add_generated_mdp_params_orders (mdp_params ):
170
171
"""
171
- adds generated parameters (i.e. generated orders) to mdp_params
172
+ adds generated parameters (i.e. generated orders) to mdp_params,
173
+ returns onchanged copy of mdp_params when there is no "generate_all_orders" and "generate_bonus_orders" keys inside mdp_params
172
174
"""
173
175
mdp_params = copy .deepcopy (mdp_params )
174
176
if mdp_params .get ("generate_all_orders" ):
@@ -199,10 +201,18 @@ def padded_mdp(self, mdp, display=False):
199
201
200
202
start_positions = self .get_random_starting_positions (padded_grid )
201
203
mdp_grid = self .padded_grid_to_layout_grid (padded_grid , start_positions , display = display )
204
+ return OvercookedGridworld .from_grid (mdp_grid )
205
+
206
+ def make_new_layout (self , mdp_gen_params ):
207
+ return self .make_disjoint_sets_layout (
208
+ inner_shape = mdp_gen_params ["inner_shape" ],
209
+ prop_empty = mdp_gen_params ["prop_empty" ],
210
+ prop_features = mdp_gen_params ["prop_feats" ],
211
+ base_param = LayoutGenerator .create_base_params (mdp_gen_params ),
212
+ feature_types = mdp_gen_params ["feature_types" ],
213
+ display = mdp_gen_params ["display" ]
214
+ )
202
215
203
- mdp_params = LayoutGenerator .add_generated_mdp_params_orders (self .mdp_params )
204
- return OvercookedGridworld .from_grid (mdp_grid , base_layout_params = mdp_params )
205
-
206
216
def make_disjoint_sets_layout (self , inner_shape , prop_empty , prop_features , base_param , feature_types = DEFAULT_FEATURE_TYPES , display = True ):
207
217
grid = Grid (inner_shape )
208
218
self .dig_space_with_disjoint_sets (grid , prop_empty )
0 commit comments