|
7 | 7 | from itertools import cycle
|
8 | 8 | import matplotlib.colors as mcolors
|
9 | 9 | import pyvista as pv
|
10 |
| -from typing import Any |
| 10 | +from typing import Any, Callable |
11 | 11 | import sme
|
12 | 12 |
|
13 | 13 | from .pyvista_utils import (
|
14 |
| - facet_animate3D, |
15 |
| - facet_plot3D, |
16 |
| - rgb_to_scalar, |
| 14 | + find_layout, |
17 | 15 | )
|
18 | 16 |
|
19 | 17 |
|
@@ -155,216 +153,173 @@ def concentration_heatmap_animation(
|
155 | 153 | return anim
|
156 | 154 |
|
157 | 155 |
|
158 |
| -def plot_concentration_image_3D( |
159 |
| - simulation_result: sme.SimulationResult, |
160 |
| - show_cmap: bool = False, |
161 |
| - cmap: str | np.ndarray | pv.LookupTable = "viridis", |
162 |
| - plotter_kwargs: dict[str, Any] = {"border": False, "notebook": False}, |
163 |
| -) -> pv.Plotter: |
164 |
| - """Plot a single concentration image in 3D using pyvista. |
165 |
| -
|
166 |
| - Args: |
167 |
| - simulation_result (sme.SimulationResult): The simulation result object of a 3D simulation for a single timestep |
168 |
| - show_cmap (bool, optional): Whether the colormap should be shown or not. Defaults to False. |
169 |
| - cmap (str | np.ndarray | pv.LookupTable, optional): colormap to use. Either a string naming a matplotlib colormap, a numpy array of values or a pyvista lookup table mapping values ot rgba colors. Defaults to "viridis". |
170 |
| - plotter_kwargs (dict[str, Any], optional): Addtitional kwargs for the pyvista.Plotter constructor Defaults to {"border": False, "notebook": False}. |
171 |
| -
|
172 |
| - Returns: |
173 |
| - pv.Plotter: pyvista Plotter object |
174 |
| - """ |
175 |
| - |
176 |
| - def plot_single( |
177 |
| - title, |
178 |
| - data, |
179 |
| - plotter, |
180 |
| - panel, |
181 |
| - show_cmap=show_cmap, |
182 |
| - cmap=cmap, |
183 |
| - ): |
184 |
| - _data = rgb_to_scalar(data) |
185 |
| - |
186 |
| - plotter.subplot(*panel) |
187 |
| - |
188 |
| - if title: |
189 |
| - plotter.add_text(title) |
190 |
| - |
191 |
| - img_data = pv.ImageData( |
192 |
| - dimensions=_data.shape, |
193 |
| - ) |
194 |
| - img_data.point_data["Data"] = _data.flatten() |
195 |
| - img_data = img_data.points_to_cells(scalars="Data") |
196 |
| - plotter.subplot(0, 0) |
197 |
| - plotter.add_mesh( |
198 |
| - img_data, |
199 |
| - show_edges=True, |
200 |
| - show_scalar_bar=show_cmap, |
201 |
| - cmap=cmap, |
202 |
| - ) |
203 |
| - |
204 |
| - return facet_plot3D( |
205 |
| - data={"concentrations": simulation_result.concentration_image}, |
206 |
| - plotfuncs={"concentrations": plot_single}, |
207 |
| - show_cmap=show_cmap, |
208 |
| - cmap=cmap, |
209 |
| - portrait=True, |
210 |
| - with_titles=True, |
211 |
| - plotter_kwargs=plotter_kwargs, |
212 |
| - ) |
213 |
| - |
214 |
| - |
215 |
| -def plot_species_concentration_3D( |
216 |
| - simulation_result: sme.SimulationResult, |
217 |
| - species: list[str], |
218 |
| - thresholds: list[float] = [], |
| 156 | +def facet_grid_3D( |
| 157 | + data: dict[str, np.ndarray], |
| 158 | + plotfuncs: dict[str, Callable], |
219 | 159 | show_cmap: bool = False,
|
220 | 160 | cmap: str | np.ndarray | pv.LookupTable = "viridis",
|
221 | 161 | portrait: bool = False,
|
222 | 162 | linked_views: bool = True,
|
223 |
| - with_titles: bool = True, |
224 |
| - plotter_kwargs: dict[str, Any] = {"border": False, "notebook": False}, |
| 163 | + plotter_kwargs: dict = {}, |
| 164 | + plotfuncs_kwargs: dict[str, dict[str, Any]] = {}, |
225 | 165 | ) -> pv.Plotter:
|
226 |
| - """Plot the concentration of a list of species in 3D using pyvista. |
227 |
| - This function creates a 3D plot of the concentration for each species in a separate subplot |
228 |
| - Args: |
229 |
| - simulation_result (sme.SimulationResult): Simulationresult object for a given timestep |
230 |
| - species (list[str]): list of species to plot |
231 |
| - thresholds (list[float], optional): Thresholds to exclude some values for each plot. If empty, it is set to 1e6 to effectively have no threshold. Defaults to []. |
232 |
| - show_cmap (bool, optional): Whether the colormap should be shown or not. Defaults to False. |
233 |
| - cmap (str | np.ndarray | pv.LookupTable, optional): colormap to use. Either a string naming a matplotlib colormap, a numpy array of values or a pyvista lookup table mapping values ot rgba colors. Defaults to "viridis". |
234 |
| - portrait (bool, optional): Whether to organize plot grid (n,m) in potrait mode (smaller of (n,m) as columns) or landscape mode (smaller number as rows). Defaults to False. |
235 |
| - linked_views (bool, optional): If all the views should be linked togehter such that perspective changes affect all plots the same. Defaults to True. |
236 |
| - with_titles (bool, optional): Have a title for each plot or not. Defaults to True. |
237 |
| - plotter_kwargs (dict[str, Any], optional): Additional kwargs for pyvista.Plotter. Defaults to {"border": False, "notebook": False}. |
| 166 | + """ |
| 167 | + Create a 3D facet plot using PyVista. This follows the seaborn.FacetGrid concept. This function creates a grid of subplots where each subplot is filled by a function in the plotfuncs argument. The keys for plotfuncs and data must be the same, such that plotfuncs can be unambiguously mapped over the data dictionary. |
| 168 | + Do not attempt to plot 2D images and 3D images into the same facet grid, as this will create odd artifacts and |
| 169 | + may not work as expected. |
| 170 | + Parameters: |
| 171 | + ----------- |
| 172 | + data : dict[str, np.ndarray] |
| 173 | + A dictionary where keys are labels and values are numpy arrays containing the data to be plotted. |
| 174 | + plotfuncs : dict[str, Callable] |
| 175 | + A dictionary where keys are labels and values are functions with signature f( |
| 176 | + label:str, |
| 177 | + data:np.ndarray | pyvista.ImageData | pyvista.UniformGrid, |
| 178 | + plotter:pv.Plotter, |
| 179 | + panel:tuple[int, int], |
| 180 | + show_cmap:bool=show_cmap, |
| 181 | + cmap=cmap, |
| 182 | + **plotfuncs_kwargs |
| 183 | + ) -> None |
| 184 | + show_cmap : bool, optional |
| 185 | + Whether to show the color map. Default is False. |
| 186 | + cmap : str | np.ndarray | pv.LookupTable, optional |
| 187 | + The color map to use. Default is "viridis". |
| 188 | + portrait : bool, optional |
| 189 | + Whether to use a portrait layout. Default is False. |
| 190 | + linked_views : bool, optional |
| 191 | + Whether to link the views of the subplots. Default is True. |
| 192 | + plotter_kwargs : dict, optional |
| 193 | + Additional keyword arguments to pass to the PyVista Plotter. |
| 194 | + plotfuncs_kwargs : dict[str, dict[str, Any]], optional |
| 195 | + Additional keyword arguments to pass to each plotting function. |
238 | 196 |
|
239 | 197 | Returns:
|
240 |
| - pv.Plotter: pyvista plotter object |
| 198 | + -------- |
| 199 | + pv.Plotter |
| 200 | + The PyVista Plotter object with the created facet plot. |
241 | 201 | """
|
| 202 | + if data.keys() != plotfuncs.keys(): |
| 203 | + raise ValueError( |
| 204 | + "The keys for the data and plotfuncs dictionaries must be the same." |
| 205 | + ) |
242 | 206 |
|
243 |
| - def plot_single( |
244 |
| - label: str, |
245 |
| - data: np.ndarray, |
246 |
| - plotter: pv.Plotter, |
247 |
| - panel: tuple[int, int], |
248 |
| - show_cmap, |
249 |
| - cmap=cmap, |
250 |
| - threshold_value=1e6, |
251 |
| - ): |
252 |
| - _data = rgb_to_scalar(data) if len(data.shape) == 4 else data |
| 207 | + layout = find_layout(len(data), portrait=portrait) |
253 | 208 |
|
254 |
| - plotter.subplot(*panel) |
| 209 | + plotter = pv.Plotter(shape=layout, **plotter_kwargs) |
255 | 210 |
|
256 |
| - if with_titles: |
257 |
| - plotter.add_text(label) |
| 211 | + label = iter(plotfuncs.keys()) |
258 | 212 |
|
259 |
| - img_data = pv.ImageData( |
260 |
| - dimensions=_data.shape, |
261 |
| - ) |
262 |
| - img_data.point_data["Data"] = _data.flatten() |
263 |
| - img_data = img_data.points_to_cells(scalars="Data") |
264 |
| - plotter.subplot(0, 0) |
265 |
| - plotter.add_mesh( |
266 |
| - img_data.threshold(threshold_value), |
267 |
| - show_edges=True, |
268 |
| - show_scalar_bar=show_cmap, |
269 |
| - cmap=cmap, |
270 |
| - ) |
| 213 | + for i in range(layout[0]): |
| 214 | + for j in range(layout[1]): |
| 215 | + current_label = next(label) |
| 216 | + plotfuncs[current_label]( |
| 217 | + current_label, |
| 218 | + data[current_label], |
| 219 | + plotter, |
| 220 | + panel=(i, j), |
| 221 | + show_cmap=show_cmap, |
| 222 | + cmap=cmap, |
| 223 | + **plotfuncs_kwargs.get(current_label, {}), |
| 224 | + ) |
271 | 225 |
|
272 |
| - return facet_plot3D( |
273 |
| - data={sp: simulation_result.species_concentration[sp] for sp in species}, |
274 |
| - plotfuncs={sp: plot_single for sp in species}, |
275 |
| - show_cmap=show_cmap, |
276 |
| - cmap=cmap, |
277 |
| - portrait=portrait, |
278 |
| - with_titles=with_titles, |
279 |
| - linked_views=linked_views, |
280 |
| - plotter_kwargs=plotter_kwargs, |
281 |
| - plotfuncs_kwargs={ |
282 |
| - species[i]: {"threshold_value": thresholds[i]} |
283 |
| - for i in range(0, len(species)) |
284 |
| - } |
285 |
| - if thresholds != [] |
286 |
| - else {}, |
287 |
| - ) |
| 226 | + if linked_views: |
| 227 | + plotter.link_views() |
| 228 | + |
| 229 | + return plotter |
288 | 230 |
|
289 | 231 |
|
290 |
| -def concentrations_animation_3D( |
| 232 | +def facet_grid_animate_3D( |
291 | 233 | filename: str,
|
292 |
| - simulation_results: sme.SimulationResultList, |
293 |
| - species: list[str], |
294 |
| - thresholds: list[float] = [], |
| 234 | + data: list[dict[str, np.ndarray]], |
| 235 | + plotfuncs: dict[str, Callable], |
295 | 236 | show_cmap: bool = False,
|
296 | 237 | cmap: str | np.ndarray | pv.LookupTable = "viridis",
|
297 | 238 | portrait: bool = False,
|
298 | 239 | linked_views: bool = True,
|
299 |
| - with_titles: bool = True, |
300 |
| - plotter_kwargs: dict = {"border": False, "notebook": False}, |
| 240 | + titles: list[dict[str, str]] = [], |
| 241 | + plotter_kwargs: dict = {}, |
| 242 | + plotfuncs_kwargs: dict[str, dict[str, Any]] = {}, |
301 | 243 | ) -> str:
|
302 |
| - """Create an .mp4 video of the concentration of a list of species in 3D using pyvista, with one frame being one timestep for each species. In essence, this is an animated version of the plot_species_concentration_3D function. |
303 |
| -
|
304 |
| - Args: |
305 |
| - filename (str): filename to save the video |
306 |
| - simulation_results (sme.SimulationResultList): List of simulation results for each timestep |
307 |
| - species (list[str]): List of species to animate |
308 |
| - thresholds (list[float], optional): Thresholds to limit the plotted values for each species Values larger than the threshold will be cut. Defaults to []. |
309 |
| - show_cmap (bool, optional): Whether the colormap should be shown or not. Defaults to False. |
310 |
| - cmap (str | np.ndarray | pv.LookupTable, optional): colormap to use. Either a string naming a matplotlib colormap, a numpy array of values or a pyvista lookup table mapping values ot rgba colors. Defaults to "viridis". |
311 |
| - portrait (bool, optional): Whether to organize plot grid (n,m) in potrait mode (smaller of (n,m) as columns) or landscape mode (smaller number as rows). Defaults to False. |
312 |
| - linked_views (bool, optional): If all the views should be linked togehter such that perspective changes affect all plots the same. Defaults to True. |
313 |
| - with_titles (bool, optional): Have a title for each plot or not. Defaults to True. |
314 |
| - plotter_kwargs (dict[str, Any], optional): Additional kwargs for pyvista.Plotter. Defaults to {"border": False, "notebook": False}. |
315 |
| -
|
| 244 | + """ |
| 245 | + Create a 3D animation from a series of data snapshots using PyVista. |
| 246 | + This series must be a list of dictionaries with the data for each frame keyed by a label used to title the panel it will be plotted into. The final plot will have as many subplots as there are labels in the data dictionaries. The keys for plotfuncs and data must be the same. |
| 247 | + Parameters: |
| 248 | + ----------- |
| 249 | + filename : str |
| 250 | + The name of the output movie file. |
| 251 | + data : list[dict[str, np.ndarray]] |
| 252 | + A list of dictionaries containing the data for each timestep. |
| 253 | + plotfuncs : dict[str, Callable] |
| 254 | + A dictionary of plotting functions keyed by data label. The keys for plotfuncs and data must be the same. |
| 255 | + show_cmap : bool, optional |
| 256 | + Whether to show the color map (default is False). |
| 257 | + cmap : str | np.ndarray | pv.LookupTable, optional |
| 258 | + The colormap to use (default is "viridis"). |
| 259 | + portrait : bool, optional |
| 260 | + Whether to use portrait layout (default is False). |
| 261 | + linked_views : bool, optional |
| 262 | + Whether to link the views of the subplots (default is True). |
| 263 | + titles : list[dict[str, str]], optional |
| 264 | + A list of dictionaries containing titles for each subplot (default is an empty list). |
| 265 | + plotter_kwargs : dict, optional |
| 266 | + Additional keyword arguments to pass to the PyVista Plotter (default is an empty dictionary). |
| 267 | + plotfuncs_kwargs : dict[str, dict[str, Any]], optional |
| 268 | + Additional keyword arguments to pass to each plotting function (default is an empty dictionary). |
316 | 269 | Returns:
|
317 |
| - str: filename of the saved video |
| 270 | + -------- |
| 271 | + str |
| 272 | + The filename of the created movie. |
318 | 273 | """
|
319 | 274 |
|
320 |
| - def plot_single( |
321 |
| - label: str, |
322 |
| - data: np.ndarray, |
323 |
| - plotter: pv.Plotter, |
324 |
| - panel: tuple[int, int], |
325 |
| - show_cmap, |
326 |
| - cmap=cmap, |
327 |
| - threshold_value=1e6, |
328 |
| - ): |
329 |
| - _data = rgb_to_scalar(data) if len(data.shape) == 4 else data |
| 275 | + if len(titles) > 0 and len(titles) != len(data): |
| 276 | + raise ValueError( |
| 277 | + "The number of titles must be the same as the number of data dictionaries." |
| 278 | + ) |
330 | 279 |
|
331 |
| - img_data = pv.ImageData( |
332 |
| - dimensions=_data.shape, |
| 280 | + if data[0].keys() != plotfuncs.keys(): |
| 281 | + raise ValueError( |
| 282 | + "The keys for the data and plotfuncs dictionaries must be the same." |
333 | 283 | )
|
334 |
| - img_data.point_data["Data"] = _data.flatten() |
335 |
| - img_data = img_data.points_to_cells(scalars="Data") |
336 |
| - |
337 |
| - plotter.subplot(*panel) |
338 |
| - if with_titles: |
339 |
| - plotter.add_text(label, name=label + str(panel)) |
340 |
| - |
341 |
| - actor = plotter.add_mesh( |
342 |
| - img_data.threshold(threshold_value), |
343 |
| - show_edges=True, |
344 |
| - show_scalar_bar=show_cmap, |
345 |
| - cmap=cmap, |
346 |
| - name="mesh" + label + str(panel), |
| 284 | + |
| 285 | + # main function, called for each frame in the movie |
| 286 | + def create_frame( |
| 287 | + data_dict: dict[str, np.ndarray], title: dict[str:str], layout=(1, 1) |
| 288 | + ): |
| 289 | + label = iter(data_dict.keys()) |
| 290 | + for i in range(layout[0]): |
| 291 | + for j in range(layout[1]): |
| 292 | + current_label = next(label) |
| 293 | + plotfuncs[current_label]( |
| 294 | + title.get(current_label, current_label), |
| 295 | + data_dict[current_label], |
| 296 | + plotter, |
| 297 | + panel=(i, j), |
| 298 | + show_cmap=show_cmap, |
| 299 | + cmap=cmap, |
| 300 | + **plotfuncs_kwargs.get(current_label, {}), |
| 301 | + ) |
| 302 | + |
| 303 | + plotter.write_frame() |
| 304 | + |
| 305 | + # preparations |
| 306 | + layout = find_layout(len(plotfuncs), portrait=portrait) |
| 307 | + |
| 308 | + plotter = pv.Plotter(shape=layout, **plotter_kwargs) |
| 309 | + |
| 310 | + plotter.open_movie(filename) |
| 311 | + |
| 312 | + # add first frame here to set up the plotter |
| 313 | + create_frame(data[0], titles[0] if len(titles) > 0 else {}, layout) |
| 314 | + |
| 315 | + if linked_views: |
| 316 | + plotter.link_views() |
| 317 | + |
| 318 | + for i, single_timestep_data in enumerate(data[1::]): |
| 319 | + create_frame( |
| 320 | + single_timestep_data, titles[i] if len(titles) > 0 else {}, layout=layout |
347 | 321 | )
|
348 |
| - actor.mapper.scalar_range = (np.min(_data), np.max(_data)) |
349 |
| - |
350 |
| - return facet_animate3D( |
351 |
| - filename=filename, |
352 |
| - data=[s.species_concentration for s in simulation_results], |
353 |
| - titles=[ |
354 |
| - {sp: f"Concentration of {sp} at t={s.time_point}" for sp in species} |
355 |
| - for s in simulation_results |
356 |
| - ], |
357 |
| - plotfuncs={sp: plot_single for sp in species}, |
358 |
| - show_cmap=show_cmap, |
359 |
| - cmap=cmap, |
360 |
| - portrait=portrait, |
361 |
| - linked_views=linked_views, |
362 |
| - with_titles=with_titles, |
363 |
| - plotter_kwargs=plotter_kwargs, |
364 |
| - plotfuncs_kwargs={ |
365 |
| - species[i]: {"threshold_value": thresholds[i]} |
366 |
| - for i in range(0, len(species)) |
367 |
| - } |
368 |
| - if len(thresholds) > 0 |
369 |
| - else {}, |
370 |
| - ) |
| 322 | + |
| 323 | + plotter.close() |
| 324 | + |
| 325 | + return filename |
0 commit comments