Skip to content

Commit 0212591

Browse files
committed
local seed
1 parent 0e23acd commit 0212591

File tree

2 files changed

+139
-29
lines changed

2 files changed

+139
-29
lines changed

tests/test_vmas.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2022-2024.
1+
# Copyright (c) 2022-2025.
22
# ProrokLab (https://www.proroklab.org/)
33
# All rights reserved.
44
import math
@@ -302,3 +302,21 @@ def test_vmas_differentiable(scenario, n_steps=10, n_envs=10):
302302

303303
loss = obs[-1].mean() + rews[-1].mean()
304304
grad = torch.autograd.grad(loss, first_action)
305+
306+
307+
def test_seeding():
308+
env = make_env(scenario="balance", num_envs=2, seed=0)
309+
env.seed(0)
310+
random_obs = env.reset()[0][0, 0]
311+
env.seed(0)
312+
assert random_obs == env.reset()[0][0, 0]
313+
env.seed(0)
314+
torch.manual_seed(1)
315+
assert random_obs == env.reset()[0][0, 0]
316+
317+
torch.manual_seed(0)
318+
random_obs = torch.randn(1)
319+
torch.manual_seed(0)
320+
env.seed(1)
321+
env.reset()
322+
assert random_obs == torch.randn(1)

vmas/simulator/environment/environment.py

Lines changed: 120 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# ProrokLab (https://www.proroklab.org/)
33
# All rights reserved.
44
import contextlib
5-
import functools
65
import math
76
import random
87
from ctypes import byref
@@ -47,22 +46,6 @@ def local_seed(vmas_random_state):
4746
random.setstate(py_state)
4847

4948

50-
def apply_local_seed(cls):
51-
"""Applies the local seed to all the functions."""
52-
for attr_name, attr_value in cls.__dict__.items():
53-
if callable(attr_value):
54-
wrapped = attr_value # Keep reference to original method
55-
56-
@functools.wraps(wrapped)
57-
def wrapper(self, *args, _wrapped=wrapped, **kwargs):
58-
with local_seed(cls.vmas_random_state):
59-
return _wrapped(self, *args, **kwargs)
60-
61-
setattr(cls, attr_name, wrapper)
62-
return cls
63-
64-
65-
@apply_local_seed
6649
class Environment(TorchVectorizedObject):
6750
metadata = {
6851
"render.modes": ["human", "rgb_array"],
@@ -74,6 +57,7 @@ class Environment(TorchVectorizedObject):
7457
random.getstate(),
7558
]
7659

60+
@local_seed(vmas_random_state)
7761
def __init__(
7862
self,
7963
scenario: BaseScenario,
@@ -108,7 +92,7 @@ def __init__(
10892
self.grad_enabled = grad_enabled
10993
self.terminated_truncated = terminated_truncated
11094

111-
observations = self.reset(seed=seed)
95+
observations = self._reset(seed=seed)
11296

11397
# configure spaces
11498
self.multidiscrete_actions = multidiscrete_actions
@@ -121,6 +105,7 @@ def __init__(
121105
self.visible_display = None
122106
self.text_lines = None
123107

108+
@local_seed(vmas_random_state)
124109
def reset(
125110
self,
126111
seed: Optional[int] = None,
@@ -132,21 +117,112 @@ def reset(
132117
Resets the environment in a vectorized way
133118
Returns observations for all envs and agents
134119
"""
120+
return self._reset(
121+
seed=seed,
122+
return_observations=return_observations,
123+
return_info=return_info,
124+
return_dones=return_dones,
125+
)
126+
127+
@local_seed(vmas_random_state)
128+
def reset_at(
129+
self,
130+
index: int,
131+
return_observations: bool = True,
132+
return_info: bool = False,
133+
return_dones: bool = False,
134+
):
135+
"""
136+
Resets the environment at index
137+
Returns observations for all agents in that environment
138+
"""
139+
return self._reset_at(
140+
index=index,
141+
return_observations=return_observations,
142+
return_info=return_info,
143+
return_dones=return_dones,
144+
)
145+
146+
@local_seed(vmas_random_state)
147+
def get_from_scenario(
148+
self,
149+
get_observations: bool,
150+
get_rewards: bool,
151+
get_infos: bool,
152+
get_dones: bool,
153+
dict_agent_names: Optional[bool] = None,
154+
):
155+
"""
156+
Get the environment data from the scenario
157+
158+
Args:
159+
get_observations (bool): whether to return the observations
160+
get_rewards (bool): whether to return the rewards
161+
get_infos (bool): whether to return the infos
162+
get_dones (bool): whether to return the dones
163+
dict_agent_names (bool, optional): whether to return the information in a dictionary with agent names as keys
164+
or in a list
165+
166+
Returns:
167+
The agents' data
168+
169+
"""
170+
return self._get_from_scenario(
171+
get_observations=get_observations,
172+
get_rewards=get_rewards,
173+
get_infos=get_infos,
174+
get_dones=get_dones,
175+
dict_agent_names=dict_agent_names,
176+
)
177+
178+
@local_seed(vmas_random_state)
179+
def seed(self, seed=None):
180+
"""
181+
Sets the seed for the environment
182+
Args:
183+
seed (int, optional): Seed for the environment. Defaults to None.
184+
185+
"""
186+
return self._seed(seed=seed)
187+
188+
@local_seed(vmas_random_state)
189+
def done(self):
190+
"""
191+
Get the done flags for the scenario.
192+
193+
Returns:
194+
Either terminated, truncated (if self.terminated_truncated==True) or terminated + truncated (if self.terminated_truncated==False)
195+
196+
"""
197+
return self._done()
198+
199+
def _reset(
200+
self,
201+
seed: Optional[int] = None,
202+
return_observations: bool = True,
203+
return_info: bool = False,
204+
return_dones: bool = False,
205+
):
206+
"""
207+
Resets the environment in a vectorized way
208+
Returns observations for all envs and agents
209+
"""
210+
135211
if seed is not None:
136-
self.seed(seed)
212+
self._seed(seed)
137213
# reset world
138214
self.scenario.env_reset_world_at(env_index=None)
139215
self.steps = torch.zeros(self.num_envs, device=self.device)
140216

141-
result = self.get_from_scenario(
217+
result = self._get_from_scenario(
142218
get_observations=return_observations,
143219
get_infos=return_info,
144220
get_rewards=False,
145221
get_dones=return_dones,
146222
)
147223
return result[0] if result and len(result) == 1 else result
148224

149-
def reset_at(
225+
def _reset_at(
150226
self,
151227
index: int,
152228
return_observations: bool = True,
@@ -161,7 +237,7 @@ def reset_at(
161237
self.scenario.env_reset_world_at(index)
162238
self.steps[index] = 0
163239

164-
result = self.get_from_scenario(
240+
result = self._get_from_scenario(
165241
get_observations=return_observations,
166242
get_infos=return_info,
167243
get_rewards=False,
@@ -170,7 +246,7 @@ def reset_at(
170246

171247
return result[0] if result and len(result) == 1 else result
172248

173-
def get_from_scenario(
249+
def _get_from_scenario(
174250
self,
175251
get_observations: bool,
176252
get_rewards: bool,
@@ -218,23 +294,30 @@ def get_from_scenario(
218294

219295
if self.terminated_truncated:
220296
if get_dones:
221-
terminated, truncated = self.done()
297+
terminated, truncated = self._done()
222298
result = [obs, rewards, terminated, truncated, infos]
223299
else:
224300
if get_dones:
225-
dones = self.done()
301+
dones = self._done()
226302
result = [obs, rewards, dones, infos]
227303

228304
return [data for data in result if data is not None]
229305

230-
def seed(self, seed=None):
306+
def _seed(self, seed=None):
307+
"""
308+
Sets the seed for the environment
309+
Args:
310+
seed (int, optional): Seed for the environment. Defaults to None.
311+
312+
"""
231313
if seed is None:
232314
seed = 0
233315
torch.manual_seed(seed)
234316
np.random.seed(seed)
235317
random.seed(seed)
236318
return [seed]
237319

320+
@local_seed(vmas_random_state)
238321
def step(self, actions: Union[List, Dict]):
239322
"""Performs a vectorized step on all sub environments using `actions`.
240323
Args:
@@ -309,14 +392,21 @@ def step(self, actions: Union[List, Dict]):
309392

310393
self.steps += 1
311394

312-
return self.get_from_scenario(
395+
return self._get_from_scenario(
313396
get_observations=True,
314397
get_infos=True,
315398
get_rewards=True,
316399
get_dones=True,
317400
)
318401

319-
def done(self):
402+
def _done(self):
403+
"""
404+
Get the done flags for the scenario.
405+
406+
Returns:
407+
Either terminated, truncated (if self.terminated_truncated==True) or terminated + truncated (if self.terminated_truncated==False)
408+
409+
"""
320410
terminated = self.scenario.done().clone()
321411

322412
if self.max_steps is not None:
@@ -427,6 +517,7 @@ def get_agent_observation_space(self, agent: Agent, obs: AGENT_OBS_TYPE):
427517
f"Invalid type of observation {obs} for agent {agent.name}"
428518
)
429519

520+
@local_seed(vmas_random_state)
430521
def get_random_action(self, agent: Agent) -> torch.Tensor:
431522
"""Returns a random action for the given agent.
432523
@@ -652,6 +743,7 @@ def _set_action(self, action, agent):
652743
)
653744
agent.action.c += noise
654745

746+
@local_seed(vmas_random_state)
655747
def render(
656748
self,
657749
mode="human",

0 commit comments

Comments
 (0)