Skip to content

Commit

Permalink
start refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
d-chambers committed Jan 7, 2025
1 parent bc1a02f commit cdb98c2
Show file tree
Hide file tree
Showing 23 changed files with 326 additions and 217 deletions.
95 changes: 47 additions & 48 deletions dascore/core/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
to_str,
)
from dascore.utils.models import (
CommaSeparatedStr,
StrTupleStrSerialized,
IntTupleStrSerialized,
DascoreBaseModel,
UnitQuantity,
frozen_dict_serializer,
Expand All @@ -35,31 +36,22 @@
_coord_required = {"min", "max"}


def _get_coords_dict(data_dict):
"""
Add coords dict to data dict, pop out any coordinate attributes.
For example, if time_min, time_step are in data_dict, these will be
grouped into the coords sub dict under "time".
"""

def _get_dims(data_dict):
"""Try to get dim tuple."""
dims = None
if "dims" in data_dict:
dims = data_dict["dims"]
elif hasattr(coord := data_dict.get("coords"), "dims"):
dims = coord.dims
if isinstance(dims, str):
dims = tuple(dims.split(","))
return dims

dims = _get_dims(data_dict)
coord_info, new_attrs = separate_coord_info(
data_dict, dims, required=("min", "max")
)
new_attrs["coords"] = {i: dc.core.CoordSummary(**v) for i, v in coord_info.items()}
return new_attrs
def _to_coord_summary(coord_dict) -> FrozenDict[str, CoordSummary]:
"""Convert a dict of potential coord info to a coord summary dict."""
# We already have a summary dict, just return.
if hasattr(coord_dict, "to_summary_dict"):
return coord_dict.to_summary_dict()
# Otherwise, build up summary dict contents.
out = {}
for i, v in coord_dict.items():
if hasattr(v, "to_summary"):
v = v.to_summary()
elif isinstance(v,CoordSummary):
pass
else:
v = CoordSummary(**v)
out[i] = v
return FrozenDict(out)


class PatchAttrs(DascoreBaseModel):
Expand Down Expand Up @@ -92,6 +84,12 @@ class PatchAttrs(DascoreBaseModel):
data_type: Annotated[Literal[VALID_DATA_TYPES], str_validator] = Field(
description="Describes the quantity being measured.", default=""
)
dtype: str = Field(
description="The data type of the patch array (e.g, f32).",
)
shape: IntTupleStrSerialized = Field(
description="The shape of the patch array.",
)
data_category: Annotated[Literal[VALID_DATA_CATEGORIES], str_validator] = Field(
description="Describes the type of data.",
default="",
Expand Down Expand Up @@ -123,27 +121,31 @@ class PatchAttrs(DascoreBaseModel):
description="A list of processing performed on the patch.",
)

dims: CommaSeparatedStr = Field(
default="",
dims: StrTupleStrSerialized = Field(
default=(),
max_length=max_lens["dims"],
description="A tuple of comma-separated dimensions names.",
)

coords: Annotated[
FrozenDict[str, CoordSummary],
frozen_dict_validator,
PlainValidator(_to_coord_summary),
frozen_dict_serializer,
] = Field(default_factory=dict)


@model_validator(mode="before")
@classmethod
def parse_coord_attributes(cls, data: Any) -> Any:
def _get_dims(cls, data: Any) -> Any:
"""Parse the coordinate attributes into coord dict."""
if isinstance(data, dict):
data = _get_coords_dict(data)
# add dims as coords if dims is not included.
if "dims" not in data:
data["dims"] = tuple(data["coords"])
# Add dims from coords if they aren't found.
dims = data.get("dims")
if not dims:
coords = data.get("coords", {})
dims = getattr(coords, "dims", None)
if dims is None and isinstance(coords, dict):
dims = coords.get('dims', ())
data['dims'] = dims
return data

def __getitem__(self, item):
Expand Down Expand Up @@ -185,14 +187,15 @@ def items(self):
def coords_from_dims(self) -> Mapping[str, BaseCoord]:
"""Return coordinates from dimensions assuming evenly sampled."""
out = {}
for dim in self.dim_tuple:
for dim in self.dims:
out[dim] = self.coords[dim].to_coord()
return out

@classmethod
def from_dict(
cls,
attr_map: Mapping | PatchAttrs,
data=None,
) -> Self:
"""
Get a new instance of the PatchAttrs.
Expand All @@ -205,26 +208,22 @@ def from_dict(
attr_map
Anything convertible to a dict that contains the attr info.
"""
data_info = {}
if data is not None:
data_info = {"dtype": data.dtype.str, "shape": data.shape}
if isinstance(attr_map, cls):
return attr_map
return attr_map.update(**data_info)
out = {} if attr_map is None else attr_map
out.update(**data_info)
return cls(**out)

@property
def dim_tuple(self):
"""Return a tuple of dimensions. The dims attr is a string."""
dim_str = self.dims
if not dim_str:
return tuple()
return tuple(self.dims.split(","))

def rename_dimension(self, **kwargs):
"""Rename one or more dimensions if in kwargs. Return new PatchAttrs."""
if not (dims := set(kwargs) & set(self.dim_tuple)):
if not (dims := set(kwargs) & set(self.dims)):
return self
new = self.model_dump(exclude_defaults=True)
coords = new.get("coords", {})
new_dims = list(self.dim_tuple)
new_dims = list(self.dims)
for old_name, new_name in {x: kwargs[x] for x in dims}.items():
new_dims[new_dims.index(old_name)] = new_name
coords[new_name] = coords.pop(old_name, None)
Expand All @@ -233,7 +232,7 @@ def rename_dimension(self, **kwargs):

def update(self, **kwargs) -> Self:
"""Update an attribute in the model, return new model."""
coord_info, attr_info = separate_coord_info(kwargs, dims=self.dim_tuple)
coord_info, attr_info = separate_coord_info(kwargs, dims=self.dims)
out = self.model_dump(exclude_unset=True)
out.update(attr_info)
out_coord_dict = out["coords"]
Expand Down
6 changes: 3 additions & 3 deletions dascore/core/coordmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def _divide_kwargs(kwargs):
update_coords = update

def update_from_attrs(
self, attrs: Mapping | dc.PatchAttrs
self, attrs: Mapping | dc.PatchAttrs, data=None,
) -> tuple[Self, dc.PatchAttrs]:
"""
Update coordinates from attrs.
Expand Down Expand Up @@ -321,7 +321,7 @@ def update_from_attrs(
attr_info[f"{unused_dim}_{key}"] = val
attr_info["coords"] = coords.to_summary_dict()
attr_info["dims"] = coords.dims
attrs = dc.PatchAttrs.from_dict(attr_info)
attrs = dc.PatchAttrs.from_dict(attr_info, data=data)
return coords, attrs

def sort(
Expand Down Expand Up @@ -984,7 +984,7 @@ def to_summary_dict(self) -> dict[str, CoordSummary | tuple[str, ...]]:
dim_map = self.dim_map
out = {}
for name, coord in self.coord_map.items():
out[name] = coord.to_summary(dims=dim_map[name])
out[name] = coord.to_summary(dims=dim_map[name], name=name)
return out

def get_coord(self, coord_name: str) -> BaseCoord:
Expand Down
21 changes: 17 additions & 4 deletions dascore/core/coords.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
ArrayLike,
DascoreBaseModel,
UnitQuantity,
IntTupleStrSerialized,
StrTupleStrSerialized,
)
from dascore.utils.time import dtype_time_like, is_datetime64, is_timedelta64, to_float

Expand Down Expand Up @@ -98,8 +100,11 @@ class CoordSummary(DascoreBaseModel):
dtype: str
min: min_max_type
max: min_max_type
shape: IntTupleStrSerialized
step: step_type | None = None
units: UnitQuantity | None = None
dims: StrTupleStrSerialized = ()
name: str = ''

@model_serializer(when_used="json")
def ser_model(self) -> dict[str, str]:
Expand Down Expand Up @@ -581,7 +586,9 @@ def new(self, **kwargs):
# Need to ensure new data is used in constructor, not old shape
if "values" in kwargs:
info.pop("shape", None)

# Get rid of name and dims from summary.
kwargs.pop("name", None)
kwargs.pop("dims", None)
info.update(kwargs)
return get_coord(**info)

Expand Down Expand Up @@ -703,14 +710,17 @@ def get_attrs_dict(self, name):
out[f"{name}_units"] = self.units
return out

def to_summary(self, dims=()) -> CoordSummary:
def to_summary(self, dims=(), name='') -> CoordSummary:
"""Get the summary info about the coord."""
return CoordSummary(
min=self.min(),
max=self.max(),
step=self.step,
dtype=self.dtype,
units=self.units,
shape = self.shape,
dims=dims,
name=name,
)

def update(self, **kwargs):
Expand Down Expand Up @@ -969,14 +979,17 @@ def change_length(self, length: int) -> Self:
assert self.ndim == 1, "change_length only works on 1D coords."
return get_coord(shape=(length,))

def to_summary(self, dims=()) -> CoordSummary:
def to_summary(self, dims=(), name='') -> CoordSummary:
"""Get the summary info about the coord."""
return CoordSummary(
min=np.nan,
max=np.nan,
step=np.nan,
dtype=self.dtype,
units=None,
units=self.units,
dims=dims,
shape=(),
name=name,
)


Expand Down
11 changes: 5 additions & 6 deletions dascore/core/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ class Patch:
if there is a conflict between information contained in both, the coords
will be recalculated.
"""

data: ArrayLike
coords: CoordManager
dims: tuple[str, ...]
Expand All @@ -81,7 +80,7 @@ def __init__(
if coords is None and attrs is not None:
attrs = dc.PatchAttrs.from_dict(attrs)
coords = attrs.coords_from_dims()
dims = dims if dims is not None else attrs.dim_tuple
dims = dims if dims is not None else attrs.dims
# Ensure required info is here
non_attrs = [x is None for x in [data, coords, dims]]
if any(non_attrs) and not all(non_attrs):
Expand All @@ -93,13 +92,13 @@ def __init__(
# the only case we allow attrs to include coords is if they are both
# dicts, in which case attrs might have unit info for coords.
if isinstance(attrs, Mapping) and attrs:
coords, attrs = coords.update_from_attrs(attrs)
coords, attrs = coords.update_from_attrs(attrs, data)
else:
# ensure attrs conforms to coords
attrs = dc.PatchAttrs.from_dict(attrs).update(coords=coords)
assert coords.dims == attrs.dim_tuple, "dim mismatch on coords and attrs"
self._coords = coords
attrs = dc.PatchAttrs.from_dict(attrs, data=data).update(coords=coords)
assert coords.dims == attrs.dims, "dim mismatch on coords and attrs"
self._attrs = attrs
self._coords = coords
self._data = array(self.coords.validate_data(data))

def __eq__(self, other):
Expand Down
52 changes: 25 additions & 27 deletions dascore/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,42 +69,38 @@ def random_patch(
"""
# get input data
rand = np.random.RandomState(13)
array = rand.random(shape)
array: np.ndarray = rand.random(shape)
# create attrs
t1 = np.atleast_1d(np.datetime64(time_min))[0]
d1 = np.atleast_1d(distance_min)
attrs = dict(
distance_step=distance_step,
time_step=to_timedelta64(time_step),
category="DAS",
time_min=t1,
network=network,
station=station,
tag=tag,
time_units="s",
distance_units="m",
)
# need to pop out dim attrs if coordinates provided.
if time_array is not None:
attrs.pop("time_min")
# need to keep time_step if time_array is len 1 to get coord range
if len(time_array) > 1:
attrs.pop("time_step")
time_coord = dc.get_coord(
data=time_array,
step=time_step if time_array.size <= 1 else None,
units='s',
)
else:
time_array = dascore.core.get_coord(
data=t1 + np.arange(array.shape[1]) * attrs["time_step"],
step=attrs["time_step"],
units=attrs["time_units"],
time_coord = dascore.core.get_coord(
data=t1 + np.arange(array.shape[1]) * to_timedelta64(time_step),
step=to_timedelta64(time_step),
units="s",
)
if dist_array is not None:
attrs.pop("distance_step")
dist_coord = dc.get_coord(data=dist_array)
else:
dist_array = dascore.core.get_coord(
data=d1 + np.arange(array.shape[0]) * attrs["distance_step"],
step=attrs["distance_step"],
units=attrs["distance_units"],
dist_coord = dascore.core.get_coord(
data=d1 + np.arange(array.shape[0]) * distance_step,
step=distance_step,
units="m",
)
coords = dict(distance=dist_array, time=time_array)
coords = dict(distance=dist_coord, time=time_coord)
# assemble and output.
out = dict(data=array, coords=coords, attrs=attrs, dims=("distance", "time"))
patch = dc.Patch(**out)
Expand Down Expand Up @@ -722,11 +718,13 @@ def get_example_spool(example_name="random_das", **kwargs) -> dc.BaseSpool:
"""
if example_name not in EXAMPLE_SPOOLS:
# Allow the example spool to be a data registry file.
with suppress(ValueError):
return dc.spool(fetch(example_name))
msg = (
f"No example spool registered with name {example_name} "
f"Registered example spools are {list(EXAMPLE_SPOOLS)}"
)
raise UnknownExampleError(msg)
try:
path = fetch(example_name)
except ValueError:
msg = (
f"No example spool registered with name {example_name} "
f"Registered example spools are {list(EXAMPLE_SPOOLS)}"
)
raise UnknownExampleError(msg)
return dc.spool(path)
return EXAMPLE_SPOOLS[example_name](**kwargs)
Loading

0 comments on commit cdb98c2

Please sign in to comment.