9
9
from stable_baselines3 .common .torch_layers import (
10
10
BaseFeaturesExtractor ,
11
11
FlattenExtractor ,
12
+ CombinedExtractor ,
12
13
create_mlp ,
13
14
get_actor_critic_arch ,
14
15
)
@@ -529,3 +530,77 @@ def set_training_mode(self, mode: bool) -> None:
529
530
530
531
531
532
MlpPolicy = CrossQPolicy
533
+
534
+ class MultiInputPolicy (CrossQPolicy ):
535
+ """
536
+ Policy class (with both actor and critic) for CrossQ.
537
+
538
+ :param observation_space: Observation space
539
+ :param action_space: Action space
540
+ :param lr_schedule: Learning rate schedule (could be constant)
541
+ :param net_arch: The specification of the policy and value networks.
542
+ :param activation_fn: Activation function
543
+ :param use_sde: Whether to use State Dependent Exploration or not
544
+ :param log_std_init: Initial value for the log standard deviation
545
+ :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
546
+ a positive standard deviation (cf paper). It allows to keep variance
547
+ above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
548
+ :param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
549
+ :param features_extractor_class: Features extractor to use.
550
+ :param normalize_images: Whether to normalize images or not,
551
+ dividing by 255.0 (True by default)
552
+ :param optimizer_class: The optimizer to use,
553
+ ``th.optim.Adam`` by default
554
+ :param optimizer_kwargs: Additional keyword arguments,
555
+ excluding the learning rate, to pass to the optimizer
556
+ :param n_quantiles: Number of quantiles for the critic.
557
+ :param n_critics: Number of critic networks to create.
558
+ :param share_features_extractor: Whether to share or not the features extractor
559
+ between the actor and the critic (this saves computation time)
560
+ """
561
+
562
+ def __init__ (
563
+ self ,
564
+ observation_space : spaces .Space ,
565
+ action_space : spaces .Box ,
566
+ lr_schedule : Schedule ,
567
+ net_arch : Optional [Union [list [int ], dict [str , list [int ]]]] = None ,
568
+ activation_fn : type [nn .Module ] = nn .ReLU ,
569
+ batch_norm : bool = True ,
570
+ batch_norm_momentum : float = 0.01 , # Note: Jax implementation is 1 - momentum = 0.99
571
+ batch_norm_eps : float = 0.001 ,
572
+ renorm_warmup_steps : int = 100_000 ,
573
+ use_sde : bool = False ,
574
+ log_std_init : float = - 3 ,
575
+ use_expln : bool = False ,
576
+ clip_mean : float = 2.0 ,
577
+ features_extractor_class : type [BaseFeaturesExtractor ] = CombinedExtractor ,
578
+ features_extractor_kwargs : Optional [dict [str , Any ]] = None ,
579
+ normalize_images : bool = True ,
580
+ optimizer_class : type [th .optim .Optimizer ] = th .optim .Adam ,
581
+ optimizer_kwargs : Optional [dict [str , Any ]] = None ,
582
+ n_critics : int = 2 ,
583
+ share_features_extractor : bool = False ,
584
+ ):
585
+ super ().__init__ (
586
+ observation_space ,
587
+ action_space ,
588
+ lr_schedule ,
589
+ net_arch ,
590
+ activation_fn ,
591
+ batch_norm ,
592
+ batch_norm_momentum ,
593
+ batch_norm_eps ,
594
+ renorm_warmup_steps ,
595
+ use_sde ,
596
+ log_std_init ,
597
+ use_expln ,
598
+ clip_mean ,
599
+ features_extractor_class ,
600
+ features_extractor_kwargs ,
601
+ normalize_images ,
602
+ optimizer_class ,
603
+ optimizer_kwargs ,
604
+ n_critics ,
605
+ share_features_extractor ,
606
+ )
0 commit comments