|
| 1 | +import pyvista as pv |
| 2 | +import numpy as np |
| 3 | +from typing import Callable |
| 4 | + |
| 5 | + |
| 6 | +def rgb_to_scalar(img: np.ndarray) -> np.ndarray: |
| 7 | + """Convert an array of RGB values to scalar values. |
| 8 | + This function is necessary because pyvista does not support RGB values directly as mesh data |
| 9 | +
|
| 10 | + Args: |
| 11 | + img (np.ndarray): data to be converted, of shape (n, m, 3) |
| 12 | +
|
| 13 | + Returns: |
| 14 | + np.ndarray: data converted to scalar values, of shape (n, m) |
| 15 | + """ |
| 16 | + reshaped = img.reshape(-1, 3, copy=True) |
| 17 | + unique_rgb, ridx = np.unique(reshaped, axis=0, return_inverse=True) |
| 18 | + |
| 19 | + values = np.arange(len(unique_rgb)) |
| 20 | + return values[ridx].reshape(img.shape[:-1]) |
| 21 | + |
| 22 | + |
| 23 | +def _find_good_shape(num_plots: int, portrait: bool = False) -> tuple[int, int]: |
| 24 | + """Find a good shape (rows, columns) for a grid of plots which should be such that |
| 25 | + rows*columns >= num_plots and rows is as close to columns as possible and rows*columns is minimal. |
| 26 | + There are sophisticated ways to do this, which are way beyond what is needed here, so a simple heuristic based on sqrt(num_plots) is used. |
| 27 | + Args: |
| 28 | + num_plots (int): number of plots to distribute |
| 29 | + portrait (bool, optional): whether the plots should be in portrait mode. If yes, the rows will become the larger number and cols the smaller, otherwise, it will be the other way round. Defaults to False |
| 30 | +
|
| 31 | + Returns: |
| 32 | + tuple[int, int]: shape of the grid (rows, columns) |
| 33 | + """ |
| 34 | + root = np.sqrt(num_plots) |
| 35 | + root_int = np.rint(root) |
| 36 | + |
| 37 | + a = int(np.floor(root)) |
| 38 | + b = int(np.ceil(root)) |
| 39 | + |
| 40 | + a_1 = int(a - 1) |
| 41 | + b_1 = int(b + 1) |
| 42 | + guesses = [ |
| 43 | + (x, y) |
| 44 | + for x, y in [ |
| 45 | + (a, b), |
| 46 | + (a_1, b_1), |
| 47 | + (a, b_1), |
| 48 | + (a_1, b), |
| 49 | + ] |
| 50 | + if x * y >= num_plots |
| 51 | + ] |
| 52 | + best_guess = guesses[np.argmin([x * y for x, y in guesses])] |
| 53 | + |
| 54 | + if np.isclose(root, root_int): |
| 55 | + return int(root_int), int(root_int) |
| 56 | + elif best_guess[0] * best_guess[1] >= num_plots: |
| 57 | + return ( |
| 58 | + (np.min(best_guess), np.max(best_guess)) |
| 59 | + if not portrait |
| 60 | + else (np.max(best_guess), np.min(best_guess)) |
| 61 | + ) |
| 62 | + else: |
| 63 | + return b, b |
| 64 | + |
| 65 | + |
| 66 | +def _animate_with_func3D_facet( |
| 67 | + filename: str, |
| 68 | + data: dict[str, list[np.ndarray]], |
| 69 | + titles: dict[str, list[str]], |
| 70 | + plotfuncs: dict[str, Callable], |
| 71 | + show_cmap: bool = False, |
| 72 | + cmap: str | np.ndarray = "viridis", |
| 73 | + portrait: bool = False, |
| 74 | + linked_views: bool = True, |
| 75 | + with_titles: bool = True, |
| 76 | + plotter_kwargs: dict = {}, |
| 77 | +) -> str: |
| 78 | + """Animate a set of data with a corresponding set of plot functions. |
| 79 | + The The list of data for each plot must be of equal length. Each element |
| 80 | + of these lists will be one frame in the animation. |
| 81 | + The result is stored as an .mp4 file. |
| 82 | + Args: |
| 83 | + filename (str): A filename or path to store the animation. This function automatically adds an .mp4 extension to the given filename. |
| 84 | + data (dict[str, list[np.ndarray]]): The data to animate in each suplot. The keys are the labels for the plots and the values are lists of data to plot. Each element of these lists will be one frame in the animation. The lists must be of equal length. |
| 85 | + titles (dict[str, list[str]]): The titles for each subplot. The keys are the labels for the plots and the values are lists of titles to display. Each element of these lists will be the title for the corresponding frame in the animation. Useful to have timestep labels for instance. |
| 86 | + plotfuncs (dict[str, Callable]): Function to write the data into a frame in the plotter. Signature: func(data, plotter, show_cmap:bool = show_cmap, cmap=cmap) |
| 87 | + show_cmap (bool, optional): whether to show the colormap in each plot or not. Defaults to False. |
| 88 | + cmap (str | np.ndarray, optional): matplotlib colormap name. Defaults to "viridis". |
| 89 | + portrait (bool, optional): Aspect ratio mode. if this is true, the larger dimension of the plot table will be the rows. Otherwise, the larger dimension will be the columns. Defaults to False. |
| 90 | + linked_views (bool, optional): Whether to link all the views together for interactive plotting. If true, they will move in unison if one is moved. Defaults to True. |
| 91 | + with_titles (bool, optional): If the labels should be used as titles. Defaults to True. |
| 92 | + plotter_kwargs (dict, optional): Other keyword arguments passed to the plotter constructor. Defaults to {}. |
| 93 | +
|
| 94 | + Returns: |
| 95 | + str: path to where the given animation is stored. |
| 96 | + """ |
| 97 | + |
| 98 | + def create_frame(label): |
| 99 | + for i in range(layout[0]): |
| 100 | + for j in range(layout[1]): |
| 101 | + plotter.subplot(i, j) |
| 102 | + |
| 103 | + current_label = next(label) |
| 104 | + |
| 105 | + if with_titles: |
| 106 | + plotter.add_text(titles[current_label][0]) |
| 107 | + |
| 108 | + plotfuncs[current_label]( |
| 109 | + data[current_label][0], plotter, show_cmap=show_cmap, cmap=cmap |
| 110 | + ) |
| 111 | + |
| 112 | + plotter.write_frame() |
| 113 | + |
| 114 | + layout = _find_good_shape(len(data), portrait=portrait) |
| 115 | + |
| 116 | + plotter = pv.Plotter(shape=layout, **plotter_kwargs) |
| 117 | + |
| 118 | + plotter.open_movie(filename) |
| 119 | + |
| 120 | + label = iter(data.keys()) |
| 121 | + |
| 122 | + create_frame(label) |
| 123 | + |
| 124 | + if linked_views: |
| 125 | + plotter.link_views() |
| 126 | + |
| 127 | + current_label = next(iter(data.keys())) |
| 128 | + |
| 129 | + for i in range(1, len(data[current_label])): |
| 130 | + label = iter(data.keys()) |
| 131 | + create_frame(label) |
| 132 | + |
| 133 | + plotter.close() |
| 134 | + |
| 135 | + return filename |
| 136 | + |
| 137 | + |
| 138 | +def _plot3Dfacet( |
| 139 | + data: dict[str, np.ndarray], |
| 140 | + plotfuncs: dict[str, Callable], |
| 141 | + show_cmap: bool = False, |
| 142 | + cmap: str | np.ndarray = "viridis", |
| 143 | + portrait: bool = False, |
| 144 | + linked_views: bool = True, |
| 145 | + with_titles: bool = True, |
| 146 | + plotter_kwargs: dict = {}, |
| 147 | +) -> pv.Plotter: |
| 148 | + """Plot a set of data with a corresponding set of plot functions into a grid of subplots, similar to seaborn facetplots. |
| 149 | +
|
| 150 | + Args: |
| 151 | + data (dict[str, np.ndarray]): Dictionary of data to plot. The keys are the labels for the plots and the values are the data to plot. |
| 152 | + plotfuncs (dict[str, Callable]): Functions that take the data and a pyvista plotter object and plot the data into the each subplot. |
| 153 | + show_cmap (bool, optional): whether to show the colormap in each plot or not. Defaults to False. |
| 154 | + cmap (str | np.ndarray, optional): matplotlib colormap name. Defaults to "viridis". |
| 155 | + portrait (bool, optional): Aspect ratio mode. if this is true, the larger dimension of the plot table will be the rows. Otherwise, the larger dimension will be the columns. Defaults to False. |
| 156 | + linked_views (bool, optional): Whether to link all the views together for interactive plotting. If true, they will move in unison if one is moved. Defaults to True. |
| 157 | + with_titles (bool, optional): If the labels should be used as titles. Defaults to True. |
| 158 | + plotter_kwargs (dict, optional): Other keyword arguments passed to the plotter constructor. Defaults to {}. |
| 159 | +
|
| 160 | + Returns: |
| 161 | + pv.Plotter: pyvista plotter object. Call plotter.show() to display the plot. |
| 162 | + """ |
| 163 | + |
| 164 | + layout = _find_good_shape(len(data), portrait=portrait) |
| 165 | + |
| 166 | + plotter = pv.Plotter(shape=layout, **plotter_kwargs) |
| 167 | + |
| 168 | + label = iter(data.keys()) |
| 169 | + |
| 170 | + for i in range(layout[0]): |
| 171 | + for j in range(layout[1]): |
| 172 | + plotter.subplot(i, j) |
| 173 | + current_label = next(label) |
| 174 | + if with_titles: |
| 175 | + plotter.add_text(current_label) |
| 176 | + plotfuncs[current_label]( |
| 177 | + data[current_label], plotter, show_cmap=show_cmap, cmap=cmap |
| 178 | + ) |
| 179 | + |
| 180 | + if linked_views: |
| 181 | + plotter.link_views() |
| 182 | + |
| 183 | + return plotter |
0 commit comments