Skip to content

Commit

Permalink
base_comm & MAAC_Policy_With_Communication
Browse files Browse the repository at this point in the history
  • Loading branch information
wenzhangliu committed Feb 14, 2025
1 parent db19bbd commit 42f250c
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 13 deletions.
32 changes: 27 additions & 5 deletions xuance/torch/communications/base_comm.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,36 @@
import torch
import torch.nn as nn
from xuance.common import Optional, Callable, Union, Sequence
from xuance.torch import Module, Tensor
from xuance.torch.utils import mlp_block, ModuleType


class BaseComm(Module):
def __init__(self, n_agents, msg_dims, **kwargs):
def __init__(self,
state_dim: int,
n_agents: int,
hidden_sizes_comm: Sequence[int],
msg_dim: int,
normalize: Optional[ModuleType] = None,
initialize: Optional[Callable[..., Tensor]] = None,
activation: Optional[ModuleType] = None,
device: Optional[Union[str, int, torch.device]] = None,
**kwargs):
super().__init__()
self.n_agents = n_agents
self.msg_dims = msg_dims

def forward(self, msg: Tensor, **kwargs):
raise NotImplementedError
self.msg_dim = msg_dim
self.hidden_sizes_comm = hidden_sizes_comm
layers_ = []
input_shape = (state_dim,)
for h in hidden_sizes_comm:
mlp, input_shape = mlp_block(input_shape[0], h, normalize, activation, initialize, device)
layers_.extend(mlp)
layers_.extend(mlp_block(input_shape[0], msg_dim, None, None, initialize, device)[0])
self.msg_encoder = nn.Sequential(*layers_)

def forward(self, hidden_features: Tensor):
encoded_msg = self.msg_encoder(hidden_features)
return encoded_msg


