Skip to content

Commit

Permalink
start refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
d-chambers committed Jan 8, 2025
1 parent bc1a02f commit 186991a
Show file tree
Hide file tree
Showing 22 changed files with 361 additions and 496 deletions.
2 changes: 1 addition & 1 deletion dascore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from __future__ import annotations
from rich import print # noqa

from dascore.core.patch import Patch
from dascore.core.patch import Patch, PatchSummary
from dascore.core.attrs import PatchAttrs
from dascore.core.spool import BaseSpool, spool
from dascore.core.coordmanager import get_coord_manager, CoordManager
Expand Down
179 changes: 11 additions & 168 deletions dascore/core/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,66 +2,33 @@

from __future__ import annotations

import warnings
from collections.abc import Mapping
from typing import Annotated, Any, Literal
from typing import Annotated, Literal

import numpy as np
from pydantic import ConfigDict, Field, PlainValidator, model_validator
from pydantic import ConfigDict, Field, PlainValidator
from typing_extensions import Self

import dascore as dc
from dascore.constants import (
VALID_DATA_CATEGORIES,
VALID_DATA_TYPES,
max_lens,
)
from dascore.core.coords import BaseCoord, CoordSummary
from dascore.core.coords import CoordSummary
from dascore.utils.attrs import separate_coord_info
from dascore.utils.mapping import FrozenDict
from dascore.utils.misc import (
to_str,
)
from dascore.utils.models import (
CommaSeparatedStr,
DascoreBaseModel,
StrTupleStrSerialized,
UnitQuantity,
frozen_dict_serializer,
frozen_dict_validator,
)

str_validator = PlainValidator(to_str)
_coord_summary_suffixes = set(CoordSummary.model_fields)
_coord_required = {"min", "max"}


def _get_coords_dict(data_dict):
"""
Add coords dict to data dict, pop out any coordinate attributes.
For example, if time_min, time_step are in data_dict, these will be
grouped into the coords sub dict under "time".
"""

def _get_dims(data_dict):
"""Try to get dim tuple."""
dims = None
if "dims" in data_dict:
dims = data_dict["dims"]
elif hasattr(coord := data_dict.get("coords"), "dims"):
dims = coord.dims
if isinstance(dims, str):
dims = tuple(dims.split(","))
return dims

dims = _get_dims(data_dict)
coord_info, new_attrs = separate_coord_info(
data_dict, dims, required=("min", "max")
)
new_attrs["coords"] = {i: dc.core.CoordSummary(**v) for i, v in coord_info.items()}
return new_attrs


