Skip to content

Commit 7586908

Browse files
authoredJan 18, 2022
Add Augmented Random Search (ARS) support (#201)
* Add support for ARS * Update params * Update hyperparams * Update params * Update params * Update ARS multienvs * Update params * Add hyperparam optimization for ARS * Fix ARS multi envs * Remove unused param * Update pendulum params * Tuned classic control envs * Update params * Add episode length plot * Add delta_std for the schedules * Update search range * Save params * Add A1 * Add jumping env * Update image and readme * Update changelog * Update benchmark * Add pre-trained agents * Update requirements * Update trained agents
1 parent c40cea6 commit 7586908

File tree

26 files changed

+6368
-21
lines changed

26 files changed

+6368
-21
lines changed
 

‎CHANGELOG.md

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,21 @@
1-
## Release 1.3.1a7 (WIP)
1+
## Release 1.3.1a9 (WIP)
22

33
### Breaking Changes
44
- Dropped python 3.6 support
5-
- Upgrade to Stable-Baselines3 (SB3) >= 1.3.1a8
6-
- Upgrade to sb3-contrib >= 1.3.1a7
5+
- Upgrade to Stable-Baselines3 (SB3) >= 1.3.1a9
6+
- Upgrade to sb3-contrib >= 1.3.1a9
77

88
### New Features
99
- Added mujoco hyperparameters
1010
- Added MuJoCo pre-trained agents
1111
- Added script to parse best hyperparameters of an optuna study
1212
- Added TRPO support
13+
- Added ARS support and pre-trained agents
1314

1415
### Bug fixes
1516

1617
### Documentation
18+
- Replace front image
1719

1820
### Other
1921

‎README.md

+5-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
# RL Baselines3 Zoo: A Training Framework for Stable Baselines3 Reinforcement Learning Agents
66

7-
<img src="images/panda_pick.gif" align="right" width="35%"/>
7+
<img src="images/car.jpg" align="right" width="40%"/>
88

99
RL Baselines3 Zoo is a training framework for Reinforcement Learning (RL), using [Stable Baselines3](https://github.com/DLR-RM/stable-baselines3).
1010

@@ -322,6 +322,7 @@ Additional Atari Games (to be completed):
322322

323323
| RL Algo | CartPole-v1 | MountainCar-v0 | Acrobot-v1 | Pendulum-v0 | MountainCarContinuous-v0 |
324324
|----------|--------------|----------------|------------|--------------|--------------------------|
325+
| ARS | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
325326
| A2C | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
326327
| PPO | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
327328
| DQN | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | N/A | N/A |
@@ -337,6 +338,7 @@ Additional Atari Games (to be completed):
337338

338339
| RL Algo | BipedalWalker-v3 | LunarLander-v2 | LunarLanderContinuous-v2 | BipedalWalkerHardcore-v3 | CarRacing-v0 |
339340
|----------|--------------|----------------|------------|--------------|--------------------------|
341+
| ARS | | :heavy_check_mark: | | :heavy_check_mark: | |
340342
| A2C | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | |
341343
| PPO | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | |
342344
| DQN | N/A | :heavy_check_mark: | N/A | N/A | N/A |
@@ -356,6 +358,7 @@ Note: those environments are derived from [Roboschool](https://github.com/openai
356358

357359
| RL Algo | Walker2D | HalfCheetah | Ant | Reacher | Hopper | Humanoid |
358360
|----------|-----------|-------------|-----|---------|---------|----------|
361+
| ARS | | | | | | |
359362
| A2C | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | |
360363
| PPO | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | |
361364
| DDPG | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | |
@@ -379,6 +382,7 @@ PyBullet Envs (Continued)
379382

380383
| RL Algo | Walker2d | HalfCheetah | Ant | Swimmer | Hopper | Humanoid |
381384
|----------|-----------|-------------|-----|---------|---------|----------|
385+
| ARS | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | |
382386
| A2C | | :heavy_check_mark: | | :heavy_check_mark: | :heavy_check_mark: | |
383387
| PPO | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | |
384388
| DDPG | | | | | | |

‎benchmark.md

+11
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,17 @@ and also allow users to have access to pretrained agents.*
4343
|a2c |SpaceInvadersNoFrameskip-v4| 627.160| 201.974|10M | 604848| 162|
4444
|a2c |Swimmer-v3 | 200.627| 2.544|1M | 150000| 150|
4545
|a2c |Walker2DBulletEnv-v0 | 858.209| 333.116|2M | 149156| 173|
46+
|ars |Acrobot-v1 | -82.884| 23.825|500k | 149985| 1788|
47+
|ars |Ant-v3 | 2333.773| 20.597|75M | 150000| 150|
48+
|ars |CartPole-v1 | 500.000| 0.000|50k | 150000| 300|
49+
|ars |HalfCheetah-v3 | 4815.192| 1340.752|12M | 150000| 150|
50+
|ars |Hopper-v3 | 3343.919| 5.730|7M | 150000| 150|
51+
|ars |LunarLanderContinuous-v2 | 167.959| 147.071|2M | 149883| 562|
52+
|ars |MountainCar-v0 | -122.000| 33.456|500k | 149938| 1229|
53+
|ars |MountainCarContinuous-v0 | 96.672| 0.784|500k | 149990| 621|
54+
|ars |Pendulum-v0 | -212.540| 160.444|2M | 150000| 750|
55+
|ars |Swimmer-v3 | 355.267| 12.796|2M | 150000| 150|
56+
|ars |Walker2d-v3 | 2993.582| 166.289|75M | 149821| 152|
4657
|ddpg |AntBulletEnv-v0 | 2399.147| 75.410|1M | 150000| 150|
4758
|ddpg |BipedalWalker-v3 | 197.486| 141.580|1M | 149237| 227|
4859
|ddpg |HalfCheetahBulletEnv-v0 | 2078.325| 208.379|1M | 150000| 150|

‎hyperparams/ars.yml

+218
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
# Tuned
2+
CartPole-v1:
3+
n_envs: 1
4+
n_timesteps: !!float 5e4
5+
policy: 'LinearPolicy'
6+
n_delta: 2
7+
8+
# Tuned
9+
Pendulum-v0: &pendulum-params
10+
n_envs: 1
11+
n_timesteps: !!float 2e6
12+
policy: 'MlpPolicy'
13+
normalize: "dict(norm_obs=True, norm_reward=False)"
14+
learning_rate: !!float 0.018
15+
n_delta: 4
16+
n_top: 1
17+
delta_std: 0.1
18+
policy_kwargs: "dict(net_arch=[16])"
19+
zero_policy: False
20+
21+
# TO BE Tuned
22+
LunarLander-v2:
23+
<<: *pendulum-params
24+
n_delta: 6
25+
n_top: 1
26+
n_timesteps: !!float 2e6
27+
28+
# Tuned
29+
LunarLanderContinuous-v2:
30+
<<: *pendulum-params
31+
n_timesteps: !!float 2e6
32+
33+
# Tuned
34+
Acrobot-v1:
35+
<<: *pendulum-params
36+
n_timesteps: !!float 5e5
37+
38+
# Tuned
39+
MountainCar-v0:
40+
<<: *pendulum-params
41+
n_delta: 8
42+
n_timesteps: !!float 5e5
43+
44+
# Tuned
45+
MountainCarContinuous-v0:
46+
<<: *pendulum-params
47+
n_timesteps: !!float 5e5
48+
delta_std: 0.2
49+
50+
# === Pybullet Envs ===
51+
52+
# Almost tuned
53+
HalfCheetahBulletEnv-v0: &pybullet-defaults
54+
n_envs: 1
55+
policy: 'MlpPolicy'
56+
n_timesteps: !!float 7.5e7
57+
learning_rate: !!float 0.02
58+
delta_std: !!float 0.03
59+
n_delta: 8
60+
n_top: 8
61+
alive_bonus_offset: 0
62+
normalize: "dict(norm_obs=True, norm_reward=False)"
63+
policy_kwargs: "dict(net_arch=[64, 64])"
64+
zero_policy: False
65+
66+
# To be tuned
67+
AntBulletEnv-v0:
68+
n_envs: 1
69+
policy: 'MlpPolicy'
70+
n_timesteps: !!float 7.5e7
71+
learning_rate: !!float 0.02
72+
delta_std: !!float 0.03
73+
n_delta: 32
74+
n_top: 32
75+
alive_bonus_offset: 0
76+
normalize: "dict(norm_obs=True, norm_reward=False)"
77+
policy_kwargs: "dict(net_arch=[128, 64])"
78+
zero_policy: False
79+
80+
81+
Walker2DBulletEnv-v0:
82+
policy: 'MlpPolicy'
83+
n_timesteps: !!float 7.5e7
84+
learning_rate: !!float 0.03
85+
delta_std: !!float 0.025
86+
n_delta: 40
87+
n_top: 30
88+
alive_bonus_offset: -1
89+
normalize: "dict(norm_obs=True, norm_reward=False)"
90+
policy_kwargs: "dict(net_arch=[64, 64])"
91+
zero_policy: False
92+
93+
# Tuned
94+
HopperBulletEnv-v0:
95+
n_envs: 1
96+
policy: 'LinearPolicy'
97+
n_timesteps: !!float 7e6
98+
learning_rate: !!float 0.01
99+
delta_std: !!float 0.025
100+
n_delta: 8
101+
n_top: 4
102+
alive_bonus_offset: -1
103+
normalize: "dict(norm_obs=True, norm_reward=False)"
104+
105+
ReacherBulletEnv-v0:
106+
<<: *pybullet-defaults
107+
n_timesteps: !!float 1e6
108+
109+
# === Mujoco Envs ===
110+
# Params closest to original paper
111+
Swimmer-v3:
112+
n_envs: 1
113+
policy: 'LinearPolicy'
114+
n_timesteps: !!float 2e6
115+
learning_rate: !!float 0.02
116+
delta_std: !!float 0.01
117+
n_delta: 1
118+
n_top: 1
119+
alive_bonus_offset: 0
120+
# normalize: "dict(norm_obs=True, norm_reward=False)"
121+
122+
Hopper-v3:
123+
n_envs: 1
124+
policy: 'LinearPolicy'
125+
n_timesteps: !!float 7e6
126+
learning_rate: !!float 0.01
127+
delta_std: !!float 0.025
128+
n_delta: 8
129+
n_top: 4
130+
alive_bonus_offset: -1
131+
normalize: "dict(norm_obs=True, norm_reward=False)"
132+
133+
HalfCheetah-v3:
134+
n_envs: 1
135+
policy: 'LinearPolicy'
136+
n_timesteps: !!float 1.25e7
137+
learning_rate: !!float 0.02
138+
delta_std: !!float 0.03
139+
n_delta: 32
140+
n_top: 4
141+
alive_bonus_offset: 0
142+
normalize: "dict(norm_obs=True, norm_reward=False)"
143+
144+
Walker2d-v3:
145+
n_envs: 1
146+
policy: 'LinearPolicy'
147+
n_timesteps: !!float 7.5e7
148+
learning_rate: !!float 0.03
149+
delta_std: !!float 0.025
150+
n_delta: 40
151+
n_top: 30
152+
alive_bonus_offset: -1
153+
normalize: "dict(norm_obs=True, norm_reward=False)"
154+
155+
Ant-v3:
156+
n_envs: 1
157+
policy: 'LinearPolicy'
158+
n_timesteps: !!float 7.5e7
159+
learning_rate: !!float 0.015
160+
delta_std: !!float 0.025
161+
n_delta: 60
162+
n_top: 20
163+
alive_bonus_offset: -1
164+
normalize: "dict(norm_obs=True, norm_reward=False)"
165+
166+
167+
Humanoid-v3:
168+
n_envs: 1
169+
policy: 'LinearPolicy'
170+
n_timesteps: !!float 2.5e8
171+
learning_rate: 0.02
172+
delta_std: 0.0075
173+
n_delta: 256
174+
n_top: 256
175+
alive_bonus_offset: -5
176+
normalize: "dict(norm_obs=True, norm_reward=False)"
177+
178+
# Almost tuned
179+
BipedalWalker-v3:
180+
n_envs: 1
181+
policy: 'MlpPolicy'
182+
n_timesteps: !!float 1e8
183+
learning_rate: 0.02
184+
delta_std: 0.0075
185+
n_delta: 64
186+
n_top: 32
187+
alive_bonus_offset: -0.1
188+
normalize: "dict(norm_obs=True, norm_reward=False)"
189+
policy_kwargs: "dict(net_arch=[16])"
190+
191+
# TO Be Tuned
192+
BipedalWalkerHardcore-v3:
193+
n_envs: 1
194+
policy: 'MlpPolicy'
195+
n_timesteps: !!float 5e8
196+
learning_rate: 0.02
197+
delta_std: 0.0075
198+
n_delta: 64
199+
n_top: 32
200+
alive_bonus_offset: -0.1
201+
normalize: "dict(norm_obs=True, norm_reward=False)"
202+
policy_kwargs: "dict(net_arch=[16])"
203+
204+
A1Walking-v0:
205+
<<: *pendulum-params
206+
n_timesteps: !!float 2e6
207+
208+
A1Jumping-v0:
209+
policy: 'LinearPolicy'
210+
n_timesteps: !!float 7.5e7
211+
learning_rate: !!float 0.03
212+
delta_std: !!float 0.025
213+
n_delta: 40
214+
n_top: 30
215+
# alive_bonus_offset: -1
216+
normalize: "dict(norm_obs=True, norm_reward=False)"
217+
# policy_kwargs: "dict(net_arch=[16])"
218+

‎hyperparams/tqc.yml

+3
Original file line numberDiff line numberDiff line change
@@ -254,3 +254,6 @@ parking-v0:
254254
n_sampled_goal=4,
255255
max_episode_length=100
256256
)"
257+
258+
A1Walking-v0:
259+
<<: *pybullet-defaults

‎images/car.jpg

132 KB
Loading

‎images/panda_pick.gif

-642 KB
Binary file not shown.

0 commit comments

Comments
 (0)
Failed to load comments.