class NoneComm(Module):
Expand Down
156 changes: 148 additions & 8 deletions xuance/torch/policies/categorical_marl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from copy import deepcopy
from operator import itemgetter
from gym.spaces import Discrete
from torch.distributions import Categorical
from xuance.common import Sequence, Optional, Callable, Union, Dict, List
from xuance.torch.policies import CategoricalActorNet, ActorNet
from xuance.torch.policies.core import CriticNet, BasicQhead
Expand All @@ -24,7 +23,6 @@ class MAAC_Policy(Module):
representation_actor (ModuleDict): A dict of representation modules for each agent's actor.
representation_critic (ModuleDict): A dict of representation modules for each agent's critic.
mixer (Module): The mixer module that mix together the individual values to the total value.
communicator (Optional[Module]): The communicator module.
actor_hidden_size (Sequence[int]): A list of hidden layer sizes for actor network.
critic_hidden_size (Sequence[int]): A list of hidden layer sizes for critic network.
normalize (Optional[ModuleType]): The layer normalization over a minibatch of inputs.
Expand All @@ -41,7 +39,6 @@ def __init__(self,
representation_actor: ModuleDict,
representation_critic: ModuleDict,
mixer: Optional[Module] = None,
communicator: Optional[Module] = None,
actor_hidden_size: Sequence[int] = None,
critic_hidden_size: Sequence[int] = None,
normalize: Optional[ModuleType] = None,
Expand Down Expand Up @@ -76,7 +73,6 @@ def __init__(self,
self.critic[key] = CriticNet(dim_critic_in, critic_hidden_size, normalize, initialize, activation, device)

self.mixer = mixer
self.comunicator = communicator

# Prepare DDP module.
self.distributed_training = use_distributed_training
Expand All @@ -93,17 +89,13 @@ def __init__(self,
self.critic[key] = DistributedDataParallel(module=self.critic[key], device_ids=[self.rank])
if self.mixer is not None:
self.mixer = DistributedDataParallel(module=self.mixer, device_ids=[self.rank])
if self.comunicator is not None:
self.comunicator = DistributedDataParallel(module=self.comunicator, device_ids=[self.rank])

@property
def parameters_model(self):
parameters = list(self.actor_representation.parameters()) + list(self.actor.parameters()) + list(
self.critic_representation.parameters()) + list(self.critic.parameters())
if self.mixer is not None:
parameters += list(self.mixer.parameters())
if self.comunicator is not None:
parameters += list(self.comunicator.parameters())
return parameters

def _get_actor_critic_input(self, dim_action, dim_actor_rep, dim_critic_rep, n_agents):
Expand Down Expand Up @@ -284,6 +276,154 @@ def value_tot(self, values_n: Tensor, global_state=None):
return values_n if self.mixer is None else self.mixer(values_n, global_state)


class MAAC_Policy_With_Communication(MAAC_Policy):
def __init__(self,
action_space: Optional[Dict[str, Discrete]],
n_agents: int,
representation_actor: ModuleDict,
representation_critic: ModuleDict,
mixer: Optional[Module] = None,
communicators: Optional[Module] = None,
actor_hidden_size: Sequence[int] = None,
critic_hidden_size: Sequence[int] = None,
normalize: Optional[ModuleType] = None,
initialize: Optional[Callable[..., Tensor]] = None,
activation: Optional[ModuleType] = None,
device: Optional[Union[str, int, torch.device]] = None,
use_distributed_training: bool = False,
**kwargs):
self.communicators = communicators
self.msg_dim = self.communicators.msg_dim
super().__init__(action_space, n_agents, representation_actor, representation_critic, mixer,
actor_hidden_size, critic_hidden_size, normalize, initialize,
activation, device, use_distributed_training, **kwargs)
if self.distributed_training:
self.communicators = DistributedDataParallel(module=self.communicators, device_ids=[self.rank])

@property
def parameters_model(self):
parameters = list(self.actor_representation.parameters()) + list(self.actor.parameters()) + list(
self.critic_representation.parameters()) + list(self.critic.parameters())
if self.mixer is not None:
parameters += list(self.mixer.parameters())
parameters += list(self.communicators.parameters())
return parameters

def _get_actor_critic_input(self, dim_action, dim_actor_rep, dim_critic_rep, n_agents):
"""
Returns the input dimensions of actor netwrok and critic networks.
Parameters:
dim_action: The dimension of actions.
dim_actor_rep: The dimension of the output of actor presentation.
dim_critic_rep: The dimension of the output of critic presentation.
n_agents: The number of agents.
Returns:
dim_actor_in: The dimension of input of the actor networks.
dim_actor_out: The dimension of output of the actor networks.
dim_critic_in: The dimension of the input of critic networks.
dim_critic_out: The dimension of the output of critic networks.
"""
dim_actor_in, dim_actor_out = dim_actor_rep + self.msg_dim, dim_action
dim_critic_in, dim_critic_out = dim_critic_rep + self.msg_dim, dim_action
if self.use_parameter_sharing:
dim_actor_in += n_agents
dim_critic_in += n_agents
return dim_actor_in, dim_actor_out, dim_critic_in, dim_critic_out

def forward(self, observation: Dict[str, Tensor], agent_ids: Optional[Tensor] = None,
avail_actions: Dict[str, Tensor] = None, agent_key: str = None,
rnn_hidden: Optional[Dict[str, List[Tensor]]] = None):
"""
Returns actions of the policy.
Parameters:
observation (Dict[str, Tensor]): The input observations for the policies.
agent_ids (Tensor): The agents' ids (for parameter sharing).
avail_actions (Dict[str, Tensor]): Actions mask values, default is None.
agent_key (str): Calculate actions for specified agent.
rnn_hidden (Optional[Dict[str, List[Tensor]]]): The RNN hidden states of actor representation.
Returns:
rnn_hidden_new (Optional[Dict[str, List[Tensor]]]): The new RNN hidden states of actor representation.
pi_dists (dict): The stochastic policy distributions.
"""
rnn_hidden_new, pi_dists = {}, {}
agent_list = self.model_keys if agent_key is None else [agent_key]

if avail_actions is not None:
avail_actions = {key: Tensor(avail_actions[key]) for key in agent_list}

for key in agent_list:
if self.use_rnn:
outputs = self.actor_representation[key](observation[key], *rnn_hidden[key])
rnn_hidden_new[key] = (outputs['rnn_hidden'], outputs['rnn_cell'])
else:
outputs = self.actor_representation[key](observation[key])
rnn_hidden_new[key] = [None, None]

if self.use_parameter_sharing:
communicator_input = torch.concat([outputs['state'], agent_ids], dim=-1)
else:
communicator_input = outputs['state']

msg_to_send = self.communicators(communicator_input)
msg_receive = self.communicators.receive(msg_to_send, key)

if self.use_parameter_sharing:
actor_input = torch.concat([outputs['state'], msg_receive, agent_ids], dim=-1)
else:
actor_input = torch.concat([outputs['state'], msg_receive], dim=-1)

avail_actions_input = None if avail_actions is None else avail_actions[key]
pi_dists[key] = self.actor[key](actor_input, avail_actions_input)
return rnn_hidden_new, pi_dists

def get_values(self, observation: Dict[str, Tensor], agent_ids: Tensor = None, agent_key: str = None,
rnn_hidden: Optional[Dict[str, List[Tensor]]] = None):
"""
Get critic values via critic networks.
Parameters:
observation (Dict[str, Tensor]): The input observations for the policies.
agent_ids (Tensor): The agents' ids (for parameter sharing).
agent_key (str): Calculate actions for specified agent.
rnn_hidden (Optional[Dict[str, List[Tensor]]]): The RNN hidden states of critic representation.
Returns:
rnn_hidden_new (Optional[Dict[str, List[Tensor]]]): The new RNN hidden states of critic representation.
values (dict): The evaluated critic values.
"""
rnn_hidden_new, values = {}, {}
agent_list = self.model_keys if agent_key is None else [agent_key]

for key in agent_list:
if self.use_rnn:
outputs = self.critic_representation[key](observation[key], *rnn_hidden[key])
rnn_hidden_new[key] = (outputs['rnn_hidden'], outputs['rnn_cell'])
else:
outputs = self.critic_representation[key](observation[key])
rnn_hidden_new[key] = [None, None]

if self.use_parameter_sharing:
communicator_input = torch.concat([outputs['state'], agent_ids], dim=-1)
else:
communicator_input = outputs['state']

msg_to_send = self.communicators(communicator_input)
msg_receive = self.communicators.receive(msg_to_send, key)

if self.use_parameter_sharing:
critic_input = torch.concat([outputs['state'], msg_receive, agent_ids], dim=-1)
else:
critic_input = torch.concat([outputs['state'], msg_receive], dim=-1)

values[key] = self.critic[key](critic_input)

return rnn_hidden_new, values


class COMA_Policy(Module):
"""
COMA_Policy: Counterfactual Multi-Agent Actor-Critic Policy with categorical distributions.
Expand Down

0 comments on commit 42f250c

Please sign in to comment.