Skip to content

Commit

Permalink
merging
Browse files Browse the repository at this point in the history
  • Loading branch information
strakam committed Jan 27, 2024
2 parents a7634d9 + 542e92b commit a087daa
Show file tree
Hide file tree
Showing 31 changed files with 1,476 additions and 32 deletions.
71 changes: 71 additions & 0 deletions agents/run_experiment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""Script learns an agent to follow a target trajectory.
"""

import argparse
from gym_pybullet_drones.utils.utils import str2bool
from gym_pybullet_drones.utils.enums import ObservationType
from agents.utils.create_env import EnvFactorySimpleFollowerAviary
from agents.utils.parse_configuration import Configuration
from train_policy import run_train
from test_policy import run_test
from utils.parse_configuration import parse_config

# defaults for command line arguments
DEFAULT_OUTPUT_FOLDER = 'results'
DEFAULT_GUI = True
DEFAULT_TIMESTEPS = 3e5
DEFAULT_ACTION_TYPE = 'rpm' # 'rpm', 'one_d_rpm', 'attitude'
DEFAULT_TRAIN = True
DEFAULT_TEST = True

# more configurations
DEFAULT_OBS = ObservationType('kin') # 'kin' or 'rgb'


def run(output_folder=DEFAULT_OUTPUT_FOLDER,
gui=DEFAULT_GUI,
timesteps=DEFAULT_TIMESTEPS,
action_type: str='rpm',
train: bool=DEFAULT_TRAIN,
test: bool=DEFAULT_TEST
):

config: Configuration = parse_config(
t_waypoint=[0, 0.5, 0.5],
initial_waypoint=[0, 0, 0.1],
action_type=action_type,
output_folder=output_folder,
n_timesteps=timesteps,
local=False
)

env_factory = EnvFactorySimpleFollowerAviary(
config=config,
output_folder=output_folder,
observation_type=DEFAULT_OBS,
use_gui_for_test_env=gui,
n_env_training=20,
seed=0,
)

if train:
run_train(config=config,
env_factory=env_factory)

if test:
run_test(config=config,
env_factory=env_factory)


if __name__ == '__main__':
#### Define and parse (optional) arguments for the script ##
parser = argparse.ArgumentParser(description='Single agent reinforcement learning example script')
parser.add_argument('--gui', default=DEFAULT_GUI, type=str2bool, help='Whether to use PyBullet GUI (default: True)', metavar='')
parser.add_argument('--output_folder', default=DEFAULT_OUTPUT_FOLDER, type=str, help='Folder where to save logs (default: "results")', metavar='')
parser.add_argument('--timesteps', default=DEFAULT_TIMESTEPS, type=int, help='number of train timesteps before stopping', metavar='')
parser.add_argument('--action_type', default=DEFAULT_TIMESTEPS, type=str, help='Either "one_d_rpm", "rpm" or "attitude"', metavar='')
parser.add_argument('--train', default=DEFAULT_TRAIN, type=str2bool, help='Whether to train (default: True)', metavar='')
parser.add_argument('--test', default=DEFAULT_TEST, type=str2bool, help='Whether to test (default: True)', metavar='')
ARGS = parser.parse_args()

run(**vars(ARGS))
112 changes: 112 additions & 0 deletions agents/run_experiment_A.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""
RUNS EXPERIMENT A.
3 MODES:
- DOWN (from [0,0,1] to [0,0,0.1])
- UP (from [0,0,0.1] to [0,0,1])
- SIDEWAYS (from [0,0,1] to [0,1,1])
- DIAGONAL_UP (from [0,0,0.1] to [0,1,1])
- DIAGONAL_DOWN (from [0,0,1] to [0,1,0.1])
- every mode 10 times
"""

import argparse
from gym_pybullet_drones.utils.utils import str2bool
from gym_pybullet_drones.utils.enums import ObservationType
from agents.utils.create_env import EnvFactorySimpleFollowerAviary
from agents.utils.parse_configuration import Configuration
from train_policy import run_train
from test_policy import run_test
from utils.parse_configuration import parse_config

# defaults for command line arguments
DEFAULT_OUTPUT_FOLDER = 'results'
DEFAULT_GUI = True
DEFAULT_TIMESTEPS = 3e5
DEFAULT_ACTION_TYPE = 'rpm' # 'rpm', 'one_d_rpm', 'attitude'
DEFAULT_TRAIN = True
DEFAULT_TEST = True
DEFAULT_MODE = "UP" # DOWN, UP, SIDEWAYS, DIAGONAL_UP, DIAGONAL_DOWN

