Skip to content

Commit 7f98df9

Browse files
authored
Release v2.1.0 (#395)
* Release v2.1.0 * Fix mypy
1 parent 660f2d3 commit 7f98df9

10 files changed

+20
-32
lines changed

CHANGELOG.md

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1-
## Release 2.1.0a0 (WIP)
1+
## Release 2.1.0 (2023-08-17)
22

33
### Breaking Changes
44
- Dropped python 3.7 support
55
- SB3 now requires PyTorch 1.13+
6+
- Upgraded to SB3 >= 2.1.0
7+
- Upgraded to Huggingface-SB3 >= 2.3
8+
- Upgraded to Optuna >= 3.0
9+
- Upgraded to cloudpickle >= 2.2.1
610

711
### New Features
812
- Added python 3.11 support

Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
LINT_PATHS = *.py tests/ scripts/ rl_zoo3/ hyperparams/python/*.py
1+
LINT_PATHS = *.py tests/ scripts/ rl_zoo3/ hyperparams/python/*.py docs/conf.py
22

33
# Run pytest and coverage report
44
pytest:

docs/conf.py

+1-17
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
#
1414
import os
1515
import sys
16-
from typing import Dict, List
17-
from unittest.mock import MagicMock
16+
from typing import Dict
1817

1918
# We CANNOT enable 'sphinxcontrib.spelling' because ReadTheDocs.org does not support
2019
# PyEnchant.
@@ -37,21 +36,6 @@
3736
sys.path.insert(0, os.path.abspath(".."))
3837

3938

40-
class Mock(MagicMock):
41-
__subclasses__ = [] # type: ignore
42-
43-
@classmethod
44-
def __getattr__(cls, name):
45-
return MagicMock()
46-
47-
48-
# Mock modules that requires C modules
49-
# Note: because of that we cannot test examples using CI
50-
# 'torch', 'torch.nn', 'torch.nn.functional',
51-
# DO not mock modules for now, we will need to do that for read the docs later
52-
MOCK_MODULES: List[str] = []
53-
sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)
54-
5539
# Read version from file
5640
version_file = os.path.join(os.path.dirname(__file__), "../rl_zoo3", "version.txt")
5741
with open(version_file) as file_handler:

requirements.txt

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11
gym==0.26.2
2-
stable-baselines3[extra_no_roms,tests,docs]>=2.0.0
3-
sb3-contrib>=2.0.0
2+
stable-baselines3[extra_no_roms,tests,docs]>=2.1.0
3+
sb3-contrib>=2.1.0
44
box2d-py==2.3.8
55
pybullet
66
# minigrid
77
# scikit-optimize
88
optuna~=3.0
99
pytablewriter~=0.64
1010
pyyaml>=5.1
11-
cloudpickle>=1.5.0
11+
cloudpickle>=2.2.1
1212
plotly
1313
# need to upgrade to gymnasium:
1414
# panda-gym~=3.0.1
1515
rliable>=1.0.5
1616
wandb
17-
huggingface_sb3>=2.2.5
17+
huggingface_sb3>=2.3
1818
seaborn
1919
tqdm
2020
rich

rl_zoo3/gym_patches.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -89,4 +89,4 @@ def step(self, action):
8989
# Patch Gymnasium TimeLimit
9090
gymnasium.wrappers.TimeLimit = PatchedTimeLimit # type: ignore[misc]
9191
gymnasium.wrappers.time_limit.TimeLimit = PatchedTimeLimit # type: ignore[misc]
92-
gymnasium.envs.registration.TimeLimit = PatchedTimeLimit # type: ignore[misc]
92+
gymnasium.envs.registration.TimeLimit = PatchedTimeLimit # type: ignore[misc,attr-defined]

rl_zoo3/plots/plot_from_file.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def plot_from_file(): # noqa: C901
9898
for new_key in results_2[key].keys():
9999
results[key][new_key] = results_2[key][new_key]
100100

101-
keys = [key for key in results[list(results.keys())[0]].keys() if key not in args.skip_keys]
101+
keys = [key for key in results[next(iter(results.keys()))].keys() if key not in args.skip_keys]
102102
print(f"keys: {keys}")
103103
if len(args.keep_keys) > 0:
104104
keys = [key for key in keys if key in args.keep_keys]

rl_zoo3/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def get_class_name(wrapper_name):
9595
"You should check the indentation."
9696
)
9797
wrapper_dict = wrapper_name
98-
wrapper_name = list(wrapper_dict.keys())[0]
98+
wrapper_name = next(iter(wrapper_dict.keys()))
9999
kwargs = wrapper_dict[wrapper_name]
100100
else:
101101
kwargs = {}
@@ -178,7 +178,7 @@ def get_callback_list(hyperparams: Dict[str, Any]) -> List[BaseCallback]:
178178
"You should check the indentation."
179179
)
180180
callback_dict = callback_name
181-
callback_name = list(callback_dict.keys())[0]
181+
callback_name = next(iter(callback_dict.keys()))
182182
kwargs = callback_dict[callback_name]
183183
else:
184184
kwargs = {}

rl_zoo3/version.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.1.0a0
1+
2.1.0

setup.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@
2727
},
2828
entry_points={"console_scripts": ["rl_zoo3=rl_zoo3.cli:main"]},
2929
install_requires=[
30-
"sb3_contrib>=2.0.0",
30+
"sb3_contrib>=2.1.0",
3131
"gym==0.26.2", # for patches to make gym backward compat
32-
"huggingface_sb3>=2.2.5",
32+
"huggingface_sb3>=2.3",
3333
"tqdm",
3434
"rich",
35-
"optuna",
35+
"optuna>=3.0",
3636
"pyyaml>=5.1",
3737
"pytablewriter~=0.64",
3838
# TODO: add test dependencies

tests/test_hyperparams_opt.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def test_optimize_log_path(tmp_path):
108108
assert os.path.isdir(os.path.join(optimization_log_path, "trial_1"))
109109
assert os.path.isfile(os.path.join(optimization_log_path, "trial_1", "evaluations.npz"))
110110

111-
study_path = list(glob.glob(str(tmp_path / algo / "report_*.pkl")))[0]
111+
study_path = next(iter(glob.glob(str(tmp_path / algo / "report_*.pkl"))))
112112
print(study_path)
113113
# Test reading best trials
114114
args = [

0 commit comments

Comments
 (0)