2
2
# ProrokLab (https://www.proroklab.org/)
3
3
# All rights reserved.
4
4
import contextlib
5
- import functools
6
5
import math
7
6
import random
8
7
from ctypes import byref
@@ -47,22 +46,6 @@ def local_seed(vmas_random_state):
47
46
random .setstate (py_state )
48
47
49
48
50
- def apply_local_seed (cls ):
51
- """Applies the local seed to all the functions."""
52
- for attr_name , attr_value in cls .__dict__ .items ():
53
- if callable (attr_value ):
54
- wrapped = attr_value # Keep reference to original method
55
-
56
- @functools .wraps (wrapped )
57
- def wrapper (self , * args , _wrapped = wrapped , ** kwargs ):
58
- with local_seed (cls .vmas_random_state ):
59
- return _wrapped (self , * args , ** kwargs )
60
-
61
- setattr (cls , attr_name , wrapper )
62
- return cls
63
-
64
-
65
- @apply_local_seed
66
49
class Environment (TorchVectorizedObject ):
67
50
metadata = {
68
51
"render.modes" : ["human" , "rgb_array" ],
@@ -74,6 +57,7 @@ class Environment(TorchVectorizedObject):
74
57
random .getstate (),
75
58
]
76
59
60
+ @local_seed (vmas_random_state )
77
61
def __init__ (
78
62
self ,
79
63
scenario : BaseScenario ,
@@ -108,7 +92,7 @@ def __init__(
108
92
self .grad_enabled = grad_enabled
109
93
self .terminated_truncated = terminated_truncated
110
94
111
- observations = self .reset (seed = seed )
95
+ observations = self ._reset (seed = seed )
112
96
113
97
# configure spaces
114
98
self .multidiscrete_actions = multidiscrete_actions
@@ -121,6 +105,7 @@ def __init__(
121
105
self .visible_display = None
122
106
self .text_lines = None
123
107
108
+ @local_seed (vmas_random_state )
124
109
def reset (
125
110
self ,
126
111
seed : Optional [int ] = None ,
@@ -132,21 +117,112 @@ def reset(
132
117
Resets the environment in a vectorized way
133
118
Returns observations for all envs and agents
134
119
"""
120
+ return self ._reset (
121
+ seed = seed ,
122
+ return_observations = return_observations ,
123
+ return_info = return_info ,
124
+ return_dones = return_dones ,
125
+ )
126
+
127
+ @local_seed (vmas_random_state )
128
+ def reset_at (
129
+ self ,
130
+ index : int ,
131
+ return_observations : bool = True ,
132
+ return_info : bool = False ,
133
+ return_dones : bool = False ,
134
+ ):
135
+ """
136
+ Resets the environment at index
137
+ Returns observations for all agents in that environment
138
+ """
139
+ return self ._reset_at (
140
+ index = index ,
141
+ return_observations = return_observations ,
142
+ return_info = return_info ,
143
+ return_dones = return_dones ,
144
+ )
145
+
146
+ @local_seed (vmas_random_state )
147
+ def get_from_scenario (
148
+ self ,
149
+ get_observations : bool ,
150
+ get_rewards : bool ,
151
+ get_infos : bool ,
152
+ get_dones : bool ,
153
+ dict_agent_names : Optional [bool ] = None ,
154
+ ):
155
+ """
156
+ Get the environment data from the scenario
157
+
158
+ Args:
159
+ get_observations (bool): whether to return the observations
160
+ get_rewards (bool): whether to return the rewards
161
+ get_infos (bool): whether to return the infos
162
+ get_dones (bool): whether to return the dones
163
+ dict_agent_names (bool, optional): whether to return the information in a dictionary with agent names as keys
164
+ or in a list
165
+
166
+ Returns:
167
+ The agents' data
168
+
169
+ """
170
+ return self ._get_from_scenario (
171
+ get_observations = get_observations ,
172
+ get_rewards = get_rewards ,
173
+ get_infos = get_infos ,
174
+ get_dones = get_dones ,
175
+ dict_agent_names = dict_agent_names ,
176
+ )
177
+
178
+ @local_seed (vmas_random_state )
179
+ def seed (self , seed = None ):
180
+ """
181
+ Sets the seed for the environment
182
+ Args:
183
+ seed (int, optional): Seed for the environment. Defaults to None.
184
+
185
+ """
186
+ return self ._seed (seed = seed )
187
+
188
+ @local_seed (vmas_random_state )
189
+ def done (self ):
190
+ """
191
+ Get the done flags for the scenario.
192
+
193
+ Returns:
194
+ Either terminated, truncated (if self.terminated_truncated==True) or terminated + truncated (if self.terminated_truncated==False)
195
+
196
+ """
197
+ return self ._done ()
198
+
199
+ def _reset (
200
+ self ,
201
+ seed : Optional [int ] = None ,
202
+ return_observations : bool = True ,
203
+ return_info : bool = False ,
204
+ return_dones : bool = False ,
205
+ ):
206
+ """
207
+ Resets the environment in a vectorized way
208
+ Returns observations for all envs and agents
209
+ """
210
+
135
211
if seed is not None :
136
- self .seed (seed )
212
+ self ._seed (seed )
137
213
# reset world
138
214
self .scenario .env_reset_world_at (env_index = None )
139
215
self .steps = torch .zeros (self .num_envs , device = self .device )
140
216
141
- result = self .get_from_scenario (
217
+ result = self ._get_from_scenario (
142
218
get_observations = return_observations ,
143
219
get_infos = return_info ,
144
220
get_rewards = False ,
145
221
get_dones = return_dones ,
146
222
)
147
223
return result [0 ] if result and len (result ) == 1 else result
148
224
149
- def reset_at (
225
+ def _reset_at (
150
226
self ,
151
227
index : int ,
152
228
return_observations : bool = True ,
@@ -161,7 +237,7 @@ def reset_at(
161
237
self .scenario .env_reset_world_at (index )
162
238
self .steps [index ] = 0
163
239
164
- result = self .get_from_scenario (
240
+ result = self ._get_from_scenario (
165
241
get_observations = return_observations ,
166
242
get_infos = return_info ,
167
243
get_rewards = False ,
@@ -170,7 +246,7 @@ def reset_at(
170
246
171
247
return result [0 ] if result and len (result ) == 1 else result
172
248
173
- def get_from_scenario (
249
+ def _get_from_scenario (
174
250
self ,
175
251
get_observations : bool ,
176
252
get_rewards : bool ,
@@ -218,23 +294,30 @@ def get_from_scenario(
218
294
219
295
if self .terminated_truncated :
220
296
if get_dones :
221
- terminated , truncated = self .done ()
297
+ terminated , truncated = self ._done ()
222
298
result = [obs , rewards , terminated , truncated , infos ]
223
299
else :
224
300
if get_dones :
225
- dones = self .done ()
301
+ dones = self ._done ()
226
302
result = [obs , rewards , dones , infos ]
227
303
228
304
return [data for data in result if data is not None ]
229
305
230
- def seed (self , seed = None ):
306
+ def _seed (self , seed = None ):
307
+ """
308
+ Sets the seed for the environment
309
+ Args:
310
+ seed (int, optional): Seed for the environment. Defaults to None.
311
+
312
+ """
231
313
if seed is None :
232
314
seed = 0
233
315
torch .manual_seed (seed )
234
316
np .random .seed (seed )
235
317
random .seed (seed )
236
318
return [seed ]
237
319
320
+ @local_seed (vmas_random_state )
238
321
def step (self , actions : Union [List , Dict ]):
239
322
"""Performs a vectorized step on all sub environments using `actions`.
240
323
Args:
@@ -309,14 +392,21 @@ def step(self, actions: Union[List, Dict]):
309
392
310
393
self .steps += 1
311
394
312
- return self .get_from_scenario (
395
+ return self ._get_from_scenario (
313
396
get_observations = True ,
314
397
get_infos = True ,
315
398
get_rewards = True ,
316
399
get_dones = True ,
317
400
)
318
401
319
- def done (self ):
402
+ def _done (self ):
403
+ """
404
+ Get the done flags for the scenario.
405
+
406
+ Returns:
407
+ Either terminated, truncated (if self.terminated_truncated==True) or terminated + truncated (if self.terminated_truncated==False)
408
+
409
+ """
320
410
terminated = self .scenario .done ().clone ()
321
411
322
412
if self .max_steps is not None :
@@ -427,6 +517,7 @@ def get_agent_observation_space(self, agent: Agent, obs: AGENT_OBS_TYPE):
427
517
f"Invalid type of observation { obs } for agent { agent .name } "
428
518
)
429
519
520
+ @local_seed (vmas_random_state )
430
521
def get_random_action (self , agent : Agent ) -> torch .Tensor :
431
522
"""Returns a random action for the given agent.
432
523
@@ -652,6 +743,7 @@ def _set_action(self, action, agent):
652
743
)
653
744
agent .action .c += noise
654
745
746
+ @local_seed (vmas_random_state )
655
747
def render (
656
748
self ,
657
749
mode = "human" ,
0 commit comments