# more configurations
DEFAULT_OBS = ObservationType('kin') # 'kin' or 'rgb'

def parse_mode(mode:str):
init_wp = None
t_wp = None

if mode == "DOWN":
init_wp = [0,0,1]
t_wp = [0,0,0.1]
elif mode == "UP":
init_wp = [0,0,0.1]
t_wp = [0,0,1]
elif mode == "SIDEWAYS":
init_wp = [0,0,1]
t_wp = [0,1,1]
elif mode == "DIAGONAL_UP":
init_wp = [0,0,0.1]
t_wp = [0,1,1]
elif mode == "DIAGONAL_DOWN":
init_wp = [0,0,1]
t_wp = [0,1,0.1]
else:
raise ValueError(f'Invalide mode {mode}')

return t_wp, init_wp



def run(output_folder=DEFAULT_OUTPUT_FOLDER,
gui=DEFAULT_GUI,
timesteps=DEFAULT_TIMESTEPS,
action_type: str='rpm',
train: bool=DEFAULT_TRAIN,
test: bool=DEFAULT_TEST,
mode: str = DEFAULT_MODE
):

t_wp, init_wp = parse_mode(mode)

config: Configuration = parse_config(
t_waypoint=t_wp,
initial_waypoint=init_wp,
action_type=action_type,
output_folder=output_folder,
n_timesteps=timesteps,
local=False
)

env_factory = EnvFactorySimpleFollowerAviary(
config=config,
output_folder=output_folder,
observation_type=DEFAULT_OBS,
use_gui_for_test_env=gui,
n_env_training=20,
seed=0,
)

if train:
run_train(config=config,
env_factory=env_factory)

if test:
run_test(config=config,
env_factory=env_factory)


if __name__ == '__main__':
#### Define and parse (optional) arguments for the script ##
parser = argparse.ArgumentParser(description='Single agent reinforcement learning example script')
parser.add_argument('--gui', default=DEFAULT_GUI, type=str2bool, help='Whether to use PyBullet GUI (default: True)', metavar='')
parser.add_argument('--output_folder', default=DEFAULT_OUTPUT_FOLDER, type=str, help='Folder where to save logs (default: "results")', metavar='')
parser.add_argument('--timesteps', default=DEFAULT_TIMESTEPS, type=int, help='number of train timesteps before stopping', metavar='')
parser.add_argument('--action_type', default=DEFAULT_TIMESTEPS, type=str, help='Either "one_d_rpm", "rpm" or "attitude"', metavar='')
parser.add_argument('--train', default=DEFAULT_TRAIN, type=str2bool, help='Whether to train (default: True)', metavar='')
parser.add_argument('--test', default=DEFAULT_TEST, type=str2bool, help='Whether to test (default: True)', metavar='')
parser.add_argument('--mode', default=DEFAULT_MODE, type=str, help='Experiment mode (default "UP")', metavar='')
ARGS = parser.parse_args()

run(**vars(ARGS))
24 changes: 21 additions & 3 deletions agents/test_simple_follower.py → agents/test_policy.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from agents.utils.create_env import EnvFactorySimpleFollowerAviary
from agents.utils.parse_configuration import Configuration
import os
from stable_baselines3.common.evaluation import evaluate_policy
from aviaries.SimpleFollowerAviary import SimpleFollowerAviary
Expand All @@ -9,9 +11,8 @@
import numpy as np
from gym_pybullet_drones.utils.utils import sync


def test_simple_follower(local: bool, filename: str, test_env_nogui: SimpleFollowerAviary, test_env: SimpleFollowerAviary, output_folder: str):
if local:
input("Press Enter to continue...")

