Skip to content

Commit 4d92d5c

Browse files
committed
fix import issue
1 parent 76a8f85 commit 4d92d5c

File tree

2 files changed

+70
-41
lines changed

2 files changed

+70
-41
lines changed

docs/notebooks/plot.ipynb

Lines changed: 52 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,19 @@
1515
},
1616
"outputs": [],
1717
"source": [
18-
"!pip install -q sme_contrib\n",
19-
"import sme\n",
20-
"import sme_contrib.plot as smeplot\n",
21-
"from matplotlib import pyplot as plt\n",
18+
"# !pip install -q sme_contrib\n",
2219
"import pyvista as pv\n",
2320
"from pyvista import examples\n",
21+
"\n",
22+
"pv.set_jupyter_backend(\"static\")\n",
23+
"\n",
24+
"import sme\n",
25+
"import sme_contrib.plot as smeplot\n",
2426
"import numpy as np\n",
2527
"import tempfile\n",
2628
"from IPython.display import HTML\n",
27-
"from IPython.display import Video"
28-
]
29-
},
30-
{
31-
"cell_type": "code",
32-
"execution_count": null,
33-
"metadata": {},
34-
"outputs": [],
35-
"source": [
36-
"pv.set_jupyter_backend(\"static\")"
29+
"from IPython.display import Video\n",
30+
"from matplotlib import pyplot as plt"
3731
]
3832
},
3933
{
@@ -217,6 +211,15 @@
217211
"- `concentrations3D` and `concentrationsAnimate3D`: These high-level API directly uses `sme.SimulationResult` objects as data input, but only plots concentrations by default. These are wrappers around the low-level functions that provide default plotting functions for each pane and handle the data preparation for each pane automatically."
218212
]
219213
},
214+
{
215+
"cell_type": "markdown",
216+
"metadata": {},
217+
"source": [
218+
"**README:** Using `import pyvista as pv` in both a notebook and a module imported in the notebook will cause a premature initialization of `pyvista` when the imported module is initialized and in turn a configuration conflict that will prevent the notebook from working properly. You thus have to make sure that the `pyvista` module is only imported in one of the two. Since notebooks often have specific requirements for how the plots can be shown (statically or interactively) it is best to only import the needed objects from pyvista when developing a module and only importing the package-level module (`import pyvista...`) within notebooks that use this module. \n",
219+
"\n",
220+
"If a complete import of pyvista in a module is unavoidable, make sure to only import the called functions in the notebook and not the entire module, to prevent pyvista's `__init__` from running prematurely."
221+
]
222+
},
220223
{
221224
"cell_type": "code",
222225
"execution_count": null,
@@ -227,8 +230,7 @@
227230
"model = sme.open_sbml_file(model_file)\n",
228231
"results = model.simulate(500, 10)\n",
229232
"\n",
230-
"species = list(results[0].species_concentration.keys())\n",
231-
"species"
233+
"species = list(results[0].species_concentration.keys())"
232234
]
233235
},
234236
{
@@ -249,7 +251,15 @@
249251
" }\n",
250252
"\n",
251253
"\n",
252-
"datasets = exampledata()\n",
254+
"datasets = exampledata()"
255+
]
256+
},
257+
{
258+
"cell_type": "code",
259+
"execution_count": null,
260+
"metadata": {},
261+
"outputs": [],
262+
"source": [
253263
"tempdir = tempfile.TemporaryDirectory()"
254264
]
255265
},
@@ -322,7 +332,9 @@
322332
" \"brain\": plot_brain,\n",
323333
" },\n",
324334
" linked_views=False,\n",
325-
")"
335+
")\n",
336+
"\n",
337+
"facetgrid.show()"
326338
]
327339
},
328340
{
@@ -393,6 +405,15 @@
393405
")"
394406
]
395407
},
408+
{
409+
"cell_type": "code",
410+
"execution_count": null,
411+
"metadata": {},
412+
"outputs": [],
413+
"source": [
414+
"Video(vidpath, embed=True, width=800, height=600)"
415+
]
416+
},
396417
{
397418
"cell_type": "markdown",
398419
"metadata": {},
@@ -413,12 +434,12 @@
413434
"metadata": {},
414435
"outputs": [],
415436
"source": [
416-
"facetgrid = smeplot.concentrations3D(\n",
437+
"smeplot.concentrations3D(\n",
417438
" simulation_result=results[10],\n",
418439
" species=[\"A_nucl\"],\n",
419440
" cmap=\"tab10\",\n",
420441
" show_cmap=True,\n",
421-
")"
442+
").show()"
422443
]
423444
},
424445
{
@@ -442,6 +463,15 @@
442463
")"
443464
]
444465
},
466+
{
467+
"cell_type": "code",
468+
"execution_count": null,
469+
"metadata": {},
470+
"outputs": [],
471+
"source": [
472+
"Video(vidpath, embed=True, width=800, height=600)"
473+
]
474+
},
445475
{
446476
"cell_type": "code",
447477
"execution_count": null,
@@ -457,7 +487,7 @@
457487
"provenance": []
458488
},
459489
"kernelspec": {
460-
"display_name": ".venv39",
490+
"display_name": ".venv313",
461491
"language": "python",
462492
"name": "python3"
463493
},
@@ -471,7 +501,7 @@
471501
"name": "python",
472502
"nbconvert_exporter": "python",
473503
"pygments_lexer": "ipython3",
474-
"version": "3.9.21"
504+
"version": "3.13.2"
475505
}
476506
},
477507
"nbformat": 4,

