Skip to content

Commit

Permalink
duck-work
Browse files Browse the repository at this point in the history
  • Loading branch information
d-chambers committed Jan 9, 2025
1 parent cdb98c2 commit 81065d8
Show file tree
Hide file tree
Showing 17 changed files with 49 additions and 72 deletions.
13 changes: 5 additions & 8 deletions dascore/core/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from pydantic import ConfigDict, Field, PlainValidator, model_validator
from typing_extensions import Self

import dascore as dc
from dascore.constants import (
VALID_DATA_CATEGORIES,
VALID_DATA_TYPES,
Expand All @@ -23,12 +22,11 @@
to_str,
)
from dascore.utils.models import (
StrTupleStrSerialized,
IntTupleStrSerialized,
DascoreBaseModel,
IntTupleStrSerialized,
StrTupleStrSerialized,
UnitQuantity,
frozen_dict_serializer,
frozen_dict_validator,
)

str_validator = PlainValidator(to_str)
Expand All @@ -46,7 +44,7 @@ def _to_coord_summary(coord_dict) -> FrozenDict[str, CoordSummary]:
for i, v in coord_dict.items():
if hasattr(v, "to_summary"):
v = v.to_summary()
elif isinstance(v,CoordSummary):
elif isinstance(v, CoordSummary):
pass
else:
v = CoordSummary(**v)
Expand Down Expand Up @@ -133,7 +131,6 @@ class PatchAttrs(DascoreBaseModel):
frozen_dict_serializer,
] = Field(default_factory=dict)


@model_validator(mode="before")
@classmethod
def _get_dims(cls, data: Any) -> Any:
Expand All @@ -144,8 +141,8 @@ def _get_dims(cls, data: Any) -> Any:
coords = data.get("coords", {})
dims = getattr(coords, "dims", None)
if dims is None and isinstance(coords, dict):
dims = coords.get('dims', ())
data['dims'] = dims
dims = coords.get("dims", ())
data["dims"] = dims
return data

def __getitem__(self, item):
Expand Down
4 changes: 3 additions & 1 deletion dascore/core/coordmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,9 @@ def _divide_kwargs(kwargs):
update_coords = update

def update_from_attrs(
self, attrs: Mapping | dc.PatchAttrs, data=None,
self,
attrs: Mapping | dc.PatchAttrs,
data=None,
) -> tuple[Self, dc.PatchAttrs]:
"""
Update coordinates from attrs.
Expand Down
10 changes: 5 additions & 5 deletions dascore/core/coords.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@
from dascore.utils.models import (
ArrayLike,
DascoreBaseModel,
UnitQuantity,
IntTupleStrSerialized,
StrTupleStrSerialized,
UnitQuantity,
)
from dascore.utils.time import dtype_time_like, is_datetime64, is_timedelta64, to_float

Expand Down Expand Up @@ -104,7 +104,7 @@ class CoordSummary(DascoreBaseModel):
step: step_type | None = None
units: UnitQuantity | None = None
dims: StrTupleStrSerialized = ()
name: str = ''
name: str = ""

@model_serializer(when_used="json")
def ser_model(self) -> dict[str, str]:
Expand Down Expand Up @@ -710,15 +710,15 @@ def get_attrs_dict(self, name):
out[f"{name}_units"] = self.units
return out

def to_summary(self, dims=(), name='') -> CoordSummary:
def to_summary(self, dims=(), name="") -> CoordSummary:
"""Get the summary info about the coord."""
return CoordSummary(
min=self.min(),
max=self.max(),
step=self.step,
dtype=self.dtype,
units=self.units,
shape = self.shape,
shape=self.shape,
dims=dims,
name=name,
)
Expand Down Expand Up @@ -979,7 +979,7 @@ def change_length(self, length: int) -> Self:
assert self.ndim == 1, "change_length only works on 1D coords."
return get_coord(shape=(length,))

def to_summary(self, dims=(), name='') -> CoordSummary:
def to_summary(self, dims=(), name="") -> CoordSummary:
"""Get the summary info about the coord."""
return CoordSummary(
min=np.nan,
Expand Down
1 change: 1 addition & 0 deletions dascore/core/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class Patch:
if there is a conflict between information contained in both, the coords
will be recalculated.
"""

