Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] VecMaskWrapper for MaskablePPO #279

Open
2 tasks done
CAI23sbP opened this issue Feb 20, 2025 · 1 comment
Open
2 tasks done

[Feature Request] VecMaskWrapper for MaskablePPO #279

CAI23sbP opened this issue Feb 20, 2025 · 1 comment
Labels
enhancement New feature or request

Comments

@CAI23sbP
Copy link

CAI23sbP commented Feb 20, 2025

🚀 Feature

From the issue.

@araffin !
I am planning to use a Masknet (a custom neural network for masking invalid actions) that requires batch processing

Here is my test code. Environment name is 'CartPole-v1'

My library version is 'sb3_contrib==2.1.0 , stable-baselines==2.1.0'

  1. modify a original code
def get_action_masks(env: GymEnv) -> np.ndarray:
    """
    Checks whether gym env exposes a method returning invalid action masks
    :param env: the Gym environment to get masks from
    :return: A numpy array of the masks
    """

    if isinstance(env, VecEnv):
        return env.get_attr(EXPECTED_METHOD_NAME)
    else:
        return getattr(env, EXPECTED_METHOD_NAME)
    
    
def is_masking_supported(env: GymEnv) -> bool:
    """
    Checks whether gym env exposes a method returning invalid action masks

    :param env: the Gym environment to check
    :return: True if the method is found, False otherwise
    """
    if isinstance(env, VecEnv):
        try:
            env.get_attr(EXPECTED_METHOD_NAME)
            
            return True
        except AttributeError:
            return False
    else:
        return hasattr(env, EXPECTED_METHOD_NAME)
  1. Add a VecMaskWrapper : masking a invalid action
from typing import List
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper
import numpy as np 

EXPECTED_METHOD_NAME = 'action_masks'
class VecMaskWrapper(VecEnvWrapper):
    def __init__(
        self,
        venv: VecEnv,
    ):
        VecEnvWrapper.__init__(self, venv)
        temp_env = venv
        actions = temp_env.action_space.n
        self.num_envs = temp_env.num_envs
        self.possible_actions = np.arange(actions)
        self.all_valid_mask = np.ones((self.num_envs, actions)).astype(np.bool_)

    def reset(self) -> VecEnvObs:
        self.observations = self.venv.reset()
        return self.observations

    def action_masks(self) -> List[bool]:
        """ 
            https://www.gymlibrary.dev/environments/classic_control/cart_pole/
            1. The pole angle can be observed between (-.418, .418) radians (or ±24°), 
            but the episode terminates if the pole angle is not in the range (-.2095, .2095) (or ±12°)

        """
        masks = np.ones_like(self.all_valid_mask, dtype=np.bool_)
        condition_1 = np.where(self.observations[:, 2] <= -0.05)[0] # left terminate
        condition_2 = np.where(self.observations[:, 2] >= 0.05)[0] # right terminate
        masks[condition_1, 1] = False 
        masks[condition_2, 0] = False 
        return masks
        
            
    def get_attr(self, attr_name, indices=None):
        if attr_name == EXPECTED_METHOD_NAME:
            return self.action_masks()
        else:
            return super().get_attr(attr_name, indices)

    def step_wait(self) -> VecEnvStepReturn:
        self.observations, rews, dones, infos = self.venv.step_wait()
        return self.observations, rews, dones, infos
  1. Example test code


from stable_baselines3.common.env_util import make_vec_env
from sb3_contrib.ppo_mask.ppo_mask import MaskablePPO

if __name__ == "__main__":
    env_id = "CartPole-v1"
    vec_env = make_vec_env(env_id, n_envs= 2) 
    vec_env = VecMaskWrapper(vec_env)
    vec_env.reset()
    model = MaskablePPO("MlpPolicy", vec_env, verbose=1)
    model.learn(total_timesteps=25_000)

    obs = vec_env.reset()
    for _ in range(1000):
        action_masks = get_action_masks(vec_env)
        action, _states = model.predict(obs, action_masks= action_masks)
        obs, rewards, dones, info = vec_env.step(action)
        vec_env.render()

Motivation

No response

Pitch

No response

Alternatives

No response

Additional context

No response

Checklist

  • I have checked that there is no similar issue in the repo
  • If I'm requesting a new feature, I have proposed alternatives
@araffin
Copy link
Member

araffin commented Mar 31, 2025

Original issue: #68

@CAI23sbP I would simplify by having a hasattr() check in isinstance(env, VecEnv), or a is_vec_env_wrapped(env, VecMaskWrapper)

Could you submit a PR with tests and doc?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants