Skip to content

Commit

Permalink
sacdis for mindspore
Browse files Browse the repository at this point in the history
  • Loading branch information
wenzhangliu committed Dec 14, 2023
1 parent 9a49ebb commit bee5e11
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 119 deletions.
6 changes: 3 additions & 3 deletions demo_mindspore.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

def parse_args():
parser = argparse.ArgumentParser("Run a demo.")
parser.add_argument("--method", type=str, default="qtran")
parser.add_argument("--env", type=str, default="mpe")
parser.add_argument("--env-id", type=str, default="simple_spread_v3")
parser.add_argument("--method", type=str, default="sac")
parser.add_argument("--env", type=str, default="classic_control")
parser.add_argument("--env-id", type=str, default="CartPole-v1")
parser.add_argument("--test", type=int, default=0)
parser.add_argument("--device", type=str, default="GPU")
parser.add_argument("--dl_toolbox", type=str, default="mindspore")
Expand Down
200 changes: 97 additions & 103 deletions xuance/mindspore/agents/policy_gradient/sacdis_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,14 @@
class SACDIS_Agent(Agent):
def __init__(self,
config: Namespace,
envs: VecEnv,
envs: DummyVecEnv_Gym,
policy: nn.Cell,
optimizer: Sequence[nn.Optimizer],
scheduler):
self.config = config
self.comm = MPI.COMM_WORLD
self.nenvs = envs.num_envs
self.render = config.render
self.n_envs = envs.num_envs

self.gamma = config.gamma
self.use_obsnorm = config.use_obsnorm
self.use_rewnorm = config.use_rewnorm
self.obsnorm_range = config.obsnorm_range
self.rewnorm_range = config.rewnorm_range

self.train_frequency = config.training_frequency
self.start_training = config.start_training
self.start_noise = config.start_noise
Expand All @@ -28,111 +21,112 @@ def __init__(self,

self.observation_space = envs.observation_space
self.action_space = envs.action_space
self.representation_info_shape = policy.representation.output_shapes
self.auxiliary_info_shape = {}

writer = SummaryWriter(config.logdir)
memory = DummyOffPolicyBuffer(self.observation_space,
self.action_space,
self.representation_info_shape,
self.auxiliary_info_shape,
self.nenvs,
config.nsize,
config.batchsize)
self.atari = True if config.env_name == "Atari" else False
Buffer = DummyOffPolicyBuffer_Atari if self.atari else DummyOffPolicyBuffer
memory = Buffer(self.observation_space,
self.action_space,
self.auxiliary_info_shape,
self.n_envs,
config.n_size,
config.batch_size)
learner = SACDIS_Learner(policy,
optimizer,
scheduler,
writer,
config.modeldir,
config.gamma,
config.tau)

self.obs_rms = RunningMeanStd(shape=space2shape(self.observation_space), comm=self.comm, use_mpi=False)
self.ret_rms = RunningMeanStd(shape=(), comm=self.comm, use_mpi=False)
super(SACDIS_Agent, self).__init__(envs, policy, memory, learner, writer, config.logdir, config.modeldir)

def _process_observation(self, observations):
if self.use_obsnorm:
if isinstance(self.observation_space, gym.spaces.Dict):
for key in self.observation_space.spaces.keys():
observations[key] = np.clip(
(observations[key] - self.obs_rms.mean[key]) / (self.obs_rms.std[key] + EPS),
-self.obsnorm_range, self.obsnorm_range)
else:
observations = np.clip((observations - self.obs_rms.mean) / (self.obs_rms.std + EPS),
-self.obsnorm_range, self.obsnorm_range)
return observations
return observations

def _process_reward(self, rewards):
if self.use_rewnorm:
std = np.clip(self.ret_rms.std, 0.1, 100)
return np.clip(rewards / std, -self.rewnorm_range, self.rewnorm_range)
return rewards
optimizer,
scheduler,
config.model_dir,
config.gamma,
config.tau)
super(SACDIS_Agent, self).__init__(config, envs, policy, memory, learner, config.log_dir, config.model_dir)

def _action(self, obs):
states, act_probs = self.policy(ms.Tensor(obs))
_, act_probs = self.policy(ms.Tensor(obs))
acts = self.policy.actor.sample(act_probs).asnumpy()
return acts

if context._get_mode() == 0:
return {"state": states[0].asnumpy()}, acts
else:
for key in states.keys():
states[key] = states[key].asnumpy()
return states, acts

def train(self, train_steps=10000):
episodes = np.zeros((self.nenvs,), np.int32)
scores = np.zeros((self.nenvs,), np.float32)
returns = np.zeros((self.nenvs,), np.float32)

