Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a transform option to line plots. #222

Merged
merged 1 commit into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
748 changes: 748 additions & 0 deletions doc/devlog/2024-12-19-v0.9.ipynb

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions epymorph/geography/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ def nodes(self) -> int:
"""The number of nodes in this scope."""
return len(self.node_ids)

def index_of(self, node_id: str) -> int:
"""Returns the index of a given node by its ID string.
Raises ValueError if the given ID isn't in the scope."""
idxs, *_ = np.where(self.node_ids == node_id)
if len(idxs) == 0:
raise ValueError(f"'{node_id}' not present in geo scope.")
return idxs[0]

@property
def labels_option(self) -> NDArray[np.str_] | None:
"""An optional text label for each node. If this returns None,
Expand Down
21 changes: 21 additions & 0 deletions epymorph/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,27 @@ def _to_selection(
to_index = (other.end_date - self.time_frame.start_date).days + 1
return TimeSelection(self.time_frame, (slice(from_index, to_index), step))

def days(
self,
from_day: int,
to_day: int,
step: int | None = None,
) -> TimeSelection:
"""Subset the time frame by providing a start and end simulation day
(inclusive).

Parameters
----------
from_day : int
the starting simulation day of the range, as an index
to_day : int
the last included simulation day of the range, as an index
step : int, optional
if given, narrow the selection to a specific tau step (by index) within
the date range; by default include all steps
"""
return TimeSelection(self.time_frame, (slice(from_day, to_day), step))

def range(
self,
from_date: date | str,
Expand Down
36 changes: 36 additions & 0 deletions epymorph/tools/out_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def line(
ordering: Literal["location", "quantity"] = "location",
time_format: Literal["auto", "date", "day"] = "auto",
title: str | None = None,
transform: Callable[[pd.DataFrame], pd.DataFrame] | None = None,
) -> None:
"""
Renders a line plot using matplotlib showing the given selections.
Expand Down Expand Up @@ -186,6 +187,21 @@ def line(
may be ignored.
title : str, optional
a title to draw on the plot
transform : Callable[[pd.DataFrame], pd.DataFrame], optional
allows you to specify an arbitrary transform function for the source
dataframe before we plot it, e.g., to rescale the values.
The function will be called once per geo/quantity group -- one per line,
essentially -- with a dataframe that contains just the data for that group.
The dataframe given as the argument is the result of applying
all selections and the projection if specified.
You should return a dataframe with the same format, where the
values of the data column have been modified for your purposes.

Dataframe columns:
- "time": the time series column
- "geo": the node ID (same value per group)
- "quantity": the label of the quantity (same value per group)
- "value": the data column
"""

# Adjust figsize to make room for an outside legend, if needed
Expand All @@ -210,6 +226,7 @@ def line(
label_format=label_format,
ordering=ordering,
time_format=time_format,
transform=transform,
)
# Make sure the plot does not grow if we widened the figure.
x, y, w, h = ax.get_position(original=True).bounds
Expand Down Expand Up @@ -264,6 +281,7 @@ def line_plt(
line_kwargs: list[dict] | None = None,
ordering: Literal["location", "quantity"] = "location",
time_format: Literal["auto", "date", "day"] = "auto",
transform: Callable[[pd.DataFrame], pd.DataFrame] | None = None,
) -> list[Line2D]:
"""
Draws lines onto the given matplotlib Axes to show the given selections.
Expand Down Expand Up @@ -304,6 +322,21 @@ def line_plt(
simulation with the first day being 0.
If the system cannot convert to the requested time format, this argument
may be ignored.
transform : Callable[[pd.DataFrame], pd.DataFrame], optional
allows you to specify an arbitrary transform function for the source
dataframe before we plot it, e.g., to rescale the values.
The function will be called once per geo/quantity group -- one per line,
essentially -- with a dataframe that contains just the data for that group.
The dataframe given as the argument is the result of applying
all selections and the projection if specified.
You should return a dataframe with the same format, where the
values of the data column have been modified for your purposes.

Dataframe columns:
- "time": the time series column
- "geo": the node ID (same value per group)
- "quantity": the label of the quantity (same value per group)
- "value": the data column

Returns
-------
Expand All @@ -313,6 +346,8 @@ def line_plt(
"""
if line_kwargs is None or len(line_kwargs) == 0:
line_kwargs = [{}]
if transform is None:
transform = identity

data_df = munge(self.output, geo, time, quantity)

Expand Down Expand Up @@ -346,6 +381,7 @@ def line_plt(
q_label = q_mapping[q_label_dis]
label = label_format.format(n=n_label, q=q_label)
curr_kwargs = {"label": label, **kwargs}
data = transform(data.assign(quantity=q_label))
ls = ax.plot(data["time"], data["value"], **curr_kwargs)
lines.extend(ls)
return lines
Expand Down
Loading