Skip to content

Commit

Permalink
Fix countdown_time inf bug
Browse files Browse the repository at this point in the history
  • Loading branch information
jackvial committed Dec 27, 2024
1 parent b636718 commit 24b43e1
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 23 deletions.
14 changes: 9 additions & 5 deletions lerobot/common/robot_devices/control_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def _initialize_state(self):
if self.config.assign_rewards:
self.events["next_reward"] = 0

self.pressed_keys = []
self.current_episode_index = 0

# Define the control instructions
Expand Down Expand Up @@ -182,6 +181,11 @@ def _publish_observations(self, observation: Dict[str, np.ndarray], log_items: l
"current_episode": self.current_episode_index,
}

# Sanitize countdown time. if inf set to max 32-bit int
countdown_time = int(countdown_time) if countdown_time != float("inf") else 2 ** 31 - 1
if self.config.control_phase == ControlPhase.TELEOPERATE:
countdown_time = 0

message = {
"type": "observation_update",
"timestamp": time.time(),
Expand Down Expand Up @@ -214,7 +218,8 @@ def log_control_info(self, start_loop_t):
return log_items

def log_say(self, message):
self._publish_log_say(message)
# self._publish_log_say(message)
pass

def _publish_log_say(self, message):
message = {
Expand Down Expand Up @@ -250,7 +255,6 @@ def read_image_from_camera(cap):
return torch.tensor(frame_rgb).float()

config = ControlContextConfig(
display_cameras=False,
assign_rewards=True,
control_phase=ControlPhase.RECORD,
num_episodes=200,
Expand All @@ -259,7 +263,7 @@ def read_image_from_camera(cap):
context = ControlContext(config)
context.update_current_episode(199)

cameras = {"main": cv2.VideoCapture(0), "top": cv2.VideoCapture(4), "web": cv2.VideoCapture(6)}
cameras = {"main": cv2.VideoCapture(0), "top": cv2.VideoCapture(4)}

for name, cap in cameras.items():
if not cap.isOpened():
Expand All @@ -282,7 +286,7 @@ def read_image_from_camera(cap):
obs_dict[f"observation.images.{name}"] = images[name]

# Update context with observations
context.update_with_observations(obs_dict)
context.update_with_observations(obs_dict, time.perf_counter(), countdown_time=10)
events = context.get_events()

if events["exit_early"]:
Expand Down
17 changes: 4 additions & 13 deletions lerobot/scripts/browser_ui_server.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
from flask import Flask, render_template
from flask_socketio import SocketIO
import cv2
import numpy as np
import base64
import json
import threading
import time
import zmq
Expand Down Expand Up @@ -37,9 +33,6 @@ def zmq_consumer():
while True:
try:
message = subscriber_socket.recv_json()

message_type = message.get("type")
print(f"Received message: {message_type}")

if message.get("type") == "observation_update":
processed_data = {
Expand All @@ -48,8 +41,8 @@ def zmq_consumer():
"state": {},
"events": message.get("events", {}),
"config": message.get("config", {}),
"log_items": message.get("log_items", []),
"countdown_time": message.get("countdown_time"),
# "log_items": message.get("log_items", []),
"countdown_time": message.get("countdown_time")
}

# Process observation data
Expand All @@ -69,8 +62,9 @@ def zmq_consumer():
latest_data["observation"].update(processed_data)
latest_data["config"].update(processed_data.get("config", {}))

# Emit the observation data to the browser
# # Emit the observation data to the browser
socketio.emit("observation_update", processed_data)


elif message.get("type") == "config_update":
# Handle dedicated config updates
Expand All @@ -91,12 +85,9 @@ def zmq_consumer():
"message": data
})


except Exception as e:
print(f"ZMQ consumer error: {e}")
time.sleep(1)

time.sleep(0.001) # Small sleep to prevent busy-waiting


@socketio.on("keydown_event")
Expand Down
16 changes: 11 additions & 5 deletions lerobot/templates/browser_ui.html
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
<title>Robot Observation Stream (Alpine.js Refactor - Dark Theme)</title>

<!-- Socket.IO -->
<script src="https://cdnjs.cloudflare.com/ajax/libs/socket.io/4.0.1/socket.io.js"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/socket.io/4.8.1/socket.io.js"></script>
<!-- Tailwind CSS -->
<script src="https://cdn.tailwindcss.com"></script>
<!-- Alpine.js -->
Expand Down Expand Up @@ -82,7 +82,6 @@
<div
id="cameras"
class="grid grid-cols-1 md:grid-cols-2 gap-4"
x-show="configData.display_cameras"
>
<!-- Loop over images object -->
<template x-for="(imageData, name) in images" :key="name">
Expand Down Expand Up @@ -215,7 +214,7 @@ <h3 class="text-lg font-semibold mb-4 text-gray-300">Logs</h3>

<script>
document.addEventListener('alpine:init', () => {
Alpine.data('LeRobot', () => ({
Alpine.data('robotApp', () => ({
// Reactive state
configData: {},
stateData: {},
Expand All @@ -233,7 +232,14 @@ <h3 class="text-lg font-semibold mb-4 text-gray-300">Logs</h3>
},

initSocket() {
this.socket = io();
this.socket = io({
reconnection: true,
reconnectionAttempts: Infinity, // Unlimited reconnection attempts
reconnectionDelay: 1000, // Start with 1 second delay
reconnectionDelayMax: 5000, // Max delay of 5 seconds
randomizationFactor: 0.5
});

this.socket.on('connect', () => {
console.log('Connected to server');
});
Expand Down Expand Up @@ -261,7 +267,7 @@ <h3 class="text-lg font-semibold mb-4 text-gray-300">Logs</h3>
if (data.log_items) {
this.logs = data.log_items;
}
if (data.images && this.configData.display_cameras) {
if (data.images) {
this.images = data.images;
}
if (data.countdown_time !== undefined) {
Expand Down

0 comments on commit 24b43e1

Please sign in to comment.