Skip to content

Commit 9cad1d0

Browse files
authored
Add SimBa Policy: Simplicity Bias for Scaling Up Parameters in DRL (#59)
* Start testing simba * Quick try with CrossQ * Add actor for CrossQ * Add simba net for TQC * Remove unused param * Add parameter resets for TQC * Fix reset * Add missing param * Update documentation * Add parameter resets * Reformat pyproject.toml * Refactor: share actor between SAC and TQC * Add run tests for simba * Upgrade to python 3.9 (#64) * Fix mypy error, update version
1 parent 1c79684 commit 9cad1d0

26 files changed

+702
-227
lines changed

.github/workflows/ci.yml

+1-3
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ jobs:
2020
runs-on: ubuntu-latest
2121
strategy:
2222
matrix:
23-
python-version: ["3.8", "3.9", "3.10", "3.11"]
23+
python-version: ["3.9", "3.10", "3.11", "3.12"]
2424

2525
steps:
2626
- uses: actions/checkout@v3
@@ -52,8 +52,6 @@ jobs:
5252
- name: Type check
5353
run: |
5454
make type
55-
# skip mypy, jax doesn't have its latest version for python 3.8
56-
if: "!(matrix.python-version == '3.8')"
5755
- name: Test with pytest
5856
run: |
5957
make pytest

README.md

+44
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,11 @@ Implemented algorithms:
1818
- [Twin Delayed DDPG (TD3)](https://arxiv.org/abs/1802.09477)
1919
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/abs/1509.02971)
2020
- [Batch Normalization in Deep Reinforcement Learning (CrossQ)](https://openreview.net/forum?id=PczQtTsTIX)
21+
- [Simplicity Bias for Scaling Up Parameters in Deep Reinforcement Learning (SimBa)](https://openreview.net/forum?id=jXLiDKsuDo)
2122

2223

24+
Note: parameter resets for off-policy algorithms can be activated by passing a list of timesteps to the model constructor (ex: `param_resets=[int(1e5), int(5e5)]` to reset parameters and optimizers after 100_000 and 500_000 timesteps.
25+
2326
### Install using pip
2427

2528
For the latest master version:
@@ -132,6 +135,47 @@ Having a higher learning rate for the q-value function is also helpful: `qf_lear
132135

133136
Note: when using the DroQ configuration with CrossQ, you should set `layer_norm=False` as there is already batch normalization.
134137

138+
## Note about SimBa
139+
140+
[SimBa](https://openreview.net/forum?id=jXLiDKsuDo) is a special network architecture for off-policy algorithms (SAC, TQC, ...).
141+
142+
Some recommended hyperparameters (tested on MuJoCo and PyBullet environments):
143+
```python
144+
import optax
145+
146+
147+
default_hyperparams = dict(
148+
n_envs=1,
149+
n_timesteps=int(1e6),
150+
policy="SimbaPolicy",
151+
learning_rate=3e-4,
152+
# qf_learning_rate=1e-3,
153+
policy_kwargs={
154+
"optimizer_class": optax.adamw,
155+
# "optimizer_kwargs": {"weight_decay": 0.01},
156+
# Note: here [128] represent a residual block, not just a single layer
157+
"net_arch": {"pi": [128], "qf": [256, 256]},
158+
"n_critics": 2,
159+
},
160+
learning_starts=10_000,
161+
# Important: input normalization using VecNormalize
162+
normalize={"norm_obs": True, "norm_reward": False},
163+
)
164+
165+
hyperparams = {}
166+
167+
# You can also loop gym.registry
168+
for env_id in [
169+
"HalfCheetah-v4",
170+
"HalfCheetahBulletEnv-v0",
171+
"Ant-v4",
172+
]:
173+
hyperparams[env_id] = default_hyperparams
174+
```
175+
176+
and then using the RL Zoo script defined above: `python train.py --algo tqc --env HalfCheetah-v4 -c simba.py -P`.
177+
178+
135179
## Benchmark
136180

137181
A partial benchmark can be found on [OpenRL Benchmark](https://wandb.ai/openrlbenchmark/sbx) where you can also find several [reports](https://wandb.ai/openrlbenchmark/sbx/reportlist).

pyproject.toml

+9-7
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
[tool.ruff]
22
# Same as Black.
33
line-length = 127
4-
# Assume Python 3.8
5-
target-version = "py38"
4+
# Assume Python 3.9
5+
target-version = "py39"
66

77
[tool.ruff.lint]
88
# See https://beta.ruff.rs/docs/rules/
@@ -28,9 +28,7 @@ show_error_codes = true
2828

2929
[tool.pytest.ini_options]
3030
# Deterministic ordering for tests; useful for pytest-xdist.
31-
env = [
32-
"PYTHONHASHSEED=0"
33-
]
31+
env = ["PYTHONHASHSEED=0"]
3432

3533
filterwarnings = [
3634
# Tensorboard warnings
@@ -41,7 +39,7 @@ filterwarnings = [
4139
"ignore:rich is experimental",
4240
]
4341
markers = [
44-
"expensive: marks tests as expensive (deselect with '-m \"not expensive\"')"
42+
"expensive: marks tests as expensive (deselect with '-m \"not expensive\"')",
4543
]
4644

4745
[tool.coverage.run]
@@ -50,4 +48,8 @@ branch = false
5048
omit = ["tests/*", "setup.py"]
5149

5250
[tool.coverage.report]
53-
exclude_lines = [ "pragma: no cover", "raise NotImplementedError()", "if typing.TYPE_CHECKING:"]
51+
exclude_lines = [
52+
"pragma: no cover",
53+
"raise NotImplementedError()",
54+
"if typing.TYPE_CHECKING:",
55+
]

sbx/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@ def DroQ(*args, **kwargs):
2323

2424

2525
__all__ = [
26-
"CrossQ",
2726
"DDPG",
2827
"DQN",
2928
"PPO",
3029
"SAC",
3130
"TD3",
3231
"TQC",
32+
"CrossQ",
3333
]

sbx/common/jax_layers.py

+23-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
from typing import Any, Callable, Optional, Sequence, Tuple, Union
1+
from collections.abc import Sequence
2+
from typing import Any, Callable, Optional, Union
23

4+
import flax.linen as nn
35
import jax
46
import jax.numpy as jnp
57
from flax.linen.module import Module, compact, merge_param
@@ -8,7 +10,7 @@
810

911
PRNGKey = Any
1012
Array = Any
11-
Shape = Tuple[int, ...]
13+
Shape = tuple[int, ...]
1214
Dtype = Any # this could be a real type?
1315
Axes = Union[int, Sequence[int]]
1416

@@ -204,3 +206,22 @@ def __call__(self, x, use_running_average: Optional[bool] = None):
204206
self.bias_init,
205207
self.scale_init,
206208
)
209+
210+
211+
# Adapted from simba: https://github.com/SonyResearch/simba
212+
class SimbaResidualBlock(nn.Module):
213+
hidden_dim: int
214+
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
215+
# "the MLP is structured with an inverted bottleneck, where the hidden
216+
# dimension is expanded to 4 * hidden_dim"
217+
scale_factor: int = 4
218+
norm_layer: type[nn.Module] = nn.LayerNorm
219+
220+
@nn.compact
221+
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
222+
residual = x
223+
x = self.norm_layer()(x)
224+
x = nn.Dense(self.hidden_dim * self.scale_factor, kernel_init=nn.initializers.he_normal())(x)
225+
x = self.activation_fn(x)
226+
x = nn.Dense(self.hidden_dim, kernel_init=nn.initializers.he_normal())(x)
227+
return residual + x

sbx/common/off_policy_algorithm.py

+25-8
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import io
22
import pathlib
3-
from typing import Any, Dict, List, Optional, Tuple, Type, Union
3+
from typing import Any, Optional, Union
44

55
import jax
66
import numpy as np
@@ -17,7 +17,7 @@
1717
class OffPolicyAlgorithmJax(OffPolicyAlgorithm):
1818
def __init__(
1919
self,
20-
policy: Type[BasePolicy],
20+
policy: type[BasePolicy],
2121
env: Union[GymEnv, str],
2222
learning_rate: Union[float, Schedule],
2323
qf_learning_rate: Optional[float] = None,
@@ -26,13 +26,13 @@ def __init__(
2626
batch_size: int = 256,
2727
tau: float = 0.005,
2828
gamma: float = 0.99,
29-
train_freq: Union[int, Tuple[int, str]] = (1, "step"),
29+
train_freq: Union[int, tuple[int, str]] = (1, "step"),
3030
gradient_steps: int = 1,
3131
action_noise: Optional[ActionNoise] = None,
32-
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
33-
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
32+
replay_buffer_class: Optional[type[ReplayBuffer]] = None,
33+
replay_buffer_kwargs: Optional[dict[str, Any]] = None,
3434
optimize_memory_usage: bool = False,
35-
policy_kwargs: Optional[Dict[str, Any]] = None,
35+
policy_kwargs: Optional[dict[str, Any]] = None,
3636
tensorboard_log: Optional[str] = None,
3737
verbose: int = 0,
3838
device: str = "auto",
@@ -43,7 +43,9 @@ def __init__(
4343
sde_sample_freq: int = -1,
4444
use_sde_at_warmup: bool = False,
4545
sde_support: bool = True,
46-
supported_action_spaces: Optional[Tuple[Type[spaces.Space], ...]] = None,
46+
stats_window_size: int = 100,
47+
param_resets: Optional[list[int]] = None,
48+
supported_action_spaces: Optional[tuple[type[spaces.Space], ...]] = None,
4749
):
4850
super().__init__(
4951
policy=policy,
@@ -62,6 +64,7 @@ def __init__(
6264
use_sde=use_sde,
6365
sde_sample_freq=sde_sample_freq,
6466
use_sde_at_warmup=use_sde_at_warmup,
67+
stats_window_size=stats_window_size,
6568
policy_kwargs=policy_kwargs,
6669
tensorboard_log=tensorboard_log,
6770
verbose=verbose,
@@ -74,11 +77,25 @@ def __init__(
7477
self.key = jax.random.PRNGKey(0)
7578
# Note: we do not allow schedule for it
7679
self.qf_learning_rate = qf_learning_rate
80+
self.param_resets = param_resets
81+
self.reset_idx = 0
82+
83+
def _maybe_reset_params(self) -> None:
84+
# Maybe reset the parameters
85+
if (
86+
self.param_resets
87+
and self.reset_idx < len(self.param_resets)
88+
and self.num_timesteps >= self.param_resets[self.reset_idx]
89+
):
90+
# Note: we are not resetting the entropy coeff
91+
assert isinstance(self.qf_learning_rate, float)
92+
self.key = self.policy.build(self.key, self.lr_schedule, self.qf_learning_rate)
93+
self.reset_idx += 1
7794

7895
def _get_torch_save_params(self):
7996
return [], []
8097

81-
def _excluded_save_params(self) -> List[str]:
98+
def _excluded_save_params(self) -> list[str]:
8299
excluded = super()._excluded_save_params()
83100
excluded.remove("policy")
84101
return excluded

sbx/common/on_policy_algorithm.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
1+
from typing import Any, Optional, TypeVar, Union
22

33
import gymnasium as gym
44
import jax
@@ -24,7 +24,7 @@ class OnPolicyAlgorithmJax(OnPolicyAlgorithm):
2424

2525
def __init__(
2626
self,
27-
policy: Union[str, Type[BasePolicy]],
27+
policy: Union[str, type[BasePolicy]],
2828
env: Union[GymEnv, str],
2929
learning_rate: Union[float, Schedule],
3030
n_steps: int,
@@ -37,12 +37,12 @@ def __init__(
3737
sde_sample_freq: int,
3838
tensorboard_log: Optional[str] = None,
3939
monitor_wrapper: bool = True,
40-
policy_kwargs: Optional[Dict[str, Any]] = None,
40+
policy_kwargs: Optional[dict[str, Any]] = None,
4141
verbose: int = 0,
4242
seed: Optional[int] = None,
4343
device: str = "auto",
4444
_init_setup_model: bool = True,
45-
supported_action_spaces: Optional[Tuple[Type[spaces.Space], ...]] = None,
45+
supported_action_spaces: Optional[tuple[type[spaces.Space], ...]] = None,
4646
):
4747
super().__init__(
4848
policy=policy, # type: ignore[arg-type]
@@ -70,7 +70,7 @@ def __init__(
7070
def _get_torch_save_params(self):
7171
return [], []
7272

73-
def _excluded_save_params(self) -> List[str]:
73+
def _excluded_save_params(self) -> list[str]:
7474
excluded = super()._excluded_save_params()
7575
excluded.remove("policy")
7676
return excluded

0 commit comments

Comments
 (0)