obs = self.envs.reset()
for step in tqdm(range(train_steps)):
def train(self, train_steps):
obs = self.envs.buf_obs
for _ in tqdm(range(train_steps)):
step_info = {}
self.obs_rms.update(obs)
obs = self._process_observation(obs)
states, acts = self._action(obs)
# if step < self.start_training:
# acts = np.clip(np.random.randn(self.nenvs, self.action_space.shape[0]), -1, 1)
next_obs, rewards, dones, infos = self.envs.step(acts)
if self.render: self.envs.render()
self.memory.store(obs, acts, self._process_reward(rewards), dones, self._process_observation(next_obs),
states, {})
if step > self.start_training and step % self.train_frequency == 0:
obs_batch, act_batch, rew_batch, terminal_batch, next_batch, _, _ = self.memory.sample()
self.learner.update(obs_batch, act_batch, rew_batch, next_batch, terminal_batch)
scores += rewards
returns = self.gamma * returns + rewards
obs = next_obs
self.noise_scale = self.start_noise - (self.start_noise - self.end_noise) / train_steps
for i in range(self.nenvs):
if dones[i] == True:
self.ret_rms.update(returns[i:i + 1])
self.writer.add_scalars("returns-episode", {"env-%d" % i: scores[i]}, episodes[i])
self.writer.add_scalars("returns-step", {"env-%d" % i: scores[i]}, step)
scores[i] = 0
returns[i] = 0
episodes[i] += 1
acts = self._action(obs)
next_obs, rewards, terminals, trunctions, infos = self.envs.step(acts)
self.memory.store(obs, acts, self._process_reward(rewards), terminals, self._process_observation(next_obs))
if self.current_step > self.start_training and self.current_step % self.train_frequency == 0:
obs_batch, act_batch, rew_batch, terminal_batch, next_batch = self.memory.sample()
step_info = self.learner.update(obs_batch, act_batch, rew_batch, next_batch, terminal_batch)
self.log_infos(step_info, self.current_step)

if step % 50000 == 0 or step == train_steps - 1:
self.save_model()
np.save(self.modeldir + "/obs_rms.npy",
{'mean': self.obs_rms.mean, 'std': self.obs_rms.std, 'count': self.obs_rms.count})
self.returns = self.gamma * self.returns + rewards
obs = next_obs
for i in range(self.n_envs):
if terminals[i] or trunctions[i]:
if self.atari and (~trunctions[i]):
pass
else:
obs[i] = infos[i]["reset_obs"]
self.ret_rms.update(self.returns[i:i + 1])
self.returns[i] = 0.0
self.current_episode[i] += 1
if self.use_wandb:
step_info["Episode-Steps/env-%d" % i] = infos[i]["episode_step"]
step_info["Train-Episode-Rewards/env-%d" % i] = infos[i]["episode_score"]
else:
step_info["Episode-Steps"] = {"env-%d" % i: infos[i]["episode_step"]}
step_info["Train-Episode-Rewards"] = {"env-%d" % i: infos[i]["episode_score"]}
self.log_infos(step_info, self.current_step)
self.current_step += self.n_envs

def test(self, test_steps=10000, load_model=None):
self.load_model(self.modeldir)
scores = np.zeros((self.nenvs,), np.float32)
returns = np.zeros((self.nenvs,), np.float32)
def test(self, env_fn, test_episodes):
test_envs = env_fn()
num_envs = test_envs.num_envs
videos, episode_videos = [[] for _ in range(num_envs)], []
current_episode, scores, best_score = 0, [], -np.inf
obs, infos = test_envs.reset()
if self.config.render_mode == "rgb_array" and self.render:
images = test_envs.render(self.config.render_mode)
for idx, img in enumerate(images):
videos[idx].append(img)

obs = self.envs.reset()
for _ in tqdm(range(test_steps)):
while current_episode < test_episodes:
self.obs_rms.update(obs)
obs = self._process_observation(obs)
states, acts = self._action(obs)
next_obs, rewards, dones, infos = self.envs.step(acts)
self.envs.render()
scores += rewards
returns = self.gamma * returns + rewards
acts = self._action(obs)
next_obs, rewards, terminals, trunctions, infos = test_envs.step(acts)
if self.config.render_mode == "rgb_array" and self.render:
images = test_envs.render(self.config.render_mode)
for idx, img in enumerate(images):
videos[idx].append(img)

obs = next_obs
for i in range(self.nenvs):
if dones[i] == True:
scores[i], returns[i] = 0, 0
for i in range(num_envs):
if terminals[i] or trunctions[i]:
if self.atari and (~trunctions[i]):
pass
else:
obs[i] = infos[i]["reset_obs"]
scores.append(infos[i]["episode_score"])
current_episode += 1
if best_score < infos[i]["episode_score"]:
best_score = infos[i]["episode_score"]
episode_videos = videos[i].copy()
if self.config.test_mode:
print("Episode: %d, Score: %.2f" % (current_episode, infos[i]["episode_score"]))

if self.config.render_mode == "rgb_array" and self.render:
# time, height, width, channel -> time, channel, height, width
videos_info = {"Videos_Test": np.array([episode_videos], dtype=np.uint8).transpose((0, 1, 4, 2, 3))}
self.log_videos(info=videos_info, fps=50, x_index=self.current_step)

if self.config.test_mode:
print("Best Score: %.2f" % (best_score))

test_info = {
"Test-Episode-Rewards/Mean-Score": np.mean(scores),
"Test-Episode-Rewards/Std-Score": np.std(scores)
}
self.log_infos(test_info, self.current_step)

