Skip to content

Commit e0335d7

Browse files
committed
- added CrossQ support for MultiInputPolicy
1 parent e1ca24a commit e0335d7

File tree

2 files changed

+86
-2
lines changed

2 files changed

+86
-2
lines changed

sb3_contrib/crossq/crossq.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
1111
from torch.nn import functional as F
1212

13-
from sb3_contrib.crossq.policies import Actor, CrossQCritic, CrossQPolicy, MlpPolicy
13+
from sb3_contrib.crossq.policies import Actor, CrossQCritic, CrossQPolicy, MlpPolicy, MultiInputPolicy
1414

1515
SelfCrossQ = TypeVar("SelfCrossQ", bound="CrossQ")
1616

@@ -67,6 +67,7 @@ class CrossQ(OffPolicyAlgorithm):
6767

6868
policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = {
6969
"MlpPolicy": MlpPolicy,
70+
"MultiInputPolicy": MultiInputPolicy,
7071
# TODO: Implement CnnPolicy and MultiInputPolicy
7172
}
7273
policy: CrossQPolicy
@@ -235,7 +236,14 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:
235236
#
236237
# 2. From a computational perspective a single forward pass is simply more efficient than
237238
# two sequential forward passes.
238-
all_obs = th.cat([replay_data.observations, replay_data.next_observations], dim=0)
239+
240+
if isinstance(replay_data.observations, dict):
241+
all_obs = {
242+
key: th.cat([replay_data.observations[key], replay_data.next_observations[key]], dim=0)
243+
for key in replay_data.observations.keys()
244+
}
245+
else:
246+
all_obs = th.cat([replay_data.observations, replay_data.next_observations], dim=0)
239247
all_actions = th.cat([replay_data.actions, next_actions], dim=0)
240248
# Update critic BN stats
241249
self.critic.set_bn_training_mode(True)
@@ -331,3 +339,4 @@ def _get_torch_save_params(self) -> tuple[list[str], list[str]]:
331339
else:
332340
saved_pytorch_variables = ["ent_coef_tensor"]
333341
return state_dicts, saved_pytorch_variables
342+

sb3_contrib/crossq/policies.py

+75
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from stable_baselines3.common.torch_layers import (
1010
BaseFeaturesExtractor,
1111
FlattenExtractor,
12+
CombinedExtractor,
1213
create_mlp,
1314
get_actor_critic_arch,
1415
)
@@ -529,3 +530,77 @@ def set_training_mode(self, mode: bool) -> None:
529530

530531

531532
MlpPolicy = CrossQPolicy
533+
534+
class MultiInputPolicy(CrossQPolicy):
535+
"""
536+
Policy class (with both actor and critic) for CrossQ.
537+
538+
:param observation_space: Observation space
539+
:param action_space: Action space
540+
:param lr_schedule: Learning rate schedule (could be constant)
541+
:param net_arch: The specification of the policy and value networks.
542+
:param activation_fn: Activation function
543+
:param use_sde: Whether to use State Dependent Exploration or not
544+
:param log_std_init: Initial value for the log standard deviation
545+
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
546+
a positive standard deviation (cf paper). It allows to keep variance
547+
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
548+
:param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
549+
:param features_extractor_class: Features extractor to use.
550+
:param normalize_images: Whether to normalize images or not,
551+
dividing by 255.0 (True by default)
552+
:param optimizer_class: The optimizer to use,
553+
``th.optim.Adam`` by default
554+
:param optimizer_kwargs: Additional keyword arguments,
555+
excluding the learning rate, to pass to the optimizer
556+
:param n_quantiles: Number of quantiles for the critic.
557+
:param n_critics: Number of critic networks to create.
558+
:param share_features_extractor: Whether to share or not the features extractor
559+
between the actor and the critic (this saves computation time)
560+
"""
561+
562+
def __init__(
563+
self,
564+
observation_space: spaces.Space,
565+
action_space: spaces.Box,
566+
lr_schedule: Schedule,
567+
net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None,
568+
activation_fn: type[nn.Module] = nn.ReLU,
569+
batch_norm: bool = True,
570+
batch_norm_momentum: float = 0.01, # Note: Jax implementation is 1 - momentum = 0.99
571+
batch_norm_eps: float = 0.001,
572+
renorm_warmup_steps: int = 100_000,
573+
use_sde: bool = False,
574+
log_std_init: float = -3,
575+
use_expln: bool = False,
576+
clip_mean: float = 2.0,
577+
features_extractor_class: type[BaseFeaturesExtractor] = CombinedExtractor,
578+
features_extractor_kwargs: Optional[dict[str, Any]] = None,
579+
normalize_images: bool = True,
580+
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
581+
optimizer_kwargs: Optional[dict[str, Any]] = None,
582+
n_critics: int = 2,
583+
share_features_extractor: bool = False,
584+
):
585+
super().__init__(
586+
observation_space,
587+
action_space,
588+
lr_schedule,
589+
net_arch,
590+
activation_fn,
591+
batch_norm,
592+
batch_norm_momentum,
593+
batch_norm_eps,
594+
renorm_warmup_steps,
595+
use_sde,
596+
log_std_init,
597+
use_expln,
598+
clip_mean,
599+
features_extractor_class,
600+
features_extractor_kwargs,
601+
normalize_images,
602+
optimizer_class,
603+
optimizer_kwargs,
604+
n_critics,
605+
share_features_extractor,
606+
)

0 commit comments

Comments
 (0)