Skip to content

Commit a6eeab8

Browse files
committed
add tests for pyvista plotting utils
1 parent 090f954 commit a6eeab8

File tree

3 files changed

+193
-37
lines changed

3 files changed

+193
-37
lines changed

src/sme_contrib/pyvista_utils.py

+63-37
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
def rgb_to_scalar(img: np.ndarray) -> np.ndarray:
1111
"""
12-
Convert an RGB image 3D image to a scalar image where each unique RGB value is assigned a unique scalar.
12+
Convert an RGB 3D image represented as a 4D tensor to a 3D image tensor where each unique RGB value is assigned a unique scalar, i.e., it contracts the dimension with the RGB values into scalars. This is useful because PyVista doesn't work well with RGB values directly and expects fields defined on a grid, usually given by the tensor shape.
1313
1414
img (np.ndarray): A 3D numpy array representing an RGB image with shape (height, width, 3).
1515
@@ -34,21 +34,24 @@ def make_discrete_colormap(
3434
values (np.ndarray): An array of values to map to colors. Default is an empty array.
3535
3636
Returns:
37-
pv.LookupTable: A PyVista LookupTable object with the specified colormap and values.
37+
pv.LookupTable: A PyVista LookupTable object with the values drawn from the specified colormap in RGBA format.
3838
"""
39-
cm = [(0, 0, 0, 1)]
40-
i = 0
41-
42-
if values == []:
43-
values = np.arange(len(cm))
44-
for c in cycle(plt.get_cmap(cmap).colors):
45-
cm.append(mcolors.to_rgba(c))
46-
if len(cm) >= len(values):
47-
break
48-
i += 1
39+
cm = []
4940

41+
if values.size == 0:
42+
values = np.arange(0, 1, 1)
43+
cm = [
44+
mcolors.to_rgba(plt.get_cmap(cmap).colors[0]),
45+
]
46+
else:
47+
i = 0
48+
for c in cycle(plt.get_cmap(cmap).colors):
49+
cm.append(mcolors.to_rgba(c))
50+
if len(cm) >= len(values):
51+
break
52+
i += 1
5053
lt = pv.LookupTable(
51-
values=np.array(cm * 255),
54+
values=np.array(cm) * 255,
5255
scalar_range=(0, len(values)),
5356
n_values=len(values),
5457
)
@@ -66,39 +69,48 @@ def find_layout(num_plots: int, portrait: bool = False) -> tuple[int, int]:
6669
Returns:
6770
tuple[int, int]: Tuple describing (n_rows, n_cols) of the grid
6871
"""
72+
73+
# for checking approximation accuracy with ints. if root > root_int, then
74+
# we need to adjust n_row, n_cols sucht that n_row * n_cols >= root^2
6975
root = np.sqrt(num_plots)
7076
root_int = np.rint(root)
7177

72-
a = int(np.floor(root))
73-
b = int(np.ceil(root))
74-
75-
a_1 = int(a - 1)
76-
b_1 = int(b + 1)
77-
guesses = [
78-
(x, y)
79-
for x, y in [
80-
(a, b),
81-
(a_1, b_1),
82-
(a, b_1),
83-
(a_1, b),
78+
if np.isclose(root, root_int):
79+
return int(root_int), int(root_int) # perfect square because root is an integer
80+
else:
81+
# approximation by integer root is inexact
82+
83+
# find an approximation that is close to square such that n_row * n_cols - num_plots is
84+
# as small as possible
85+
a = int(np.floor(root))
86+
b = int(np.ceil(root))
87+
88+
a_1 = int(a - 1)
89+
b_1 = int(b + 1)
90+
91+
guesses = [
92+
(x, y)
93+
for x, y in [
94+
(a, b),
95+
(a_1, b_1),
96+
(a, b_1),
97+
(a_1, b),
98+
]
99+
if x * y >= num_plots
84100
]
85-
if x * y >= num_plots
86-
]
87-
best_guess = guesses[np.argmin([x * y for x, y in guesses])]
101+
best_guess = guesses[
102+
np.argmin([x * y for x, y in guesses])
103+
] # smallest possible approximation
88104

89-
if np.isclose(root, root_int):
90-
return int(root_int), int(root_int)
91-
elif best_guess[0] * best_guess[1] >= num_plots:
105+
# handle orientation of the grid. min => rows for landscape, min=> cols for portrait
92106
return (
93107
(np.min(best_guess), np.max(best_guess))
94108
if not portrait
95109
else (np.max(best_guess), np.min(best_guess))
96110
)
97-
else:
98-
return b, b
99111

100112

101-
def facet_plot3D(
113+
def facet_grid(
102114
data: dict[str, np.ndarray],
103115
plotfuncs: dict[str, Callable],
104116
show_cmap: bool = False,
@@ -110,14 +122,23 @@ def facet_plot3D(
110122
plotfuncs_kwargs: dict[str, dict[str, Any]] = {},
111123
) -> pv.Plotter:
112124
"""
113-
Create a 3D facet plot using PyVista. This function creates a grid of subplots where each subplot is filled by a function in the plotfuncs argument. The keys for plotfuncs and data must be the same, such that plotfuncs can be unambiguously mapped over the data dictionary.
114-
125+
Create a 3D facet plot using PyVista. This follows the seaborn.FacetGrid concept. This function creates a grid of subplots where each subplot is filled by a function in the plotfuncs argument. The keys for plotfuncs and data must be the same, such that plotfuncs can be unambiguously mapped over the data dictionary.
126+
Do not attempt to plot 2D images and 3D images into the same facet grid, as this will create odd artifacts and
127+
may not work as expected.
115128
Parameters:
116129
-----------
117130
data : dict[str, np.ndarray]
118131
A dictionary where keys are labels and values are numpy arrays containing the data to be plotted.
119132
plotfuncs : dict[str, Callable]
120-
A dictionary where keys are labels and values are functions that take the label, data, plotter, and other optional arguments to create the plot.
133+
A dictionary where keys are labels and values are functions with signature f(
134+
label:str,
135+
data:np.ndarray | pyvista.ImageData | pyvista.UniformGrid,
136+
plotter:pv.Plotter,
137+
panel:tuple[int, int],
138+
show_cmap:bool=show_cmap,
139+
cmap=cmap,
140+
**plotfuncs_kwargs
141+
) -> None
121142
show_cmap : bool, optional
122143
Whether to show the color map. Default is False.
123144
cmap : str | np.ndarray | pv.LookupTable, optional
@@ -138,6 +159,11 @@ def facet_plot3D(
138159
pv.Plotter
139160
The PyVista Plotter object with the created facet plot.
140161
"""
162+
if data.keys() != plotfuncs.keys():
163+
raise ValueError(
164+
"The keys for the data and plotfuncs dictionaries must be the same."
165+
)
166+
141167
layout = find_layout(len(data), portrait=portrait)
142168

143169
plotter = pv.Plotter(shape=layout, **plotter_kwargs)

tests/conftest.py

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import pytest
2+
from pyvista import examples
3+
4+
5+
@pytest.fixture(scope="session")
6+
def exampledata():
7+
armadillo = examples.download_armadillo()
8+
bloodvessel = examples.download_blood_vessels()
9+
brain = examples.download_brain()
10+
11+
return {
12+
"armadillo": armadillo,
13+
"bloodvessel": bloodvessel,
14+
"brain": brain,
15+
}

tests/test_pyvista_utils.py

+115
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import sme_contrib.pyvista_utils as pvu
2+
import numpy as np
3+
import matplotlib.pyplot as plt
4+
import matplotlib.colors as mcolors
5+
import pytest
6+
7+
8+
def test_rgb_to_scalar():
9+
img = np.array(
10+
[
11+
[[[0, 0, 0], [255, 255, 255]], [[0, 0, 0], [255, 255, 255]]],
12+
[[[0, 0, 0], [255, 255, 255]], [[0, 0, 0], [255, 255, 255]]],
13+
[[[0, 0, 0], [255, 255, 255]], [[0, 0, 0], [255, 255, 255]]],
14+
]
15+
)
16+
scalar_img = pvu.rgb_to_scalar(img)
17+
assert scalar_img.shape == (3, 2, 2)
18+
assert np.all(
19+
scalar_img == np.array([[[0, 1], [0, 1]], [[0, 1], [0, 1]], [[0, 1], [0, 1]]])
20+
)
21+
22+
23+
def test_make_discrete_colormap():
24+
lt = pvu.make_discrete_colormap()
25+
cm = plt.get_cmap("tab10").colors
26+
should = (np.array([mcolors.to_rgba(cm[0])]) * 255).astype(np.int32)
27+
assert lt.n_values == 1
28+
assert lt.scalar_range == (0, 1)
29+
assert np.all(lt.values == should)
30+
31+
lt = pvu.make_discrete_colormap("tab20", np.array([0, 1, 2, 3]))
32+
assert lt.n_values == 4
33+
assert lt.scalar_range == (0, 4)
34+
cm = plt.get_cmap("tab20").colors
35+
should = (
36+
np.array(
37+
[
38+
mcolors.to_rgba(cm[0]),
39+
mcolors.to_rgba(cm[1]),
40+
mcolors.to_rgba(cm[2]),
41+
mcolors.to_rgba(cm[3]),
42+
]
43+
)
44+
* 255
45+
).astype(np.int32)
46+
assert np.all(lt.values == should)
47+
48+
49+
def test_find_layout():
50+
assert pvu.find_layout(1) == (1, 1)
51+
assert pvu.find_layout(3) == (1, 3)
52+
assert pvu.find_layout(5) == (2, 3)
53+
assert pvu.find_layout(8) == (2, 4)
54+
assert pvu.find_layout(10) == (2, 5)
55+
assert pvu.find_layout(15) == (3, 5)
56+
assert pvu.find_layout(15, portrait=True) == (5, 3)
57+
assert pvu.find_layout(16) == (4, 4)
58+
assert pvu.find_layout(19) == (4, 5)
59+
assert pvu.find_layout(23) == (4, 6)
60+
assert pvu.find_layout(25) == (5, 5)
61+
assert pvu.find_layout(27) == (4, 7)
62+
assert pvu.find_layout(29) == (5, 6)
63+
assert pvu.find_layout(31) == (5, 7)
64+
assert pvu.find_layout(31, portrait=True) == (7, 5)
65+
66+
67+
def test_facet_grid_3D(exampledata):
68+
def plot_bloodvessel(label, data, plotter, panel, **kwargs):
69+
plotter.subplot(*panel)
70+
plotter.add_mesh(data)
71+
72+
def plot_brain(label, data, plotter, panel, **kwargs):
73+
plotter.subplot(*panel)
74+
plotter.add_volume(
75+
data,
76+
cmap="viridis",
77+
opacity="sigmoid", # Common opacity mapping for volume rendering
78+
shade=True,
79+
ambient=0.3,
80+
diffuse=0.6,
81+
specular=0.5,
82+
)
83+
84+
def plot_armadillo(label, data, plotter, panel, **kwargs):
85+
plotter.subplot(*panel)
86+
plotter.add_mesh(data)
87+
88+
facetgrid = pvu.facet_grid(
89+
data={
90+
"armadillo": exampledata["armadillo"],
91+
"bloodvessel": exampledata["bloodvessel"],
92+
"brain": exampledata["brain"],
93+
},
94+
plotfuncs={
95+
"armadillo": plot_armadillo,
96+
"bloodvessel": plot_bloodvessel,
97+
"brain": plot_brain,
98+
},
99+
)
100+
101+
assert facetgrid.shape == (1, 3)
102+
103+
with pytest.raises(ValueError):
104+
pvu.facet_grid(
105+
data={
106+
"armadillo": exampledata["armadillo"],
107+
"bloodvessel": exampledata["bloodvessel"],
108+
"brain": exampledata["brain"],
109+
},
110+
plotfuncs={
111+
"armadillo": plot_armadillo,
112+
"bloodvessel": plot_bloodvessel,
113+
"wrong_key": plot_brain,
114+
},
115+
)

0 commit comments

Comments
 (0)