Skip to content

Commit c2f00ea

Browse files
ernestumaraffin
andauthored
Support environments with slash ('/') in their name (#257)
* Fix bug in get_last_run_id which would ignore runs when the env name contains a slash. * Remove any slashes from the HF repo name. * Fix formatting in utils.py * Update CHANGELOG.md * Construct correct repo id when loading from huggingface hub. * Use environment name instead of environment id to ensure slashes are replaced by dashes. * Fix get_trained_models util by loading the env_id from metadata instead of parsing it from the model path. * Fix get_hf_trained_models util by loading the env_id and algo from the model card instead of parsing it from the repo id. * Remove unused lines from migrate_to_hub.py * Fix formatting in utils.py * Change help text of --env parameter back to `environment ID` * Add comments to explain naming scheme in. * Make `get_trained_models()` use the `args.yml` file instead of the monitor log file to determine the used environment. * Introduce usage of EnvironmentName to record_training.py and record_video.py * Restrict huggingface_sb3 version to avoid breaking changes. * Fix formatting in utils.py * Add missing seaborn requirement. * Pass gym_id instead of env_name to is_atari and crete_test_env. * Use EnvironmentName in the ExperimentManger to properly construct folder names. * Fix formatting in exp_manager.py * Disable slow check and fix recurrent ppo alias Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
1 parent f1064a7 commit c2f00ea

11 files changed

+143
-110
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
- Fix `Reacher-v3` name in PPO hyperparameter file
2020
- Pinned ale-py==0.7.4 until new SB3 version is released
2121
- Fix enjoy / record videos with LSTM policy
22+
- Fix bug with environments that have a slash in their name (@ernestum)
2223
- Changed `optimize_memory_usage` to `False` for DQN/QR-DQN on Atari games,
2324
if you want to save RAM, you need to deactivate `handle_timeout_termination`
2425
in the `replay_buffer_kwargs`

enjoy.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77
import torch as th
88
import yaml
9+
from huggingface_sb3 import EnvironmentName
910
from stable_baselines3.common.utils import set_random_seed
1011

1112
import utils.import_envs # noqa: F401 pylint: disable=unused-import
@@ -17,7 +18,7 @@
1718

1819
def main(): # noqa: C901
1920
parser = argparse.ArgumentParser()
20-
parser.add_argument("--env", help="environment ID", type=str, default="CartPole-v1")
21+
parser.add_argument("--env", help="environment ID", type=EnvironmentName, default="CartPole-v1")
2122
parser.add_argument("-f", "--folder", help="Log folder", type=str, default="rl-trained-agents")
2223
parser.add_argument("--algo", help="RL Algorithm", default="ppo", type=str, required=False, choices=list(ALGOS.keys()))
2324
parser.add_argument("-n", "--n-timesteps", help="number of timesteps", default=1000, type=int)
@@ -67,7 +68,7 @@ def main(): # noqa: C901
6768
for env_module in args.gym_packages:
6869
importlib.import_module(env_module)
6970

70-
env_id = args.env
71+
env_name: EnvironmentName = args.env
7172
algo = args.algo
7273
folder = args.folder
7374

@@ -76,7 +77,7 @@ def main(): # noqa: C901
7677
args.exp_id,
7778
folder,
7879
algo,
79-
env_id,
80+
env_name,
8081
args.load_best,
8182
args.load_checkpoint,
8283
args.load_last_checkpoint,
@@ -91,7 +92,7 @@ def main(): # noqa: C901
9192
# Auto-download
9293
download_from_hub(
9394
algo=algo,
94-
env_id=env_id,
95+
env_name=env_name,
9596
exp_id=args.exp_id,
9697
folder=folder,
9798
organization="sb3",
@@ -103,7 +104,7 @@ def main(): # noqa: C901
103104
args.exp_id,
104105
folder,
105106
algo,
106-
env_id,
107+
env_name,
107108
args.load_best,
108109
args.load_checkpoint,
109110
args.load_last_checkpoint,
@@ -124,14 +125,14 @@ def main(): # noqa: C901
124125
print(f"Setting torch.num_threads to {args.num_threads}")
125126
th.set_num_threads(args.num_threads)
126127

127-
is_atari = ExperimentManager.is_atari(env_id)
128+
is_atari = ExperimentManager.is_atari(env_name.gym_id)
128129

129-
stats_path = os.path.join(log_path, env_id)
130+
stats_path = os.path.join(log_path, env_name)
130131
hyperparams, stats_path = get_saved_hyperparams(stats_path, norm_reward=args.norm_reward, test_mode=True)
131132

132133
# load env_kwargs if existing
133134
env_kwargs = {}
134-
args_path = os.path.join(log_path, env_id, "args.yml")
135+
args_path = os.path.join(log_path, env_name, "args.yml")
135136
if os.path.isfile(args_path):
136137
with open(args_path) as f:
137138
loaded_args = yaml.load(f, Loader=yaml.UnsafeLoader) # pytype: disable=module-attr
@@ -144,7 +145,7 @@ def main(): # noqa: C901
144145
log_dir = args.reward_log if args.reward_log != "" else None
145146

146147
env = create_test_env(
147-
env_id,
148+
env_name.gym_id,
148149
n_envs=args.n_envs,
149150
stats_path=stats_path,
150151
seed=args.seed,

requirements.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,5 @@ panda-gym==1.1.1 # tmp fix: until compatibility with panda-gym v2
1515
rliable>=1.0.5
1616
wandb
1717
ale-py==0.7.4 # tmp fix: until new SB3 version is released
18-
# TODO: replace with release
19-
git+https://github.com/huggingface/huggingface_sb3
18+
huggingface_sb3>=2.2.1, <3.*
19+
seaborn

scripts/migrate_to_hub.py

-4
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,4 @@
1616
if algo == "her":
1717
continue
1818

19-
# if model doesn't exist already
20-
repo_name = f"{algo}-{env_id}"
21-
repo_id = f"{orga}/{repo_name}"
22-
2319
return_code = subprocess.call(["python", "-m", "utils.push_to_hub"] + args)

utils/exp_manager.py

+18-13
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import optuna
1313
import torch as th
1414
import yaml
15+
from huggingface_sb3 import EnvironmentName
1516
from optuna.integration.skopt import SkoptSampler
1617
from optuna.pruners import BasePruner, MedianPruner, NopPruner, SuccessiveHalvingPruner
1718
from optuna.samplers import BaseSampler, RandomSampler, TPESampler
@@ -95,7 +96,7 @@ def __init__(
9596
):
9697
super().__init__()
9798
self.algo = algo
98-
self.env_id = env_id
99+
self.env_name = EnvironmentName(env_id)
99100
# Custom params
100101
self.custom_hyperparams = hyperparams
101102
self.env_kwargs = {} if env_kwargs is None else env_kwargs
@@ -144,22 +145,22 @@ def __init__(
144145
self.pruner = pruner
145146
self.n_startup_trials = n_startup_trials
146147
self.n_evaluations = n_evaluations
147-
self.deterministic_eval = not self.is_atari(self.env_id)
148+
self.deterministic_eval = not self.is_atari(env_id)
148149
self.device = device
149150

150151
# Logging
151152
self.log_folder = log_folder
152-
self.tensorboard_log = None if tensorboard_log == "" else os.path.join(tensorboard_log, env_id)
153+
self.tensorboard_log = None if tensorboard_log == "" else os.path.join(tensorboard_log, self.env_name)
153154
self.verbose = verbose
154155
self.args = args
155156
self.log_interval = log_interval
156157
self.save_replay_buffer = save_replay_buffer
157158

158159
self.log_path = f"{log_folder}/{self.algo}/"
159160
self.save_path = os.path.join(
160-
self.log_path, f"{self.env_id}_{get_latest_run_id(self.log_path, self.env_id) + 1}{uuid_str}"
161+
self.log_path, f"{self.env_name}_{get_latest_run_id(self.log_path, self.env_name) + 1}{uuid_str}"
161162
)
162-
self.params_path = f"{self.save_path}/{self.env_id}"
163+
self.params_path = f"{self.save_path}/{self.env_name}"
163164

164165
def setup_experiment(self) -> Optional[Tuple[BaseAlgorithm, Dict[str, Any]]]:
165166
"""
@@ -235,7 +236,7 @@ def save_trained_model(self, model: BaseAlgorithm) -> None:
235236
:param model:
236237
"""
237238
print(f"Saving to {self.save_path}")
238-
model.save(f"{self.save_path}/{self.env_id}")
239+
model.save(f"{self.save_path}/{self.env_name}")
239240

240241
if hasattr(model, "save_replay_buffer") and self.save_replay_buffer:
241242
print("Saving replay buffer")
@@ -267,12 +268,12 @@ def read_hyperparameters(self) -> Tuple[Dict[str, Any], Dict[str, Any]]:
267268
# Load hyperparameters from yaml file
268269
with open(f"hyperparams/{self.algo}.yml") as f:
269270
hyperparams_dict = yaml.safe_load(f)
270-
if self.env_id in list(hyperparams_dict.keys()):
271-
hyperparams = hyperparams_dict[self.env_id]
271+
if self.env_name.gym_id in list(hyperparams_dict.keys()):
272+
hyperparams = hyperparams_dict[self.env_name.gym_id]
272273
elif self._is_atari:
273274
hyperparams = hyperparams_dict["atari"]
274275
else:
275-
raise ValueError(f"Hyperparameters not found for {self.algo}-{self.env_id}")
276+
raise ValueError(f"Hyperparameters not found for {self.algo}-{self.env_name.gym_id}")
276277

277278
if self.custom_hyperparams is not None:
278279
# Overwrite hyperparams if needed
@@ -486,7 +487,7 @@ def _maybe_normalize(self, env: VecEnv, eval_env: bool) -> VecEnv:
486487
:return:
487488
"""
488489
# Pretrained model, load normalization
489-
path_ = os.path.join(os.path.dirname(self.trained_agent), self.env_id)
490+
path_ = os.path.join(os.path.dirname(self.trained_agent), self.env_name)
490491
path_ = os.path.join(path_, "vecnormalize.pkl")
491492

492493
if os.path.exists(path_):
@@ -530,13 +531,17 @@ def create_envs(self, n_envs: int, eval_env: bool = False, no_log: bool = False)
530531

531532
monitor_kwargs = {}
532533
# Special case for GoalEnvs: log success rate too
533-
if "Neck" in self.env_id or self.is_robotics_env(self.env_id) or "parking-v0" in self.env_id:
534+
if (
535+
"Neck" in self.env_name.gym_id
536+
or self.is_robotics_env(self.env_name.gym_id)
537+
or "parking-v0" in self.env_name.gym_id
538+
):
534539
monitor_kwargs = dict(info_keywords=("is_success",))
535540

536541
# On most env, SubprocVecEnv does not help and is quite memory hungry
537542
# therefore we use DummyVecEnv by default
538543
env = make_vec_env(
539-
env_id=self.env_id,
544+
env_id=self.env_name.gym_id,
540545
n_envs=n_envs,
541546
seed=self.seed,
542547
env_kwargs=self.env_kwargs,
@@ -797,7 +802,7 @@ def hyperparameters_optimization(self) -> None:
797802
print(f" {key}: {value}")
798803

799804
report_name = (
800-
f"report_{self.env_id}_{self.n_trials}-trials-{self.n_timesteps}"
805+
f"report_{self.env_name}_{self.n_trials}-trials-{self.n_timesteps}"
801806
f"-{self.sampler}-{self.pruner}_{int(time.time())}"
802807
)
803808

utils/load_from_hub.py

+15-14
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55
from pathlib import Path
66
from typing import Optional
77

8-
from huggingface_sb3 import load_from_hub
8+
from huggingface_sb3 import EnvironmentName, ModelName, ModelRepoId, load_from_hub
99
from requests.exceptions import HTTPError
1010

1111
from utils import ALGOS, get_latest_run_id
1212

1313

1414
def download_from_hub(
1515
algo: str,
16-
env_id: str,
16+
env_name: EnvironmentName,
1717
exp_id: int,
1818
folder: str,
1919
organization: str,
@@ -27,7 +27,7 @@ def download_from_hub(
2727
where repo_name = {algo}-{env_id}
2828
2929
:param algo: Algorithm
30-
:param env_id: Environment id
30+
:param env_name: Environment name
3131
:param exp_id: Experiment id
3232
:param folder: Log folder
3333
:param organization: Huggingface organization
@@ -36,15 +36,16 @@ def download_from_hub(
3636
if it already exists.
3737
"""
3838

39+
model_name = ModelName(algo, env_name)
40+
3941
if repo_name is None:
40-
repo_name = f"{algo}-{env_id}"
42+
repo_name = model_name # Note: model name is {algo}-{env_name}
4143

42-
repo_id = f"{organization}/{repo_name}"
44+
# Note: repo id is {organization}/{repo_name}
45+
repo_id = ModelRepoId(organization, repo_name)
4346
print(f"Downloading from https://huggingface.co/{repo_id}")
4447

45-
model_name = f"{algo}-{env_id}"
46-
47-
checkpoint = load_from_hub(repo_id, f"{model_name}.zip")
48+
checkpoint = load_from_hub(repo_id, model_name.filename)
4849
config_path = load_from_hub(repo_id, "config.yml")
4950

5051
# If VecNormalize, download
@@ -59,10 +60,10 @@ def download_from_hub(
5960
train_eval_metrics = load_from_hub(repo_id, "train_eval_metrics.zip")
6061

6162
if exp_id == 0:
62-
exp_id = get_latest_run_id(os.path.join(folder, algo), env_id) + 1
63+
exp_id = get_latest_run_id(os.path.join(folder, algo), env_name) + 1
6364
# Sanity checks
6465
if exp_id > 0:
65-
log_path = os.path.join(folder, algo, f"{env_id}_{exp_id}")
66+
log_path = os.path.join(folder, algo, f"{env_name}_{exp_id}")
6667
else:
6768
log_path = os.path.join(folder, algo)
6869

@@ -82,11 +83,11 @@ def download_from_hub(
8283
print(f"Saving to {log_path}")
8384
# Create folder structure
8485
os.makedirs(log_path, exist_ok=True)
85-
config_folder = os.path.join(log_path, env_id)
86+
config_folder = os.path.join(log_path, env_name)
8687
os.makedirs(config_folder, exist_ok=True)
8788

8889
# Copy config files and saved stats
89-
shutil.copy(checkpoint, os.path.join(log_path, f"{env_id}.zip"))
90+
shutil.copy(checkpoint, os.path.join(log_path, f"{env_name}.zip"))
9091
shutil.copy(saved_args, os.path.join(config_folder, "args.yml"))
9192
shutil.copy(config_path, os.path.join(config_folder, "config.yml"))
9293
shutil.copy(env_kwargs, os.path.join(config_folder, "env_kwargs.yml"))
@@ -100,7 +101,7 @@ def download_from_hub(
100101

101102
if __name__ == "__main__":
102103
parser = argparse.ArgumentParser()
103-
parser.add_argument("--env", help="environment ID", type=str, required=True)
104+
parser.add_argument("--env", help="environment ID", type=EnvironmentName, required=True)
104105
parser.add_argument("-f", "--folder", help="Log folder", type=str, required=True)
105106
parser.add_argument("-orga", "--organization", help="Huggingface hub organization", default="sb3")
106107
parser.add_argument("-name", "--repo-name", help="Huggingface hub repository name, by default 'algo-env_id'", type=str)
@@ -114,7 +115,7 @@ def download_from_hub(
114115

115116
download_from_hub(
116117
algo=args.algo,
117-
env_id=args.env,
118+
env_name=args.env,
118119
exp_id=args.exp_id,
119120
folder=args.folder,
120121
organization=args.organization,

0 commit comments

Comments
 (0)