Skip to content

Commit

Permalink
added basics, refactored, parallel works
Browse files Browse the repository at this point in the history
  • Loading branch information
strakam committed Jan 20, 2024
1 parent 707b50b commit 99b2b95
Show file tree
Hide file tree
Showing 3 changed files with 312 additions and 0 deletions.
Empty file added agents/__init__.py
Empty file.
180 changes: 180 additions & 0 deletions agents/follower.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
"""Script demonstrating the use of `gym_pybullet_drones`'s Gymnasium interface.
Classes HoverAviary and MultiHoverAviary are used as learning envs for the PPO algorithm.
Example
-------
In a terminal, run as:
$ python learn.py --multiagent false
$ python learn.py --multiagent true
Notes
-----
This is a minimal working example integrating `gym-pybullet-drones` with
reinforcement learning library `stable-baselines3`.
"""
import os
import time
from datetime import datetime
import argparse
import gymnasium as gym
import numpy as np
import torch
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnRewardThreshold
from stable_baselines3.common.evaluation import evaluate_policy

from gym_pybullet_drones.utils.Logger import Logger
from gym_pybullet_drones.envs.HoverAviary import HoverAviary
from gym_pybullet_drones.envs.MultiHoverAviary import MultiHoverAviary
from aviaries.FollowerAviary import FollowerAviary
from gym_pybullet_drones.utils.utils import sync, str2bool
from gym_pybullet_drones.utils.enums import ObservationType, ActionType

DEFAULT_GUI = True
DEFAULT_RECORD_VIDEO = False
DEFAULT_OUTPUT_FOLDER = 'results'

DEFAULT_OBS = ObservationType('kin') # 'kin' or 'rgb'
DEFAULT_ACT = ActionType('one_d_rpm') # 'rpm' or 'pid' or 'vel' or 'one_d_rpm' or 'one_d_pid'
DEFAULT_AGENTS = 2

def run(output_folder=DEFAULT_OUTPUT_FOLDER, gui=DEFAULT_GUI, plot=True, record_video=DEFAULT_RECORD_VIDEO, local=True):

filename = os.path.join(output_folder, 'save-'+datetime.now().strftime("%m.%d.%Y_%H.%M.%S"))
if not os.path.exists(filename):
os.makedirs(filename+'/')

train_env = make_vec_env(
FollowerAviary,
env_kwargs=dict(obs=DEFAULT_OBS, act=DEFAULT_ACT),
n_envs=20,
seed=0
)
eval_env = FollowerAviary(obs=DEFAULT_OBS, act=DEFAULT_ACT)

#### Check the environment's spaces ########################
print('[INFO] Action space:', train_env.action_space)
print('[INFO] Observation space:', train_env.observation_space)

#### Train the model #######################################
model = PPO(
'MlpPolicy',
train_env,
# tensorboard_log=filename+'/tb/',
verbose=1
)

#### Target cumulative rewards (problem-dependent) ##########
target_reward = 474.15
callback_on_best = StopTrainingOnRewardThreshold(
reward_threshold=target_reward,
verbose=1
)
eval_callback = EvalCallback(
eval_env,
callback_on_new_best=callback_on_best,
verbose=1,
best_model_save_path=filename+'/',
log_path=filename+'/',
eval_freq=int(1000),
deterministic=True,
render=False
)
model.learn(
total_timesteps=int(1e7) if local else int(1e2), # shorter training in GitHub Actions pytest
callback=eval_callback,
log_interval=100
)

#### Save the model ########################################
model.save(filename+'/final_model.zip')
print(filename)

#### Print training progression ############################
with np.load(filename+'/evaluations.npz') as data:
for j in range(data['timesteps'].shape[0]):
print(str(data['timesteps'][j])+","+str(data['results'][j][0]))

############################################################
############################################################
############################################################
############################################################
############################################################

if local:
input("Press Enter to continue...")

# if os.path.isfile(filename+'/final_model.zip'):
# path = filename+'/final_model.zip'
if os.path.isfile(filename+'/best_model.zip'):
path = filename+'/best_model.zip'
else:
print("[ERROR]: no model under the specified path", filename)
model = PPO.load(path)

#### Show (and record a video of) the model's performance ##
test_env = HoverAviary(
gui=gui,
obs=DEFAULT_OBS,
act=DEFAULT_ACT,
record=record_video
)
test_env_nogui = HoverAviary(obs=DEFAULT_OBS, act=DEFAULT_ACT)
logger = Logger(logging_freq_hz=int(test_env.CTRL_FREQ),
num_drones=1,
output_folder=output_folder,
colab=False
)

mean_reward, std_reward = evaluate_policy(
model,
test_env_nogui,
n_eval_episodes=10
)
print("\n\n\nMean reward ", mean_reward, " +- ", std_reward, "\n\n")

