From cdb98c2d070aeb6d11ca538b2d7c9151f98ff5c7 Mon Sep 17 00:00:00 2001 From: derrick chambers Date: Mon, 6 Jan 2025 18:00:39 -0800 Subject: [PATCH] start refactor --- dascore/core/attrs.py | 95 +++++++++++----------- dascore/core/coordmanager.py | 6 +- dascore/core/coords.py | 21 ++++- dascore/core/patch.py | 11 ++- dascore/examples.py | 52 ++++++------ dascore/io/core.py | 4 +- dascore/io/dasdae/core.py | 30 ++----- dascore/io/dasdae/utils.py | 113 ++++++++++++++++++++++----- dascore/io/h5simple/core.py | 6 +- dascore/io/xml_binary/utils.py | 4 +- dascore/proc/basic.py | 2 +- dascore/utils/duck.py | 5 ++ dascore/utils/hdf5.py | 2 + dascore/utils/models.py | 23 +++++- dascore/utils/pd.py | 42 ++++++++++ pyproject.toml | 1 + tests/test_core/test_attrs.py | 67 +++++++--------- tests/test_core/test_coordmanager.py | 4 +- tests/test_core/test_patch.py | 2 +- tests/test_io/test_common_io.py | 10 +-- tests/test_utils/test_hdf_utils.py | 12 --- tests/test_utils/test_io_utils.py | 22 +----- tests/test_utils/test_pd.py | 9 +++ 23 files changed, 326 insertions(+), 217 deletions(-) create mode 100644 dascore/utils/duck.py diff --git a/dascore/core/attrs.py b/dascore/core/attrs.py index 32e1315d..c4e978a8 100644 --- a/dascore/core/attrs.py +++ b/dascore/core/attrs.py @@ -23,7 +23,8 @@ to_str, ) from dascore.utils.models import ( - CommaSeparatedStr, + StrTupleStrSerialized, + IntTupleStrSerialized, DascoreBaseModel, UnitQuantity, frozen_dict_serializer, @@ -35,31 +36,22 @@ _coord_required = {"min", "max"} -def _get_coords_dict(data_dict): - """ - Add coords dict to data dict, pop out any coordinate attributes. - - For example, if time_min, time_step are in data_dict, these will be - grouped into the coords sub dict under "time". - """ - - def _get_dims(data_dict): - """Try to get dim tuple.""" - dims = None - if "dims" in data_dict: - dims = data_dict["dims"] - elif hasattr(coord := data_dict.get("coords"), "dims"): - dims = coord.dims - if isinstance(dims, str): - dims = tuple(dims.split(",")) - return dims - - dims = _get_dims(data_dict) - coord_info, new_attrs = separate_coord_info( - data_dict, dims, required=("min", "max") - ) - new_attrs["coords"] = {i: dc.core.CoordSummary(**v) for i, v in coord_info.items()} - return new_attrs +def _to_coord_summary(coord_dict) -> FrozenDict[str, CoordSummary]: + """Convert a dict of potential coord info to a coord summary dict.""" + # We already have a summary dict, just return. + if hasattr(coord_dict, "to_summary_dict"): + return coord_dict.to_summary_dict() + # Otherwise, build up summary dict contents. + out = {} + for i, v in coord_dict.items(): + if hasattr(v, "to_summary"): + v = v.to_summary() + elif isinstance(v,CoordSummary): + pass + else: + v = CoordSummary(**v) + out[i] = v + return FrozenDict(out) class PatchAttrs(DascoreBaseModel): @@ -92,6 +84,12 @@ class PatchAttrs(DascoreBaseModel): data_type: Annotated[Literal[VALID_DATA_TYPES], str_validator] = Field( description="Describes the quantity being measured.", default="" ) + dtype: str = Field( + description="The data type of the patch array (e.g, f32).", + ) + shape: IntTupleStrSerialized = Field( + description="The shape of the patch array.", + ) data_category: Annotated[Literal[VALID_DATA_CATEGORIES], str_validator] = Field( description="Describes the type of data.", default="", @@ -123,27 +121,31 @@ class PatchAttrs(DascoreBaseModel): description="A list of processing performed on the patch.", ) - dims: CommaSeparatedStr = Field( - default="", + dims: StrTupleStrSerialized = Field( + default=(), max_length=max_lens["dims"], description="A tuple of comma-separated dimensions names.", ) coords: Annotated[ FrozenDict[str, CoordSummary], - frozen_dict_validator, + PlainValidator(_to_coord_summary), frozen_dict_serializer, ] = Field(default_factory=dict) + @model_validator(mode="before") @classmethod - def parse_coord_attributes(cls, data: Any) -> Any: + def _get_dims(cls, data: Any) -> Any: """Parse the coordinate attributes into coord dict.""" - if isinstance(data, dict): - data = _get_coords_dict(data) - # add dims as coords if dims is not included. - if "dims" not in data: - data["dims"] = tuple(data["coords"]) + # Add dims from coords if they aren't found. + dims = data.get("dims") + if not dims: + coords = data.get("coords", {}) + dims = getattr(coords, "dims", None) + if dims is None and isinstance(coords, dict): + dims = coords.get('dims', ()) + data['dims'] = dims return data def __getitem__(self, item): @@ -185,7 +187,7 @@ def items(self): def coords_from_dims(self) -> Mapping[str, BaseCoord]: """Return coordinates from dimensions assuming evenly sampled.""" out = {} - for dim in self.dim_tuple: + for dim in self.dims: out[dim] = self.coords[dim].to_coord() return out @@ -193,6 +195,7 @@ def coords_from_dims(self) -> Mapping[str, BaseCoord]: def from_dict( cls, attr_map: Mapping | PatchAttrs, + data=None, ) -> Self: """ Get a new instance of the PatchAttrs. @@ -205,26 +208,22 @@ def from_dict( attr_map Anything convertible to a dict that contains the attr info. """ + data_info = {} + if data is not None: + data_info = {"dtype": data.dtype.str, "shape": data.shape} if isinstance(attr_map, cls): - return attr_map + return attr_map.update(**data_info) out = {} if attr_map is None else attr_map + out.update(**data_info) return cls(**out) - @property - def dim_tuple(self): - """Return a tuple of dimensions. The dims attr is a string.""" - dim_str = self.dims - if not dim_str: - return tuple() - return tuple(self.dims.split(",")) - def rename_dimension(self, **kwargs): """Rename one or more dimensions if in kwargs. Return new PatchAttrs.""" - if not (dims := set(kwargs) & set(self.dim_tuple)): + if not (dims := set(kwargs) & set(self.dims)): return self new = self.model_dump(exclude_defaults=True) coords = new.get("coords", {}) - new_dims = list(self.dim_tuple) + new_dims = list(self.dims) for old_name, new_name in {x: kwargs[x] for x in dims}.items(): new_dims[new_dims.index(old_name)] = new_name coords[new_name] = coords.pop(old_name, None) @@ -233,7 +232,7 @@ def rename_dimension(self, **kwargs): def update(self, **kwargs) -> Self: """Update an attribute in the model, return new model.""" - coord_info, attr_info = separate_coord_info(kwargs, dims=self.dim_tuple) + coord_info, attr_info = separate_coord_info(kwargs, dims=self.dims) out = self.model_dump(exclude_unset=True) out.update(attr_info) out_coord_dict = out["coords"] diff --git a/dascore/core/coordmanager.py b/dascore/core/coordmanager.py index 93ace694..a53c1b96 100644 --- a/dascore/core/coordmanager.py +++ b/dascore/core/coordmanager.py @@ -289,7 +289,7 @@ def _divide_kwargs(kwargs): update_coords = update def update_from_attrs( - self, attrs: Mapping | dc.PatchAttrs + self, attrs: Mapping | dc.PatchAttrs, data=None, ) -> tuple[Self, dc.PatchAttrs]: """ Update coordinates from attrs. @@ -321,7 +321,7 @@ def update_from_attrs( attr_info[f"{unused_dim}_{key}"] = val attr_info["coords"] = coords.to_summary_dict() attr_info["dims"] = coords.dims - attrs = dc.PatchAttrs.from_dict(attr_info) + attrs = dc.PatchAttrs.from_dict(attr_info, data=data) return coords, attrs def sort( @@ -984,7 +984,7 @@ def to_summary_dict(self) -> dict[str, CoordSummary | tuple[str, ...]]: dim_map = self.dim_map out = {} for name, coord in self.coord_map.items(): - out[name] = coord.to_summary(dims=dim_map[name]) + out[name] = coord.to_summary(dims=dim_map[name], name=name) return out def get_coord(self, coord_name: str) -> BaseCoord: diff --git a/dascore/core/coords.py b/dascore/core/coords.py index af61bac2..931bd26e 100644 --- a/dascore/core/coords.py +++ b/dascore/core/coords.py @@ -48,6 +48,8 @@ ArrayLike, DascoreBaseModel, UnitQuantity, + IntTupleStrSerialized, + StrTupleStrSerialized, ) from dascore.utils.time import dtype_time_like, is_datetime64, is_timedelta64, to_float @@ -98,8 +100,11 @@ class CoordSummary(DascoreBaseModel): dtype: str min: min_max_type max: min_max_type + shape: IntTupleStrSerialized step: step_type | None = None units: UnitQuantity | None = None + dims: StrTupleStrSerialized = () + name: str = '' @model_serializer(when_used="json") def ser_model(self) -> dict[str, str]: @@ -581,7 +586,9 @@ def new(self, **kwargs): # Need to ensure new data is used in constructor, not old shape if "values" in kwargs: info.pop("shape", None) - + # Get rid of name and dims from summary. + kwargs.pop("name", None) + kwargs.pop("dims", None) info.update(kwargs) return get_coord(**info) @@ -703,7 +710,7 @@ def get_attrs_dict(self, name): out[f"{name}_units"] = self.units return out - def to_summary(self, dims=()) -> CoordSummary: + def to_summary(self, dims=(), name='') -> CoordSummary: """Get the summary info about the coord.""" return CoordSummary( min=self.min(), @@ -711,6 +718,9 @@ def to_summary(self, dims=()) -> CoordSummary: step=self.step, dtype=self.dtype, units=self.units, + shape = self.shape, + dims=dims, + name=name, ) def update(self, **kwargs): @@ -969,14 +979,17 @@ def change_length(self, length: int) -> Self: assert self.ndim == 1, "change_length only works on 1D coords." return get_coord(shape=(length,)) - def to_summary(self, dims=()) -> CoordSummary: + def to_summary(self, dims=(), name='') -> CoordSummary: """Get the summary info about the coord.""" return CoordSummary( min=np.nan, max=np.nan, step=np.nan, dtype=self.dtype, - units=None, + units=self.units, + dims=dims, + shape=(), + name=name, ) diff --git a/dascore/core/patch.py b/dascore/core/patch.py index 8573868a..9d07ccfa 100644 --- a/dascore/core/patch.py +++ b/dascore/core/patch.py @@ -60,7 +60,6 @@ class Patch: if there is a conflict between information contained in both, the coords will be recalculated. """ - data: ArrayLike coords: CoordManager dims: tuple[str, ...] @@ -81,7 +80,7 @@ def __init__( if coords is None and attrs is not None: attrs = dc.PatchAttrs.from_dict(attrs) coords = attrs.coords_from_dims() - dims = dims if dims is not None else attrs.dim_tuple + dims = dims if dims is not None else attrs.dims # Ensure required info is here non_attrs = [x is None for x in [data, coords, dims]] if any(non_attrs) and not all(non_attrs): @@ -93,13 +92,13 @@ def __init__( # the only case we allow attrs to include coords is if they are both # dicts, in which case attrs might have unit info for coords. if isinstance(attrs, Mapping) and attrs: - coords, attrs = coords.update_from_attrs(attrs) + coords, attrs = coords.update_from_attrs(attrs, data) else: # ensure attrs conforms to coords - attrs = dc.PatchAttrs.from_dict(attrs).update(coords=coords) - assert coords.dims == attrs.dim_tuple, "dim mismatch on coords and attrs" - self._coords = coords + attrs = dc.PatchAttrs.from_dict(attrs, data=data).update(coords=coords) + assert coords.dims == attrs.dims, "dim mismatch on coords and attrs" self._attrs = attrs + self._coords = coords self._data = array(self.coords.validate_data(data)) def __eq__(self, other): diff --git a/dascore/examples.py b/dascore/examples.py index 61690aaf..b55cf8f4 100644 --- a/dascore/examples.py +++ b/dascore/examples.py @@ -69,42 +69,38 @@ def random_patch( """ # get input data rand = np.random.RandomState(13) - array = rand.random(shape) + array: np.ndarray = rand.random(shape) # create attrs t1 = np.atleast_1d(np.datetime64(time_min))[0] d1 = np.atleast_1d(distance_min) attrs = dict( - distance_step=distance_step, - time_step=to_timedelta64(time_step), category="DAS", - time_min=t1, network=network, station=station, tag=tag, - time_units="s", - distance_units="m", ) # need to pop out dim attrs if coordinates provided. if time_array is not None: - attrs.pop("time_min") - # need to keep time_step if time_array is len 1 to get coord range - if len(time_array) > 1: - attrs.pop("time_step") + time_coord = dc.get_coord( + data=time_array, + step=time_step if time_array.size <= 1 else None, + units='s', + ) else: - time_array = dascore.core.get_coord( - data=t1 + np.arange(array.shape[1]) * attrs["time_step"], - step=attrs["time_step"], - units=attrs["time_units"], + time_coord = dascore.core.get_coord( + data=t1 + np.arange(array.shape[1]) * to_timedelta64(time_step), + step=to_timedelta64(time_step), + units="s", ) if dist_array is not None: - attrs.pop("distance_step") + dist_coord = dc.get_coord(data=dist_array) else: - dist_array = dascore.core.get_coord( - data=d1 + np.arange(array.shape[0]) * attrs["distance_step"], - step=attrs["distance_step"], - units=attrs["distance_units"], + dist_coord = dascore.core.get_coord( + data=d1 + np.arange(array.shape[0]) * distance_step, + step=distance_step, + units="m", ) - coords = dict(distance=dist_array, time=time_array) + coords = dict(distance=dist_coord, time=time_coord) # assemble and output. out = dict(data=array, coords=coords, attrs=attrs, dims=("distance", "time")) patch = dc.Patch(**out) @@ -722,11 +718,13 @@ def get_example_spool(example_name="random_das", **kwargs) -> dc.BaseSpool: """ if example_name not in EXAMPLE_SPOOLS: # Allow the example spool to be a data registry file. - with suppress(ValueError): - return dc.spool(fetch(example_name)) - msg = ( - f"No example spool registered with name {example_name} " - f"Registered example spools are {list(EXAMPLE_SPOOLS)}" - ) - raise UnknownExampleError(msg) + try: + path = fetch(example_name) + except ValueError: + msg = ( + f"No example spool registered with name {example_name} " + f"Registered example spools are {list(EXAMPLE_SPOOLS)}" + ) + raise UnknownExampleError(msg) + return dc.spool(path) return EXAMPLE_SPOOLS[example_name](**kwargs) diff --git a/dascore/io/core.py b/dascore/io/core.py index bb13944f..0a92422e 100644 --- a/dascore/io/core.py +++ b/dascore/io/core.py @@ -40,7 +40,7 @@ from dascore.utils.mapping import FrozenDict from dascore.utils.misc import _iter_filesystem, cached_method, iterate, warn_or_raise from dascore.utils.models import ( - CommaSeparatedStr, + StrTupleStrSerialized, DascoreBaseModel, DateTime64, TimeDelta64, @@ -68,7 +68,7 @@ class PatchFileSummary(DascoreBaseModel): tag: str = Field("", max_length=max_lens["tag"]) station: str = Field("", max_length=max_lens["station"]) network: str = Field("", max_length=max_lens["network"]) - dims: CommaSeparatedStr = Field("", max_length=max_lens["dims"]) + dims: StrTupleStrSerialized = Field("", max_length=max_lens["dims"]) time_min: DateTime64 = np.datetime64("NaT") time_max: DateTime64 = np.datetime64("NaT") time_step: TimeDelta64 = np.timedelta64("NaT") diff --git a/dascore/io/dasdae/core.py b/dascore/io/dasdae/core.py index 2f084a85..1d15f7c7 100644 --- a/dascore/io/dasdae/core.py +++ b/dascore/io/dasdae/core.py @@ -11,10 +11,9 @@ from dascore.io import FiberIO from dascore.utils.hdf5 import ( H5Reader, + H5Writer, HDFPatchIndexManager, NodeError, - PyTablesReader, - PyTablesWriter, ) from dascore.utils.misc import unbyte from dascore.utils.patch import get_patch_names @@ -52,7 +51,7 @@ class DASDAEV1(FiberIO): preferred_extensions = ("h5", "hdf5") version = "1" - def write(self, spool: SpoolType, resource: PyTablesWriter, index=False, **kwargs): + def write(self, spool: SpoolType, resource: H5Writer, index=False, **kwargs): """ Write a collection of patches to a DASDAE file. @@ -68,6 +67,7 @@ def write(self, spool: SpoolType, resource: PyTablesWriter, index=False, **kwarg This is recommended for files with many patches and not recommended for files with few patches. """ + breakpoint() # write out patches _write_meta(resource, self.version) # get an iterable of patches and save them @@ -109,7 +109,7 @@ def get_format(self, resource: H5Reader, **kwargs) -> tuple[str, str] | bool: version = unbyte(attrs.get("__DASDAE_version__", "")) return file_format, version - def read(self, resource: PyTablesReader, **kwargs) -> SpoolType: + def read(self, resource: H5Reader, **kwargs) -> SpoolType: """Read a dascore file.""" patches = [] try: @@ -120,7 +120,7 @@ def read(self, resource: PyTablesReader, **kwargs) -> SpoolType: patches.append(_read_patch(patch_group, **kwargs)) return dc.spool(patches) - def scan(self, resource: PyTablesReader, **kwargs): + def scan(self, resource: H5Reader, **kwargs): """ Get the patch info from the file. @@ -133,20 +133,6 @@ def scan(self, resource: PyTablesReader, **kwargs): resource A path to the file. """ - indexer = HDFPatchIndexManager(resource.filename) - if indexer.has_index: - # We need to change the path back to the file rather than internal - # HDF5 path so it works with FileSpool and such. - records = indexer.get_index().assign(path=str(resource)).to_dict("records") - return [dc.PatchAttrs(**x) for x in records] - else: - file_format = self.name - version = resource.root._v_attrs.__DASDAE_version__ - return _get_contents_from_patch_groups(resource, version, file_format) - - def index(self, path): - """Index the dasdae file.""" - indexer = HDFPatchIndexManager(path) - if not indexer.has_index: - df = dc.scan_to_df(path) - indexer.write_update(df) + file_format = self.name + version = resource.attrs['__DASDAE_version__'] + return _get_contents_from_patch_groups(resource, version, file_format) diff --git a/dascore/io/dasdae/utils.py b/dascore/io/dasdae/utils.py index f4dcbd9a..75d64354 100644 --- a/dascore/io/dasdae/utils.py +++ b/dascore/io/dasdae/utils.py @@ -2,6 +2,9 @@ from __future__ import annotations +import pickle +from contextlib import suppress + import numpy as np from tables import NodeError @@ -11,10 +14,30 @@ from dascore.core.coords import get_coord from dascore.utils.misc import suppress_warnings from dascore.utils.time import to_int +from dascore.utils.misc import unbyte +from dascore.utils.hdf5 import Empty # --- Functions for writing DASDAE format +def _santize_pytables(some_dict): + """Remove pytables names from a dict, remove any pickle-able things.""" + pytables_names = {"CLASS", "FLAVOR", "TITLE", "VERSION"} + out = {} + for i, v in some_dict.items(): + if i in pytables_names: + continue + try: + val = unbyte(v) + except ValueError: + continue + # Get rid of empty enum. + if isinstance(val, Empty): + val = '' + out[i] = val + return out + + def _create_or_get_group(h5, group, name): """Create a new group or get existing.""" try: @@ -26,6 +49,7 @@ def _create_or_get_group(h5, group, name): def _create_or_squash_array(h5, group, name, data): """Create a new array, if it exists delete and re-create.""" + breakpoint() try: array = h5.create_array(group, name, data) except NodeError: @@ -37,7 +61,7 @@ def _create_or_squash_array(h5, group, name, data): def _write_meta(hfile, file_version): """Write metadata to hdf5 file.""" - attrs = hfile.root._v_attrs + attrs = hfile.attrs attrs["__format__"] = "DASDAE" attrs["__DASDAE_version__"] = file_version attrs["__dascore__version__"] = dc.__version__ @@ -49,8 +73,8 @@ def _save_attrs_and_dims(patch, patch_group): # TODO will need to test if objects are serializable attr_dict = patch.attrs.model_dump(exclude_unset=True) for i, v in attr_dict.items(): - patch_group._v_attrs[f"_attrs_{i}"] = v - patch_group._v_attrs["_dims"] = ",".join(patch.dims) + patch_group.attrs[f"_attrs_{i}"] = v + patch_group.attrs["_dims"] = ",".join(patch.dims) def _save_array(data, name, group, h5): @@ -61,22 +85,30 @@ def _save_array(data, name, group, h5): if is_dt or is_td: data = to_int(data) array_node = _create_or_squash_array(h5, group, name, data) - array_node._v_attrs["is_datetime64"] = is_dt - array_node._v_attrs["is_timedelta64"] = is_td + array_node.attrs["is_datetime64"] = is_dt + array_node.attrs["is_timedelta64"] = is_td + return array_node + + def _save_coords(patch, patch_group, h5): """Save coordinates.""" cm = patch.coords for name, coord in cm.coord_map.items(): - dims = cm.dim_map[name] + summary = ( + coord.to_summary(name=name, dims=cm.dims[name]) + .model_dump(exclude_defaults=True) + ) + breakpoint() # First save coordinate arrays data = coord.values save_name = f"_coord_{name}" - _save_array(data, save_name, patch_group, h5) - # then save dimensions of coordinates - save_name = f"_cdims_{name}" - patch_group._v_attrs[save_name] = ",".join(dims) + dataset = _save_array(data, save_name, patch_group, h5) + dataset.attrs.update(summary) + + + def _save_patch(patch, wave_group, h5, name): @@ -95,7 +127,7 @@ def _save_patch(patch, wave_group, h5, name): def _get_attrs(patch_group): """Get the saved attributes form the group attrs.""" out = {} - attrs = [x for x in patch_group._v_attrs._f_list() if x.startswith("_attrs_")] + attrs = [x for x in patch_group.attrs if x.startswith("_attrs_")] for attr_name in attrs: key = attr_name.replace("_attrs_", "") val = patch_group._v_attrs[attr_name] @@ -110,9 +142,9 @@ def _get_attrs(patch_group): def _read_array(table_array): """Read an array into numpy.""" data = table_array[:] - if table_array._v_attrs["is_datetime64"]: + if table_array.attrs["is_datetime64"]: data = data.view("datetime64[ns]") - if table_array._v_attrs["is_timedelta64"]: + if table_array.attrs["is_timedelta64"]: data = data.view("timedelta64[ns]") return data @@ -131,10 +163,10 @@ def _get_coords(patch_group, dims, attrs2): ) coord_dict[name] = coord # associates coordinates with dimensions - c_dims = [x for x in patch_group._v_attrs._f_list() if x.startswith("_cdims")] + c_dims = [x for x in patch_group.attrs if x.startswith("_cdims")] for coord_name in c_dims: name = coord_name.replace("_cdims_", "") - value = patch_group._v_attrs[coord_name] + value = patch_group.attrs[coord_name] assert name in coord_dict, "Should already have loaded coordinate array" coord_dim_dict[name] = (tuple(value.split(".")), coord_dict[name]) # add dimensions to coordinates that have them. @@ -144,7 +176,7 @@ def _get_coords(patch_group, dims, attrs2): def _get_dims(patch_group): """Get the dims tuple from the patch group.""" - dims = patch_group._v_attrs["_dims"] + dims = patch_group.attrs["_dims"] if not dims: out = () else: @@ -170,7 +202,7 @@ def _read_patch(patch_group, **kwargs): def _get_contents_from_patch_groups(h5, file_version, file_format="DASDAE"): """Get the contents from each patch group.""" out = [] - for group in h5.iter_nodes("/waveforms"): + for name, group in h5[("/waveforms")].items(): contents = _get_patch_content_from_group(group) # populate file info contents["file_version"] = file_version @@ -179,21 +211,60 @@ def _get_contents_from_patch_groups(h5, file_version, file_format="DASDAE"): # suppressing warnings because old dasdae files will issue warning # due to d_dim rather than dim_step. TODO fix test files in the future with suppress_warnings(DeprecationWarning): - out.append(dc.PatchAttrs(**contents)) + try: + out.append(dc.PatchAttrs(**contents)) + except: + breakpoint() + return out +def _get_coord_info(info, group): + """Get the coord dictionary.""" + coords = {} + coord_ds_names = tuple(x for x in group if x.startswith("_coord_")) + for ds_name in coord_ds_names: + name = ds_name.replace("_coord_", "") + ds = group[ds_name] + attrs = _santize_pytables(dict(ds.attrs)) + # Need to get old dimensions from c_dims in attrs. + if "dims" not in attrs: + attrs['dims'] = info.get(f"_cdims_{name}", name) + # The summary info is not stored in attrs; need to read coord array. + c_info = {} + if 'min' not in attrs: + c_summary = ( + dc.core.get_coord(data=ds[:]) + .to_summary() + .model_dump(exclude_unset=True, exclude_defaults=True) + ) + c_info.update(c_summary) + + c_info.update({ + "dtype": ds.dtype.str, + 'shape': ds.shape, + "name": name, + } + ) + coords[name] = c_info + return coords + + def _get_patch_content_from_group(group): """Get patch content from a single node.""" - attrs = group._v_attrs out = {} - for key in attrs._f_list(): - value = getattr(attrs, key) + attrs = _santize_pytables(dict(group.attrs)) + for key, value in attrs.items(): new_key = key.replace("_attrs_", "") # need to unpack 0 dim arrays. if isinstance(value, np.ndarray) and not value.shape: value = np.atleast_1d(value)[0] out[new_key] = value + # Add coord info. + out['coords'] = _get_coord_info(out, group) + # Add data info. + out['shape'] = group['data'].shape + out['dtype'] = group['data'].dtype.str # rename dims out["dims"] = out.pop("_dims") return out diff --git a/dascore/io/h5simple/core.py b/dascore/io/h5simple/core.py index e772f72b..669b5c98 100644 --- a/dascore/io/h5simple/core.py +++ b/dascore/io/h5simple/core.py @@ -5,7 +5,7 @@ import dascore as dc from dascore.constants import SpoolType from dascore.io import FiberIO -from dascore.utils.hdf5 import H5Reader, PyTablesReader +from dascore.utils.hdf5 import H5Reader from .utils import _get_attrs_coords_and_data, _is_h5simple, _maybe_trim_data @@ -23,7 +23,7 @@ def get_format(self, resource: H5Reader, **kwargs) -> tuple[str, str] | bool: return self.name, self.version return False - def read(self, resource: PyTablesReader, snap=True, **kwargs) -> SpoolType: + def read(self, resource: H5Reader, snap=True, **kwargs) -> SpoolType: """ Read a simple h5 file. @@ -42,7 +42,7 @@ def read(self, resource: PyTablesReader, snap=True, **kwargs) -> SpoolType: return dc.spool([patch]) def scan( - self, resource: PyTablesReader, snap=True, **kwargs + self, resource: H5Reader, snap=True, **kwargs ) -> list[dc.PatchAttrs]: """Get the attributes of a h5simple file.""" attrs, cm, data = _get_attrs_coords_and_data(resource, snap, self) diff --git a/dascore/io/xml_binary/utils.py b/dascore/io/xml_binary/utils.py index a9dc9b5a..2415f031 100644 --- a/dascore/io/xml_binary/utils.py +++ b/dascore/io/xml_binary/utils.py @@ -159,7 +159,7 @@ def _read_single_file(path, metadata, time, distance): distance_coord = attr.coords["distance"].to_coord() cm = get_coord_manager( {"time": time_coord, "distance": distance_coord}, - dims=attr.dim_tuple, + dims=attr.dims, ) memmap = np.memmap(path, dtype=metadata.data_type) size = np.prod(cm.shape) @@ -167,7 +167,7 @@ def _read_single_file(path, metadata, time, distance): data = memmap.reshape(cm.shape) patch = dc.Patch( data=data, - dims=attr.dim_tuple, + dims=attr.dims, coords=cm, attrs=attr.update(coord=None), ) diff --git a/dascore/proc/basic.py b/dascore/proc/basic.py index c68ceba8..6b6164dc 100644 --- a/dascore/proc/basic.py +++ b/dascore/proc/basic.py @@ -185,7 +185,7 @@ def update( dims = coords.dims if isinstance(coords, CoordManager) else self.dims coords = get_coord_manager(coords, dims) if attrs: - coords, attrs = coords.update_from_attrs(attrs) + coords, attrs = coords.update_from_attrs(attrs, data) else: _attrs = dc.PatchAttrs.from_dict(attrs or self.attrs) attrs = _attrs.update(coords=coords, dims=coords.dims) diff --git a/dascore/utils/duck.py b/dascore/utils/duck.py new file mode 100644 index 00000000..a1084cbf --- /dev/null +++ b/dascore/utils/duck.py @@ -0,0 +1,5 @@ +""" +Utilities for working with DuckDB. +""" + + diff --git a/dascore/utils/hdf5.py b/dascore/utils/hdf5.py index 1f92dc63..eb56535d 100644 --- a/dascore/utils/hdf5.py +++ b/dascore/utils/hdf5.py @@ -19,6 +19,7 @@ import pandas as pd import tables from h5py import File as H5pyFile +from h5py import Empty from packaging.version import parse as get_version from pandas.io.common import stringify_path from tables import ClosedNodeError @@ -43,6 +44,7 @@ ) from dascore.utils.time import get_max_min_times, to_datetime64, to_int, to_timedelta64 + HDF5ExtError = tables.HDF5ExtError NoSuchNodeError = tables.NoSuchNodeError NodeError = tables.NodeError diff --git a/dascore/utils/models.py b/dascore/utils/models.py index 5192a598..779938cf 100644 --- a/dascore/utils/models.py +++ b/dascore/utils/models.py @@ -18,11 +18,18 @@ from dascore.utils.time import to_datetime64, to_timedelta64 # --- A list of custom types with appropriate serialization/deserialization -# these can just be use with pydantic type-hints. +# these can just be used with pydantic type-hints. frozen_dict_validator = PlainValidator(lambda x: FrozenDict(x)) frozen_dict_serializer = PlainSerializer(lambda x: dict(x)) + +def _str_to_int_tuple(value): + """Convert a string of ints to a tuple.""" + if isinstance(value, str): + return tuple(int(x) for x in value.split(',')) + return value + # A datetime64 DateTime64 = Annotated[ np.datetime64, @@ -52,8 +59,18 @@ PlainSerializer(get_quantity_str), ] -CommaSeparatedStr = Annotated[ - str, PlainValidator(lambda x: x if isinstance(x, str) else ",".join(x)) +# A str that should be parsed as a tuple but serialized as a string. +StrTupleStrSerialized = Annotated[ + tuple[str, ...], + PlainValidator(lambda x: tuple(x.split(",")) if isinstance(x, str) else tuple(x)), + PlainSerializer(lambda x: ",".join(x)), +] + +# A tuple of ints that should serialize to CSVs. +IntTupleStrSerialized = Annotated[ + tuple[int, ...], + PlainValidator(_str_to_int_tuple), + PlainSerializer(lambda x: ",".join((str(y) for y in x))), ] FrozenDictType = Annotated[ diff --git a/dascore/utils/pd.py b/dascore/utils/pd.py index 1bef0757..9dd6bcbe 100644 --- a/dascore/utils/pd.py +++ b/dascore/utils/pd.py @@ -2,11 +2,13 @@ from __future__ import annotations + import fnmatch import os from collections import defaultdict from collections.abc import Collection, Mapping, Sequence from functools import cache +from typing import Iterable import numpy as np import pandas as pd @@ -554,3 +556,43 @@ def rolling_df(df, window, step=None, axis=0, center=False): """ df = df if not axis else df.T # silly deprecated axis argument. return df.rolling(window=window, step=step, center=center) + + +def get_attrs_coords_patch_table( + patch_or_attrs: Iterable[dc.PatchAttrs | dc.Patch | dc.BaseSpool], +) -> tuple(pd.DataFrame, pd.DataFrame, pd.DataFrame): + """ + Get seperated attributes, coordinates, and patch tables from attrs. + + Parameters + ---------- + patch_or_attrs + An iterable with patch content. + """ + def get_coord_dict(attr, num): + """Get the coordinate information from the attrs.""" + out = [] + for coord in attr.values(): + coord['id'] = num + out.append(coord) + return out + + patch_info, coord_info, attr_info = [], [], [] + for num, attr in enumerate(patch_or_attrs): + if isinstance(attr, dc.Patch): + attr = attr.attrs + # Make sure we are working with a dict. + attr = attr.model_dump() if hasattr(attr, "model_dump") else attr + coords = attr.pop("coords", {}) + coord_names = tuple(coords.values()) + breakpoint() + coord_info.extend(get_coord_dict(attr.pop("coords", {}), num)) + breakpoint() + + + + + + + + diff --git a/pyproject.toml b/pyproject.toml index 5ed0bb88..e73a3684 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ dependencies = [ "tables>=3.7", "typing_extensions", "pint", + "duckdb>=1.1" ] [project.optional-dependencies] diff --git a/tests/test_core/test_attrs.py b/tests/test_core/test_attrs.py index e3b23040..2fe62f94 100644 --- a/tests/test_core/test_attrs.py +++ b/tests/test_core/test_attrs.py @@ -16,6 +16,8 @@ MORE_COORDS_ATTRS = [] +BASE_ATTRS = {"shape": (12, 34), "dtype": " PatchAttrs: @@ -29,22 +31,14 @@ def random_attrs(random_patch) -> PatchAttrs: return random_patch.attrs -@pytest.fixture(scope="session") -@register_func(MORE_COORDS_ATTRS) -def attrs_coords_1() -> PatchAttrs: - """Add non-standard coords to attrs.""" - attrs = {"depth_min": 10.0, "depth_max": 12.0, "another_name": "FooBar"} - out = PatchAttrs(**attrs) - assert "depth" in out.coords - return out - - @pytest.fixture(scope="session") @register_func(MORE_COORDS_ATTRS) def attrs_coords_2() -> PatchAttrs: """Add non-standard coords to attrs.""" - coords = {"depth": {"min": 10.0, "max": 12.0}} - attrs = {"coords": coords, "another_name": "FooBar"} + coords = {"depth": {"min": 10.0, "max": 12.0, "dtype": " 1: pytest.skip("Haven't implemented test for multipatch files.") attrs_init = attrs_from_file[0] - for dim in attrs_init.dim_tuple: + for dim in attrs_init.dims: start = getattr(attrs_init, f"{dim}_min") stop = getattr(attrs_init, f"{dim}_max") duration = stop - start @@ -427,7 +427,7 @@ def test_scan_attrs_match_patch_attrs(self, data_file_path): for pat_attrs1, scan_attrs2 in zip(patch_attrs_list, scan_attrs_list): assert pat_attrs1.dims == scan_attrs2.dims # first compare dimensions are related attributes - for dim in pat_attrs1.dim_tuple: + for dim in pat_attrs1.dims: assert getattr(pat_attrs1, f"{dim}_min") == getattr( pat_attrs1, f"{dim}_min" ) diff --git a/tests/test_utils/test_hdf_utils.py b/tests/test_utils/test_hdf_utils.py index 4540249c..b6a49c4f 100644 --- a/tests/test_utils/test_hdf_utils.py +++ b/tests/test_utils/test_hdf_utils.py @@ -15,7 +15,6 @@ from dascore.utils.downloader import fetch from dascore.utils.hdf5 import ( HDFPatchIndexManager, - PyTablesWriter, extract_h5_attrs, h5_matches_structure, open_hdf5_file, @@ -164,17 +163,6 @@ def test_metadata_created(self, tmp_path_factory): assert meta is not None -class TestHDFReaders: - """Tests for HDF5 readers.""" - - def test_get_handle(self, tmp_path_factory): - """Ensure we can get a handle with the class.""" - path = tmp_path_factory.mktemp("hdf_handle_test") / "test_file.h5" - handle = PyTablesWriter.get_handle(path) - assert isinstance(handle, tables.File) - handle_2 = PyTablesWriter.get_handle(handle) - assert isinstance(handle_2, tables.File) - class TestH5MatchesStructure: """Tests for the h5 matches structure function.""" diff --git a/tests/test_utils/test_io_utils.py b/tests/test_utils/test_io_utils.py index 696c1410..8c28fa07 100644 --- a/tests/test_utils/test_io_utils.py +++ b/tests/test_utils/test_io_utils.py @@ -6,17 +6,16 @@ from pathlib import Path import pytest -from tables import File import dascore as dc from dascore.exceptions import PatchConversionError -from dascore.utils.hdf5 import HDF5Reader, HDF5Writer from dascore.utils.io import ( BinaryReader, BinaryWriter, IOResourceManager, get_handle_from_resource, ) +from dascore.utils.hdf5 import H5Reader, H5Writer class _BadType: @@ -53,18 +52,6 @@ def test_path_to_buffered_writer(self, tmp_path): assert isinstance(handle, BufferedWriter) handle.close() - def test_path_to_hdf5_reader(self, generic_hdf5): - """Ensure we get a reader from tmp path reader.""" - handle = get_handle_from_resource(generic_hdf5, HDF5Reader) - assert isinstance(handle, File) - handle.close() - - def test_path_to_hdf5_writer(self, tmp_path): - """Ensure we get a reader from tmp path reader.""" - path = tmp_path / "test_hdf_writer.h5" - handle = get_handle_from_resource(path, HDF5Writer) - assert isinstance(handle, File) - def test_get_path(self, tmp_path): """Ensure we can get a path.""" path = get_handle_from_resource(tmp_path, Path) @@ -90,9 +77,9 @@ def test_not_implemented(self): with pytest.raises(NotImplementedError): get_handle_from_resource(bad_instance, BinaryWriter) with pytest.raises(NotImplementedError): - get_handle_from_resource(bad_instance, HDF5Writer) + get_handle_from_resource(bad_instance, H5Reader) with pytest.raises(NotImplementedError): - get_handle_from_resource(bad_instance, HDF5Reader) + get_handle_from_resource(bad_instance, H5Writer) class TestIOResourceManager: @@ -107,9 +94,8 @@ def test_basic_context_manager(self, tmp_path): assert isinstance(my_str, str) path = man.get_resource(Path) assert isinstance(path, Path) - hf = man.get_resource(HDF5Writer) + hf = man.get_resource(H5Writer) fi = man.get_resource(BinaryWriter) - # Why didn't pytables implement the stream like pythons? assert hf.isopen assert not fi.closed # after the context manager exists everything should be closed. diff --git a/tests/test_utils/test_pd.py b/tests/test_utils/test_pd.py index 90eecaf5..c644648e 100644 --- a/tests/test_utils/test_pd.py +++ b/tests/test_utils/test_pd.py @@ -16,6 +16,7 @@ filter_df, get_interval_columns, patch_to_dataframe, + get_attrs_coords_patch_table, ) from dascore.utils.time import to_datetime64, to_timedelta64 @@ -373,3 +374,11 @@ def test_raises(self, example_df_2): msg = "Cannot chunk spool or dataframe" with pytest.raises(ParameterError, match=msg): get_interval_columns(example_df_2, "money") + + +class TestAttrsCoordPatchTables: + """Tests suite for getting attribute, coords, and patch table.""" + + def test_from_spool(self, random_spool): + patch_df, coord_df, attr_df = get_attrs_coords_patch_table(random_spool) + breakpoint()