From ccf53e4fe11cfd8f331b74c5610d6b259b08e0a8 Mon Sep 17 00:00:00 2001 From: Jack Vial Date: Tue, 24 Dec 2024 16:40:03 -0500 Subject: [PATCH] Show reset and saving episode --- .../common/robot_devices/control_context.py | 1 + lerobot/scripts/control_robot.py | 46 ++++++++++++++----- 2 files changed, 35 insertions(+), 12 deletions(-) diff --git a/lerobot/common/robot_devices/control_context.py b/lerobot/common/robot_devices/control_context.py index d9409fcad..f86ed3a8d 100644 --- a/lerobot/common/robot_devices/control_context.py +++ b/lerobot/common/robot_devices/control_context.py @@ -17,6 +17,7 @@ class ControlPhase: WARMUP = "Warmup" RECORD = "Record" RESET = "Reset" + SAVING = "Saving" @dataclass diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index 6c9a62930..cffe705b8 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -295,18 +295,6 @@ def record( if has_method(robot, "teleop_safety_stop"): robot.teleop_safety_stop() - control_context = control_context.update_config( - ControlContextConfig( - robot=robot, - control_phase=ControlPhase.RECORD, - play_sounds=play_sounds, - assign_rewards=False, - num_episodes=num_episodes, - display_cameras=display_cameras, - fps=fps, - ) - ) - recorded_episodes = 1 while True: if recorded_episodes >= num_episodes: @@ -317,6 +305,18 @@ def record( # if multi_task: # task = input("Enter your task description: ") + control_context = control_context.update_config( + ControlContextConfig( + robot=robot, + control_phase=ControlPhase.RECORD, + play_sounds=play_sounds, + assign_rewards=False, + num_episodes=num_episodes, + display_cameras=display_cameras, + fps=fps, + ) + ) + log_say(f"Recording episode {dataset.num_episodes}", play_sounds) record_episode( dataset=dataset, @@ -339,6 +339,17 @@ def record( if not events["stop_recording"] and ( (dataset.num_episodes < num_episodes - 1) or events["rerecord_episode"] ): + control_context = control_context.update_config( + ControlContextConfig( + robot=robot, + control_phase=ControlPhase.RESET, + play_sounds=play_sounds, + assign_rewards=False, + num_episodes=num_episodes, + display_cameras=display_cameras, + fps=fps, + ) + ) log_say("Reset the environment", play_sounds) reset_environment(robot, events, reset_time_s) @@ -349,6 +360,17 @@ def record( dataset.clear_episode_buffer() continue + control_context = control_context.update_config( + ControlContextConfig( + robot=robot, + control_phase=ControlPhase.SAVING, + play_sounds=play_sounds, + assign_rewards=False, + num_episodes=num_episodes, + display_cameras=display_cameras, + fps=fps, + ) + ) dataset.save_episode(task) recorded_episodes += 1 control_context.update_current_episode(recorded_episodes)