Skip to content

Commit 9f15bf0

Browse files
committed
Giving up on RL training, at some point worked with rllib=2.2
1 parent d56526c commit 9f15bf0

File tree

2 files changed

+54
-83
lines changed

2 files changed

+54
-83
lines changed

src/human_aware_rl/imitation/behavior_cloning_tf2.py

Lines changed: 34 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import tensorflow as tf
77
from ray.rllib.policy import Policy as RllibPolicy
88
from tensorflow import keras
9-
from tensorflow.compat.v1.keras.backend import get_session
109

1110
from human_aware_rl.data_dir import DATA_DIR
1211
from human_aware_rl.human.process_dataframes import get_human_human_trajectories
@@ -176,6 +175,9 @@ def build_bc_model(use_lstm=True, eager=False, **kwargs):
176175
def train_bc_model(model_dir, bc_params, verbose=False):
177176
inputs, seq_lens, targets = load_data(bc_params, verbose)
178177

178+
# Ensure targets are int32 for SparseCategoricalCrossentropy
179+
targets = tf.cast(targets, tf.int32)
180+
179181
training_params = bc_params["training_params"]
180182

181183
if training_params["use_class_weights"]:
@@ -220,7 +222,7 @@ def train_bc_model(model_dir, bc_params, verbose=False):
220222
),
221223
# Save checkpoints of the models at the end of every epoch (saving only the best one so far)
222224
keras.callbacks.ModelCheckpoint(
223-
filepath=os.path.join(model_dir, "checkpoints"),
225+
filepath=os.path.join(model_dir, "checkpoints", "model.keras"),
224226
monitor="loss",
225227
save_best_only=True,
226228
),
@@ -235,15 +237,25 @@ def train_bc_model(model_dir, bc_params, verbose=False):
235237

236238
# Inputs unique to lstm model
237239
if bc_params["use_lstm"]:
238-
inputs["seq_in"] = seq_lens
239-
inputs["hidden_in"] = np.zeros((N, bc_params["cell_size"]))
240-
inputs["memory_in"] = np.zeros((N, bc_params["cell_size"]))
240+
inputs["seq_in"] = tf.cast(seq_lens, tf.int32)
241+
inputs["hidden_in"] = tf.zeros((N, bc_params["cell_size"]), dtype=tf.float32)
242+
inputs["memory_in"] = tf.zeros((N, bc_params["cell_size"]), dtype=tf.float32)
241243

