diff --git a/agents/__init__.py b/agents/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/agents/follower.py b/agents/follower.py new file mode 100644 index 00000000..ba4804d5 --- /dev/null +++ b/agents/follower.py @@ -0,0 +1,180 @@ +"""Script demonstrating the use of `gym_pybullet_drones`'s Gymnasium interface. + +Classes HoverAviary and MultiHoverAviary are used as learning envs for the PPO algorithm. + +Example +------- +In a terminal, run as: + + $ python learn.py --multiagent false + $ python learn.py --multiagent true + +Notes +----- +This is a minimal working example integrating `gym-pybullet-drones` with +reinforcement learning library `stable-baselines3`. + +""" +import os +import time +from datetime import datetime +import argparse +import gymnasium as gym +import numpy as np +import torch +from stable_baselines3 import PPO +from stable_baselines3.common.env_util import make_vec_env +from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnRewardThreshold +from stable_baselines3.common.evaluation import evaluate_policy + +from gym_pybullet_drones.utils.Logger import Logger +from gym_pybullet_drones.envs.HoverAviary import HoverAviary +from gym_pybullet_drones.envs.MultiHoverAviary import MultiHoverAviary +from aviaries.FollowerAviary import FollowerAviary +from gym_pybullet_drones.utils.utils import sync, str2bool +from gym_pybullet_drones.utils.enums import ObservationType, ActionType + +DEFAULT_GUI = True +DEFAULT_RECORD_VIDEO = False +DEFAULT_OUTPUT_FOLDER = 'results' + +DEFAULT_OBS = ObservationType('kin') # 'kin' or 'rgb' +DEFAULT_ACT = ActionType('one_d_rpm') # 'rpm' or 'pid' or 'vel' or 'one_d_rpm' or 'one_d_pid' +DEFAULT_AGENTS = 2 + +def run(output_folder=DEFAULT_OUTPUT_FOLDER, gui=DEFAULT_GUI, plot=True, record_video=DEFAULT_RECORD_VIDEO, local=True): + + filename = os.path.join(output_folder, 'save-'+datetime.now().strftime("%m.%d.%Y_%H.%M.%S")) + if not os.path.exists(filename): + os.makedirs(filename+'/') + + train_env = make_vec_env( + FollowerAviary, + env_kwargs=dict(obs=DEFAULT_OBS, act=DEFAULT_ACT), + n_envs=20, + seed=0 + ) + eval_env = FollowerAviary(obs=DEFAULT_OBS, act=DEFAULT_ACT) + + #### Check the environment's spaces ######################## + print('[INFO] Action space:', train_env.action_space) + print('[INFO] Observation space:', train_env.observation_space) + + #### Train the model ####################################### + model = PPO( + 'MlpPolicy', + train_env, + # tensorboard_log=filename+'/tb/', + verbose=1 + ) + + #### Target cumulative rewards (problem-dependent) ########## + target_reward = 474.15 + callback_on_best = StopTrainingOnRewardThreshold( + reward_threshold=target_reward, + verbose=1 + ) + eval_callback = EvalCallback( + eval_env, + callback_on_new_best=callback_on_best, + verbose=1, + best_model_save_path=filename+'/', + log_path=filename+'/', + eval_freq=int(1000), + deterministic=True, + render=False + ) + model.learn( + total_timesteps=int(1e7) if local else int(1e2), # shorter training in GitHub Actions pytest + callback=eval_callback, + log_interval=100 + ) + + #### Save the model ######################################## + model.save(filename+'/final_model.zip') + print(filename) + + #### Print training progression ############################ + with np.load(filename+'/evaluations.npz') as data: + for j in range(data['timesteps'].shape[0]): + print(str(data['timesteps'][j])+","+str(data['results'][j][0])) + + ############################################################ + ############################################################ + ############################################################ + ############################################################ + ############################################################ + + if local: + input("Press Enter to continue...") + + # if os.path.isfile(filename+'/final_model.zip'): + # path = filename+'/final_model.zip' + if os.path.isfile(filename+'/best_model.zip'): + path = filename+'/best_model.zip' + else: + print("[ERROR]: no model under the specified path", filename) + model = PPO.load(path) + + #### Show (and record a video of) the model's performance ## + test_env = HoverAviary( + gui=gui, + obs=DEFAULT_OBS, + act=DEFAULT_ACT, + record=record_video + ) + test_env_nogui = HoverAviary(obs=DEFAULT_OBS, act=DEFAULT_ACT) + logger = Logger(logging_freq_hz=int(test_env.CTRL_FREQ), + num_drones=1, + output_folder=output_folder, + colab=False + ) + + mean_reward, std_reward = evaluate_policy( + model, + test_env_nogui, + n_eval_episodes=10 + ) + print("\n\n\nMean reward ", mean_reward, " +- ", std_reward, "\n\n") + + obs, info = test_env.reset(seed=42, options={}) + start = time.time() + for i in range((test_env.EPISODE_LEN_SEC+2)*test_env.CTRL_FREQ): + action, _states = model.predict(obs, + deterministic=True + ) + obs, reward, terminated, truncated, info = test_env.step(action) + obs2 = obs.squeeze() + act2 = action.squeeze() + print("Obs:", obs, "\tAction", action, "\tReward:", reward, "\tTerminated:", terminated, "\tTruncated:", truncated) + if DEFAULT_OBS == ObservationType.KIN: + logger.log(drone=0, + timestamp=i/test_env.CTRL_FREQ, + state=np.hstack( + [obs2[0:3], + np.zeros(4), + obs2[3:15], + act2 + ] + ), + control=np.zeros(12) + ) + test_env.render() + print(terminated) + sync(i, start, test_env.CTRL_TIMESTEP) + if terminated: + obs = test_env.reset(seed=42, options={}) + test_env.close() + + if plot and DEFAULT_OBS == ObservationType.KIN: + logger.plot() + +if __name__ == '__main__': + #### Define and parse (optional) arguments for the script ## + parser = argparse.ArgumentParser(description='Single agent reinforcement learning example script') + parser.add_argument('--gui', default=DEFAULT_GUI, type=str2bool, help='Whether to use PyBullet GUI (default: True)', metavar='') + parser.add_argument('--record_video', default=DEFAULT_RECORD_VIDEO, type=str2bool, help='Whether to record a video (default: False)', metavar='') + parser.add_argument('--output_folder', default=DEFAULT_OUTPUT_FOLDER, type=str, help='Folder where to save logs (default: "results")', metavar='') + ARGS = parser.parse_args() + + run(**vars(ARGS)) diff --git a/aviaries/FollowerAviary.py b/aviaries/FollowerAviary.py new file mode 100644 index 00000000..5e6d0de9 --- /dev/null +++ b/aviaries/FollowerAviary.py @@ -0,0 +1,132 @@ +import numpy as np + +from gym_pybullet_drones.envs.BaseRLAviary import BaseRLAviary +from gym_pybullet_drones.utils.enums import DroneModel, Physics, ActionType, ObservationType + +class FollowerAviary(BaseRLAviary): + """Single agent RL problem: hover at position.""" + + ################################################################################ + + def __init__(self, + drone_model: DroneModel=DroneModel.CF2X, + initial_xyzs=None, + initial_rpys=None, + physics: Physics=Physics.PYB, + pyb_freq: int = 240, + ctrl_freq: int = 30, + gui=False, + record=False, + obs: ObservationType=ObservationType.KIN, + act: ActionType=ActionType.RPM + ): + """Initialization of a single agent RL environment. + + Using the generic single agent RL superclass. + + Parameters + ---------- + drone_model : DroneModel, optional + The desired drone type (detailed in an .urdf file in folder `assets`). + initial_xyzs: ndarray | None, optional + (NUM_DRONES, 3)-shaped array containing the initial XYZ position of the drones. + initial_rpys: ndarray | None, optional + (NUM_DRONES, 3)-shaped array containing the initial orientations of the drones (in radians). + physics : Physics, optional + The desired implementation of PyBullet physics/custom dynamics. + pyb_freq : int, optional + The frequency at which PyBullet steps (a multiple of ctrl_freq). + ctrl_freq : int, optional + The frequency at which the environment steps. + gui : bool, optional + Whether to use PyBullet's GUI. + record : bool, optional + Whether to save a video of the simulation. + obs : ObservationType, optional + The type of observation space (kinematic information or vision) + act : ActionType, optional + The type of action space (1 or 3D; RPMS, thurst and torques, or waypoint with PID control) + + """ + self.TARGET_POS = np.array([0,0,1]) + self.EPISODE_LEN_SEC = 8 + super().__init__(drone_model=drone_model, + num_drones=1, + initial_xyzs=initial_xyzs, + initial_rpys=initial_rpys, + physics=physics, + pyb_freq=pyb_freq, + ctrl_freq=ctrl_freq, + gui=gui, + record=record, + obs=obs, + act=act + ) + + ################################################################################ + + def _computeReward(self): + """Computes the current reward value. + + Returns + ------- + float + The reward. + + """ + state = self._getDroneStateVector(0) + ret = max(0, 2 - np.linalg.norm(self.TARGET_POS-state[0:3])**4) + return ret + + ################################################################################ + + def _computeTerminated(self): + """Computes the current done value. + + Returns + ------- + bool + Whether the current episode is done. + + """ + state = self._getDroneStateVector(0) + if np.linalg.norm(self.TARGET_POS-state[0:3]) < .0001: + return True + else: + return False + + ################################################################################ + + def _computeTruncated(self): + """Computes the current truncated value. + + Returns + ------- + bool + Whether the current episode timed out. + + """ + state = self._getDroneStateVector(0) + if (abs(state[0]) > 1.5 or abs(state[1]) > 1.5 or state[2] > 2.0 # Truncate when the drone is too far away + or abs(state[7]) > .4 or abs(state[8]) > .4 # Truncate when the drone is too tilted + ): + return True + if self.step_counter/self.PYB_FREQ > self.EPISODE_LEN_SEC: + return True + else: + return False + + ################################################################################ + + def _computeInfo(self): + """Computes the current info dict(s). + + Unused. + + Returns + ------- + dict[str, int] + Dummy value. + + """ + return {"answer": 42} #### Calculated by the Deep Thought supercomputer in 7.5M years