class PatchAttrs(DascoreBaseModel):
"""
The expected attributes for a Patch.
Expand Down Expand Up @@ -118,34 +85,11 @@ class PatchAttrs(DascoreBaseModel):
network: str = Field(
default="", max_length=max_lens["network"], description="A network code."
)
history: str | tuple[str, ...] = Field(
history: StrTupleStrSerialized = Field(
default_factory=tuple,
description="A list of processing performed on the patch.",
)

dims: CommaSeparatedStr = Field(
default="",
max_length=max_lens["dims"],
description="A tuple of comma-separated dimensions names.",
)

coords: Annotated[
FrozenDict[str, CoordSummary],
frozen_dict_validator,
frozen_dict_serializer,
] = Field(default_factory=dict)

@model_validator(mode="before")
@classmethod
def parse_coord_attributes(cls, data: Any) -> Any:
"""Parse the coordinate attributes into coord dict."""
if isinstance(data, dict):
data = _get_coords_dict(data)
# add dims as coords if dims is not included.
if "dims" not in data:
data["dims"] = tuple(data["coords"])
return data

def __getitem__(self, item):
return getattr(self, item)

Expand All @@ -155,22 +99,6 @@ def __setitem__(self, key, value):
def __len__(self):
return len(self.model_dump())

def __getattr__(self, item):
"""Enables dynamic attributes such as time_min, time_max, etc."""
split = item.split("_")
# this only works on names like time_max, distance_step, etc.
if not len(split) == 2:
return super().__getattr__(item)
first, second = split
if first == "d":
first, second = second, "step"
msg = f"{item} is depreciated, use {first}_{second} instead."
warnings.warn(msg, DeprecationWarning, stacklevel=2)
if first not in self.coords:
return super().__getattr__(item)
coord_sum = self.coords[first]
return getattr(coord_sum, second)

def get(self, item, default=None):
"""dict-like get method."""
try:
Expand All @@ -182,69 +110,11 @@ def items(self):
"""Yield (attribute, values) just like dict.items()."""
yield from self.model_dump().items()

def coords_from_dims(self) -> Mapping[str, BaseCoord]:
"""Return coordinates from dimensions assuming evenly sampled."""
out = {}
for dim in self.dim_tuple:
out[dim] = self.coords[dim].to_coord()
return out

@classmethod
def from_dict(
cls,
attr_map: Mapping | PatchAttrs,
) -> Self:
"""
Get a new instance of the PatchAttrs.
Optionally, give preference to data contained in a
[`CoordManager`](`dascore.core.coordmanager.CoordManager`).
Parameters
----------
attr_map
Anything convertible to a dict that contains the attr info.
"""
if isinstance(attr_map, cls):
return attr_map
out = {} if attr_map is None else attr_map
return cls(**out)

@property
def dim_tuple(self):
"""Return a tuple of dimensions. The dims attr is a string."""
dim_str = self.dims
if not dim_str:
return tuple()
return tuple(self.dims.split(","))

def rename_dimension(self, **kwargs):
"""Rename one or more dimensions if in kwargs. Return new PatchAttrs."""
if not (dims := set(kwargs) & set(self.dim_tuple)):
return self
new = self.model_dump(exclude_defaults=True)
coords = new.get("coords", {})
new_dims = list(self.dim_tuple)
for old_name, new_name in {x: kwargs[x] for x in dims}.items():
new_dims[new_dims.index(old_name)] = new_name
coords[new_name] = coords.pop(old_name, None)
new["dims"] = tuple(new_dims)
return self.__class__(**new)

def update(self, **kwargs) -> Self:
"""Update an attribute in the model, return new model."""
coord_info, attr_info = separate_coord_info(kwargs, dims=self.dim_tuple)
_, attr_info = separate_coord_info(kwargs)
out = self.model_dump(exclude_unset=True)
out.update(attr_info)
out_coord_dict = out["coords"]
for name, coord_dict in coord_info.items():
if name not in out_coord_dict:
out_coord_dict[name] = coord_dict
else:
out_coord_dict[name].update(coord_dict)
# silly check to clear coords
if not kwargs.get("coords", True):
out["coords"] = {}
return self.__class__(**out)

def drop_private(self) -> Self:
Expand All @@ -253,35 +123,8 @@ def drop_private(self) -> Self:
out = {i: v for i, v in contents.items() if not i.startswith("_")}
return self.__class__(**out)

def flat_dump(self, dim_tuple=False, exclude=None) -> dict:
"""
Flatten the coordinates and dump to dict.
Parameters
----------
dim_tuple
If True, return dimensional tuple instead of range. EG, the
output will have {time: (min, max)} rather than
{time_min: ..., time_max: ...,}. This is useful because it can
be passed to read, scan, select, etc.
exclude
keys to exclude.
"""
out = self.model_dump(exclude=exclude)
for coord_name, coord in out.pop("coords").items():
names = list(coord)
if dim_tuple:
names = sorted(set(names) - {"min", "max"})
out[coord_name] = (coord["min"], coord["max"])
for name in names:
out[f"{coord_name}_{name}"] = coord[name]
# ensure step has right type if nullish
step_name = f"{coord_name}_step"
step, start = out[step_name], coord["min"]
if step is None:
is_time = isinstance(start, np.datetime64 | np.timedelta64)
if is_time:
out[step_name] = np.timedelta64("NaT")
elif isinstance(start, float | np.floating):
out[step_name] = np.nan
return out
@classmethod
def from_dict(cls, obj: Mapping | Self):
if isinstance(obj, cls):
return obj
return cls(**obj)
46 changes: 42 additions & 4 deletions dascore/core/coordmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,9 +319,7 @@ def update_from_attrs(
for unused_dim in set(coord_info) - set(coords.dims):
for key, val in coord_info[unused_dim].items():
attr_info[f"{unused_dim}_{key}"] = val
attr_info["coords"] = coords.to_summary_dict()
attr_info["dims"] = coords.dims
attrs = dc.PatchAttrs.from_dict(attr_info)
attrs = dc.PatchAttrs(**attr_info)
return coords, attrs

def sort(
Expand Down Expand Up @@ -984,9 +982,24 @@ def to_summary_dict(self) -> dict[str, CoordSummary | tuple[str, ...]]:
dim_map = self.dim_map
out = {}
for name, coord in self.coord_map.items():
out[name] = coord.to_summary(dims=dim_map[name])
out[name] = coord.to_summary(dims=dim_map[name], name=name)
return out

def to_summary(self) -> Self:
"""
Convert the coordinates in the coord manager to Coord summaries.
"""
new_map = {}
for name, coord in self.coord_map.items():
dims = self.dim_map[name]
new_map[name] = coord.to_summary(dims=dims, name=name)
return CoordManagerSummary(
dim_map=self.dim_map,
coords=new_map,
dims=self.dims,
summary=True,
)

def get_coord(self, coord_name: str) -> BaseCoord:
"""
Retrieve a single coordinate from the coordinate manager.
Expand Down Expand Up @@ -1197,3 +1210,28 @@ def _maybe_coord_from_nested(name, coord, new_dims):
if new_dims:
dims = tuple(list(dims) + new_dims)
return c_map, d_map, dims


