Skip to content

Update PPO to support net_arch, and additional fixes #65

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Feb 14, 2025
8 changes: 4 additions & 4 deletions sbx/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@ def __init__(self, *args, **kwargs):

@staticmethod
@jax.jit
def sample_action(actor_state, obervations, key):
dist = actor_state.apply_fn(actor_state.params, obervations)
def sample_action(actor_state, observations, key):
dist = actor_state.apply_fn(actor_state.params, observations)
action = dist.sample(seed=key)
return action

@staticmethod
@jax.jit
def select_action(actor_state, obervations):
return actor_state.apply_fn(actor_state.params, obervations).mode()
def select_action(actor_state, observations):
return actor_state.apply_fn(actor_state.params, observations).mode()

@no_type_check
def predict(
Expand Down
18 changes: 15 additions & 3 deletions sbx/crossq/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None:
self.policy.actor_state,
self.ent_coef_state,
self.key,
(actor_loss_value, qf_loss_value, ent_coef_value),
(actor_loss_value, qf_loss_value, ent_coef_loss_value, ent_coef_value),
) = self._train(
self.gamma,
self.target_entropy,
Expand All @@ -236,6 +236,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None:
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
self.logger.record("train/actor_loss", actor_loss_value.item())
self.logger.record("train/critic_loss", qf_loss_value.item())
self.logger.record("train/ent_coef_loss", ent_coef_loss_value.item())
self.logger.record("train/ent_coef", ent_coef_value.item())

@staticmethod
Expand Down Expand Up @@ -421,6 +422,7 @@ def _train(
"actor_loss": jnp.array(0.0),
"qf_loss": jnp.array(0.0),
"ent_coef_loss": jnp.array(0.0),
"ent_coef_value": jnp.array(0.0),
},
}

Expand Down Expand Up @@ -468,7 +470,12 @@ def one_update(i: int, carry: dict[str, Any]) -> dict[str, Any]:
target_entropy,
key,
)
info = {"actor_loss": actor_loss_value, "qf_loss": qf_loss_value, "ent_coef_loss": ent_coef_loss_value}
info = {
"actor_loss": actor_loss_value,
"qf_loss": qf_loss_value,
"ent_coef_loss": ent_coef_loss_value,
"ent_coef_value": ent_coef_value,
}

return {
"actor_state": actor_state,
Expand All @@ -485,5 +492,10 @@ def one_update(i: int, carry: dict[str, Any]) -> dict[str, Any]:
update_carry["actor_state"],
update_carry["ent_coef_state"],
update_carry["key"],
(update_carry["info"]["actor_loss"], update_carry["info"]["qf_loss"], update_carry["info"]["ent_coef_loss"]),
(
update_carry["info"]["actor_loss"],
update_carry["info"]["qf_loss"],
update_carry["info"]["ent_coef_loss"],
update_carry["info"]["ent_coef_value"],
),
)
8 changes: 4 additions & 4 deletions sbx/crossq/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,21 +425,21 @@ def forward(self, obs: np.ndarray, deterministic: bool = False) -> np.ndarray:

@staticmethod
@jax.jit
def sample_action(actor_state, obervations, key):
def sample_action(actor_state, observations, key):
dist = actor_state.apply_fn(
{"params": actor_state.params, "batch_stats": actor_state.batch_stats},
obervations,
observations,
train=False,
)
action = dist.sample(seed=key)
return action

@staticmethod
@jax.jit
def select_action(actor_state, obervations):
def select_action(actor_state, observations):
return actor_state.apply_fn(
{"params": actor_state.params, "batch_stats": actor_state.batch_stats},
obervations,
observations,
train=False,
).mode()

Expand Down
40 changes: 21 additions & 19 deletions sbx/ppo/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,23 @@


class Critic(nn.Module):
n_units: int = 256
net_arch: Sequence[int]
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.tanh

@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
x = Flatten()(x)
x = nn.Dense(self.n_units)(x)
x = self.activation_fn(x)
x = nn.Dense(self.n_units)(x)
x = self.activation_fn(x)
for n_units in self.net_arch:
x = nn.Dense(n_units)(x)
x = self.activation_fn(x)

x = nn.Dense(1)(x)
return x


