6
6
import tensorflow as tf
7
7
from ray .rllib .policy import Policy as RllibPolicy
8
8
from tensorflow import keras
9
- from tensorflow .compat .v1 .keras .backend import get_session , set_session
9
+ from tensorflow .compat .v1 .keras .backend import get_session
10
10
11
11
from human_aware_rl .data_dir import DATA_DIR
12
- from human_aware_rl .human .process_dataframes import (
13
- get_human_human_trajectories ,
14
- get_trajs_from_data ,
15
- )
16
- from human_aware_rl .rllib .rllib import (
17
- RlLibAgent ,
18
- evaluate ,
19
- get_base_ae ,
20
- softmax ,
21
- )
12
+ from human_aware_rl .human .process_dataframes import get_human_human_trajectories
13
+ from human_aware_rl .rllib .rllib import evaluate , get_base_ae , softmax
22
14
from human_aware_rl .static import CLEAN_2019_HUMAN_DATA_TRAIN
23
15
from human_aware_rl .utils import get_flattened_keys , recursive_dict_update
24
16
from overcooked_ai_py .mdp .actions import Action
@@ -119,9 +111,7 @@ def get_bc_params(**args_to_override):
119
111
120
112
all_keys = get_flattened_keys (params )
121
113
if len (all_keys ) != len (set (all_keys )):
122
- raise ValueError (
123
- "Every key at every level must be distict for BC params!"
124
- )
114
+ raise ValueError ("Every key at every level must be distict for BC params!" )
125
115
126
116
return params
127
117
@@ -198,9 +188,7 @@ def train_bc_model(model_dir, bc_params, verbose=False):
198
188
class_weights = None
199
189
200
190
# Retrieve un-initialized keras model
201
- model = build_bc_model (
202
- ** bc_params , max_seq_len = np .max (seq_lens ), verbose = verbose
203
- )
191
+ model = build_bc_model (** bc_params , max_seq_len = np .max (seq_lens ), verbose = verbose )
204
192
205
193
# Initialize the model
206
194
# Note: have to use lists for multi-output model support and not dicts because of tensorlfow 2.0.0 bug
@@ -225,9 +213,7 @@ def train_bc_model(model_dir, bc_params, verbose=False):
225
213
# Early terminate training if loss doesn't improve for "patience" epochs
226
214
keras .callbacks .EarlyStopping (monitor = "loss" , patience = 20 ),
227
215
# Reduce lr by "factor" after "patience" epochs of no improvement in loss
228
- keras .callbacks .ReduceLROnPlateau (
229
- monitor = "loss" , patience = 3 , factor = 0.1
230
- ),
216
+ keras .callbacks .ReduceLROnPlateau (monitor = "loss" , patience = 3 , factor = 0.1 ),
231
217
# Log all metrics model was compiled with to tensorboard every epoch
232
218
keras .callbacks .TensorBoard (
233
219
log_dir = os .path .join (model_dir , "logs" ), write_graph = False
@@ -329,12 +315,8 @@ def featurize_fn(state):
329
315
return base_env .featurize_state_mdp (state )
330
316
331
317
# Wrap Keras models in rllib policies
332
- agent_0_policy = BehaviorCloningPolicy .from_model (
333
- model , bc_params , stochastic = True
334
- )
335
- agent_1_policy = BehaviorCloningPolicy .from_model (
336
- model , bc_params , stochastic = True
337
- )
318
+ agent_0_policy = BehaviorCloningPolicy .from_model (model , bc_params , stochastic = True )
319
+ agent_1_policy = BehaviorCloningPolicy .from_model (model , bc_params , stochastic = True )
338
320
339
321
# Compute the results of the rollout(s)
340
322
results = evaluate (
@@ -355,21 +337,17 @@ def featurize_fn(state):
355
337
356
338
def _build_model (observation_shape , action_shape , mlp_params , ** kwargs ):
357
339
## Inputs
358
- inputs = keras .Input (
359
- shape = observation_shape , name = "Overcooked_observation"
360
- )
340
+ inputs = keras .Input (shape = observation_shape , name = "Overcooked_observation" )
361
341
x = inputs
362
342
363
343
## Build fully connected layers
364
- assert (
365
- len ( mlp_params [ "net_arch" ]) == mlp_params [ "num_layers" ]
366
- ), "Invalid Fully Connected params"
344
+ assert len ( mlp_params [ "net_arch" ]) == mlp_params [ "num_layers" ], (
345
+ "Invalid Fully Connected params"
346
+ )
367
347
368
348
for i in range (mlp_params ["num_layers" ]):
369
349
units = mlp_params ["net_arch" ][i ]
370
- x = keras .layers .Dense (
371
- units , activation = "relu" , name = "fc_{0}" .format (i )
372
- )(x )
350
+ x = keras .layers .Dense (units , activation = "relu" , name = "fc_{0}" .format (i ))(x )
373
351
374
352
## output layer
375
353
logits = keras .layers .Dense (action_shape [0 ], name = "logits" )(x )
@@ -378,12 +356,7 @@ def _build_model(observation_shape, action_shape, mlp_params, **kwargs):
378
356
379
357
380
358
def _build_lstm_model (
381
- observation_shape ,
382
- action_shape ,
383
- mlp_params ,
384
- cell_size ,
385
- max_seq_len = 20 ,
386
- ** kwargs
359
+ observation_shape , action_shape , mlp_params , cell_size , max_seq_len = 20 , ** kwargs
387
360
):
388
361
## Inputs
389
362
obs_in = keras .Input (
@@ -395,21 +368,19 @@ def _build_lstm_model(
395
368
x = obs_in
396
369
397
370
## Build fully connected layers
398
- assert (
399
- len ( mlp_params [ "net_arch" ]) == mlp_params [ "num_layers" ]
400
- ), "Invalid Fully Connected params"
371
+ assert len ( mlp_params [ "net_arch" ]) == mlp_params [ "num_layers" ], (
372
+ "Invalid Fully Connected params"
373
+ )
401
374
402
375
for i in range (mlp_params ["num_layers" ]):
403
376
units = mlp_params ["net_arch" ][i ]
404
377
x = keras .layers .TimeDistributed (
405
- keras .layers .Dense (
406
- units , activation = "relu" , name = "fc_{0}" .format (i )
407
- )
378
+ keras .layers .Dense (units , activation = "relu" , name = "fc_{0}" .format (i ))
408
379
)(x )
409
380
410
- mask = keras .layers .Lambda (
411
- lambda x : tf . sequence_mask ( x , maxlen = max_seq_len )
412
- )( seq_in )
381
+ mask = keras .layers .Lambda (lambda x : tf . sequence_mask ( x , maxlen = max_seq_len ))(
382
+ seq_in
383
+ )
413
384
414
385
## LSTM layer
415
386
lstm_out , h_out , c_out = keras .layers .LSTM (
@@ -488,17 +459,15 @@ def __init__(self, observation_space, action_space, config):
488
459
)
489
460
490
461
if "bc_model" in config and config ["bc_model" ]:
491
- assert (
492
- "bc_params" in config
493
- ), "must specify params in addition to model"
494
- assert issubclass (
495
- type (config ["bc_model" ]), keras .Model
496
- ), "model must be of type keras.Model"
462
+ assert "bc_params" in config , "must specify params in addition to model"
463
+ assert issubclass (type (config ["bc_model" ]), keras .Model ), (
464
+ "model must be of type keras.Model"
465
+ )
497
466
model , bc_params = config ["bc_model" ], config ["bc_params" ]
498
467
else :
499
- assert (
500
- "model_dir" in config
501
- ), "must specify model directory if model not specified"
468
+ assert "model_dir" in config , (
469
+ "must specify model directory if model not specified"
470
+ )
502
471
model , bc_params = load_bc_model (config ["model_dir" ])
503
472
504
473
# Save the session that the model was loaded into so it is available at inference time if necessary
@@ -513,9 +482,7 @@ def __init__(self, observation_space, action_space, config):
513
482
self .stochastic = config ["stochastic" ]
514
483
self .use_lstm = bc_params ["use_lstm" ]
515
484
self .cell_size = bc_params ["cell_size" ]
516
- self .eager = (
517
- config ["eager" ] if "eager" in config else bc_params ["eager" ]
518
- )
485
+ self .eager = config ["eager" ] if "eager" in config else bc_params ["eager" ]
519
486
self .context = self ._create_execution_context ()
520
487
521
488
def _setup_shapes (self ):
@@ -540,9 +507,7 @@ def from_model_dir(cls, model_dir, stochastic=True):
540
507
"bc_params" : bc_params ,
541
508
"stochastic" : stochastic ,
542
509
}
543
- return cls (
544
- bc_params ["observation_shape" ], bc_params ["action_shape" ], config
545
- )
510
+ return cls (bc_params ["observation_shape" ], bc_params ["action_shape" ], config )
546
511
547
512
@classmethod
548
513
def from_model (cls , model , bc_params , stochastic = True ):
@@ -551,9 +516,7 @@ def from_model(cls, model, bc_params, stochastic=True):
551
516
"bc_params" : bc_params ,
552
517
"stochastic" : stochastic ,
553
518
}
554
- return cls (
555
- bc_params ["observation_shape" ], bc_params ["action_shape" ], config
556
- )
519
+ return cls (bc_params ["observation_shape" ], bc_params ["action_shape" ], config )
557
520
558
521
def compute_actions (
559
522
self ,
@@ -563,7 +526,7 @@ def compute_actions(
563
526
prev_reward_batch = None ,
564
527
info_batch = None ,
565
528
episodes = None ,
566
- ** kwargs
529
+ ** kwargs ,
567
530
):
568
531
"""
569
532
Computes sampled actions for each of the corresponding OvercookedEnv states in obs_batch
@@ -641,9 +604,7 @@ def _forward(self, obs_batch, state_batches):
641
604
if self .use_lstm :
642
605
obs_batch = np .expand_dims (obs_batch , 1 )
643
606
seq_lens = np .ones (len (obs_batch ))
644
- model_out = self .model .predict (
645
- [obs_batch , seq_lens ] + state_batches
646
- )
607
+ model_out = self .model .predict ([obs_batch , seq_lens ] + state_batches )
647
608
logits , states = model_out [0 ], model_out [1 :]
648
609
logits = logits .reshape ((logits .shape [0 ], - 1 ))
649
610
return logits , states
@@ -663,8 +624,6 @@ def _create_execution_context(self):
663
624
664
625
if __name__ == "__main__" :
665
626
params = get_bc_params ()
666
- model = train_bc_model (
667
- os .path .join (BC_SAVE_DIR , "default" ), params , verbose = True
668
- )
627
+ model = train_bc_model (os .path .join (BC_SAVE_DIR , "default" ), params , verbose = True )
669
628
# Evaluate our model's performance in a rollout
670
629
evaluate_bc_model (model , params )
0 commit comments