class CoordManagerSummary(CoordManager):
"""A coordinate manager with summary coordinates."""

coord_map: Annotated[
FrozenDict[str, CoordSummary],
frozen_dict_validator,
frozen_dict_serializer,
]

def to_coord_manager(self):
"""
Convert the summary to a coordinate manager.
This only works if the coordinates were evenly sampled/sorted.
"""
out = {}
for name, coord in self.coord_map.items():
out[name] = coord.to_coord()
return CoordManager(
coord_map=out,
dim_map=self.dim_map,
dims=self.dims,
)
18 changes: 16 additions & 2 deletions dascore/core/coords.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
from dascore.utils.models import (
ArrayLike,
DascoreBaseModel,
IntTupleStrSerialized,
StrTupleStrSerialized,
UnitQuantity,
)
from dascore.utils.time import dtype_time_like, is_datetime64, is_timedelta64, to_float
Expand Down Expand Up @@ -98,6 +100,10 @@ class CoordSummary(DascoreBaseModel):
dtype: str
min: min_max_type
max: min_max_type
shape: IntTupleStrSerialized
ndim: int
dims: StrTupleStrSerialized
name: str
step: step_type | None = None
units: UnitQuantity | None = None

Expand Down Expand Up @@ -703,14 +709,18 @@ def get_attrs_dict(self, name):
out[f"{name}_units"] = self.units
return out

def to_summary(self, dims=()) -> CoordSummary:
def to_summary(self, dims, name) -> CoordSummary:
"""Get the summary info about the coord."""
return CoordSummary(
min=self.min(),
max=self.max(),
step=self.step,
dtype=self.dtype,
units=self.units,
shape=self.shape,
ndim=self.ndim,
dims=dims,
name=name,
)

def update(self, **kwargs):
Expand Down Expand Up @@ -969,14 +979,18 @@ def change_length(self, length: int) -> Self:
assert self.ndim == 1, "change_length only works on 1D coords."
return get_coord(shape=(length,))

def to_summary(self, dims=()) -> CoordSummary:
def to_summary(self, dims, name) -> CoordSummary:
"""Get the summary info about the coord."""
return CoordSummary(
min=np.nan,
max=np.nan,
step=np.nan,
dtype=self.dtype,
units=None,
dims=dims,
name=name,
shape=self.shape,
ndim=self.ndim,
)


Expand Down
Loading

0 comments on commit 186991a

Please sign in to comment.