You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
defget_action_masks(env: GymEnv) ->np.ndarray:
""" Checks whether gym env exposes a method returning invalid action masks :param env: the Gym environment to get masks from :return: A numpy array of the masks """ifisinstance(env, VecEnv):
returnenv.get_attr(EXPECTED_METHOD_NAME)
else:
returngetattr(env, EXPECTED_METHOD_NAME)
defis_masking_supported(env: GymEnv) ->bool:
""" Checks whether gym env exposes a method returning invalid action masks :param env: the Gym environment to check :return: True if the method is found, False otherwise """ifisinstance(env, VecEnv):
try:
env.get_attr(EXPECTED_METHOD_NAME)
returnTrueexceptAttributeError:
returnFalseelse:
returnhasattr(env, EXPECTED_METHOD_NAME)
Add a VecMaskWrapper : masking a invalid action
fromtypingimportListfromstable_baselines3.common.vec_env.base_vec_envimportVecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapperimportnumpyasnpEXPECTED_METHOD_NAME='action_masks'classVecMaskWrapper(VecEnvWrapper):
def__init__(
self,
venv: VecEnv,
):
VecEnvWrapper.__init__(self, venv)
temp_env=venvactions=temp_env.action_space.nself.num_envs=temp_env.num_envsself.possible_actions=np.arange(actions)
self.all_valid_mask=np.ones((self.num_envs, actions)).astype(np.bool_)
defreset(self) ->VecEnvObs:
self.observations=self.venv.reset()
returnself.observationsdefaction_masks(self) ->List[bool]:
""" https://www.gymlibrary.dev/environments/classic_control/cart_pole/ 1. The pole angle can be observed between (-.418, .418) radians (or ±24°), but the episode terminates if the pole angle is not in the range (-.2095, .2095) (or ±12°) """masks=np.ones_like(self.all_valid_mask, dtype=np.bool_)
condition_1=np.where(self.observations[:, 2] <=-0.05)[0] # left terminatecondition_2=np.where(self.observations[:, 2] >=0.05)[0] # right terminatemasks[condition_1, 1] =Falsemasks[condition_2, 0] =Falsereturnmasksdefget_attr(self, attr_name, indices=None):
ifattr_name==EXPECTED_METHOD_NAME:
returnself.action_masks()
else:
returnsuper().get_attr(attr_name, indices)
defstep_wait(self) ->VecEnvStepReturn:
self.observations, rews, dones, infos=self.venv.step_wait()
returnself.observations, rews, dones, infos
Example test code
from stable_baselines3.common.env_util import make_vec_env
from sb3_contrib.ppo_mask.ppo_mask import MaskablePPO
if __name__ == "__main__":
env_id = "CartPole-v1"
vec_env = make_vec_env(env_id, n_envs= 2)
vec_env = VecMaskWrapper(vec_env)
vec_env.reset()
model = MaskablePPO("MlpPolicy", vec_env, verbose=1)
model.learn(total_timesteps=25_000)
obs = vec_env.reset()
for _ in range(1000):
action_masks = get_action_masks(vec_env)
action, _states = model.predict(obs, action_masks= action_masks)
obs, rewards, dones, info = vec_env.step(action)
vec_env.render()
Motivation
No response
Pitch
No response
Alternatives
No response
Additional context
No response
Checklist
I have checked that there is no similar issue in the repo
If I'm requesting a new feature, I have proposed alternatives
The text was updated successfully, but these errors were encountered:
🚀 Feature
From the issue.
@araffin !
I am planning to use a Masknet (a custom neural network for masking invalid actions) that requires batch processing
Here is my test code. Environment name is 'CartPole-v1'
My library version is 'sb3_contrib==2.1.0 , stable-baselines==2.1.0'
Motivation
No response
Pitch
No response
Alternatives
No response
Additional context
No response
Checklist
The text was updated successfully, but these errors were encountered: