Skip to content

Commit 199a2eb

Browse files
committed
add comments and data
1 parent 946c456 commit 199a2eb

File tree

3 files changed

+214
-3
lines changed

3 files changed

+214
-3
lines changed

pyproject.toml

+5-3
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@ classifiers = [
1919
"Operating System :: Microsoft :: Windows",
2020
"Operating System :: POSIX :: Linux",
2121
"Programming Language :: Python :: 3 :: Only",
22-
"Programming Language :: Python :: 3.7",
23-
"Programming Language :: Python :: 3.8",
2422
"Programming Language :: Python :: 3.9",
2523
"Programming Language :: Python :: 3.10",
2624
"Programming Language :: Python :: 3.11",
25+
"Programming Language :: Python :: 3.13",
26+
"Programming Language :: Python :: 3.14",
27+
2728
]
28-
dependencies = ["matplotlib", "numpy", "pillow", "pyswarms", "sme>=1.4.0"]
29+
dependencies = ["matplotlib", "numpy", "pillow", "pyswarms", "sme>=1.4.0","pyvista[all]"]
30+
2931
dynamic = ["version"]
3032

3133
[project.urls]

src/sme_contrib/plot.py

+26
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
from matplotlib import pyplot as plt
55
from matplotlib.colors import LinearSegmentedColormap as lscmap
66
from matplotlib import animation
7+
from itertools import cycle
8+
import matplotlib.colors as mcolors
9+
10+
from .pyvista_utils import _plot3D, _animate3D
711

812

913
def colormap(color, name="my colormap"):
@@ -28,6 +32,28 @@ def colormap(color, name="my colormap"):
2832
return lscmap.from_list(name, [(0, 0, 0), color], 256)
2933

3034

35+
def make_circular_colormap(
36+
cmap: str = "tab10", values: np.ndarray = np.array([])
37+
) -> list[tuple]:
38+
"""Create a discrete colormap of potentially repeating colors of the same size as the `values` array.
39+
40+
Args:
41+
cmap (str, optional): matplotlib colormap name. Defaults to "tab10".
42+
values (np.array, optional): values to be mapped to colors. Defaults to [].
43+
44+
Returns:
45+
list[tuple]: list of color in rgba format.
46+
"""
47+
cm = [(0.0, 0.0, 0.0, 1.0)]
48+
i = 0
49+
for c in cycle(plt.get_cmap(cmap).colors):
50+
cm.append(mcolors.to_rgba(c))
51+
if len(cm) >= len(values):
52+
break
53+
i += 1
54+
return cm
55+
56+
3157
def concentration_heatmap(
3258
simulation_result, species, z_slice: int = 0, title=None, ax=None, cmap=None
3359
):

src/sme_contrib/pyvista_utils.py

+183
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
import pyvista as pv
2+
import numpy as np
3+
from typing import Callable
4+
5+
6+
def rgb_to_scalar(img: np.ndarray) -> np.ndarray:
7+
"""Convert an array of RGB values to scalar values.
8+
This function is necessary because pyvista does not support RGB values directly as mesh data
9+
10+
Args:
11+
img (np.ndarray): data to be converted, of shape (n, m, 3)
12+
13+
Returns:
14+
np.ndarray: data converted to scalar values, of shape (n, m)
15+
"""
16+
reshaped = img.reshape(-1, 3, copy=True)
17+
unique_rgb, ridx = np.unique(reshaped, axis=0, return_inverse=True)
18+
19+
values = np.arange(len(unique_rgb))
20+
return values[ridx].reshape(img.shape[:-1])
21+
22+
23+
def _find_good_shape(num_plots: int, portrait: bool = False) -> tuple[int, int]:
24+
"""Find a good shape (rows, columns) for a grid of plots which should be such that
25+
rows*columns >= num_plots and rows is as close to columns as possible and rows*columns is minimal.
26+
There are sophisticated ways to do this, which are way beyond what is needed here, so a simple heuristic based on sqrt(num_plots) is used.
27+
Args:
28+
num_plots (int): number of plots to distribute
29+
portrait (bool, optional): whether the plots should be in portrait mode. If yes, the rows will become the larger number and cols the smaller, otherwise, it will be the other way round. Defaults to False
30+
31+
Returns:
32+
tuple[int, int]: shape of the grid (rows, columns)
33+
"""
34+
root = np.sqrt(num_plots)
35+
root_int = np.rint(root)
36+
37+
a = int(np.floor(root))
38+
b = int(np.ceil(root))
39+
40+
a_1 = int(a - 1)
41+
b_1 = int(b + 1)
42+
guesses = [
43+
(x, y)
44+
for x, y in [
45+
(a, b),
46+
(a_1, b_1),
47+
(a, b_1),
48+
(a_1, b),
49+
]
50+
if x * y >= num_plots
51+
]
52+
best_guess = guesses[np.argmin([x * y for x, y in guesses])]
53+
54+
if np.isclose(root, root_int):
55+
return int(root_int), int(root_int)
56+
elif best_guess[0] * best_guess[1] >= num_plots:
57+
return (
58+
(np.min(best_guess), np.max(best_guess))
59+
if not portrait
60+
else (np.max(best_guess), np.min(best_guess))
61+
)
62+
else:
63+
return b, b
64+
65+
66+
def _animate_with_func3D_facet(
67+
filename: str,
68+
data: dict[str, list[np.ndarray]],
69+
titles: dict[str, list[str]],
70+
plotfuncs: dict[str, Callable],
71+
show_cmap: bool = False,
72+
cmap: str | np.ndarray = "viridis",
73+
portrait: bool = False,
74+
linked_views: bool = True,
75+
with_titles: bool = True,
76+
plotter_kwargs: dict = {},
77+
) -> str:
78+
"""Animate a set of data with a corresponding set of plot functions.
79+
The The list of data for each plot must be of equal length. Each element
80+
of these lists will be one frame in the animation.
81+
The result is stored as an .mp4 file.
82+
Args:
83+
filename (str): A filename or path to store the animation. This function automatically adds an .mp4 extension to the given filename.
84+
data (dict[str, list[np.ndarray]]): The data to animate in each suplot. The keys are the labels for the plots and the values are lists of data to plot. Each element of these lists will be one frame in the animation. The lists must be of equal length.
85+
titles (dict[str, list[str]]): The titles for each subplot. The keys are the labels for the plots and the values are lists of titles to display. Each element of these lists will be the title for the corresponding frame in the animation. Useful to have timestep labels for instance.
86+
plotfuncs (dict[str, Callable]): Function to write the data into a frame in the plotter. Signature: func(data, plotter, show_cmap:bool = show_cmap, cmap=cmap)
87+
show_cmap (bool, optional): whether to show the colormap in each plot or not. Defaults to False.
88+
cmap (str | np.ndarray, optional): matplotlib colormap name. Defaults to "viridis".
89+
portrait (bool, optional): Aspect ratio mode. if this is true, the larger dimension of the plot table will be the rows. Otherwise, the larger dimension will be the columns. Defaults to False.
90+
linked_views (bool, optional): Whether to link all the views together for interactive plotting. If true, they will move in unison if one is moved. Defaults to True.
91+
with_titles (bool, optional): If the labels should be used as titles. Defaults to True.
92+
plotter_kwargs (dict, optional): Other keyword arguments passed to the plotter constructor. Defaults to {}.
93+
94+
Returns:
95+
str: path to where the given animation is stored.
96+
"""
97+
98+
def create_frame(label):
99+
for i in range(layout[0]):
100+
for j in range(layout[1]):
101+
plotter.subplot(i, j)
102+
103+
current_label = next(label)
104+
105+
if with_titles:
106+
plotter.add_text(titles[current_label][0])
107+
108+
plotfuncs[current_label](
109+
data[current_label][0], plotter, show_cmap=show_cmap, cmap=cmap
110+
)
111+
112+
plotter.write_frame()
113+
114+
layout = _find_good_shape(len(data), portrait=portrait)
115+
116+
plotter = pv.Plotter(shape=layout, **plotter_kwargs)
117+
118+
plotter.open_movie(filename)
119+
120+
label = iter(data.keys())
121+
122+
create_frame(label)
123+
124+
if linked_views:
125+
plotter.link_views()
126+
127+
current_label = next(iter(data.keys()))
128+
129+
for i in range(1, len(data[current_label])):
130+
label = iter(data.keys())
131+
create_frame(label)
132+
133+
plotter.close()
134+
135+
return filename
136+
137+
138+
def _plot3Dfacet(
139+
data: dict[str, np.ndarray],
140+
plotfuncs: dict[str, Callable],
141+
show_cmap: bool = False,
142+
cmap: str | np.ndarray = "viridis",
143+
portrait: bool = False,
144+
linked_views: bool = True,
145+
with_titles: bool = True,
146+
plotter_kwargs: dict = {},
147+
) -> pv.Plotter:
148+
"""Plot a set of data with a corresponding set of plot functions into a grid of subplots, similar to seaborn facetplots.
149+
150+
Args:
151+
data (dict[str, np.ndarray]): Dictionary of data to plot. The keys are the labels for the plots and the values are the data to plot.
152+
plotfuncs (dict[str, Callable]): Functions that take the data and a pyvista plotter object and plot the data into the each subplot.
153+
show_cmap (bool, optional): whether to show the colormap in each plot or not. Defaults to False.
154+
cmap (str | np.ndarray, optional): matplotlib colormap name. Defaults to "viridis".
155+
portrait (bool, optional): Aspect ratio mode. if this is true, the larger dimension of the plot table will be the rows. Otherwise, the larger dimension will be the columns. Defaults to False.
156+
linked_views (bool, optional): Whether to link all the views together for interactive plotting. If true, they will move in unison if one is moved. Defaults to True.
157+
with_titles (bool, optional): If the labels should be used as titles. Defaults to True.
158+
plotter_kwargs (dict, optional): Other keyword arguments passed to the plotter constructor. Defaults to {}.
159+
160+
Returns:
161+
pv.Plotter: pyvista plotter object. Call plotter.show() to display the plot.
162+
"""
163+
164+
layout = _find_good_shape(len(data), portrait=portrait)
165+
166+
plotter = pv.Plotter(shape=layout, **plotter_kwargs)
167+
168+
label = iter(data.keys())
169+
170+
for i in range(layout[0]):
171+
for j in range(layout[1]):
172+
plotter.subplot(i, j)
173+
current_label = next(label)
174+
if with_titles:
175+
plotter.add_text(current_label)
176+
plotfuncs[current_label](
177+
data[current_label], plotter, show_cmap=show_cmap, cmap=cmap
178+
)
179+
180+
if linked_views:
181+
plotter.link_views()
182+
183+
return plotter

0 commit comments

Comments
 (0)