test_envs.close()

return scores
22 changes: 13 additions & 9 deletions xuance/mindspore/learners/policy_gradient/sacdis_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def __init__(self, backbone):
self._backbone = backbone

def construct(self, x):
_, action_prob, log_pi, policy_q = self._backbone.Qpolicy(x)
action_prob, log_pi, policy_q = self._backbone.Qpolicy(x)
inside_term = 0.01 * log_pi - policy_q
p_loss = (action_prob * inside_term).sum(axis=1).mean()
return p_loss
Expand All @@ -30,13 +30,12 @@ def __init__(self,
policy: nn.Cell,
optimizers: nn.Optimizer,
schedulers: Optional[nn.exponential_decay_lr] = None,
summary_writer: Optional[SummaryWriter] = None,
modeldir: str = "./",
model_dir: str = "./",
gamma: float = 0.99,
tau: float = 0.01):
self.tau = tau
self.gamma = gamma
super(SACDIS_Learner, self).__init__(policy, optimizers, schedulers, summary_writer, modeldir)
super(SACDIS_Learner, self).__init__(policy, optimizers, schedulers, model_dir)
# define mindspore trainers
self.actor_loss_net = self.ActorNetWithLossCell(policy)
self.actor_train = nn.TrainOneStepCell(self.actor_loss_net, optimizers['actor'])
Expand All @@ -55,7 +54,7 @@ def update(self, obs_batch, act_batch, rew_batch, next_batch, terminal_batch):
ter_batch = Tensor(terminal_batch).view(-1, 1)
act_batch = self._unsqueeze(act_batch, -1)

_, action_prob_next, log_pi_next, target_q = self.policy.Qtarget(next_batch)
action_prob_next, log_pi_next, target_q = self.policy.Qtarget(next_batch)
target_q = action_prob_next * (target_q - 0.01 * log_pi_next)
target_q = self._unsqueeze(target_q.sum(axis=1), -1)
rew = self._unsqueeze(rew_batch, -1)
Expand All @@ -68,7 +67,12 @@ def update(self, obs_batch, act_batch, rew_batch, next_batch, terminal_batch):

actor_lr = self.scheduler['actor'](self.iterations).asnumpy()
critic_lr = self.scheduler['critic'](self.iterations).asnumpy()
self.writer.add_scalar("Qloss", q_loss.asnumpy(), self.iterations)
self.writer.add_scalar("Ploss", p_loss.asnumpy(), self.iterations)
self.writer.add_scalar("actor_lr", actor_lr, self.iterations)
self.writer.add_scalar("critic_lr", critic_lr, self.iterations)

info = {
"Qloss": q_loss.asnumpy(),
"Ploss": p_loss.asnumpy(),
"actor_lr": actor_lr,
"critic_lr": critic_lr
}

return info
14 changes: 10 additions & 4 deletions xuance/mindspore/policies/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ def __init__(self,
super(SACDISPolicy, self).__init__()
self.action_dim = action_space.n
self.representation = representation
self.representation_critic = copy.deepcopy(representation)
self.representation_info_shape = self.representation.output_shapes
try:
self.representation_params = self.representation.trainable_params()
Expand All @@ -267,6 +268,7 @@ def __init__(self,
normalize, initialize, activation)
self.critic = CriticNet_SACDIS(representation.output_shapes['state'][0], self.action_dim, critic_hidden_size,
initialize, activation)
self.target_representation_critic = copy.deepcopy(self.representation_critic)
self.target_critic = copy.deepcopy(self.critic)
self.actor_params = self.representation_params + self.actor.trainable_params()
self._log = ms.ops.Log()
Expand All @@ -283,20 +285,24 @@ def action(self, observation: ms.tensor):

def Qtarget(self, observation: ms.tensor):
outputs = self.representation(observation)
act_prob = self.actor(outputs[0])
outputs_critic = self.target_representation_critic(observation)
act_prob = self.actor(outputs['state'])
log_action_prob = self._log(act_prob + 1e-10)
return outputs, act_prob, log_action_prob, self.target_critic(outputs[0])
return act_prob, log_action_prob, self.target_critic(outputs_critic['state'])

def Qaction(self, observation: ms.tensor):
outputs = self.representation(observation)
outputs = self.representation_critic(observation)
return outputs, self.critic(outputs['state'])

def Qpolicy(self, observation: ms.tensor):
outputs = self.representation(observation)
outputs_critic = self.representation_critic(observation)
act_prob = self.actor(outputs['state'])
log_action_prob = self._log(act_prob + 1e-10)
return outputs, act_prob, log_action_prob, self.critic(outputs['state'])
return act_prob, log_action_prob, self.critic(outputs_critic['state'])

def soft_update(self, tau=0.005):
for ep, tp in zip(self.representation_critic.trainable_params(), self.target_representation_critic.trainable_params()):
tp.assign_value((tau * ep.data + (1 - tau) * tp.data))
for ep, tp in zip(self.critic.trainable_params(), self.target_critic.trainable_params()):
tp.assign_value((tau * ep.data + (1 - tau) * tp.data))

0 comments on commit bee5e11

Please sign in to comment.