242244
# Batch size doesn't include time dimension (seq_len) so it should be smaller for rnn model
243245
batch_size = 1 if bc_params["use_lstm"] else training_params["batch_size"]
246+
model_inputs = (
247+
inputs
248+
if not bc_params["use_lstm"]
249+
else {
250+
"Overcooked_observation": inputs,
251+
"seq_in": seq_lens,
252+
"hidden_in": np.zeros((N, bc_params["cell_size"])),
253+
"memory_in": np.zeros((N, bc_params["cell_size"])),
254+
}
255+
)
244256
model.fit(
245-
inputs,
246-
targets,
257+
model_inputs,
258+
targets["logits"],
247259
callbacks=callbacks,
248260
batch_size=batch_size,
249261
epochs=training_params["epochs"],
@@ -260,18 +272,20 @@ def train_bc_model(model_dir, bc_params, verbose=False):
260272

261273
def save_bc_model(model_dir, model, bc_params, verbose=False):
262274
"""
263-
Saves the specified model under the directory model_dir. This creates three items
264-
265-
assets/ stores information essential to reconstructing the context and tf graph
266-
variables/ stores the model's trainable weights
267-
saved_model.pd the saved state of the model object
275+
Saves the specified model under the directory model_dir. This creates a .keras file
276+
containing the model's architecture, weights, and optimizer state.
268277
269278
Additionally, saves a pickled dictionary containing all the parameters used to construct this model
270279
at model_dir/metadata.pickle
271280
"""
272281
if verbose:
273282
print("Saving bc model at ", model_dir)
274-
model.save(model_dir, save_format="tf")
283+
284+
# Save model with .keras extension
285+
model_path = os.path.join(model_dir, "model.keras")
286+
model.save(model_path)
287+
288+
# Save metadata
275289
with open(os.path.join(model_dir, "metadata.pickle"), "wb") as f:
276290
pickle.dump(bc_params, f)
277291

@@ -283,7 +297,12 @@ def load_bc_model(model_dir, verbose=False):
283297
"""
284298
if verbose:
285299
print("Loading bc model from ", model_dir)
286-
model = keras.models.load_model(model_dir, custom_objects={"tf": tf})
300+
301+
# Load model from .keras file
302+
model_path = os.path.join(model_dir, "model.keras")
303+
model = keras.models.load_model(model_path, custom_objects={"tf": tf})
304+
305+
# Load metadata
287306
with open(os.path.join(model_dir, "metadata.pickle"), "rb") as f:
288307
bc_params = pickle.load(f)
289308
return model, bc_params
@@ -406,40 +425,6 @@ def _build_lstm_model(
406425
################
407426

408427

409-
class NullContextManager:
410-
"""
411-
No-op context manager that does nothing
412-
"""
413-
414-
def __init__(self):
415-
pass
416-
417-
def __enter__(self):
418-
pass
419-
420-
def __exit__(self, *args):
421-
pass
422-
423-
424-
class TfContextManager:
425-
"""
426-
Properly sets the execution graph and session of the keras backend given a "session" object as input
427-
428-
Used for isolating tf execution in graph mode. Do not use with eager models or with eager mode on
429-
"""
430-
431-
def __init__(self, session):
432-
self.session = session
433-
434-
def __enter__(self):
435-
self.ctx = self.session.graph.as_default()
436-
self.ctx.__enter__()
437-
set_session(self.session)
438-
439-
def __exit__(self, *args):
440-
self.ctx.__exit__(*args)
441-
442-
443428
class BehaviorCloningPolicy(RllibPolicy):
444429
def __init__(self, observation_space, action_space, config):
445430
"""
@@ -470,8 +455,6 @@ def __init__(self, observation_space, action_space, config):
470455
)
471456
model, bc_params = load_bc_model(config["model_dir"])
472457

473-
# Save the session that the model was loaded into so it is available at inference time if necessary
474-
self._sess = get_session()
475458
self._setup_shapes()
476459

477460
# Basic check to make sure model dimensions match
@@ -482,8 +465,6 @@ def __init__(self, observation_space, action_space, config):
482465
self.stochastic = config["stochastic"]
483466
self.use_lstm = bc_params["use_lstm"]
484467
self.cell_size = bc_params["cell_size"]
485-
self.eager = config["eager"] if "eager" in config else bc_params["eager"]
486-
self.context = self._create_execution_context()
487468

