Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pyvista 3 d visualization #48

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
fix documentation issue
  • Loading branch information
MaHaWo committed Mar 12, 2025
commit 9520ac8ff298c96745d7ea03f1f0d4221a60e85d
34 changes: 17 additions & 17 deletions src/sme_contrib/plot.py
Original file line number Diff line number Diff line change
@@ -144,14 +144,14 @@ def facet_grid_3D(
This follows the seaborn.FacetGrid concept. This function creates a grid of subplots where each subplot is filled by a function in the plotfuncs argument. The keys for plotfuncs and data must be the same, such that plotfuncs can be unambiguously mapped over the data dictionary. Do not attempt to plot 2D images and 3D images into the same facet grid, as this will create odd artifacts and may not work as expected.

Args:
data : dict[str, np.ndarray] A dictionary where keys are labels and values are numpy arrays containing the data to be plotted.
plotfuncs : dict[str, Callable] A dictionary where keys are labels and values are functions with signature f(label:str, data:np.ndarray | pyvista.ImageData | pyvista.UniformGrid, plotter:pv.Plotter, panel:tuple[int, int], show_cmap:bool=show_cmap, cmap=cmap, **plotfuncs_kwargs ) -> None
data : (dict[str, np.ndarray]) A dictionary where keys are labels and values are numpy arrays containing the data to be plotted.
plotfuncs : (dict[str, Callable]) A dictionary where keys are labels and values are functions with signature ``f(label:str, data:np.ndarray | pyvista.ImageData | pyvista.UniformGrid, plotter:pv.Plotter, panel:tuple[int, int], show_cmap:bool=show_cmap, cmap=cmap, **plotfuncs_kwargs )`` -> None
show_cmap : bool, optional Whether to show the color map. Default is False.
cmap : str | np.ndarray | pv.LookupTable, optional The color map to use. Default is "viridis".
portrait : bool, optional Whether to use a portrait layout. Default is False.
linked_views : bool, optional Whether to link the views of the subplots. Default is True.
plotter_kwargs : dict, optional Additional keyword arguments to pass to the PyVista Plotter.
plotfuncs_kwargs : dict[str, dict[str, Any]], optional Additional keyword arguments to pass to each plotting function.
cmap : (str | np.ndarray | pv.LookupTable), optional The color map to use. Default is "viridis".
portrait : (bool), optional Whether to use a portrait layout. Default is False.
linked_views : (bool), optional Whether to link the views of the subplots. Default is True.
plotter_kwargs : (dict, optional) Additional keyword arguments to pass to the PyVista Plotter.
plotfuncs_kwargs : (dict[str, dict[str, Any]]), optional Additional keyword arguments to pass to each plotting function.

Returns:
pv.Plotter The PyVista Plotter object with the created facet plot.
@@ -204,16 +204,16 @@ def facet_grid_animate_3D(
This series must be a list of dictionaries with the data for each frame keyed by a label used to title the panel it will be plotted into. The final plot will have as many subplots as there are labels in the data dictionaries. The keys for plotfuncs and data must be the same.

Args:
filename : str The name of the output movie file.
data : list[dict[str, np.ndarray]] A list of dictionaries containing the data for each timestep.
plotfuncs : dict[str, Callable] A dictionary of plotting functions keyed by data label. The keys for plotfuncs and data must be the same.
show_cmap : bool, optional Whether to show the color map (default is False).
cmap : str | np.ndarray | pv.LookupTable, optional The colormap to use (default is "viridis").
portrait : bool, optional Whether to use portrait layout (default is False).
linked_views : bool, optional Whether to link the views of the subplots (default is True).
titles : list[dict[str, str]], optional A list of dictionaries containing titles for each subplot (default is an empty list).
plotter_kwargs : dict, optional Additional keyword arguments to pass to the PyVista Plotter (default is an empty dictionary).
plotfuncs_kwargs : dict[str, dict[str, Any]], optional Additional keyword arguments to pass to each plotting function (default is an empty dictionary).
filename : (str) The name of the output movie file.
data : (list[dict[str, np.ndarray]]) A list of dictionaries containing the data for each timestep.
plotfuncs : (dict[str, Callable]) A dictionary of plotting functions keyed by data label. The keys for plotfuncs and data must be the same.
show_cmap : (bool), optional Whether to show the color map (default is False).
cmap : (str | np.ndarray | pv.LookupTable, optional) The colormap to use (default is "viridis").
portrait : (bool), optional Whether to use portrait layout (default is False).
linked_views : (bool), optional Whether to link the views of the subplots (default is True).
titles : (list[dict[str, str]]), optional A list of dictionaries containing titles for each subplot (default is an empty list).
plotter_kwargs : (dict), optional Additional keyword arguments to pass to the PyVista Plotter (default is an empty dictionary).
plotfuncs_kwargs : (dict[str, dict[str, Any]]), optional Additional keyword arguments to pass to each plotting function (default is an empty dictionary).

Returns:
str The filename of the created movie.
11 changes: 6 additions & 5 deletions src/sme_contrib/pyvista_utils.py
Original file line number Diff line number Diff line change
@@ -9,10 +9,11 @@ def rgb_to_scalar(img: np.ndarray) -> np.ndarray:
"""
Convert an RGB 3D image represented as a 4D tensor to a 3D image tensor where each unique RGB value is assigned a unique scalar, i.e., it contracts the dimension with the RGB values into scalars in such a way that 2 different colors are mapped to 2 different scalars, too. This is needed because PyVista doesn't work with RGB values directly and expects fields defined on a grid.

img (np.ndarray): A 3D numpy array representing an RGB image with shape (height, width, 3).
Args:
img (np.ndarray): A 3D numpy array representing an RGB image with shape (height, width, 3).

np.ndarray: A 2D numpy array with the same height and width as the input image, where each pixel's value
corresponds to a unique scalar representing the original RGB value.
Retruns:
np.ndarray: A 2D numpy array with the same height and width as the input image, where each pixel's value corresponds to a unique scalar representing the original RGB value.
"""
reshaped = np.copy(img.reshape(-1, 3))
unique_rgb, ridx = np.unique(reshaped, axis=0, return_inverse=True)
@@ -61,11 +62,11 @@ def find_layout(num_plots: int, portrait: bool = False) -> tuple[int, int]:
"""Find a reasonable layout for a grid of subplots. This splits num_subplots into n x m subplots where n and m are as close as possible to each other. This can include a case where n x m > num_plots. Then, the superficial panels in the grid are ignored in the plotting process.

Args:
num_plots (int): Number of plots to arrange
num_plots (int): Number of plots to arrange
portrait (bool, optional): Whether the min or max of (n,m) should be the column number in the resulting grid. Defaults to False.

Returns:
tuple[int, int]: Tuple describing (n_rows, n_cols) of the grid
tuple[int, int]: Tuple describing (n_rows, n_cols) of the grid
"""

# for checking approximation accuracy with ints. if root > root_int, then
Loading
Oops, something went wrong.