# load model
if os.path.isfile(filename+'/best_model.zip'):
Expand Down Expand Up @@ -61,4 +62,21 @@ def test_simple_follower(local: bool, filename: str, test_env_nogui: SimpleFollo

test_env.close()

logger.plot()
logger.plot()


def run_test(config: Configuration, env_factory: EnvFactorySimpleFollowerAviary):


test_env = env_factory.get_test_env_gui()
test_env_nogui = env_factory.get_test_env_no_gui()

test_simple_follower(
local=config.local,
filename=config.output_path_location,
test_env_nogui=test_env_nogui,
test_env=test_env,
output_folder=config.output_path_location
)


64 changes: 64 additions & 0 deletions agents/train_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnRewardThreshold
from agents.utils.create_env import EnvFactorySimpleFollowerAviary
from agents.utils.parse_configuration import Configuration


def run_train(config: Configuration, env_factory: EnvFactorySimpleFollowerAviary):

# CONFIG ##################################################

train_env = env_factory.get_train_env()
eval_env = env_factory.get_eval_env()

# #########################################################

# SETUP ###################################################

# model
model = PPO('MlpPolicy',
train_env,
tensorboard_log=config.output_path_location+'/tb/',
verbose=1)

# callbacks
callback_on_best = StopTrainingOnRewardThreshold(
reward_threshold=config.target_reward,
verbose=1
)
eval_callback = EvalCallback(eval_env,
callback_on_new_best=callback_on_best,
verbose=1,
best_model_save_path=config.output_path_location+'/',
log_path=config.output_path_location+'/',
eval_freq=int(1000),
deterministic=True,
render=False)

# ##########################################################

print('[INFO] Action space:', train_env.action_space)
print('[INFO] Observation space:', train_env.observation_space)
print('[INFO] Number of timesteps:', config.n_timesteps)

# TRAIN ####################################################

# fit
model.learn(total_timesteps=config.n_timesteps,
callback=eval_callback,
log_interval=100)

# save model
model.save(config.output_path_location+'/final_model.zip')
print(config.output_path_location)

# print training progression
with np.load(config.output_path_location+'/evaluations.npz') as data:
for j in range(data['timesteps'].shape[0]):
print(str(data['timesteps'][j])+","+str(data['results'][j][0]))

# ##########################################################



Empty file added agents/utils/__init__.py
Empty file.
86 changes: 86 additions & 0 deletions agents/utils/create_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@

import numpy as np
from stable_baselines3.common.env_util import make_vec_env
from aviaries.SimpleFollowerAviary import SimpleFollowerAviary
from gym_pybullet_drones.utils.enums import ObservationType, ActionType
from trajectories import DiscretizedTrajectory
from stable_baselines3.common.vec_env import VecEnv
from .parse_configuration import Configuration


class EnvFactorySimpleFollowerAviary():
action_type: ActionType
observation_type: ObservationType
t_traj: DiscretizedTrajectory
n_env_training: int
initial_xyzs: np.ndarray
seed: int
use_gui_for_test_env: bool
output_path_location: str

def __init__(self,
config: Configuration,
observation_type: ObservationType,
output_folder: str,
use_gui_for_test_env: bool = True,
n_env_training: int=20,
seed: int = 0,
) -> None:

initial_xyzs = config.initial_xyzs
action_type = config.action_type
t_traj = config.t_traj

self.initial_xyzs = initial_xyzs
self.observation_type = observation_type
self.action_type = action_type
self.t_traj = t_traj
self.n_env_training = n_env_training
self.seed = seed
self.use_gui_for_test_env = use_gui_for_test_env



def get_train_env(self) -> VecEnv:
train_env = make_vec_env(
SimpleFollowerAviary,
env_kwargs=dict(
target_trajectory=self.t_traj,
initial_xyzs=self.initial_xyzs,
obs=self.observation_type,
act=self.action_type
),
n_envs=self.n_env_training,
seed=self.seed
)
return train_env

def get_eval_env(self):
eval_env = SimpleFollowerAviary(
target_trajectory=self.t_traj,
initial_xyzs=self.initial_xyzs,
obs=self.observation_type,
act=self.action_type
)
return eval_env

def get_test_env_gui(self):
test_env = SimpleFollowerAviary(
target_trajectory=self.t_traj,
initial_xyzs=self.initial_xyzs,
gui=self.use_gui_for_test_env,
obs=self.observation_type,
act=self.action_type,
record=False
)
return test_env

def get_test_env_no_gui(self):
test_env_nogui = SimpleFollowerAviary(
target_trajectory=self.t_traj,
initial_xyzs=self.initial_xyzs,
obs=self.observation_type,
act=self.action_type
)
return test_env_nogui

Loading

0 comments on commit a087daa

Please sign in to comment.