Skip to content

Commit

Permalink
Fix updating control context config
Browse files Browse the repository at this point in the history
  • Loading branch information
jackvial committed Dec 24, 2024
1 parent 3c54bf4 commit 9108649
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 59 deletions.
103 changes: 64 additions & 39 deletions lerobot/common/robot_devices/control_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,37 +33,40 @@ class ControlContextConfig:
class ControlContext:
def __init__(self, config: ControlContextConfig):
self.config = config
self._initialize_display()
self._initialize_communication()
self._initialize_state()

if not self.config.robot:
raise ValueError("Robot object must be provided in ControlContextConfig")

def _initialize_display(self):
pygame.init()
if not self.config.display_cameras:
pygame.display.set_mode((1, 1), pygame.HIDDEN)

self.screen = None
self.image_positions = {}
self.padding = 20
self.title_height = 30
self.font = pygame.font.SysFont("courier", 24)
self.small_font = pygame.font.SysFont("courier", 18)

# Color theme
self.text_bg_color = (0, 0, 0)
self.text_color = (0, 255, 0)

def _initialize_state(self):
self.events = {
"exit_early": False,
"rerecord_episode": False,
"stop_recording": False,
"next_reward": 0,
}

if self.config.assign_rewards:
self.events["next_reward"] = 0

self.pressed_keys = []
self.font = pygame.font.SysFont("courier", 24)
self.small_font = pygame.font.SysFont("courier", 18)
self.current_episode_index = 0

# Color theme
self.text_bg_color = (0, 0, 0)
self.text_color = (0, 255, 0)


# Define the control instructions
self.controls = [
("Right Arrow", "Exit Early"),
Expand All @@ -72,19 +75,51 @@ def __init__(self, config: ControlContextConfig):
("Space", "Toggle Reward"),
]

def _initialize_communication(self):
self.zmq_context = zmq.Context()

self.publisher_socket = self.zmq_context.socket(zmq.PUB)
self.publisher_socket.bind("tcp://127.0.0.1:5555")

self.command_sub_socket = self.zmq_context.socket(zmq.SUB)
self.command_sub_socket.connect("tcp://127.0.0.1:5556")

# Subscribe to all messages
self.command_sub_socket.setsockopt_string(zmq.SUBSCRIBE, "")

def update_config(self, config: ControlContextConfig):
"""Update configuration and reinitialize UI components as needed"""
old_display_setting = self.config.display_cameras
self.config = config

# If display setting changed, reinitialize display
if old_display_setting != self.config.display_cameras:
pygame.quit()
self._initialize_display()

# Force screen recreation on next render
self.screen = None

# Update ZMQ message with new config
self._publish_config_update()

return self

def _publish_config_update(self):
"""Publish configuration update to ZMQ subscribers"""
config_data = {
"display_cameras": self.config.display_cameras,
"play_sounds": self.config.play_sounds,
"assign_rewards": self.config.assign_rewards,
"control_phase": self.config.control_phase,
"num_episodes": self.config.num_episodes,
"current_episode": self.current_episode_index,
}

message = {
"type": "config_update",
"timestamp": time.time(),
"config": config_data,
}

self.publisher_socket.send_json(message)

def calculate_window_size(self, images: Dict[str, np.ndarray]):
"""Calculate required window size based on images"""
Expand Down Expand Up @@ -187,26 +222,16 @@ def handle_browser_events(self):
print(f"Error while polling for commands: {e}")

def publish_observations(self, observation: Dict[str, np.ndarray], log_items: list):
"""
Encode and publish the full observation object via ZeroMQ PUB socket.
Includes observation data, events, and config information.
Args:
observation (Dict[str, np.ndarray]): Dictionary containing observation data,
including images and state information
"""
"""Encode and publish observation data with current configuration"""
processed_data = {}