class Actor(nn.Module):
action_dim: int
n_units: int = 256
net_arch: Sequence[int]
log_std_init: float = 0.0
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.tanh
# For Discrete, MultiDiscrete and MultiBinary actions
Expand All @@ -60,10 +60,11 @@ def __post_init__(self) -> None:
@nn.compact
def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-defined]
x = Flatten()(x)
x = nn.Dense(self.n_units)(x)
x = self.activation_fn(x)
x = nn.Dense(self.n_units)(x)
x = self.activation_fn(x)

for n_units in self.net_arch:
x = nn.Dense(n_units)(x)
x = self.activation_fn(x)

action_logits = nn.Dense(self.action_dim)(x)
if self.num_discrete_choices is None:
# Continuous actions
Expand Down Expand Up @@ -131,18 +132,19 @@ def __init__(
features_extractor_kwargs,
optimizer_class=optimizer_class,
optimizer_kwargs=optimizer_kwargs,
squash_output=True,
squash_output=False,
)
self.log_std_init = log_std_init
self.activation_fn = activation_fn
if net_arch is not None:
if isinstance(net_arch, list):
self.n_units = net_arch[0]
self.net_arch_pi = self.net_arch_vf = net_arch
else:
assert isinstance(net_arch, dict)
self.n_units = net_arch["pi"][0]
self.net_arch_pi = net_arch["pi"]
self.net_arch_vf = net_arch["vf"]
else:
self.n_units = 64
self.net_arch_pi = self.net_arch_vf = [64, 64]
self.use_sde = use_sde

self.key = self.noise_key = jax.random.PRNGKey(0)
Expand Down Expand Up @@ -188,7 +190,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, max_grad_norm: float) ->
raise NotImplementedError(f"{self.action_space}")

self.actor = Actor(
n_units=self.n_units,
net_arch=self.net_arch_pi,
log_std_init=self.log_std_init,
activation_fn=self.activation_fn,
**actor_kwargs, # type: ignore[arg-type]
Expand All @@ -208,7 +210,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, max_grad_norm: float) ->
),
)

self.vf = Critic(n_units=self.n_units, activation_fn=self.activation_fn)
self.vf = Critic(net_arch=self.net_arch_vf, activation_fn=self.activation_fn)

self.vf_state = TrainState.create(
apply_fn=self.vf.apply,
Expand Down Expand Up @@ -249,9 +251,9 @@ def predict_all(self, observation: np.ndarray, key: jax.Array) -> np.ndarray:

@staticmethod
@jax.jit
def _predict_all(actor_state, vf_state, obervations, key):
dist = actor_state.apply_fn(actor_state.params, obervations)
def _predict_all(actor_state, vf_state, observations, key):
dist = actor_state.apply_fn(actor_state.params, observations)
actions = dist.sample(seed=key)
log_probs = dist.log_prob(actions)
values = vf_state.apply_fn(vf_state.params, obervations).flatten()
values = vf_state.apply_fn(vf_state.params, observations).flatten()
return actions, log_probs, values
8 changes: 5 additions & 3 deletions sbx/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,9 +293,11 @@ def train(self) -> None:
# self.logger.record("train/clip_fraction", np.mean(clip_fractions))
self.logger.record("train/pg_loss", pg_loss.item())
self.logger.record("train/explained_variance", explained_var)
# if hasattr(self.policy, "log_std"):
# self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())

try:
log_std = self.policy.actor_state.params["params"]["log_std"]
self.logger.record("train/std", np.exp(log_std).mean().item())
except KeyError:
pass
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
self.logger.record("train/clip_range", clip_range)
# if self.clip_range_vf is not None:
Expand Down
18 changes: 15 additions & 3 deletions sbx/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None:
self.policy.actor_state,
self.ent_coef_state,
self.key,
(actor_loss_value, qf_loss_value, ent_coef_value),
(actor_loss_value, qf_loss_value, ent_coef_loss_value, ent_coef_value),
) = self._train(
self.gamma,
self.tau,
Expand All @@ -238,6 +238,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None:
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
self.logger.record("train/actor_loss", actor_loss_value.item())
self.logger.record("train/critic_loss", qf_loss_value.item())
self.logger.record("train/ent_coef_loss", ent_coef_loss_value.item())
self.logger.record("train/ent_coef", ent_coef_value.item())

@staticmethod
Expand Down Expand Up @@ -391,6 +392,7 @@ def _train(
"actor_loss": jnp.array(0.0),
"qf_loss": jnp.array(0.0),
"ent_coef_loss": jnp.array(0.0),
"ent_coef_value": jnp.array(0.0),
},
}

