-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
31 changed files
with
1,476 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
"""Script learns an agent to follow a target trajectory. | ||
""" | ||
|
||
import argparse | ||
from gym_pybullet_drones.utils.utils import str2bool | ||
from gym_pybullet_drones.utils.enums import ObservationType | ||
from agents.utils.create_env import EnvFactorySimpleFollowerAviary | ||
from agents.utils.parse_configuration import Configuration | ||
from train_policy import run_train | ||
from test_policy import run_test | ||
from utils.parse_configuration import parse_config | ||
|
||
# defaults for command line arguments | ||
DEFAULT_OUTPUT_FOLDER = 'results' | ||
DEFAULT_GUI = True | ||
DEFAULT_TIMESTEPS = 3e5 | ||
DEFAULT_ACTION_TYPE = 'rpm' # 'rpm', 'one_d_rpm', 'attitude' | ||
DEFAULT_TRAIN = True | ||
DEFAULT_TEST = True | ||
|
||
# more configurations | ||
DEFAULT_OBS = ObservationType('kin') # 'kin' or 'rgb' | ||
|
||
|
||
def run(output_folder=DEFAULT_OUTPUT_FOLDER, | ||
gui=DEFAULT_GUI, | ||
timesteps=DEFAULT_TIMESTEPS, | ||
action_type: str='rpm', | ||
train: bool=DEFAULT_TRAIN, | ||
test: bool=DEFAULT_TEST | ||
): | ||
|
||
config: Configuration = parse_config( | ||
t_waypoint=[0, 0.5, 0.5], | ||
initial_waypoint=[0, 0, 0.1], | ||
action_type=action_type, | ||
output_folder=output_folder, | ||
n_timesteps=timesteps, | ||
local=False | ||
) | ||
|
||
env_factory = EnvFactorySimpleFollowerAviary( | ||
config=config, | ||
output_folder=output_folder, | ||
observation_type=DEFAULT_OBS, | ||
use_gui_for_test_env=gui, | ||
n_env_training=20, | ||
seed=0, | ||
) | ||
|
||
if train: | ||
run_train(config=config, | ||
env_factory=env_factory) | ||
|
||
if test: | ||
run_test(config=config, | ||
env_factory=env_factory) | ||
|
||
|
||
if __name__ == '__main__': | ||
#### Define and parse (optional) arguments for the script ## | ||
parser = argparse.ArgumentParser(description='Single agent reinforcement learning example script') | ||
parser.add_argument('--gui', default=DEFAULT_GUI, type=str2bool, help='Whether to use PyBullet GUI (default: True)', metavar='') | ||
parser.add_argument('--output_folder', default=DEFAULT_OUTPUT_FOLDER, type=str, help='Folder where to save logs (default: "results")', metavar='') | ||
parser.add_argument('--timesteps', default=DEFAULT_TIMESTEPS, type=int, help='number of train timesteps before stopping', metavar='') | ||
parser.add_argument('--action_type', default=DEFAULT_TIMESTEPS, type=str, help='Either "one_d_rpm", "rpm" or "attitude"', metavar='') | ||
parser.add_argument('--train', default=DEFAULT_TRAIN, type=str2bool, help='Whether to train (default: True)', metavar='') | ||
parser.add_argument('--test', default=DEFAULT_TEST, type=str2bool, help='Whether to test (default: True)', metavar='') | ||
ARGS = parser.parse_args() | ||
|
||
run(**vars(ARGS)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
""" | ||
RUNS EXPERIMENT A. | ||
3 MODES: | ||
- DOWN (from [0,0,1] to [0,0,0.1]) | ||
- UP (from [0,0,0.1] to [0,0,1]) | ||
- SIDEWAYS (from [0,0,1] to [0,1,1]) | ||
- DIAGONAL_UP (from [0,0,0.1] to [0,1,1]) | ||
- DIAGONAL_DOWN (from [0,0,1] to [0,1,0.1]) | ||
- every mode 10 times | ||
""" | ||
|
||
import argparse | ||
from gym_pybullet_drones.utils.utils import str2bool | ||
from gym_pybullet_drones.utils.enums import ObservationType | ||
from agents.utils.create_env import EnvFactorySimpleFollowerAviary | ||
from agents.utils.parse_configuration import Configuration | ||
from train_policy import run_train | ||
from test_policy import run_test | ||
from utils.parse_configuration import parse_config | ||
|
||
# defaults for command line arguments | ||
DEFAULT_OUTPUT_FOLDER = 'results' | ||
DEFAULT_GUI = True | ||
DEFAULT_TIMESTEPS = 3e5 | ||
DEFAULT_ACTION_TYPE = 'rpm' # 'rpm', 'one_d_rpm', 'attitude' | ||
DEFAULT_TRAIN = True | ||
DEFAULT_TEST = True | ||
DEFAULT_MODE = "UP" # DOWN, UP, SIDEWAYS, DIAGONAL_UP, DIAGONAL_DOWN | ||
|
||
# more configurations | ||
DEFAULT_OBS = ObservationType('kin') # 'kin' or 'rgb' | ||
|
||
def parse_mode(mode:str): | ||
init_wp = None | ||
t_wp = None | ||
|
||
if mode == "DOWN": | ||
init_wp = [0,0,1] | ||
t_wp = [0,0,0.1] | ||
elif mode == "UP": | ||
init_wp = [0,0,0.1] | ||
t_wp = [0,0,1] | ||
elif mode == "SIDEWAYS": | ||
init_wp = [0,0,1] | ||
t_wp = [0,1,1] | ||
elif mode == "DIAGONAL_UP": | ||
init_wp = [0,0,0.1] | ||
t_wp = [0,1,1] | ||
elif mode == "DIAGONAL_DOWN": | ||
init_wp = [0,0,1] | ||
t_wp = [0,1,0.1] | ||
else: | ||
raise ValueError(f'Invalide mode {mode}') | ||
|
||
return t_wp, init_wp | ||
|
||
|
||
|
||
def run(output_folder=DEFAULT_OUTPUT_FOLDER, | ||
gui=DEFAULT_GUI, | ||
timesteps=DEFAULT_TIMESTEPS, | ||
action_type: str='rpm', | ||
train: bool=DEFAULT_TRAIN, | ||
test: bool=DEFAULT_TEST, | ||
mode: str = DEFAULT_MODE | ||
): | ||
|
||
t_wp, init_wp = parse_mode(mode) | ||
|
||
config: Configuration = parse_config( | ||
t_waypoint=t_wp, | ||
initial_waypoint=init_wp, | ||
action_type=action_type, | ||
output_folder=output_folder, | ||
n_timesteps=timesteps, | ||
local=False | ||
) | ||
|
||
env_factory = EnvFactorySimpleFollowerAviary( | ||
config=config, | ||
output_folder=output_folder, | ||
observation_type=DEFAULT_OBS, | ||
use_gui_for_test_env=gui, | ||
n_env_training=20, | ||
seed=0, | ||
) | ||
|
||
if train: | ||
run_train(config=config, | ||
env_factory=env_factory) | ||
|
||
if test: | ||
run_test(config=config, | ||
env_factory=env_factory) | ||
|
||
|
||
if __name__ == '__main__': | ||
#### Define and parse (optional) arguments for the script ## | ||
parser = argparse.ArgumentParser(description='Single agent reinforcement learning example script') | ||
parser.add_argument('--gui', default=DEFAULT_GUI, type=str2bool, help='Whether to use PyBullet GUI (default: True)', metavar='') | ||
parser.add_argument('--output_folder', default=DEFAULT_OUTPUT_FOLDER, type=str, help='Folder where to save logs (default: "results")', metavar='') | ||
parser.add_argument('--timesteps', default=DEFAULT_TIMESTEPS, type=int, help='number of train timesteps before stopping', metavar='') | ||
parser.add_argument('--action_type', default=DEFAULT_TIMESTEPS, type=str, help='Either "one_d_rpm", "rpm" or "attitude"', metavar='') | ||
parser.add_argument('--train', default=DEFAULT_TRAIN, type=str2bool, help='Whether to train (default: True)', metavar='') | ||
parser.add_argument('--test', default=DEFAULT_TEST, type=str2bool, help='Whether to test (default: True)', metavar='') | ||
parser.add_argument('--mode', default=DEFAULT_MODE, type=str, help='Experiment mode (default "UP")', metavar='') | ||
ARGS = parser.parse_args() | ||
|
||
run(**vars(ARGS)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import numpy as np | ||
from stable_baselines3 import PPO | ||
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnRewardThreshold | ||
from agents.utils.create_env import EnvFactorySimpleFollowerAviary | ||
from agents.utils.parse_configuration import Configuration | ||
|
||
|
||
def run_train(config: Configuration, env_factory: EnvFactorySimpleFollowerAviary): | ||
|
||
# CONFIG ################################################## | ||
|
||
train_env = env_factory.get_train_env() | ||
eval_env = env_factory.get_eval_env() | ||
|
||
# ######################################################### | ||
|
||
# SETUP ################################################### | ||
|
||
# model | ||
model = PPO('MlpPolicy', | ||
train_env, | ||
tensorboard_log=config.output_path_location+'/tb/', | ||
verbose=1) | ||
|
||
# callbacks | ||
callback_on_best = StopTrainingOnRewardThreshold( | ||
reward_threshold=config.target_reward, | ||
verbose=1 | ||
) | ||
eval_callback = EvalCallback(eval_env, | ||
callback_on_new_best=callback_on_best, | ||
verbose=1, | ||
best_model_save_path=config.output_path_location+'/', | ||
log_path=config.output_path_location+'/', | ||
eval_freq=int(1000), | ||
deterministic=True, | ||
render=False) | ||
|
||
# ########################################################## | ||
|
||
print('[INFO] Action space:', train_env.action_space) | ||
print('[INFO] Observation space:', train_env.observation_space) | ||
print('[INFO] Number of timesteps:', config.n_timesteps) | ||
|
||
# TRAIN #################################################### | ||
|
||
# fit | ||
model.learn(total_timesteps=config.n_timesteps, | ||
callback=eval_callback, | ||
log_interval=100) | ||
|
||
# save model | ||
model.save(config.output_path_location+'/final_model.zip') | ||
print(config.output_path_location) | ||
|
||
# print training progression | ||
with np.load(config.output_path_location+'/evaluations.npz') as data: | ||
for j in range(data['timesteps'].shape[0]): | ||
print(str(data['timesteps'][j])+","+str(data['results'][j][0])) | ||
|
||
# ########################################################## | ||
|
||
|
||
|
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
|
||
import numpy as np | ||
from stable_baselines3.common.env_util import make_vec_env | ||
from aviaries.SimpleFollowerAviary import SimpleFollowerAviary | ||
from gym_pybullet_drones.utils.enums import ObservationType, ActionType | ||
from trajectories import DiscretizedTrajectory | ||
from stable_baselines3.common.vec_env import VecEnv | ||
from .parse_configuration import Configuration | ||
|
||
|
||
class EnvFactorySimpleFollowerAviary(): | ||
action_type: ActionType | ||
observation_type: ObservationType | ||
t_traj: DiscretizedTrajectory | ||
n_env_training: int | ||
initial_xyzs: np.ndarray | ||
seed: int | ||
use_gui_for_test_env: bool | ||
output_path_location: str | ||
|
||
def __init__(self, | ||
config: Configuration, | ||
observation_type: ObservationType, | ||
output_folder: str, | ||
use_gui_for_test_env: bool = True, | ||
n_env_training: int=20, | ||
seed: int = 0, | ||
) -> None: | ||
|
||
initial_xyzs = config.initial_xyzs | ||
action_type = config.action_type | ||
t_traj = config.t_traj | ||
|
||
self.initial_xyzs = initial_xyzs | ||
self.observation_type = observation_type | ||
self.action_type = action_type | ||
self.t_traj = t_traj | ||
self.n_env_training = n_env_training | ||
self.seed = seed | ||
self.use_gui_for_test_env = use_gui_for_test_env | ||
|
||
|
||
|
||
def get_train_env(self) -> VecEnv: | ||
train_env = make_vec_env( | ||
SimpleFollowerAviary, | ||
env_kwargs=dict( | ||
target_trajectory=self.t_traj, | ||
initial_xyzs=self.initial_xyzs, | ||
obs=self.observation_type, | ||
act=self.action_type | ||
), | ||
n_envs=self.n_env_training, | ||
seed=self.seed | ||
) | ||
return train_env | ||
|
||
def get_eval_env(self): | ||
eval_env = SimpleFollowerAviary( | ||
target_trajectory=self.t_traj, | ||
initial_xyzs=self.initial_xyzs, | ||
obs=self.observation_type, | ||
act=self.action_type | ||
) | ||
return eval_env | ||
|
||
def get_test_env_gui(self): | ||
test_env = SimpleFollowerAviary( | ||
target_trajectory=self.t_traj, | ||
initial_xyzs=self.initial_xyzs, | ||
gui=self.use_gui_for_test_env, | ||
obs=self.observation_type, | ||
act=self.action_type, | ||
record=False | ||
) | ||
return test_env | ||
|
||
def get_test_env_no_gui(self): | ||
test_env_nogui = SimpleFollowerAviary( | ||
target_trajectory=self.t_traj, | ||
initial_xyzs=self.initial_xyzs, | ||
obs=self.observation_type, | ||
act=self.action_type | ||
) | ||
return test_env_nogui | ||
|
Oops, something went wrong.