# Process observation data
for key, value in observation.items():
if "image" in key:
# Handle image data
image = value.numpy() if torch.is_tensor(value) else value
# Convert from RGB to BGR for JPEG encoding
bgr_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
success, buffer = cv2.imencode(".jpg", bgr_image)
if success:
# Convert to base64
b64_jpeg = base64.b64encode(buffer).decode("utf-8")
processed_data[key] = {
"type": "image",
Expand All @@ -216,15 +241,13 @@ def publish_observations(self, observation: Dict[str, np.ndarray], log_items: li
}
else:
tensor_data = value.detach().cpu().numpy() if torch.is_tensor(value) else value

processed_data[key] = {
"type": "tensor",
"data": tensor_data.tolist(),
"shape": tensor_data.shape,
}

# Add events and config information
events_data = self.get_events()
# Include current configuration in observation update
config_data = {
"display_cameras": self.config.display_cameras,
"play_sounds": self.config.play_sounds,
Expand All @@ -238,12 +261,11 @@ def publish_observations(self, observation: Dict[str, np.ndarray], log_items: li
"type": "observation_update",
"timestamp": time.time(),
"data": processed_data,
"events": events_data,
"events": self.get_events(),
"config": config_data,
"log_items": log_items,
}

# Send JSON over ZeroMQ
self.publisher_socket.send_json(message)

def render_scene_from_observations(self, observation: Dict[str, np.ndarray]):
Expand Down Expand Up @@ -307,19 +329,22 @@ def log_control_info(self, start_loop_t):

return log_items

def cleanup(self):
self.config.robot.disconnect()
def cleanup(self, robot=None):
"""Clean up resources and connections"""
if robot:
robot.disconnect()

pygame.quit()

self.publisher_socket.close()
self.command_sub_socket.close()
self.zmq_context.term()


if __name__ == "__main__":
import torch
import cv2
import time
import numpy as np
import time
import torch

def read_image_from_camera(cap):
ret, frame = cap.read()
Expand Down
38 changes: 29 additions & 9 deletions lerobot/scripts/browser_ui_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,28 @@
app = Flask(__name__, template_folder=str(template_dir))
socketio = SocketIO(app, cors_allowed_origins="*")

# Global dictionary to hold the latest observation data from ZeroMQ
latest_observation = {}
# Global dictionary to hold the latest data from ZeroMQ
latest_data = {
"observation": {},
"config": {}
}

zmq_context = zmq.Context()

# For recieving observation (camera frames, state, events) from ControlContext
# so we can send them to the browser
# For receiving updates from ControlContext
subscriber_socket = zmq_context.socket(zmq.SUB)
subscriber_socket.connect("tcp://127.0.0.1:5555")
subscriber_socket.setsockopt_string(zmq.SUBSCRIBE, "")

# For sending keydown events from the browser to ControlContext
# For sending keydown events to ControlContext
command_publisher = zmq_context.socket(zmq.PUB)
command_publisher.bind("tcp://127.0.0.1:5556")

def zmq_consumer():
while True:
try:
message = subscriber_socket.recv_json()

if message.get("type") == "observation_update":
processed_data = {
"timestamp": message.get("timestamp"),
Expand All @@ -58,12 +61,24 @@ def zmq_consumer():
"shape": value["shape"]
}

# Update latest observation
latest_observation.update(processed_data)
# Update latest observation and config
latest_data["observation"].update(processed_data)
latest_data["config"].update(processed_data.get("config", {}))

# Emit the observation data to the browser
socketio.emit("observation_update", processed_data)

elif message.get("type") == "config_update":
# Handle dedicated config updates
config_data = message.get("config", {})
latest_data["config"].update(config_data)

# Emit configuration update to browser
socketio.emit("config_update", {
"timestamp": message.get("timestamp"),
"config": config_data
})

except Exception as e:
print(f"ZMQ consumer error: {e}")
time.sleep(1)
Expand Down Expand Up @@ -96,8 +111,13 @@ def handle_connect():
"""Handle client connection."""
print("Client connected")
# Send current state if available
if latest_observation:
socketio.emit("observation_update", latest_observation)
if latest_data["observation"]:
socketio.emit("observation_update", latest_data["observation"])
if latest_data["config"]:
socketio.emit("config_update", {
"timestamp": time.time(),
"config": latest_data["config"]
})

@socketio.on("disconnect")
def handle_disconnect():
Expand Down
5 changes: 3 additions & 2 deletions lerobot/scripts/control_robot.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,18 +295,19 @@ def record(
if has_method(robot, "teleop_safety_stop"):
robot.teleop_safety_stop()

control_context.update_config(
control_context = control_context.update_config(
ControlContextConfig(
robot=robot,
control_phase=ControlPhase.RECORD,
play_sounds=play_sounds,
assign_rewards=False,
num_episodes=num_episodes,
display_cameras=display_cameras,
fps=fps,
)
)

recorded_episodes = 0
recorded_episodes = 1
while True:
if recorded_episodes >= num_episodes:
break
Expand Down
Loading

0 comments on commit 9108649

Please sign in to comment.