Expand Down Expand Up @@ -438,7 +440,12 @@ def one_update(i: int, carry: dict[str, Any]) -> dict[str, Any]:
target_entropy,
key,
)
info = {"actor_loss": actor_loss_value, "qf_loss": qf_loss_value, "ent_coef_loss": ent_coef_loss_value}
info = {
"actor_loss": actor_loss_value,
"qf_loss": qf_loss_value,
"ent_coef_loss": ent_coef_loss_value,
"ent_coef_value": ent_coef_value,
}

return {
"actor_state": actor_state,
Expand All @@ -455,5 +462,10 @@ def one_update(i: int, carry: dict[str, Any]) -> dict[str, Any]:
update_carry["actor_state"],
update_carry["ent_coef_state"],
update_carry["key"],
(update_carry["info"]["actor_loss"], update_carry["info"]["qf_loss"], update_carry["info"]["ent_coef_loss"]),
(
update_carry["info"]["actor_loss"],
update_carry["info"]["qf_loss"],
update_carry["info"]["ent_coef_loss"],
update_carry["info"]["ent_coef_value"],
),
)
4 changes: 2 additions & 2 deletions sbx/td3/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ def forward(self, obs: np.ndarray, deterministic: bool = True) -> np.ndarray:

@staticmethod
@jax.jit
def select_action(actor_state, obervations) -> np.ndarray:
return actor_state.apply_fn(actor_state.params, obervations)
def select_action(actor_state, observations) -> np.ndarray:
return actor_state.apply_fn(actor_state.params, observations)

def _predict(self, observation: np.ndarray, deterministic: bool = True) -> np.ndarray: # type: ignore[override]
# TD3 is always deterministic
Expand Down
6 changes: 5 additions & 1 deletion sbx/tqc/tqc.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None:
self.policy.actor_state,
self.ent_coef_state,
self.key,
(qf1_loss_value, qf2_loss_value, actor_loss_value, ent_coef_value),
(qf1_loss_value, qf2_loss_value, actor_loss_value, ent_coef_loss_value, ent_coef_value),
) = self._train(
self.gamma,
self.tau,
Expand All @@ -244,6 +244,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None:
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
self.logger.record("train/actor_loss", actor_loss_value.item())
self.logger.record("train/critic_loss", qf1_loss_value.item())
self.logger.record("train/ent_coef_loss", ent_coef_loss_value.item())
self.logger.record("train/ent_coef", ent_coef_value.item())

@staticmethod
Expand Down Expand Up @@ -455,6 +456,7 @@ def _train(
"qf1_loss": jnp.array(0.0),
"qf2_loss": jnp.array(0.0),
"ent_coef_loss": jnp.array(0.0),
"ent_coef_value": jnp.array(0.0),
},
}

Expand Down Expand Up @@ -518,6 +520,7 @@ def one_update(i: int, carry: dict[str, Any]) -> dict[str, Any]:
"qf1_loss": qf1_loss_value,
"qf2_loss": qf2_loss_value,
"ent_coef_loss": ent_coef_loss_value,
"ent_coef_value": ent_coef_value,
}

return {
Expand All @@ -542,5 +545,6 @@ def one_update(i: int, carry: dict[str, Any]) -> dict[str, Any]:
update_carry["info"]["qf2_loss"],
update_carry["info"]["actor_loss"],
update_carry["info"]["ent_coef_loss"],
update_carry["info"]["ent_coef_value"],
),
)
2 changes: 1 addition & 1 deletion sbx/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.19.0
0.20.0
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@
packages=[package for package in find_packages() if package.startswith("sbx")],
package_data={"sbx": ["py.typed", "version.txt"]},
install_requires=[
"stable_baselines3>=2.4.0,<3.0",
"jax>=0.4.12",
"stable_baselines3>=2.5.0,<3.0",
"jax>=0.4.24",
"jaxlib",
"flax",
"optax",
Expand Down