488469
def _setup_shapes(self):
489470
# This is here to make the class compatible with both tuples or gymnasium.Space objs for the spaces
@@ -542,11 +523,8 @@ def compute_actions(
542523
# Cast to np.array if list (no-op if already np.array)
543524
obs_batch = np.array(obs_batch)
544525

545-
# Run the model
546-
with self.context:
547-
action_logits, states = self._forward(obs_batch, state_batches)
526+
action_logits, states = self._forward(obs_batch, state_batches)
548527

549-
# Softmax in numpy to convert logits to probabilities
550528
action_probs = softmax(action_logits)
551529
if self.stochastic:
552530
# Sample according to action_probs for each row in the output
@@ -611,16 +589,6 @@ def _forward(self, obs_batch, state_batches):
611589
else:
612590
return self.model.predict(obs_batch, verbose=0), []
613591

614-
def _create_execution_context(self):
615-
"""
616-
Creates a private execution context for the model
617-
618-
Necessary if using with rllib in order to isolate this policy model from others
619-
"""
620-
if self.eager:
621-
return NullContextManager()
622-
return TfContextManager(self._sess)
623-
624592

625593
if __name__ == "__main__":
626594
params = get_bc_params()

src/human_aware_rl/imitation/behavior_cloning_tf2_test.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import argparse
2+
import gc
23
import os
34
import pickle
45
import shutil
56
import sys
7+
import time
68
import unittest
79
import warnings
810

@@ -32,7 +34,6 @@ def _clear_pickle():
3234

3335

3436
class TestBCTraining(unittest.TestCase):
35-
3637
"""
3738
Unittests for behavior cloning training and utilities
3839
@@ -48,9 +49,9 @@ def __init__(self, test_name):
4849
self.compute_pickle = False
4950
self.strict = False
5051
self.min_performance = 0
51-
assert not (
52-
self.compute_pickle and self.strict
53-
), "Cannot compute pickle and run strict reproducibility tests at same time"
52+
assert not (self.compute_pickle and self.strict), (
53+
"Cannot compute pickle and run strict reproducibility tests at same time"
54+
)
5455
if self.compute_pickle:
5556
_clear_pickle()
5657

@@ -94,7 +95,17 @@ def tearDown(self):
9495
with open(BC_EXPECTED_DATA_PATH, "wb") as f:
9596
pickle.dump(self.expected, f)
9697

97-
shutil.rmtree(self.model_dir)
98+
# Force garbage collection to close any open files
99+
gc.collect()
100+
101+
# Add a small delay to ensure files are released
102+
time.sleep(0.1)
103+
104+
try:
105+
# Use ignore_errors=True to force removal even if some files are still locked
106+
shutil.rmtree(self.model_dir, ignore_errors=True)
107+
except Exception as e:
108+
print(f"Warning: Could not fully remove directory {self.model_dir}: {e}")
98109

99110
def test_model_construction(self):
100111
model = build_bc_model(**self.bc_params)
@@ -115,9 +126,7 @@ def test_save_and_load(self):
115126
loaded_model, loaded_params = load_bc_model(self.model_dir)
116127
self.assertDictEqual(self.bc_params, loaded_params)
117128
self.assertTrue(
118-
np.allclose(
119-
model(self.dummy_input), loaded_model(self.dummy_input)
120-
)
129+
np.allclose(model(self.dummy_input), loaded_model(self.dummy_input))
121130
)
122131

123132
def test_training(self):
@@ -127,9 +136,7 @@ def test_training(self):
127136
self.expected["test_training"] = model(self.dummy_input)
128137
if self.strict:
129138
self.assertTrue(
130-
np.allclose(
131-
model(self.dummy_input), self.expected["test_training"]
132-
)
139+
np.allclose(model(self.dummy_input), self.expected["test_training"])
133140
)
134141

135142
def test_agent_evaluation(self):
@@ -143,9 +150,7 @@ def test_agent_evaluation(self):
143150
if self.compute_pickle:
144151
self.expected["test_agent_evaluation"] = results
145152
if self.strict:
146-
self.assertAlmostEqual(
147-
results, self.expected["test_agent_evaluation"]
148-
)
153+
self.assertAlmostEqual(results, self.expected["test_agent_evaluation"])
149154

150155

151156
class TestBCTrainingLSTM(TestBCTraining):
@@ -190,9 +195,7 @@ def test_lstm_evaluation(self):
190195
if self.compute_pickle:
191196
self.expected["test_lstm_evaluation"] = results
192197
if self.strict:
193-
self.assertAlmostEqual(
194-
results, self.expected["test_lstm_evaluation"]
195-
)
198+
self.assertAlmostEqual(results, self.expected["test_lstm_evaluation"])
196199

197200
def test_lstm_save_and_load(self):
198201
self.bc_params["use_lstm"] = True

0 commit comments

Comments
 (0)