obs, info = test_env.reset(seed=42, options={})
start = time.time()
for i in range((test_env.EPISODE_LEN_SEC+2)*test_env.CTRL_FREQ):
action, _states = model.predict(obs,
deterministic=True
)
obs, reward, terminated, truncated, info = test_env.step(action)
obs2 = obs.squeeze()
act2 = action.squeeze()
print("Obs:", obs, "\tAction", action, "\tReward:", reward, "\tTerminated:", terminated, "\tTruncated:", truncated)
if DEFAULT_OBS == ObservationType.KIN:
logger.log(drone=0,
timestamp=i/test_env.CTRL_FREQ,
state=np.hstack(
[obs2[0:3],
np.zeros(4),
obs2[3:15],
act2
]
),
control=np.zeros(12)
)
test_env.render()
print(terminated)
sync(i, start, test_env.CTRL_TIMESTEP)
if terminated:
obs = test_env.reset(seed=42, options={})
test_env.close()

if plot and DEFAULT_OBS == ObservationType.KIN:
logger.plot()

if __name__ == '__main__':
#### Define and parse (optional) arguments for the script ##
parser = argparse.ArgumentParser(description='Single agent reinforcement learning example script')
parser.add_argument('--gui', default=DEFAULT_GUI, type=str2bool, help='Whether to use PyBullet GUI (default: True)', metavar='')
parser.add_argument('--record_video', default=DEFAULT_RECORD_VIDEO, type=str2bool, help='Whether to record a video (default: False)', metavar='')
parser.add_argument('--output_folder', default=DEFAULT_OUTPUT_FOLDER, type=str, help='Folder where to save logs (default: "results")', metavar='')
ARGS = parser.parse_args()

run(**vars(ARGS))
132 changes: 132 additions & 0 deletions aviaries/FollowerAviary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import numpy as np

from gym_pybullet_drones.envs.BaseRLAviary import BaseRLAviary
from gym_pybullet_drones.utils.enums import DroneModel, Physics, ActionType, ObservationType

class FollowerAviary(BaseRLAviary):
"""Single agent RL problem: hover at position."""

################################################################################

def __init__(self,
drone_model: DroneModel=DroneModel.CF2X,
initial_xyzs=None,
initial_rpys=None,
physics: Physics=Physics.PYB,
pyb_freq: int = 240,
ctrl_freq: int = 30,
gui=False,
record=False,
obs: ObservationType=ObservationType.KIN,
act: ActionType=ActionType.RPM
):
"""Initialization of a single agent RL environment.
Using the generic single agent RL superclass.
Parameters
----------
drone_model : DroneModel, optional
The desired drone type (detailed in an .urdf file in folder `assets`).
initial_xyzs: ndarray | None, optional
(NUM_DRONES, 3)-shaped array containing the initial XYZ position of the drones.
initial_rpys: ndarray | None, optional
(NUM_DRONES, 3)-shaped array containing the initial orientations of the drones (in radians).
physics : Physics, optional
The desired implementation of PyBullet physics/custom dynamics.
pyb_freq : int, optional
The frequency at which PyBullet steps (a multiple of ctrl_freq).
ctrl_freq : int, optional
The frequency at which the environment steps.
gui : bool, optional
Whether to use PyBullet's GUI.
record : bool, optional
Whether to save a video of the simulation.
obs : ObservationType, optional
The type of observation space (kinematic information or vision)
act : ActionType, optional
The type of action space (1 or 3D; RPMS, thurst and torques, or waypoint with PID control)
"""
self.TARGET_POS = np.array([0,0,1])
self.EPISODE_LEN_SEC = 8
super().__init__(drone_model=drone_model,
num_drones=1,
initial_xyzs=initial_xyzs,
initial_rpys=initial_rpys,
physics=physics,
pyb_freq=pyb_freq,
ctrl_freq=ctrl_freq,
gui=gui,
record=record,
obs=obs,
act=act
)

################################################################################

def _computeReward(self):
"""Computes the current reward value.
Returns
-------
float
The reward.
"""
state = self._getDroneStateVector(0)
ret = max(0, 2 - np.linalg.norm(self.TARGET_POS-state[0:3])**4)
return ret

################################################################################

def _computeTerminated(self):
"""Computes the current done value.
Returns
-------
bool
Whether the current episode is done.
"""
state = self._getDroneStateVector(0)
if np.linalg.norm(self.TARGET_POS-state[0:3]) < .0001:
return True
else:
return False

################################################################################

def _computeTruncated(self):
"""Computes the current truncated value.
Returns
-------
bool
Whether the current episode timed out.
"""
state = self._getDroneStateVector(0)
if (abs(state[0]) > 1.5 or abs(state[1]) > 1.5 or state[2] > 2.0 # Truncate when the drone is too far away
or abs(state[7]) > .4 or abs(state[8]) > .4 # Truncate when the drone is too tilted
):
return True
if self.step_counter/self.PYB_FREQ > self.EPISODE_LEN_SEC:
return True
else:
return False

################################################################################

def _computeInfo(self):
"""Computes the current info dict(s).
Unused.
Returns
-------
dict[str, int]
Dummy value.
"""
return {"answer": 42} #### Calculated by the Deep Thought supercomputer in 7.5M years

0 comments on commit 99b2b95

Please sign in to comment.