Skip to content

Commit

Permalink
add visualization of waypoints
Browse files Browse the repository at this point in the history
  • Loading branch information
AlboAlby00 committed Jan 23, 2024
1 parent 6d45513 commit 04e0159
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 26 deletions.
31 changes: 10 additions & 21 deletions agents/follower.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@
DEFAULT_OBS = ObservationType('kin') # 'kin' or 'rgb'
DEFAULT_ACT = ActionType('vel') # 'rpm' or 'pid' or 'vel' or 'one_d_rpm' or 'one_d_pid'
DEFAULT_AGENTS = 1
DEFAULT_TIMESTEPS = 20000

def run(output_folder=DEFAULT_OUTPUT_FOLDER, gui=DEFAULT_GUI, plot=True, record_video=DEFAULT_RECORD_VIDEO, local=True):
def run(output_folder=DEFAULT_OUTPUT_FOLDER, timesteps=DEFAULT_TIMESTEPS, gui=DEFAULT_GUI, plot=True, record_video=DEFAULT_RECORD_VIDEO, local=True):

filename = os.path.join(output_folder, 'save-'+datetime.now().strftime("%m.%d.%Y_%H.%M.%S"))
if not os.path.exists(filename):
Expand All @@ -59,6 +60,7 @@ def run(output_folder=DEFAULT_OUTPUT_FOLDER, gui=DEFAULT_GUI, plot=True, record_
#### Check the environment's spaces ########################
print('[INFO] Action space:', train_env.action_space)
print('[INFO] Observation space:', train_env.observation_space)
print('[INFO] Number of timesteps:', timesteps)

#### Train the model #######################################
model = PPO(
Expand All @@ -82,10 +84,10 @@ def run(output_folder=DEFAULT_OUTPUT_FOLDER, gui=DEFAULT_GUI, plot=True, record_
log_path=filename+'/',
eval_freq=int(1000),
deterministic=True,
render=True
render=False
)
model.learn(
total_timesteps=int(5e5) if local else int(1e2), # shorter training in GitHub Actions pytest
total_timesteps=timesteps,
callback=eval_callback,
log_interval=100
)
Expand Down Expand Up @@ -144,23 +146,8 @@ def run(output_folder=DEFAULT_OUTPUT_FOLDER, gui=DEFAULT_GUI, plot=True, record_
deterministic=True
)
obs, reward, terminated, truncated, info = test_env.step(action)
obs2 = obs.squeeze()
act2 = action.squeeze()
print("Obs:", obs, "\tAction", action, "\tReward:", reward, "\tTerminated:", terminated, "\tTruncated:", truncated)
if DEFAULT_OBS == ObservationType.KIN:
logger.log(drone=0,
timestamp=i/test_env.CTRL_FREQ,
state=np.hstack(
[obs2[0:3],
np.zeros(4),
obs2[3:15],
act2
]
),
control=np.zeros(12)
)
test_env.render()
print(terminated)

# test_env.render()
sync(i, start, test_env.CTRL_TIMESTEP)
if terminated:
obs = test_env.reset(seed=42, options={})
Expand All @@ -169,12 +156,14 @@ def run(output_folder=DEFAULT_OUTPUT_FOLDER, gui=DEFAULT_GUI, plot=True, record_
if plot and DEFAULT_OBS == ObservationType.KIN:
logger.plot()


if __name__ == '__main__':
#### Define and parse (optional) arguments for the script ##
parser = argparse.ArgumentParser(description='Single agent reinforcement learning example script')
parser.add_argument('--gui', default=DEFAULT_GUI, type=str2bool, help='Whether to use PyBullet GUI (default: True)', metavar='')
parser.add_argument('--record_video', default=DEFAULT_RECORD_VIDEO, type=str2bool, help='Whether to record a video (default: False)', metavar='')
parser.add_argument('--output_folder', default=DEFAULT_OUTPUT_FOLDER, type=str, help='Folder where to save logs (default: "results")', metavar='')
parser.add_argument('--timesteps', default=DEFAULT_TIMESTEPS, type=int, help='number of train timesteps before stopping', metavar='')
ARGS = parser.parse_args()

print(ARGS)
run(**vars(ARGS))
25 changes: 22 additions & 3 deletions aviaries/FollowerAviary.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,6 @@ def __init__(self,
act=act
)



def _computeReward(self):
"""Computes the current reward value.
Expand All @@ -98,7 +96,7 @@ def _computeReward(self):
return ret

################################################################################

def _computeTerminated(self):
"""Computes the current done value.
Expand Down Expand Up @@ -194,6 +192,10 @@ def _observationSpace(self):

################################################################################

def step(self,action):

return super().step(action)

def update_waypoints(self):
drone_position = self._getDroneStateVector(0)[0:3]
current_waypoint = self.waypoint_buffer[self.current_waypoint_idx]
Expand All @@ -203,6 +205,23 @@ def update_waypoints(self):
self.waypoint_buffer[self.current_waypoint_idx] = next_waypoint
# set next waypoint
self.current_waypoint_idx = (self.current_waypoint_idx + 1) % self.WAYPOINT_BUFFER_SIZE

if self.GUI:
print('current waypoint:', current_waypoint)
sphere_visual = p.createVisualShape(shapeType=p.GEOM_SPHERE,
radius=0.03,
rgbaColor=[0, 1, 0, 1],
physicsClientId=self.CLIENT)
target = p.createMultiBody(baseMass=0.0,
baseCollisionShapeIndex=-1,
baseVisualShapeIndex=sphere_visual,
basePosition=current_waypoint,
useMaximalCoordinates=False,
physicsClientId=self.CLIENT)
p.changeVisualShape(target,
-1,
rgbaColor=[0.9, 0.3, 0.3, 1],
physicsClientId=self.CLIENT)


def _computeObs(self):
Expand Down
2 changes: 1 addition & 1 deletion gym_pybullet_drones/envs/BaseRLAviary.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def _observationSpace(self):
############################################################
else:
print("[ERROR] in BaseRLAviary._observationSpace()")

################################################################################

def _computeObs(self):
Expand Down
2 changes: 1 addition & 1 deletion gym_pybullet_drones/envs/HoverAviary.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def _computeReward(self):
return ret

################################################################################

def _computeTerminated(self):
"""Computes the current done value.
Expand Down

0 comments on commit 04e0159

Please sign in to comment.