Skip to content

Commit

Permalink
communicator in categorical_marl policy
Browse files Browse the repository at this point in the history
  • Loading branch information
wenzhangliu committed Feb 12, 2025
1 parent d22ae8c commit bf9c48c
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions xuance/torch/policies/categorical_marl.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class MAAC_Policy(Module):
representation_actor (ModuleDict): A dict of representation modules for each agent's actor.
representation_critic (ModuleDict): A dict of representation modules for each agent's critic.
mixer (Module): The mixer module that mix together the individual values to the total value.
communicator (Optional[Module]): The communicator module.
actor_hidden_size (Sequence[int]): A list of hidden layer sizes for actor network.
critic_hidden_size (Sequence[int]): A list of hidden layer sizes for critic network.
normalize (Optional[ModuleType]): The layer normalization over a minibatch of inputs.
Expand All @@ -40,6 +41,7 @@ def __init__(self,
representation_actor: ModuleDict,
representation_critic: ModuleDict,
mixer: Optional[Module] = None,
communicator: Optional[Module] = None,
actor_hidden_size: Sequence[int] = None,
critic_hidden_size: Sequence[int] = None,
normalize: Optional[ModuleType] = None,
Expand Down Expand Up @@ -74,6 +76,7 @@ def __init__(self,
self.critic[key] = CriticNet(dim_critic_in, critic_hidden_size, normalize, initialize, activation, device)

self.mixer = mixer
self.comunicator = communicator

# Prepare DDP module.
self.distributed_training = use_distributed_training
Expand All @@ -90,15 +93,18 @@ def __init__(self,
self.critic[key] = DistributedDataParallel(module=self.critic[key], device_ids=[self.rank])
if self.mixer is not None:
self.mixer = DistributedDataParallel(module=self.mixer, device_ids=[self.rank])
if self.comunicator is not None:
self.comunicator = DistributedDataParallel(module=self.comunicator, device_ids=[self.rank])

@property
def parameters_model(self):
parameters = list(self.actor_representation.parameters()) + list(self.actor.parameters()) + list(
self.critic_representation.parameters()) + list(self.critic.parameters())
if self.mixer is None:
return parameters
else:
return parameters + list(self.mixer.parameters())
if self.mixer is not None:
parameters += list(self.mixer.parameters())
if self.comunicator is not None:
parameters += list(self.comunicator.parameters())
return parameters

def _get_actor_critic_input(self, dim_action, dim_actor_rep, dim_critic_rep, n_agents):
"""
Expand Down

0 comments on commit bf9c48c

Please sign in to comment.