diff --git a/beluga_tutorial/package.xml b/beluga_tutorial/package.xml index 857770ee2..747a02cd0 100644 --- a/beluga_tutorial/package.xml +++ b/beluga_tutorial/package.xml @@ -18,6 +18,10 @@ beluga yaml_cpp_vendor + python3-matplotlib + python3-numpy + python3-yaml + cmake diff --git a/beluga_tutorial/scripts/visualize.py b/beluga_tutorial/scripts/visualize.py new file mode 100755 index 000000000..e89126b4b --- /dev/null +++ b/beluga_tutorial/scripts/visualize.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python3 + +# Copyright 2024 Ekumen, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Visualization Scripts for Beluga Tutorial. + +This script uses the matplotlib to visualize the data generated +by beluga_tutorial example code. +""" + +import argparse +import matplotlib.pyplot as plt +import matplotlib.animation as animation +import numpy as np +import yaml + + +def plot_stages(yaml_data, axs, index): + particles = yaml_data['simulation_records'][index]["particles"] + for j, stage in enumerate(['current', 'prediction', 'update']): + states = particles[stage]['states'] + ax = axs[j] + ax.clear() + ax.hist( + states, + bins=np.arange(min(states), max(states) + 0.5, 0.5), + color='skyblue', + edgecolor='black', + ) + ax.set_title(f"{stage} - Sim Cycle {index+1}") + ax.set_xlabel("States") + ax.set_ylabel("NP") + ax.set_xlim(-5, 100) + ax.set_xticks(np.arange(-5, 101, 5)) + + landmark_map = yaml_data["landmark_map"] + ax_landmark = axs[3] + ax_landmark.clear() + + for x in range(101): + if x in landmark_map: + color = 'red' + else: + color = 'blue' + ax_landmark.bar(x, 1, color=color) + + ax_landmark.set_title("Landmark Map") + ax_landmark.set_xlabel("States") + ax_landmark.set_xlim(-5, 100) + ax_landmark.set_ylim(0, 2) + ax_landmark.set_xticks(np.arange(-5, 101, 5)) + + ground_truth = yaml_data['simulation_records'][index]["ground_truth"] + ax_ground_truth = axs[4] + ax_ground_truth.clear() + ax_ground_truth.bar(ground_truth, 1, color='green') + ax_ground_truth.set_title(f"Ground Truth: {ground_truth}") + ax_ground_truth.set_xlabel("States") + ax_ground_truth.set_xlim(-5, 100) + ax_ground_truth.set_ylim(0, 2) + ax_ground_truth.set_xticks(np.arange(-5, 101, 5)) + + mean = "{:.3f}".format(yaml_data['simulation_records'][index]["estimation"]["mean"]) + sd = "{:.3f}".format(yaml_data['simulation_records'][index]["estimation"]["sd"]) + plt.text( + 0.5, + 0.5, + f"Mean: {mean}\nSD: {sd}", + ha='center', + va='center', + transform=axs[4].transAxes, + bbox=dict(facecolor='white', alpha=0.5), + ) + + plt.tight_layout() + plt.draw() + + +def main(argv=None) -> int: + """Run the entry point of the program.""" + parser = argparse.ArgumentParser(description=globals()['__doc__']) + + parser.add_argument( + 'record_file_path', + type=str, + help='Absolute path to the record file generated by the beluga_tutorial example code', + ) + + parser.add_argument( + '-m', + '--manual-control', + action='store_true', + help='Manual controlling the time steps of the visualization', + ) + + parser.add_argument( + '-i', + '--interval-ms', + type=int, + help='Delay between frames in milliseconds.', + required=False, + default=250, + ) + + parser.add_argument( + '-r', + '--repeat-animation', + action='store_true', + help='Repeat the animation when it is finished', + ) + + args = parser.parse_args(argv) + + with open(args.record_file_path, 'r') as file: + yaml_data = yaml.safe_load(file) + + fig, axs = plt.subplots(5, 1) + num_frames = len(yaml_data['simulation_records']) + + def plot_stages_update(current_frame: int) -> None: + """Update plots (Helper function).""" + plot_stages(yaml_data, axs, current_frame) + + if args.manual_control: + current_frame = 0 + plot_stages_update(current_frame) + + def on_key(event): + """Handle key events (Callback function).""" + nonlocal current_frame + + if event.key == 'right': + current_frame += 1 + elif event.key == 'left': + current_frame -= 1 + + current_frame %= num_frames + plot_stages_update(current_frame) + + fig.canvas.mpl_connect('key_press_event', on_key) + + else: + anim = animation.FuncAnimation( # noqa: F841 + fig, + plot_stages_update, + frames=num_frames, + blit=False, + interval=args.interval_ms, + repeat=args.repeat_animation, + ) + + plt.show() + + +if __name__ == '__main__': + exit(main())