src/sme_contrib/plot.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,13 @@
44
from matplotlib import pyplot as plt
55
from matplotlib.colors import LinearSegmentedColormap as lscmap
66
from matplotlib import animation
7-
import pyvista as pv
7+
from pyvista import Plotter, LookupTable
88
from typing import Any, Callable, Union
99
import sme
1010
from pathlib import Path
1111

1212
from .pyvista_utils import (
1313
find_layout,
14-
make_discrete_colormap,
1514
)
1615

1716

@@ -135,29 +134,29 @@ def facet_grid_3D(
135134
data: dict[str, np.ndarray],
136135
plotfuncs: dict[str, Callable],
137136
show_cmap: bool = False,
138-
cmap: Union[str, np.ndarray, pv.LookupTable] = "viridis",
137+
cmap: Union[str, np.ndarray, LookupTable] = "viridis",
139138
portrait: bool = False,
140139
linked_views: bool = True,
141140
plotter_kwargs: Union[dict, None] = None,
142141
plotfuncs_kwargs: Union[dict[str, dict[str, Any]], None] = None,
143-
) -> pv.Plotter:
142+
) -> Plotter:
144143
"""
145144
Create a 3D facet plot using PyVista.
146145
147146
This follows the seaborn.FacetGrid concept. This function creates a grid of subplots where each subplot is filled by a function in the plotfuncs argument. The keys for plotfuncs and data must be the same, such that plotfuncs can be unambiguously mapped over the data dictionary. Do not attempt to plot 2D images and 3D images into the same facet grid, as this will create odd artifacts and may not work as expected.
148147
149148
Args:
150149
data : (dict[str, np.ndarray]) A dictionary where keys are labels and values are numpy arrays containing the data to be plotted.
151-
plotfuncs : (dict[str, Callable]) A dictionary where keys are labels and values are functions with signature ``f(label:str, data:np.ndarray | pyvista.ImageData | pyvista.UniformGrid, plotter:pv.Plotter, panel:tuple[int, int], show_cmap:bool=show_cmap, cmap=cmap, **plotfuncs_kwargs )`` -> None
150+
plotfuncs : (dict[str, Callable]) A dictionary where keys are labels and values are functions with signature ``f(label:str, data:np.ndarray | pyvista.ImageData | pyvista.UniformGrid, plotter:Plotter, panel:tuple[int, int], show_cmap:bool=show_cmap, cmap=cmap, **plotfuncs_kwargs )`` -> None
152151
show_cmap : bool, optional Whether to show the color map. Default is False.
153-
cmap : (str | np.ndarray | pv.LookupTable), optional The color map to use. Default is "viridis".
152+
cmap : (str | np.ndarray | LookupTable), optional The color map to use. Default is "viridis".
154153
portrait : (bool), optional Whether to use a portrait layout. Default is False.
155154
linked_views : (bool), optional Whether to link the views of the subplots. Default is True.
156155
plotter_kwargs : (dict, optional) Additional keyword arguments to pass to the PyVista Plotter.
157156
plotfuncs_kwargs : (dict[str, dict[str, Any]]), optional Additional keyword arguments to pass to each plotting function.
158157
159158
Returns:
160-
pv.Plotter The PyVista Plotter object with the created facet plot.
159+
Plotter The PyVista Plotter object with the created facet plot.
161160
"""
162161
if data.keys() != plotfuncs.keys():
163162
raise ValueError(
@@ -166,7 +165,7 @@ def facet_grid_3D(
166165

167166
layout = find_layout(len(data), portrait=portrait)
168167

169-
plotter = pv.Plotter(
168+
plotter = Plotter(
170169
shape=layout, **(plotter_kwargs if plotter_kwargs is not None else {})
171170
)
172171

@@ -200,7 +199,7 @@ def facet_grid_animate_3D(
200199
data: list[dict[str, np.ndarray]],
201200
plotfuncs: dict[str, Callable],
202201
show_cmap: bool = False,
203-
cmap: Union[str, np.ndarray, pv.LookupTable] = "viridis",
202+
cmap: Union[str, np.ndarray, LookupTable] = "viridis",
204203
portrait: bool = False,
205204
linked_views: bool = True,
206205
titles: Union[list[dict[str, str]], None] = None,
@@ -217,7 +216,7 @@ def facet_grid_animate_3D(
217216
data : (list[dict[str, np.ndarray]]) A list of dictionaries containing the data for each timestep.
218217
plotfuncs : (dict[str, Callable]) A dictionary of plotting functions keyed by data label. The keys for plotfuncs and data must be the same.
219218
show_cmap : (bool), optional Whether to show the color map (default is False).
220-
cmap : (str | np.ndarray | pv.LookupTable, optional) The colormap to use (default is "viridis").
219+
cmap : (str | np.ndarray | LookupTable, optional) The colormap to use (default is "viridis").
221220
portrait : (bool), optional Whether to use portrait layout (default is False).
222221
linked_views : (bool), optional Whether to link the views of the subplots (default is True).
223222
titles : (list[dict[str, str]]), optional A list of dictionaries containing titles for each subplot (default is an empty list).
@@ -265,7 +264,7 @@ def create_frame(
265264
# preparations
266265
layout = find_layout(len(plotfuncs), portrait=portrait)
267266

268-
plotter = pv.Plotter(
267+
plotter = Plotter(
269268
shape=layout, **plotter_kwargs if plotter_kwargs is not None else {}
270269
)
271270

@@ -290,19 +289,19 @@ def create_frame(
290289
def concentrations3D(
291290
simulation_result: sme.SimulationResult,
292291
species: list[str],
293-
cmap: Union[str, np.ndarray, pv.LookupTable] = "viridis",
292+
cmap: Union[str, np.ndarray, LookupTable] = "viridis",
294293
show_cmap: bool = False,
295294
plotter_kwargs: Union[None, dict[str, Any]] = None,
296295
plotfunc_kwargs: Union[None, dict[str, Any]] = None,
297-
) -> pv.Plotter:
296+
) -> Plotter:
298297
"""Plot a 3D facet grid of species concentrations.
299298
This function creates a 3D facet grid of species concentrations. Each panel will be a 3D plot of the concentration of a single species.
300299
This function is a wrapper around the facet_grid_3D function.
301300
302301
Args:
303302
simulation_result (sme.SimulationResult): a single simulation result object, i.e., a single recorded frame of the simulations
304303
species (list[str]): A list of species strings
305-
cmap (str | np.ndarray | pv.LookupTable, optional): Name of a matplotlib colorbar. Defaults to "viridis".
304+
cmap (str | np.ndarray | LookupTable, optional): Name of a matplotlib colorbar. Defaults to "viridis".
306305
show_cmap (bool, optional): Whether or not to show the colorbar on the plot. Defaults to False.
307306
plotter_kwargs (dict[str, Any], optional): Additional keyword arguments for the used pyVista.Plotter. Defaults to None.
308307
plotfunc_kwargs (dict[str, Any], optional): Additional keyword arguments passed to plotter.add_mesh. Defaults to None.
@@ -311,7 +310,7 @@ def concentrations3D(
311310
ValueError: if a given species is not found in the simulation result
312311
313312
Returns:
314-
pv.Plotter: pyvista.Plotter object the data has been plotted into
313+
Plotter: pyvista.Plotter object the data has been plotted into
315314
"""
316315
# turn the simulation result into numpy ndarray
317316
datadict = {}
@@ -328,10 +327,10 @@ def concentrations3D(
328327
def plotfunc(
329328
label: str,
330329
data: np.ndarray,
331-
plotter: pv.Plotter,
330+
plotter: Plotter,
332331
panel: tuple[int, int],
333332
show_cmap: bool,
334-
cmap: Union[str, np.ndarray, pv.LookupTable],
333+
cmap: Union[str, np.ndarray, LookupTable],
335334
**kwargs: dict[str, Any],
336335
):
337336
# create a pyvista grid
@@ -369,7 +368,7 @@ def concentrationsAnimate3D(
369368
simulation_results: sme.SimulationResultList,
370369
species: list[str],
371370
show_cmap: bool = False,
372-
cmap: Union[str, np.ndarray, pv.LookupTable] = "viridis",
371+
cmap: Union[str, np.ndarray, LookupTable] = "viridis",
373372
portrait: bool = False,
374373
titles: Union[list[dict[str, str]], None] = None,
375374
linked_views: bool = True,
@@ -386,7 +385,7 @@ def concentrationsAnimate3D(
386385
simulation_results (sme.SimulationResultList): a list of `SimulationResult` objects, i.e., a list of recorded frames of the simulations
387386
species (list[str]): list of species to plot
388387
show_cmap (bool, optional): Whether to show the colorbar on theplots or not. Defaults to False.
389-
cmap (Union[str, np.ndarray, pv.LookupTable], optional): name of matplotlib colormap or custom colormap that maps scalar values to rbp. Defaults to "viridis".
388+
cmap (Union[str, np.ndarray, LookupTable], optional): name of matplotlib colormap or custom colormap that maps scalar values to rbp. Defaults to "viridis".
390389
portrait (bool, optional): Whether to use the smaller or larger number of plots as rows. Defaults to False.
391390
titles (Union[list[dict[str, str]], None], optional): Titles of the different plots if not just the species name is desired. Defaults to None.
392391
linked_views (bool, optional): link the view cameras. Defaults to True.

0 commit comments

Comments
 (0)