Skip to content

Commit d56526c

Browse files
committed
pre tf upgrade
1 parent 049b2df commit d56526c

File tree

1 file changed

+34
-75
lines changed

1 file changed

+34
-75
lines changed

src/human_aware_rl/imitation/behavior_cloning_tf2.py

Lines changed: 34 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,11 @@
66
import tensorflow as tf
77
from ray.rllib.policy import Policy as RllibPolicy
88
from tensorflow import keras
9-
from tensorflow.compat.v1.keras.backend import get_session, set_session
9+
from tensorflow.compat.v1.keras.backend import get_session
1010

1111
from human_aware_rl.data_dir import DATA_DIR
12-
from human_aware_rl.human.process_dataframes import (
13-
get_human_human_trajectories,
14-
get_trajs_from_data,
15-
)
16-
from human_aware_rl.rllib.rllib import (
17-
RlLibAgent,
18-
evaluate,
19-
get_base_ae,
20-
softmax,
21-
)
12+
from human_aware_rl.human.process_dataframes import get_human_human_trajectories
13+
from human_aware_rl.rllib.rllib import evaluate, get_base_ae, softmax
2214
from human_aware_rl.static import CLEAN_2019_HUMAN_DATA_TRAIN
2315
from human_aware_rl.utils import get_flattened_keys, recursive_dict_update
2416
from overcooked_ai_py.mdp.actions import Action
@@ -119,9 +111,7 @@ def get_bc_params(**args_to_override):
119111

120112
all_keys = get_flattened_keys(params)
121113
if len(all_keys) != len(set(all_keys)):
122-
raise ValueError(
123-
"Every key at every level must be distict for BC params!"
124-
)
114+
raise ValueError("Every key at every level must be distict for BC params!")
125115

126116
return params
127117

@@ -198,9 +188,7 @@ def train_bc_model(model_dir, bc_params, verbose=False):
198188
class_weights = None
199189

200190
# Retrieve un-initialized keras model
201-
model = build_bc_model(
202-
**bc_params, max_seq_len=np.max(seq_lens), verbose=verbose
203-
)
191+
model = build_bc_model(**bc_params, max_seq_len=np.max(seq_lens), verbose=verbose)
204192

