Skip to content

Commit dc25cc6

Browse files
authored
Update SB3 and remove gSDE resampling (#251)
1 parent 25b4326 commit dc25cc6

File tree

6 files changed

+29
-12
lines changed

6 files changed

+29
-12
lines changed

Diff for: Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ lint:
1414
# see https://www.flake8rules.com/
1515
ruff check ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full
1616
# exit-zero treats all errors as warnings.
17-
ruff check ${LINT_PATHS} --exit-zero
17+
ruff check ${LINT_PATHS} --exit-zero --output-format=concise
1818

1919
format:
2020
# Sort imports

Diff for: docs/misc/changelog.rst

+26
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,32 @@
33
Changelog
44
==========
55

6+
7+
Release 2.4.0a4 (WIP)
8+
--------------------------
9+
10+
Breaking Changes:
11+
^^^^^^^^^^^^^^^^^
12+
- Upgraded to Stable-Baselines3 >= 2.4.0
13+
14+
New Features:
15+
^^^^^^^^^^^^^
16+
17+
Bug Fixes:
18+
^^^^^^^^^^
19+
20+
Deprecations:
21+
^^^^^^^^^^^^^
22+
23+
Others:
24+
^^^^^^^
25+
- Updated PyTorch version on CI to 2.3.1
26+
- Remove unnecessary SDE noise resampling in PPO/TRPO update
27+
28+
Documentation:
29+
^^^^^^^^^^^^^^
30+
31+
632
Release 2.3.0 (2024-03-31)
733
--------------------------
834

Diff for: sb3_contrib/ppo_recurrent/ppo_recurrent.py

-4
Original file line numberDiff line numberDiff line change
@@ -342,10 +342,6 @@ def train(self) -> None:
342342
# Convert mask from float to bool
343343
mask = rollout_data.mask > 1e-8
344344

345-
# Re-sample the noise matrix because the log_std has changed
346-
if self.use_sde:
347-
self.policy.reset_noise(self.batch_size)
348-
349345
values, log_prob, entropy = self.policy.evaluate_actions(
350346
rollout_data.observations,
351347
actions,

Diff for: sb3_contrib/trpo/trpo.py

-5
Original file line numberDiff line numberDiff line change
@@ -261,11 +261,6 @@ def train(self) -> None:
261261
# Convert discrete action from float to long
262262
actions = rollout_data.actions.long().flatten()
263263

264-
# Re-sample the noise matrix because the log_std has changed
265-
if self.use_sde:
266-
# batch_size is only used for the value function
267-
self.policy.reset_noise(actions.shape[0])
268-
269264
with th.no_grad():
270265
# Note: is copy enough, no need for deepcopy?
271266
# If using gSDE and deepcopy, we need to use `old_distribution.distribution`

Diff for: sb3_contrib/version.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.3.0
1+
2.4.0a4

Diff for: setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
6666
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
6767
install_requires=[
68-
"stable_baselines3>=2.3.0,<3.0",
68+
"stable_baselines3>=2.4.0a4,<3.0",
6969
],
7070
description="Contrib package of Stable Baselines3, experimental code.",
7171
author="Antonin Raffin",

0 commit comments

Comments
 (0)