Skip to content

Commit

Permalink
Move to scripts folder and address the comments
Browse files Browse the repository at this point in the history
Signed-off-by: Alon Druck <alon.druck@ekumenlabs.com>
  • Loading branch information
Alondruck committed May 24, 2024
1 parent 5cece78 commit edc1a3b
Showing 1 changed file with 42 additions and 44 deletions.
86 changes: 42 additions & 44 deletions beluga_tutorial/src/visualization.py → beluga_tutorial/scripts/visualize.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#!/usr/bin/env python3

# Copyright 2024 Ekumen, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -12,19 +14,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Visualization Scripts for Beluga Tutorial.
This script uses the matplotlib to visualize the data generated
by beluga_tutorial example code.
"""

import yaml
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
import argparse


def read_yaml_file(file_path):
with open(file_path, 'r') as file:
data = yaml.safe_load(file)
return data


def plot_stages(yaml_data, axs, index):
particles = yaml_data['simulation_records'][index]["particles"]
for j, stage in enumerate(['current', 'propagate', 'reweight', 'resample']):
Expand Down Expand Up @@ -86,19 +88,20 @@ def plot_stages(yaml_data, axs, index):
plt.draw()


def on_key(event, yaml_data, axs, current_frame_container, num_frames):
def on_key(event, yaml_data, axs, num_frames):
if not hasattr(on_key, "current_frame"):
on_key.current_frame = 0

if event.key == 'right':
current_frame_container[0] = (current_frame_container[0] + 1) % num_frames
on_key.current_frame = (on_key.current_frame + 1) % num_frames
elif event.key == 'left':
current_frame_container[0] = (current_frame_container[0] - 1) % num_frames
plot_stages(yaml_data, axs, current_frame_container[0])
on_key.current_frame = (on_key.current_frame - 1) % num_frames

plot_stages(yaml_data, axs, on_key.current_frame)


def main(argv=None) -> int:
parser = argparse.ArgumentParser(
description="This script uses the matplotlib to visualize the data generated\
by beluga_tutorial example code."
)
parser = argparse.ArgumentParser(description=globals()['__doc__'])

parser.add_argument(
'-p',
Expand All @@ -111,10 +114,8 @@ def main(argv=None) -> int:
parser.add_argument(
'-m',
'--manual-control',
type=bool,
action='store_true',
help='Manual controlling the time steps of the visualization',
required=False,
default=False,
)

parser.add_argument(
Expand All @@ -128,36 +129,33 @@ def main(argv=None) -> int:

args = parser.parse_args(argv)

yaml_data = read_yaml_file(args.record_file_path)

if yaml_data:
fig, axs = plt.subplots(6, 1)
num_frames = len(yaml_data['simulation_records'])

if args.manual_control:
current_frame_container = [0]
plot_stages(yaml_data, axs, current_frame_container[0])
fig.canvas.mpl_connect(
'key_press_event',
lambda event: on_key(
event, yaml_data, axs, current_frame_container, num_frames
),
)
else:
with open(args.record_file_path, 'r') as file:
yaml_data = yaml.safe_load(file)

def update(current_frame):
plot_stages(yaml_data, axs, current_frame)
fig, axs = plt.subplots(6, 1)
num_frames = len(yaml_data['simulation_records'])

animation.FuncAnimation(
fig,
update,
frames=num_frames,
blit=False,
interval=args.interval_ms,
repeat=False,
)
if args.manual_control:
plot_stages(yaml_data, axs, 0)
fig.canvas.mpl_connect(
'key_press_event',
lambda event: on_key(event, yaml_data, axs, num_frames),
)
else:

def update(current_frame):
plot_stages(yaml_data, axs, current_frame)

anim = animation.FuncAnimation( # noqa: F841
fig,
update,
frames=num_frames,
blit=False,
interval=args.interval_ms,
repeat=False,
)

plt.show()
plt.show()


if __name__ == '__main__':
Expand Down

0 comments on commit edc1a3b

Please sign in to comment.