data: ArrayLike
coords: CoordManager
dims: tuple[str, ...]
Expand Down
2 changes: 1 addition & 1 deletion dascore/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def random_patch(
time_coord = dc.get_coord(
data=time_array,
step=time_step if time_array.size <= 1 else None,
units='s',
units="s",
)
else:
time_coord = dascore.core.get_coord(
Expand Down
2 changes: 1 addition & 1 deletion dascore/io/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@
from dascore.utils.mapping import FrozenDict
from dascore.utils.misc import _iter_filesystem, cached_method, iterate, warn_or_raise
from dascore.utils.models import (
StrTupleStrSerialized,
DascoreBaseModel,
DateTime64,
StrTupleStrSerialized,
TimeDelta64,
)
from dascore.utils.pd import _model_list_to_df
Expand Down
2 changes: 1 addition & 1 deletion dascore/io/dasdae/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,5 +134,5 @@ def scan(self, resource: H5Reader, **kwargs):
A path to the file.
"""
file_format = self.name
version = resource.attrs['__DASDAE_version__']
version = resource.attrs["__DASDAE_version__"]
return _get_contents_from_patch_groups(resource, version, file_format)
41 changes: 16 additions & 25 deletions dascore/io/dasdae/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,16 @@

from __future__ import annotations

import pickle
from contextlib import suppress

import numpy as np
from tables import NodeError

import dascore as dc
from dascore.core.attrs import PatchAttrs
from dascore.core.coordmanager import get_coord_manager
from dascore.core.coords import get_coord
from dascore.utils.misc import suppress_warnings
from dascore.utils.time import to_int
from dascore.utils.misc import unbyte
from dascore.utils.hdf5 import Empty
from dascore.utils.misc import suppress_warnings, unbyte
from dascore.utils.time import to_int

# --- Functions for writing DASDAE format

Expand All @@ -33,7 +29,7 @@ def _santize_pytables(some_dict):
continue
# Get rid of empty enum.
if isinstance(val, Empty):
val = ''
val = ""
out[i] = val
return out

Expand Down Expand Up @@ -90,15 +86,12 @@ def _save_array(data, name, group, h5):
return array_node




def _save_coords(patch, patch_group, h5):
"""Save coordinates."""
cm = patch.coords
for name, coord in cm.coord_map.items():
summary = (
coord.to_summary(name=name, dims=cm.dims[name])
.model_dump(exclude_defaults=True)
summary = coord.to_summary(name=name, dims=cm.dims[name]).model_dump(
exclude_defaults=True
)
breakpoint()
# First save coordinate arrays
Expand All @@ -108,9 +101,6 @@ def _save_coords(patch, patch_group, h5):
dataset.attrs.update(summary)





def _save_patch(patch, wave_group, h5, name):
"""Save the patch to disk."""
patch_group = _create_or_get_group(h5, wave_group, name)
Expand Down Expand Up @@ -229,22 +219,23 @@ def _get_coord_info(info, group):
attrs = _santize_pytables(dict(ds.attrs))
# Need to get old dimensions from c_dims in attrs.
if "dims" not in attrs:
attrs['dims'] = info.get(f"_cdims_{name}", name)
attrs["dims"] = info.get(f"_cdims_{name}", name)
# The summary info is not stored in attrs; need to read coord array.
c_info = {}
if 'min' not in attrs:
if "min" not in attrs:
c_summary = (
dc.core.get_coord(data=ds[:])
.to_summary()
.model_dump(exclude_unset=True, exclude_defaults=True)
)
c_info.update(c_summary)

c_info.update({
"dtype": ds.dtype.str,
'shape': ds.shape,
"name": name,
}
c_info.update(
{
"dtype": ds.dtype.str,
"shape": ds.shape,
"name": name,
}
)
coords[name] = c_info
return coords
Expand All @@ -261,10 +252,10 @@ def _get_patch_content_from_group(group):
value = np.atleast_1d(value)[0]
out[new_key] = value
# Add coord info.
out['coords'] = _get_coord_info(out, group)
out["coords"] = _get_coord_info(out, group)
# Add data info.
out['shape'] = group['data'].shape
out['dtype'] = group['data'].dtype.str
out["shape"] = group["data"].shape
out["dtype"] = group["data"].dtype.str
# rename dims
out["dims"] = out.pop("_dims")
return out
4 changes: 1 addition & 3 deletions dascore/io/h5simple/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ def read(self, resource: H5Reader, snap=True, **kwargs) -> SpoolType:
patch = dc.Patch(coords=new_cm, data=new_data[:], attrs=attrs)
return dc.spool([patch])

def scan(
self, resource: H5Reader, snap=True, **kwargs
) -> list[dc.PatchAttrs]:
def scan(self, resource: H5Reader, snap=True, **kwargs) -> list[dc.PatchAttrs]:
"""Get the attributes of a h5simple file."""
attrs, cm, data = _get_attrs_coords_and_data(resource, snap, self)
attrs["coords"] = cm.to_summary_dict()
Expand Down
2 changes: 0 additions & 2 deletions dascore/utils/duck.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
"""
Utilities for working with DuckDB.
"""


2 changes: 0 additions & 2 deletions dascore/utils/hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import pandas as pd
import tables
from h5py import File as H5pyFile
from h5py import Empty
from packaging.version import parse as get_version
from pandas.io.common import stringify_path
from tables import ClosedNodeError
Expand All @@ -44,7 +43,6 @@
)
from dascore.utils.time import get_max_min_times, to_datetime64, to_int, to_timedelta64


HDF5ExtError = tables.HDF5ExtError
NoSuchNodeError = tables.NoSuchNodeError
NodeError = tables.NodeError
Expand Down
5 changes: 3 additions & 2 deletions dascore/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@
def _str_to_int_tuple(value):
"""Convert a string of ints to a tuple."""
if isinstance(value, str):
return tuple(int(x) for x in value.split(','))
return tuple(int(x) for x in value.split(","))
return value


# A datetime64
DateTime64 = Annotated[
np.datetime64,
Expand Down Expand Up @@ -70,7 +71,7 @@ def _str_to_int_tuple(value):
IntTupleStrSerialized = Annotated[
tuple[int, ...],
PlainValidator(_str_to_int_tuple),
PlainSerializer(lambda x: ",".join((str(y) for y in x))),
PlainSerializer(lambda x: ",".join(str(y) for y in x)),
]

FrozenDictType = Annotated[
Expand Down
17 changes: 4 additions & 13 deletions dascore/utils/pd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@

from __future__ import annotations


import fnmatch
import os
from collections import defaultdict
from collections.abc import Collection, Mapping, Sequence
from collections.abc import Collection, Iterable, Mapping, Sequence
from functools import cache
from typing import Iterable

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -559,7 +557,7 @@ def rolling_df(df, window, step=None, axis=0, center=False):


def get_attrs_coords_patch_table(
patch_or_attrs: Iterable[dc.PatchAttrs | dc.Patch | dc.BaseSpool],
patch_or_attrs: Iterable[dc.PatchAttrs | dc.Patch | dc.BaseSpool],
) -> tuple(pd.DataFrame, pd.DataFrame, pd.DataFrame):
"""
Get seperated attributes, coordinates, and patch tables from attrs.
Expand All @@ -569,11 +567,12 @@ def get_attrs_coords_patch_table(
patch_or_attrs
An iterable with patch content.
"""

def get_coord_dict(attr, num):
"""Get the coordinate information from the attrs."""
out = []
for coord in attr.values():
coord['id'] = num
coord["id"] = num
out.append(coord)
return out

Expand All @@ -588,11 +587,3 @@ def get_coord_dict(attr, num):
breakpoint()
coord_info.extend(get_coord_dict(attr.pop("coords", {}), num))
breakpoint()








11 changes: 6 additions & 5 deletions tests/test_core/test_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ def attrs_coords_2() -> PatchAttrs:
"""Add non-standard coords to attrs."""
coords = {"depth": {"min": 10.0, "max": 12.0, "dtype": "<f8", "shape": (12,)}}
attrs = {
"coords": coords, "another_name": "FooBar", "dtype": "<i4", "shape": (12,22)
"coords": coords,
"another_name": "FooBar",
"dtype": "<i4",
"shape": (12, 22),
}
return PatchAttrs(**attrs)

Expand Down Expand Up @@ -131,9 +134,7 @@ def test_extra_attrs_not_in_dump_random_attrs(self, random_attrs):
def test_supports_extra_attrs(self, random_attrs):
"""The attr dict should allow extra attributes."""
model_dump = random_attrs.model_dump()
out = PatchAttrs(
bob="doesnt", bill_min=12, bob_max="2012-01-12", **model_dump
)
out = PatchAttrs(bob="doesnt", bill_min=12, bob_max="2012-01-12", **model_dump)
assert out.bob == "doesnt"
assert out.bill_min == 12

Expand Down Expand Up @@ -173,7 +174,7 @@ def test_coords_to_coord_summary(self):
"dims": ("time", "distance"),
}
attr = dc.PatchAttrs(**out)
assert attr.dims == out['dims']
assert attr.dims == out["dims"]
for name, coord in attr.coords.items():
assert isinstance(coord, CoordSummary)

Expand Down
1 change: 0 additions & 1 deletion tests/test_utils/test_hdf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,6 @@ def test_metadata_created(self, tmp_path_factory):
assert meta is not None



class TestH5MatchesStructure:
"""Tests for the h5 matches structure function."""

Expand Down
Loading

0 comments on commit 81065d8

Please sign in to comment.