205193
# Initialize the model
206194
# Note: have to use lists for multi-output model support and not dicts because of tensorlfow 2.0.0 bug
@@ -225,9 +213,7 @@ def train_bc_model(model_dir, bc_params, verbose=False):
225213
# Early terminate training if loss doesn't improve for "patience" epochs
226214
keras.callbacks.EarlyStopping(monitor="loss", patience=20),
227215
# Reduce lr by "factor" after "patience" epochs of no improvement in loss
228-
keras.callbacks.ReduceLROnPlateau(
229-
monitor="loss", patience=3, factor=0.1
230-
),
216+
keras.callbacks.ReduceLROnPlateau(monitor="loss", patience=3, factor=0.1),
231217
# Log all metrics model was compiled with to tensorboard every epoch
232218
keras.callbacks.TensorBoard(
233219
log_dir=os.path.join(model_dir, "logs"), write_graph=False
@@ -329,12 +315,8 @@ def featurize_fn(state):
329315
return base_env.featurize_state_mdp(state)
330316

331317
# Wrap Keras models in rllib policies
332-
agent_0_policy = BehaviorCloningPolicy.from_model(
333-
model, bc_params, stochastic=True
334-
)
335-
agent_1_policy = BehaviorCloningPolicy.from_model(
336-
model, bc_params, stochastic=True
337-
)
318+
agent_0_policy = BehaviorCloningPolicy.from_model(model, bc_params, stochastic=True)
319+
agent_1_policy = BehaviorCloningPolicy.from_model(model, bc_params, stochastic=True)
338320

339321
# Compute the results of the rollout(s)
340322
results = evaluate(
@@ -355,21 +337,17 @@ def featurize_fn(state):
355337

356338
def _build_model(observation_shape, action_shape, mlp_params, **kwargs):
357339
## Inputs
358-
inputs = keras.Input(
359-
shape=observation_shape, name="Overcooked_observation"
360-
)
340+
inputs = keras.Input(shape=observation_shape, name="Overcooked_observation")
361341
x = inputs
362342

363343
## Build fully connected layers
364-
assert (
365-
len(mlp_params["net_arch"]) == mlp_params["num_layers"]
366-
), "Invalid Fully Connected params"
344+
assert len(mlp_params["net_arch"]) == mlp_params["num_layers"], (
345+
"Invalid Fully Connected params"
346+
)
367347

368348
for i in range(mlp_params["num_layers"]):
369349
units = mlp_params["net_arch"][i]
370-
x = keras.layers.Dense(
371-
units, activation="relu", name="fc_{0}".format(i)
372-
)(x)
350+
x = keras.layers.Dense(units, activation="relu", name="fc_{0}".format(i))(x)
373351

374352
## output layer
375353
logits = keras.layers.Dense(action_shape[0], name="logits")(x)
@@ -378,12 +356,7 @@ def _build_model(observation_shape, action_shape, mlp_params, **kwargs):
378356

379357

380358
def _build_lstm_model(
381-
observation_shape,
382-
action_shape,
383-
mlp_params,
384-
cell_size,
385-
max_seq_len=20,
386-
**kwargs
359+
observation_shape, action_shape, mlp_params, cell_size, max_seq_len=20, **kwargs
387360
):
388361
## Inputs
389362
obs_in = keras.Input(
@@ -395,21 +368,19 @@ def _build_lstm_model(
395368
x = obs_in
396369

397370
## Build fully connected layers
398-
assert (
399-
len(mlp_params["net_arch"]) == mlp_params["num_layers"]
400-
), "Invalid Fully Connected params"
371+
assert len(mlp_params["net_arch"]) == mlp_params["num_layers"], (
372+
"Invalid Fully Connected params"
373+
)
401374

402375
for i in range(mlp_params["num_layers"]):
403376
units = mlp_params["net_arch"][i]
404377
x = keras.layers.TimeDistributed(
405-
keras.layers.Dense(
406-
units, activation="relu", name="fc_{0}".format(i)
407-
)
378+
keras.layers.Dense(units, activation="relu", name="fc_{0}".format(i))
408379
)(x)
409380

410-
mask = keras.layers.Lambda(
411-
lambda x: tf.sequence_mask(x, maxlen=max_seq_len)
412-
)(seq_in)
381+
mask = keras.layers.Lambda(lambda x: tf.sequence_mask(x, maxlen=max_seq_len))(
382+
seq_in
383+
)
413384

414385
## LSTM layer
415386
lstm_out, h_out, c_out = keras.layers.LSTM(
@@ -488,17 +459,15 @@ def __init__(self, observation_space, action_space, config):
488459
)
489460

490461
if "bc_model" in config and config["bc_model"]:
491-
assert (
492-
"bc_params" in config
493-
), "must specify params in addition to model"
494-
assert issubclass(
495-
type(config["bc_model"]), keras.Model
496-
), "model must be of type keras.Model"
462+
assert "bc_params" in config, "must specify params in addition to model"
463+
assert issubclass(type(config["bc_model"]), keras.Model), (
464+
"model must be of type keras.Model"
465+
)
497466
model, bc_params = config["bc_model"], config["bc_params"]
498467
else:
499-
assert (
500-
"model_dir" in config
501-
), "must specify model directory if model not specified"
468+
assert "model_dir" in config, (
469+
"must specify model directory if model not specified"
470+
)
502471
model, bc_params = load_bc_model(config["model_dir"])
503472

504473
# Save the session that the model was loaded into so it is available at inference time if necessary
@@ -513,9 +482,7 @@ def __init__(self, observation_space, action_space, config):
513482
self.stochastic = config["stochastic"]
514483
self.use_lstm = bc_params["use_lstm"]
515484
self.cell_size = bc_params["cell_size"]
516-
self.eager = (
517-
config["eager"] if "eager" in config else bc_params["eager"]
518-
)
485+
self.eager = config["eager"] if "eager" in config else bc_params["eager"]
519486
self.context = self._create_execution_context()
520487

521488
def _setup_shapes(self):
@@ -540,9 +507,7 @@ def from_model_dir(cls, model_dir, stochastic=True):
540507
"bc_params": bc_params,
541508
"stochastic": stochastic,
542509
}
543-
return cls(
544-
bc_params["observation_shape"], bc_params["action_shape"], config
545-
)
510+
return cls(bc_params["observation_shape"], bc_params["action_shape"], config)
546511

547512
@classmethod
548513
def from_model(cls, model, bc_params, stochastic=True):
@@ -551,9 +516,7 @@ def from_model(cls, model, bc_params, stochastic=True):
551516
"bc_params": bc_params,
552517
"stochastic": stochastic,
553518
}
554-
return cls(
555-
bc_params["observation_shape"], bc_params["action_shape"], config
556-
)
519+
return cls(bc_params["observation_shape"], bc_params["action_shape"], config)
557520

558521
def compute_actions(
559522
self,
@@ -563,7 +526,7 @@ def compute_actions(
563526
prev_reward_batch=None,
564527
info_batch=None,
565528
episodes=None,
566-
**kwargs
529+
**kwargs,
567530
):
568531
"""
569532
Computes sampled actions for each of the corresponding OvercookedEnv states in obs_batch
@@ -641,9 +604,7 @@ def _forward(self, obs_batch, state_batches):
641604
if self.use_lstm:
642605
obs_batch = np.expand_dims(obs_batch, 1)
643606
seq_lens = np.ones(len(obs_batch))
644-
model_out = self.model.predict(
645-
[obs_batch, seq_lens] + state_batches
646-
)
607+
model_out = self.model.predict([obs_batch, seq_lens] + state_batches)
647608
logits, states = model_out[0], model_out[1:]
648609
logits = logits.reshape((logits.shape[0], -1))
649610
return logits, states
@@ -663,8 +624,6 @@ def _create_execution_context(self):
663624

664625
if __name__ == "__main__":
665626
params = get_bc_params()
666-
model = train_bc_model(
667-
os.path.join(BC_SAVE_DIR, "default"), params, verbose=True
668-
)
627+
model = train_bc_model(os.path.join(BC_SAVE_DIR, "default"), params, verbose=True)
669628
# Evaluate our model's performance in a rollout
670629
evaluate_bc_model(model, params)

0 commit comments

Comments
 (0)