-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'origin/main' into traj
- Loading branch information
Showing
3 changed files
with
312 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
"""Script demonstrating the use of `gym_pybullet_drones`'s Gymnasium interface. | ||
Classes HoverAviary and MultiHoverAviary are used as learning envs for the PPO algorithm. | ||
Example | ||
------- | ||
In a terminal, run as: | ||
$ python learn.py --multiagent false | ||
$ python learn.py --multiagent true | ||
Notes | ||
----- | ||
This is a minimal working example integrating `gym-pybullet-drones` with | ||
reinforcement learning library `stable-baselines3`. | ||
""" | ||
import os | ||
import time | ||
from datetime import datetime | ||
import argparse | ||
import gymnasium as gym | ||
import numpy as np | ||
import torch | ||
from stable_baselines3 import PPO | ||
from stable_baselines3.common.env_util import make_vec_env | ||
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnRewardThreshold | ||
from stable_baselines3.common.evaluation import evaluate_policy | ||
|
||
from gym_pybullet_drones.utils.Logger import Logger | ||
from gym_pybullet_drones.envs.HoverAviary import HoverAviary | ||
from gym_pybullet_drones.envs.MultiHoverAviary import MultiHoverAviary | ||
from aviaries.FollowerAviary import FollowerAviary | ||
from gym_pybullet_drones.utils.utils import sync, str2bool | ||
from gym_pybullet_drones.utils.enums import ObservationType, ActionType | ||
|
||
DEFAULT_GUI = True | ||
DEFAULT_RECORD_VIDEO = False | ||
DEFAULT_OUTPUT_FOLDER = 'results' | ||
|
||
DEFAULT_OBS = ObservationType('kin') # 'kin' or 'rgb' | ||
DEFAULT_ACT = ActionType('one_d_rpm') # 'rpm' or 'pid' or 'vel' or 'one_d_rpm' or 'one_d_pid' | ||
DEFAULT_AGENTS = 2 | ||
|
||
def run(output_folder=DEFAULT_OUTPUT_FOLDER, gui=DEFAULT_GUI, plot=True, record_video=DEFAULT_RECORD_VIDEO, local=True): | ||
|
||
filename = os.path.join(output_folder, 'save-'+datetime.now().strftime("%m.%d.%Y_%H.%M.%S")) | ||
if not os.path.exists(filename): | ||
os.makedirs(filename+'/') | ||
|
||
train_env = make_vec_env( | ||
FollowerAviary, | ||
env_kwargs=dict(obs=DEFAULT_OBS, act=DEFAULT_ACT), | ||
n_envs=20, | ||
seed=0 | ||
) | ||
eval_env = FollowerAviary(obs=DEFAULT_OBS, act=DEFAULT_ACT) | ||
|
||
#### Check the environment's spaces ######################## | ||
print('[INFO] Action space:', train_env.action_space) | ||
print('[INFO] Observation space:', train_env.observation_space) | ||
|
||
#### Train the model ####################################### | ||
model = PPO( | ||
'MlpPolicy', | ||
train_env, | ||
# tensorboard_log=filename+'/tb/', | ||
verbose=1 | ||
) | ||
|
||
#### Target cumulative rewards (problem-dependent) ########## | ||
target_reward = 474.15 | ||
callback_on_best = StopTrainingOnRewardThreshold( | ||
reward_threshold=target_reward, | ||
verbose=1 | ||
) | ||
eval_callback = EvalCallback( | ||
eval_env, | ||
callback_on_new_best=callback_on_best, | ||
verbose=1, | ||
best_model_save_path=filename+'/', | ||
log_path=filename+'/', | ||
eval_freq=int(1000), | ||
deterministic=True, | ||
render=False | ||
) | ||
model.learn( | ||
total_timesteps=int(1e7) if local else int(1e2), # shorter training in GitHub Actions pytest | ||
callback=eval_callback, | ||
log_interval=100 | ||
) | ||
|
||
#### Save the model ######################################## | ||
model.save(filename+'/final_model.zip') | ||
print(filename) | ||
|
||
#### Print training progression ############################ | ||
with np.load(filename+'/evaluations.npz') as data: | ||
for j in range(data['timesteps'].shape[0]): | ||
print(str(data['timesteps'][j])+","+str(data['results'][j][0])) | ||
|
||
############################################################ | ||
############################################################ | ||
############################################################ | ||
############################################################ | ||
############################################################ | ||
|
||
if local: | ||
input("Press Enter to continue...") | ||
|
||
# if os.path.isfile(filename+'/final_model.zip'): | ||
# path = filename+'/final_model.zip' | ||
if os.path.isfile(filename+'/best_model.zip'): | ||
path = filename+'/best_model.zip' | ||
else: | ||
print("[ERROR]: no model under the specified path", filename) | ||
model = PPO.load(path) | ||
|
||
#### Show (and record a video of) the model's performance ## | ||
test_env = HoverAviary( | ||
gui=gui, | ||
obs=DEFAULT_OBS, | ||
act=DEFAULT_ACT, | ||
record=record_video | ||
) | ||
test_env_nogui = HoverAviary(obs=DEFAULT_OBS, act=DEFAULT_ACT) | ||
logger = Logger(logging_freq_hz=int(test_env.CTRL_FREQ), | ||
num_drones=1, | ||
output_folder=output_folder, | ||
colab=False | ||
) | ||
|
||
mean_reward, std_reward = evaluate_policy( | ||
model, | ||
test_env_nogui, | ||
n_eval_episodes=10 | ||
) | ||
print("\n\n\nMean reward ", mean_reward, " +- ", std_reward, "\n\n") | ||
|
||
obs, info = test_env.reset(seed=42, options={}) | ||
start = time.time() | ||
for i in range((test_env.EPISODE_LEN_SEC+2)*test_env.CTRL_FREQ): | ||
action, _states = model.predict(obs, | ||
deterministic=True | ||
) | ||
obs, reward, terminated, truncated, info = test_env.step(action) | ||
obs2 = obs.squeeze() | ||
act2 = action.squeeze() | ||
print("Obs:", obs, "\tAction", action, "\tReward:", reward, "\tTerminated:", terminated, "\tTruncated:", truncated) | ||
if DEFAULT_OBS == ObservationType.KIN: | ||
logger.log(drone=0, | ||
timestamp=i/test_env.CTRL_FREQ, | ||
state=np.hstack( | ||
[obs2[0:3], | ||
np.zeros(4), | ||
obs2[3:15], | ||
act2 | ||
] | ||
), | ||
control=np.zeros(12) | ||
) | ||
test_env.render() | ||
print(terminated) | ||
sync(i, start, test_env.CTRL_TIMESTEP) | ||
if terminated: | ||
obs = test_env.reset(seed=42, options={}) | ||
test_env.close() | ||
|
||
if plot and DEFAULT_OBS == ObservationType.KIN: | ||
logger.plot() | ||
|
||
if __name__ == '__main__': | ||
#### Define and parse (optional) arguments for the script ## | ||
parser = argparse.ArgumentParser(description='Single agent reinforcement learning example script') | ||
parser.add_argument('--gui', default=DEFAULT_GUI, type=str2bool, help='Whether to use PyBullet GUI (default: True)', metavar='') | ||
parser.add_argument('--record_video', default=DEFAULT_RECORD_VIDEO, type=str2bool, help='Whether to record a video (default: False)', metavar='') | ||
parser.add_argument('--output_folder', default=DEFAULT_OUTPUT_FOLDER, type=str, help='Folder where to save logs (default: "results")', metavar='') | ||
ARGS = parser.parse_args() | ||
|
||
run(**vars(ARGS)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
import numpy as np | ||
|
||
from gym_pybullet_drones.envs.BaseRLAviary import BaseRLAviary | ||
from gym_pybullet_drones.utils.enums import DroneModel, Physics, ActionType, ObservationType | ||
|
||
class FollowerAviary(BaseRLAviary): | ||
"""Single agent RL problem: hover at position.""" | ||
|
||
################################################################################ | ||
|
||
def __init__(self, | ||
drone_model: DroneModel=DroneModel.CF2X, | ||
initial_xyzs=None, | ||
initial_rpys=None, | ||
physics: Physics=Physics.PYB, | ||
pyb_freq: int = 240, | ||
ctrl_freq: int = 30, | ||
gui=False, | ||
record=False, | ||
obs: ObservationType=ObservationType.KIN, | ||
act: ActionType=ActionType.RPM | ||
): | ||
"""Initialization of a single agent RL environment. | ||
Using the generic single agent RL superclass. | ||
Parameters | ||
---------- | ||
drone_model : DroneModel, optional | ||
The desired drone type (detailed in an .urdf file in folder `assets`). | ||
initial_xyzs: ndarray | None, optional | ||
(NUM_DRONES, 3)-shaped array containing the initial XYZ position of the drones. | ||
initial_rpys: ndarray | None, optional | ||
(NUM_DRONES, 3)-shaped array containing the initial orientations of the drones (in radians). | ||
physics : Physics, optional | ||
The desired implementation of PyBullet physics/custom dynamics. | ||
pyb_freq : int, optional | ||
The frequency at which PyBullet steps (a multiple of ctrl_freq). | ||
ctrl_freq : int, optional | ||
The frequency at which the environment steps. | ||
gui : bool, optional | ||
Whether to use PyBullet's GUI. | ||
record : bool, optional | ||
Whether to save a video of the simulation. | ||
obs : ObservationType, optional | ||
The type of observation space (kinematic information or vision) | ||
act : ActionType, optional | ||
The type of action space (1 or 3D; RPMS, thurst and torques, or waypoint with PID control) | ||
""" | ||
self.TARGET_POS = np.array([0,0,1]) | ||
self.EPISODE_LEN_SEC = 8 | ||
super().__init__(drone_model=drone_model, | ||
num_drones=1, | ||
initial_xyzs=initial_xyzs, | ||
initial_rpys=initial_rpys, | ||
physics=physics, | ||
pyb_freq=pyb_freq, | ||
ctrl_freq=ctrl_freq, | ||
gui=gui, | ||
record=record, | ||
obs=obs, | ||
act=act | ||
) | ||
|
||
################################################################################ | ||
|
||
def _computeReward(self): | ||
"""Computes the current reward value. | ||
Returns | ||
------- | ||
float | ||
The reward. | ||
""" | ||
state = self._getDroneStateVector(0) | ||
ret = max(0, 2 - np.linalg.norm(self.TARGET_POS-state[0:3])**4) | ||
return ret | ||
|
||
################################################################################ | ||
|
||
def _computeTerminated(self): | ||
"""Computes the current done value. | ||
Returns | ||
------- | ||
bool | ||
Whether the current episode is done. | ||
""" | ||
state = self._getDroneStateVector(0) | ||
if np.linalg.norm(self.TARGET_POS-state[0:3]) < .0001: | ||
return True | ||
else: | ||
return False | ||
|
||
################################################################################ | ||
|
||
def _computeTruncated(self): | ||
"""Computes the current truncated value. | ||
Returns | ||
------- | ||
bool | ||
Whether the current episode timed out. | ||
""" | ||
state = self._getDroneStateVector(0) | ||
if (abs(state[0]) > 1.5 or abs(state[1]) > 1.5 or state[2] > 2.0 # Truncate when the drone is too far away | ||
or abs(state[7]) > .4 or abs(state[8]) > .4 # Truncate when the drone is too tilted | ||
): | ||
return True | ||
if self.step_counter/self.PYB_FREQ > self.EPISODE_LEN_SEC: | ||
return True | ||
else: | ||
return False | ||
|
||
################################################################################ | ||
|
||
def _computeInfo(self): | ||
"""Computes the current info dict(s). | ||
Unused. | ||
Returns | ||
------- | ||
dict[str, int] | ||
Dummy value. | ||
""" | ||
return {"answer": 42} #### Calculated by the Deep Thought supercomputer in 7.5M years |