6
6
import tensorflow as tf
7
7
from ray .rllib .policy import Policy as RllibPolicy
8
8
from tensorflow import keras
9
- from tensorflow .compat .v1 .keras .backend import get_session
10
9
11
10
from human_aware_rl .data_dir import DATA_DIR
12
11
from human_aware_rl .human .process_dataframes import get_human_human_trajectories
@@ -176,6 +175,9 @@ def build_bc_model(use_lstm=True, eager=False, **kwargs):
176
175
def train_bc_model (model_dir , bc_params , verbose = False ):
177
176
inputs , seq_lens , targets = load_data (bc_params , verbose )
178
177
178
+ # Ensure targets are int32 for SparseCategoricalCrossentropy
179
+ targets = tf .cast (targets , tf .int32 )
180
+
179
181
training_params = bc_params ["training_params" ]
180
182
181
183
if training_params ["use_class_weights" ]:
@@ -220,7 +222,7 @@ def train_bc_model(model_dir, bc_params, verbose=False):
220
222
),
221
223
# Save checkpoints of the models at the end of every epoch (saving only the best one so far)
222
224
keras .callbacks .ModelCheckpoint (
223
- filepath = os .path .join (model_dir , "checkpoints" ),
225
+ filepath = os .path .join (model_dir , "checkpoints" , "model.keras" ),
224
226
monitor = "loss" ,
225
227
save_best_only = True ,
226
228
),
@@ -235,15 +237,25 @@ def train_bc_model(model_dir, bc_params, verbose=False):
235
237
236
238
# Inputs unique to lstm model
237
239
if bc_params ["use_lstm" ]:
238
- inputs ["seq_in" ] = seq_lens
239
- inputs ["hidden_in" ] = np .zeros ((N , bc_params ["cell_size" ]))
240
- inputs ["memory_in" ] = np .zeros ((N , bc_params ["cell_size" ]))
240
+ inputs ["seq_in" ] = tf . cast ( seq_lens , tf . int32 )
241
+ inputs ["hidden_in" ] = tf .zeros ((N , bc_params ["cell_size" ]), dtype = tf . float32 )
242
+ inputs ["memory_in" ] = tf .zeros ((N , bc_params ["cell_size" ]), dtype = tf . float32 )
241
243
242
244
# Batch size doesn't include time dimension (seq_len) so it should be smaller for rnn model
243
245
batch_size = 1 if bc_params ["use_lstm" ] else training_params ["batch_size" ]
246
+ model_inputs = (
247
+ inputs
248
+ if not bc_params ["use_lstm" ]
249
+ else {
250
+ "Overcooked_observation" : inputs ,
251
+ "seq_in" : seq_lens ,
252
+ "hidden_in" : np .zeros ((N , bc_params ["cell_size" ])),
253
+ "memory_in" : np .zeros ((N , bc_params ["cell_size" ])),
254
+ }
255
+ )
244
256
model .fit (
245
- inputs ,
246
- targets ,
257
+ model_inputs ,
258
+ targets [ "logits" ] ,
247
259
callbacks = callbacks ,
248
260
batch_size = batch_size ,
249
261
epochs = training_params ["epochs" ],
@@ -260,18 +272,20 @@ def train_bc_model(model_dir, bc_params, verbose=False):
260
272
261
273
def save_bc_model (model_dir , model , bc_params , verbose = False ):
262
274
"""
263
- Saves the specified model under the directory model_dir. This creates three items
264
-
265
- assets/ stores information essential to reconstructing the context and tf graph
266
- variables/ stores the model's trainable weights
267
- saved_model.pd the saved state of the model object
275
+ Saves the specified model under the directory model_dir. This creates a .keras file
276
+ containing the model's architecture, weights, and optimizer state.
268
277
269
278
Additionally, saves a pickled dictionary containing all the parameters used to construct this model
270
279
at model_dir/metadata.pickle
271
280
"""
272
281
if verbose :
273
282
print ("Saving bc model at " , model_dir )
274
- model .save (model_dir , save_format = "tf" )
283
+
284
+ # Save model with .keras extension
285
+ model_path = os .path .join (model_dir , "model.keras" )
286
+ model .save (model_path )
287
+
288
+ # Save metadata
275
289
with open (os .path .join (model_dir , "metadata.pickle" ), "wb" ) as f :
276
290
pickle .dump (bc_params , f )
277
291
@@ -283,7 +297,12 @@ def load_bc_model(model_dir, verbose=False):
283
297
"""
284
298
if verbose :
285
299
print ("Loading bc model from " , model_dir )
286
- model = keras .models .load_model (model_dir , custom_objects = {"tf" : tf })
300
+
301
+ # Load model from .keras file
302
+ model_path = os .path .join (model_dir , "model.keras" )
303
+ model = keras .models .load_model (model_path , custom_objects = {"tf" : tf })
304
+
305
+ # Load metadata
287
306
with open (os .path .join (model_dir , "metadata.pickle" ), "rb" ) as f :
288
307
bc_params = pickle .load (f )
289
308
return model , bc_params
@@ -406,40 +425,6 @@ def _build_lstm_model(
406
425
################
407
426
408
427
409
- class NullContextManager :
410
- """
411
- No-op context manager that does nothing
412
- """
413
-
414
- def __init__ (self ):
415
- pass
416
-
417
- def __enter__ (self ):
418
- pass
419
-
420
- def __exit__ (self , * args ):
421
- pass
422
-
423
-
424
- class TfContextManager :
425
- """
426
- Properly sets the execution graph and session of the keras backend given a "session" object as input
427
-
428
- Used for isolating tf execution in graph mode. Do not use with eager models or with eager mode on
429
- """
430
-
431
- def __init__ (self , session ):
432
- self .session = session
433
-
434
- def __enter__ (self ):
435
- self .ctx = self .session .graph .as_default ()
436
- self .ctx .__enter__ ()
437
- set_session (self .session )
438
-
439
- def __exit__ (self , * args ):
440
- self .ctx .__exit__ (* args )
441
-
442
-
443
428
class BehaviorCloningPolicy (RllibPolicy ):
444
429
def __init__ (self , observation_space , action_space , config ):
445
430
"""
@@ -470,8 +455,6 @@ def __init__(self, observation_space, action_space, config):
470
455
)
471
456
model , bc_params = load_bc_model (config ["model_dir" ])
472
457
473
- # Save the session that the model was loaded into so it is available at inference time if necessary
474
- self ._sess = get_session ()
475
458
self ._setup_shapes ()
476
459
477
460
# Basic check to make sure model dimensions match
@@ -482,8 +465,6 @@ def __init__(self, observation_space, action_space, config):
482
465
self .stochastic = config ["stochastic" ]
483
466
self .use_lstm = bc_params ["use_lstm" ]
484
467
self .cell_size = bc_params ["cell_size" ]
485
- self .eager = config ["eager" ] if "eager" in config else bc_params ["eager" ]
486
- self .context = self ._create_execution_context ()
487
468
488
469
def _setup_shapes (self ):
489
470
# This is here to make the class compatible with both tuples or gymnasium.Space objs for the spaces
@@ -542,11 +523,8 @@ def compute_actions(
542
523
# Cast to np.array if list (no-op if already np.array)
543
524
obs_batch = np .array (obs_batch )
544
525
545
- # Run the model
546
- with self .context :
547
- action_logits , states = self ._forward (obs_batch , state_batches )
526
+ action_logits , states = self ._forward (obs_batch , state_batches )
548
527
549
- # Softmax in numpy to convert logits to probabilities
550
528
action_probs = softmax (action_logits )
551
529
if self .stochastic :
552
530
# Sample according to action_probs for each row in the output
@@ -611,16 +589,6 @@ def _forward(self, obs_batch, state_batches):
611
589
else :
612
590
return self .model .predict (obs_batch , verbose = 0 ), []
613
591
614
- def _create_execution_context (self ):
615
- """
616
- Creates a private execution context for the model
617
-
618
- Necessary if using with rllib in order to isolate this policy model from others
619
- """
620
- if self .eager :
621
- return NullContextManager ()
622
- return TfContextManager (self ._sess )
623
-
624
592
625
593
if __name__ == "__main__" :
626
594
params = get_bc_params ()
0 commit comments