Skip to content

Commit d6e6d31

Browse files
committed
reorganize tests
1 parent 2ab4fa6 commit d6e6d31

File tree

4 files changed

+274
-488
lines changed

4 files changed

+274
-488
lines changed

src/sme_contrib/plot.py

+145-190
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,11 @@
77
from itertools import cycle
88
import matplotlib.colors as mcolors
99
import pyvista as pv
10-
from typing import Any
10+
from typing import Any, Callable
1111
import sme
1212

1313
from .pyvista_utils import (
14-
facet_animate3D,
15-
facet_plot3D,
16-
rgb_to_scalar,
14+
find_layout,
1715
)
1816

1917

@@ -155,216 +153,173 @@ def concentration_heatmap_animation(
155153
return anim
156154

157155

158-
def plot_concentration_image_3D(
159-
simulation_result: sme.SimulationResult,
160-
show_cmap: bool = False,
161-
cmap: str | np.ndarray | pv.LookupTable = "viridis",
162-
plotter_kwargs: dict[str, Any] = {"border": False, "notebook": False},
163-
) -> pv.Plotter:
164-
"""Plot a single concentration image in 3D using pyvista.
165-
166-
Args:
167-
simulation_result (sme.SimulationResult): The simulation result object of a 3D simulation for a single timestep
168-
show_cmap (bool, optional): Whether the colormap should be shown or not. Defaults to False.
169-
cmap (str | np.ndarray | pv.LookupTable, optional): colormap to use. Either a string naming a matplotlib colormap, a numpy array of values or a pyvista lookup table mapping values ot rgba colors. Defaults to "viridis".
170-
plotter_kwargs (dict[str, Any], optional): Addtitional kwargs for the pyvista.Plotter constructor Defaults to {"border": False, "notebook": False}.
171-
172-
Returns:
173-
pv.Plotter: pyvista Plotter object
174-
"""
175-
176-
def plot_single(
177-
title,
178-
data,
179-
plotter,
180-
panel,
181-
show_cmap=show_cmap,
182-
cmap=cmap,
183-
):
184-
_data = rgb_to_scalar(data)
185-
186-
plotter.subplot(*panel)
187-
188-
if title:
189-
plotter.add_text(title)
190-
191-
img_data = pv.ImageData(
192-
dimensions=_data.shape,
193-
)
194-
img_data.point_data["Data"] = _data.flatten()
195-
img_data = img_data.points_to_cells(scalars="Data")
196-
plotter.subplot(0, 0)
197-
plotter.add_mesh(
198-
img_data,
199-
show_edges=True,
200-
show_scalar_bar=show_cmap,
201-
cmap=cmap,
202-
)
203-
204-
return facet_plot3D(
205-
data={"concentrations": simulation_result.concentration_image},
206-
plotfuncs={"concentrations": plot_single},
207-
show_cmap=show_cmap,
208-
cmap=cmap,
209-
portrait=True,
210-
with_titles=True,
211-
plotter_kwargs=plotter_kwargs,
212-
)
213-
214-
215-
def plot_species_concentration_3D(
216-
simulation_result: sme.SimulationResult,
217-
species: list[str],
218-
thresholds: list[float] = [],
156+
def facet_grid_3D(
157+
data: dict[str, np.ndarray],
158+
plotfuncs: dict[str, Callable],
219159
show_cmap: bool = False,
220160
cmap: str | np.ndarray | pv.LookupTable = "viridis",
221161
portrait: bool = False,
222162
linked_views: bool = True,
223-
with_titles: bool = True,
224-
plotter_kwargs: dict[str, Any] = {"border": False, "notebook": False},
163+
plotter_kwargs: dict = {},
164+
plotfuncs_kwargs: dict[str, dict[str, Any]] = {},
225165
) -> pv.Plotter:
226-
"""Plot the concentration of a list of species in 3D using pyvista.
227-
This function creates a 3D plot of the concentration for each species in a separate subplot
228-
Args:
229-
simulation_result (sme.SimulationResult): Simulationresult object for a given timestep
230-
species (list[str]): list of species to plot
231-
thresholds (list[float], optional): Thresholds to exclude some values for each plot. If empty, it is set to 1e6 to effectively have no threshold. Defaults to [].
232-
show_cmap (bool, optional): Whether the colormap should be shown or not. Defaults to False.
233-
cmap (str | np.ndarray | pv.LookupTable, optional): colormap to use. Either a string naming a matplotlib colormap, a numpy array of values or a pyvista lookup table mapping values ot rgba colors. Defaults to "viridis".
234-
portrait (bool, optional): Whether to organize plot grid (n,m) in potrait mode (smaller of (n,m) as columns) or landscape mode (smaller number as rows). Defaults to False.
235-
linked_views (bool, optional): If all the views should be linked togehter such that perspective changes affect all plots the same. Defaults to True.
236-
with_titles (bool, optional): Have a title for each plot or not. Defaults to True.
237-
plotter_kwargs (dict[str, Any], optional): Additional kwargs for pyvista.Plotter. Defaults to {"border": False, "notebook": False}.
166+
"""
167+
Create a 3D facet plot using PyVista. This follows the seaborn.FacetGrid concept. This function creates a grid of subplots where each subplot is filled by a function in the plotfuncs argument. The keys for plotfuncs and data must be the same, such that plotfuncs can be unambiguously mapped over the data dictionary.
168+
Do not attempt to plot 2D images and 3D images into the same facet grid, as this will create odd artifacts and
169+
may not work as expected.
170+
Parameters:
171+
-----------
172+
data : dict[str, np.ndarray]
173+
A dictionary where keys are labels and values are numpy arrays containing the data to be plotted.
174+
plotfuncs : dict[str, Callable]
175+
A dictionary where keys are labels and values are functions with signature f(
176+
label:str,
177+
data:np.ndarray | pyvista.ImageData | pyvista.UniformGrid,
178+
plotter:pv.Plotter,
179+
panel:tuple[int, int],
180+
show_cmap:bool=show_cmap,
181+
cmap=cmap,
182+
**plotfuncs_kwargs
183+
) -> None
184+
show_cmap : bool, optional
185+
Whether to show the color map. Default is False.
186+
cmap : str | np.ndarray | pv.LookupTable, optional
187+
The color map to use. Default is "viridis".
188+
portrait : bool, optional
189+
Whether to use a portrait layout. Default is False.
190+
linked_views : bool, optional
191+
Whether to link the views of the subplots. Default is True.
192+
plotter_kwargs : dict, optional
193+
Additional keyword arguments to pass to the PyVista Plotter.
194+
plotfuncs_kwargs : dict[str, dict[str, Any]], optional
195+
Additional keyword arguments to pass to each plotting function.
238196
239197
Returns:
240-
pv.Plotter: pyvista plotter object
198+
--------
199+
pv.Plotter
200+
The PyVista Plotter object with the created facet plot.
241201
"""
202+
if data.keys() != plotfuncs.keys():
203+
raise ValueError(
204+
"The keys for the data and plotfuncs dictionaries must be the same."
205+
)
242206

243-
def plot_single(
244-
label: str,
245-
data: np.ndarray,
246-
plotter: pv.Plotter,
247-
panel: tuple[int, int],
248-
show_cmap,
249-
cmap=cmap,
250-
threshold_value=1e6,
251-
):
252-
_data = rgb_to_scalar(data) if len(data.shape) == 4 else data
207+
layout = find_layout(len(data), portrait=portrait)
253208

254-
plotter.subplot(*panel)
209+
plotter = pv.Plotter(shape=layout, **plotter_kwargs)
255210

256-
if with_titles:
257-
plotter.add_text(label)
211+
label = iter(plotfuncs.keys())
258212

259-
img_data = pv.ImageData(
260-
dimensions=_data.shape,
261-
)
262-
img_data.point_data["Data"] = _data.flatten()
263-
img_data = img_data.points_to_cells(scalars="Data")
264-
plotter.subplot(0, 0)
265-
plotter.add_mesh(
266-
img_data.threshold(threshold_value),
267-
show_edges=True,
268-
show_scalar_bar=show_cmap,
269-
cmap=cmap,
270-
)
213+
for i in range(layout[0]):
214+
for j in range(layout[1]):
215+
current_label = next(label)
216+
plotfuncs[current_label](
217+
current_label,
218+
data[current_label],
219+
plotter,
220+
panel=(i, j),
221+
show_cmap=show_cmap,
222+
cmap=cmap,
223+
**plotfuncs_kwargs.get(current_label, {}),
224+
)
271225

272-
return facet_plot3D(
273-
data={sp: simulation_result.species_concentration[sp] for sp in species},
274-
plotfuncs={sp: plot_single for sp in species},
275-
show_cmap=show_cmap,
276-
cmap=cmap,
277-
portrait=portrait,
278-
with_titles=with_titles,
279-
linked_views=linked_views,
280-
plotter_kwargs=plotter_kwargs,
281-
plotfuncs_kwargs={
282-
species[i]: {"threshold_value": thresholds[i]}
283-
for i in range(0, len(species))
284-
}
285-
if thresholds != []
286-
else {},
287-
)
226+
if linked_views:
227+
plotter.link_views()
228+
229+
return plotter
288230

289231

290-
def concentrations_animation_3D(
232+
def facet_grid_animate_3D(
291233
filename: str,
292-
simulation_results: sme.SimulationResultList,
293-
species: list[str],
294-
thresholds: list[float] = [],
234+
data: list[dict[str, np.ndarray]],
235+
plotfuncs: dict[str, Callable],
295236
show_cmap: bool = False,
296237
cmap: str | np.ndarray | pv.LookupTable = "viridis",
297238
portrait: bool = False,
298239
linked_views: bool = True,
299-
with_titles: bool = True,
300-
plotter_kwargs: dict = {"border": False, "notebook": False},
240+
titles: list[dict[str, str]] = [],
241+
plotter_kwargs: dict = {},
242+
plotfuncs_kwargs: dict[str, dict[str, Any]] = {},
301243
) -> str:
302-
"""Create an .mp4 video of the concentration of a list of species in 3D using pyvista, with one frame being one timestep for each species. In essence, this is an animated version of the plot_species_concentration_3D function.
303-
304-
Args:
305-
filename (str): filename to save the video
306-
simulation_results (sme.SimulationResultList): List of simulation results for each timestep
307-
species (list[str]): List of species to animate
308-
thresholds (list[float], optional): Thresholds to limit the plotted values for each species Values larger than the threshold will be cut. Defaults to [].
309-
show_cmap (bool, optional): Whether the colormap should be shown or not. Defaults to False.
310-
cmap (str | np.ndarray | pv.LookupTable, optional): colormap to use. Either a string naming a matplotlib colormap, a numpy array of values or a pyvista lookup table mapping values ot rgba colors. Defaults to "viridis".
311-
portrait (bool, optional): Whether to organize plot grid (n,m) in potrait mode (smaller of (n,m) as columns) or landscape mode (smaller number as rows). Defaults to False.
312-
linked_views (bool, optional): If all the views should be linked togehter such that perspective changes affect all plots the same. Defaults to True.
313-
with_titles (bool, optional): Have a title for each plot or not. Defaults to True.
314-
plotter_kwargs (dict[str, Any], optional): Additional kwargs for pyvista.Plotter. Defaults to {"border": False, "notebook": False}.
315-
244+
"""
245+
Create a 3D animation from a series of data snapshots using PyVista.
246+
This series must be a list of dictionaries with the data for each frame keyed by a label used to title the panel it will be plotted into. The final plot will have as many subplots as there are labels in the data dictionaries. The keys for plotfuncs and data must be the same.
247+
Parameters:
248+
-----------
249+
filename : str
250+
The name of the output movie file.
251+
data : list[dict[str, np.ndarray]]
252+
A list of dictionaries containing the data for each timestep.
253+
plotfuncs : dict[str, Callable]
254+
A dictionary of plotting functions keyed by data label. The keys for plotfuncs and data must be the same.
255+
show_cmap : bool, optional
256+
Whether to show the color map (default is False).
257+
cmap : str | np.ndarray | pv.LookupTable, optional
258+
The colormap to use (default is "viridis").
259+
portrait : bool, optional
260+
Whether to use portrait layout (default is False).
261+
linked_views : bool, optional
262+
Whether to link the views of the subplots (default is True).
263+
titles : list[dict[str, str]], optional
264+
A list of dictionaries containing titles for each subplot (default is an empty list).
265+
plotter_kwargs : dict, optional
266+
Additional keyword arguments to pass to the PyVista Plotter (default is an empty dictionary).
267+
plotfuncs_kwargs : dict[str, dict[str, Any]], optional
268+
Additional keyword arguments to pass to each plotting function (default is an empty dictionary).
316269
Returns:
317-
str: filename of the saved video
270+
--------
271+
str
272+
The filename of the created movie.
318273
"""
319274

320-
def plot_single(
321-
label: str,
322-
data: np.ndarray,
323-
plotter: pv.Plotter,
324-
panel: tuple[int, int],
325-
show_cmap,
326-
cmap=cmap,
327-
threshold_value=1e6,
328-
):
329-
_data = rgb_to_scalar(data) if len(data.shape) == 4 else data
275+
if len(titles) > 0 and len(titles) != len(data):
276+
raise ValueError(
277+
"The number of titles must be the same as the number of data dictionaries."
278+
)
330279

331-
img_data = pv.ImageData(
332-
dimensions=_data.shape,
280+
if data[0].keys() != plotfuncs.keys():
281+
raise ValueError(
282+
"The keys for the data and plotfuncs dictionaries must be the same."
333283
)
334-
img_data.point_data["Data"] = _data.flatten()
335-
img_data = img_data.points_to_cells(scalars="Data")
336-
337-
plotter.subplot(*panel)
338-
if with_titles:
339-
plotter.add_text(label, name=label + str(panel))
340-
341-
actor = plotter.add_mesh(
342-
img_data.threshold(threshold_value),
343-
show_edges=True,
344-
show_scalar_bar=show_cmap,
345-
cmap=cmap,
346-
name="mesh" + label + str(panel),
284+
285+
# main function, called for each frame in the movie
286+
def create_frame(
287+
data_dict: dict[str, np.ndarray], title: dict[str:str], layout=(1, 1)
288+
):
289+
label = iter(data_dict.keys())
290+
for i in range(layout[0]):
291+
for j in range(layout[1]):
292+
current_label = next(label)
293+
plotfuncs[current_label](
294+
title.get(current_label, current_label),
295+
data_dict[current_label],
296+
plotter,
297+
panel=(i, j),
298+
show_cmap=show_cmap,
299+
cmap=cmap,
300+
**plotfuncs_kwargs.get(current_label, {}),
301+
)
302+
303+
plotter.write_frame()
304+
305+
# preparations
306+
layout = find_layout(len(plotfuncs), portrait=portrait)
307+
308+
plotter = pv.Plotter(shape=layout, **plotter_kwargs)
309+
310+
plotter.open_movie(filename)
311+
312+
# add first frame here to set up the plotter
313+
create_frame(data[0], titles[0] if len(titles) > 0 else {}, layout)
314+
315+
if linked_views:
316+
plotter.link_views()
317+
318+
for i, single_timestep_data in enumerate(data[1::]):
319+
create_frame(
320+
single_timestep_data, titles[i] if len(titles) > 0 else {}, layout=layout
347321
)
348-
actor.mapper.scalar_range = (np.min(_data), np.max(_data))
349-
350-
return facet_animate3D(
351-
filename=filename,
352-
data=[s.species_concentration for s in simulation_results],
353-
titles=[
354-
{sp: f"Concentration of {sp} at t={s.time_point}" for sp in species}
355-
for s in simulation_results
356-
],
357-
plotfuncs={sp: plot_single for sp in species},
358-
show_cmap=show_cmap,
359-
cmap=cmap,
360-
portrait=portrait,
361-
linked_views=linked_views,
362-
with_titles=with_titles,
363-
plotter_kwargs=plotter_kwargs,
364-
plotfuncs_kwargs={
365-
species[i]: {"threshold_value": thresholds[i]}
366-
for i in range(0, len(species))
367-
}
368-
if len(thresholds) > 0
369-
else {},
370-
)
322+
323+
plotter.close()
324+
325+
return filename

0 commit comments

Comments
 (0)