diff --git a/dascore/__init__.py b/dascore/__init__.py index f8a95846..1675b73e 100644 --- a/dascore/__init__.py +++ b/dascore/__init__.py @@ -2,7 +2,7 @@ from __future__ import annotations from rich import print # noqa -from dascore.core.patch import Patch +from dascore.core.patch import Patch, PatchSummary from dascore.core.attrs import PatchAttrs from dascore.core.spool import BaseSpool, spool from dascore.core.coordmanager import get_coord_manager, CoordManager diff --git a/dascore/core/attrs.py b/dascore/core/attrs.py index 32e1315d..42911313 100644 --- a/dascore/core/attrs.py +++ b/dascore/core/attrs.py @@ -2,32 +2,26 @@ from __future__ import annotations -import warnings from collections.abc import Mapping -from typing import Annotated, Any, Literal +from typing import Annotated, Literal -import numpy as np -from pydantic import ConfigDict, Field, PlainValidator, model_validator +from pydantic import ConfigDict, Field, PlainValidator from typing_extensions import Self -import dascore as dc from dascore.constants import ( VALID_DATA_CATEGORIES, VALID_DATA_TYPES, max_lens, ) -from dascore.core.coords import BaseCoord, CoordSummary +from dascore.core.coords import CoordSummary from dascore.utils.attrs import separate_coord_info -from dascore.utils.mapping import FrozenDict from dascore.utils.misc import ( to_str, ) from dascore.utils.models import ( - CommaSeparatedStr, DascoreBaseModel, + StrTupleStrSerialized, UnitQuantity, - frozen_dict_serializer, - frozen_dict_validator, ) str_validator = PlainValidator(to_str) @@ -35,33 +29,6 @@ _coord_required = {"min", "max"} -def _get_coords_dict(data_dict): - """ - Add coords dict to data dict, pop out any coordinate attributes. - - For example, if time_min, time_step are in data_dict, these will be - grouped into the coords sub dict under "time". - """ - - def _get_dims(data_dict): - """Try to get dim tuple.""" - dims = None - if "dims" in data_dict: - dims = data_dict["dims"] - elif hasattr(coord := data_dict.get("coords"), "dims"): - dims = coord.dims - if isinstance(dims, str): - dims = tuple(dims.split(",")) - return dims - - dims = _get_dims(data_dict) - coord_info, new_attrs = separate_coord_info( - data_dict, dims, required=("min", "max") - ) - new_attrs["coords"] = {i: dc.core.CoordSummary(**v) for i, v in coord_info.items()} - return new_attrs - - class PatchAttrs(DascoreBaseModel): """ The expected attributes for a Patch. @@ -118,34 +85,11 @@ class PatchAttrs(DascoreBaseModel): network: str = Field( default="", max_length=max_lens["network"], description="A network code." ) - history: str | tuple[str, ...] = Field( + history: StrTupleStrSerialized = Field( default_factory=tuple, description="A list of processing performed on the patch.", ) - dims: CommaSeparatedStr = Field( - default="", - max_length=max_lens["dims"], - description="A tuple of comma-separated dimensions names.", - ) - - coords: Annotated[ - FrozenDict[str, CoordSummary], - frozen_dict_validator, - frozen_dict_serializer, - ] = Field(default_factory=dict) - - @model_validator(mode="before") - @classmethod - def parse_coord_attributes(cls, data: Any) -> Any: - """Parse the coordinate attributes into coord dict.""" - if isinstance(data, dict): - data = _get_coords_dict(data) - # add dims as coords if dims is not included. - if "dims" not in data: - data["dims"] = tuple(data["coords"]) - return data - def __getitem__(self, item): return getattr(self, item) @@ -155,22 +99,6 @@ def __setitem__(self, key, value): def __len__(self): return len(self.model_dump()) - def __getattr__(self, item): - """Enables dynamic attributes such as time_min, time_max, etc.""" - split = item.split("_") - # this only works on names like time_max, distance_step, etc. - if not len(split) == 2: - return super().__getattr__(item) - first, second = split - if first == "d": - first, second = second, "step" - msg = f"{item} is depreciated, use {first}_{second} instead." - warnings.warn(msg, DeprecationWarning, stacklevel=2) - if first not in self.coords: - return super().__getattr__(item) - coord_sum = self.coords[first] - return getattr(coord_sum, second) - def get(self, item, default=None): """dict-like get method.""" try: @@ -182,69 +110,11 @@ def items(self): """Yield (attribute, values) just like dict.items().""" yield from self.model_dump().items() - def coords_from_dims(self) -> Mapping[str, BaseCoord]: - """Return coordinates from dimensions assuming evenly sampled.""" - out = {} - for dim in self.dim_tuple: - out[dim] = self.coords[dim].to_coord() - return out - - @classmethod - def from_dict( - cls, - attr_map: Mapping | PatchAttrs, - ) -> Self: - """ - Get a new instance of the PatchAttrs. - - Optionally, give preference to data contained in a - [`CoordManager`](`dascore.core.coordmanager.CoordManager`). - - Parameters - ---------- - attr_map - Anything convertible to a dict that contains the attr info. - """ - if isinstance(attr_map, cls): - return attr_map - out = {} if attr_map is None else attr_map - return cls(**out) - - @property - def dim_tuple(self): - """Return a tuple of dimensions. The dims attr is a string.""" - dim_str = self.dims - if not dim_str: - return tuple() - return tuple(self.dims.split(",")) - - def rename_dimension(self, **kwargs): - """Rename one or more dimensions if in kwargs. Return new PatchAttrs.""" - if not (dims := set(kwargs) & set(self.dim_tuple)): - return self - new = self.model_dump(exclude_defaults=True) - coords = new.get("coords", {}) - new_dims = list(self.dim_tuple) - for old_name, new_name in {x: kwargs[x] for x in dims}.items(): - new_dims[new_dims.index(old_name)] = new_name - coords[new_name] = coords.pop(old_name, None) - new["dims"] = tuple(new_dims) - return self.__class__(**new) - def update(self, **kwargs) -> Self: """Update an attribute in the model, return new model.""" - coord_info, attr_info = separate_coord_info(kwargs, dims=self.dim_tuple) + _, attr_info = separate_coord_info(kwargs) out = self.model_dump(exclude_unset=True) out.update(attr_info) - out_coord_dict = out["coords"] - for name, coord_dict in coord_info.items(): - if name not in out_coord_dict: - out_coord_dict[name] = coord_dict - else: - out_coord_dict[name].update(coord_dict) - # silly check to clear coords - if not kwargs.get("coords", True): - out["coords"] = {} return self.__class__(**out) def drop_private(self) -> Self: @@ -253,35 +123,8 @@ def drop_private(self) -> Self: out = {i: v for i, v in contents.items() if not i.startswith("_")} return self.__class__(**out) - def flat_dump(self, dim_tuple=False, exclude=None) -> dict: - """ - Flatten the coordinates and dump to dict. - - Parameters - ---------- - dim_tuple - If True, return dimensional tuple instead of range. EG, the - output will have {time: (min, max)} rather than - {time_min: ..., time_max: ...,}. This is useful because it can - be passed to read, scan, select, etc. - exclude - keys to exclude. - """ - out = self.model_dump(exclude=exclude) - for coord_name, coord in out.pop("coords").items(): - names = list(coord) - if dim_tuple: - names = sorted(set(names) - {"min", "max"}) - out[coord_name] = (coord["min"], coord["max"]) - for name in names: - out[f"{coord_name}_{name}"] = coord[name] - # ensure step has right type if nullish - step_name = f"{coord_name}_step" - step, start = out[step_name], coord["min"] - if step is None: - is_time = isinstance(start, np.datetime64 | np.timedelta64) - if is_time: - out[step_name] = np.timedelta64("NaT") - elif isinstance(start, float | np.floating): - out[step_name] = np.nan - return out + @classmethod + def from_dict(cls, obj: Mapping | Self): + if isinstance(obj, cls): + return obj + return cls(**obj) diff --git a/dascore/core/coordmanager.py b/dascore/core/coordmanager.py index 93ace694..eb1613b8 100644 --- a/dascore/core/coordmanager.py +++ b/dascore/core/coordmanager.py @@ -319,9 +319,7 @@ def update_from_attrs( for unused_dim in set(coord_info) - set(coords.dims): for key, val in coord_info[unused_dim].items(): attr_info[f"{unused_dim}_{key}"] = val - attr_info["coords"] = coords.to_summary_dict() - attr_info["dims"] = coords.dims - attrs = dc.PatchAttrs.from_dict(attr_info) + attrs = dc.PatchAttrs(**attr_info) return coords, attrs def sort( @@ -984,9 +982,24 @@ def to_summary_dict(self) -> dict[str, CoordSummary | tuple[str, ...]]: dim_map = self.dim_map out = {} for name, coord in self.coord_map.items(): - out[name] = coord.to_summary(dims=dim_map[name]) + out[name] = coord.to_summary(dims=dim_map[name], name=name) return out + def to_summary(self) -> Self: + """ + Convert the coordinates in the coord manager to Coord summaries. + """ + new_map = {} + for name, coord in self.coord_map.items(): + dims = self.dim_map[name] + new_map[name] = coord.to_summary(dims=dims, name=name) + return CoordManagerSummary( + dim_map=self.dim_map, + coords=new_map, + dims=self.dims, + summary=True, + ) + def get_coord(self, coord_name: str) -> BaseCoord: """ Retrieve a single coordinate from the coordinate manager. @@ -1197,3 +1210,28 @@ def _maybe_coord_from_nested(name, coord, new_dims): if new_dims: dims = tuple(list(dims) + new_dims) return c_map, d_map, dims + + +class CoordManagerSummary(CoordManager): + """A coordinate manager with summary coordinates.""" + + coord_map: Annotated[ + FrozenDict[str, CoordSummary], + frozen_dict_validator, + frozen_dict_serializer, + ] + + def to_coord_manager(self): + """ + Convert the summary to a coordinate manager. + + This only works if the coordinates were evenly sampled/sorted. + """ + out = {} + for name, coord in self.coord_map.items(): + out[name] = coord.to_coord() + return CoordManager( + coord_map=out, + dim_map=self.dim_map, + dims=self.dims, + ) diff --git a/dascore/core/coords.py b/dascore/core/coords.py index af61bac2..01340245 100644 --- a/dascore/core/coords.py +++ b/dascore/core/coords.py @@ -47,6 +47,8 @@ from dascore.utils.models import ( ArrayLike, DascoreBaseModel, + IntTupleStrSerialized, + StrTupleStrSerialized, UnitQuantity, ) from dascore.utils.time import dtype_time_like, is_datetime64, is_timedelta64, to_float @@ -98,6 +100,10 @@ class CoordSummary(DascoreBaseModel): dtype: str min: min_max_type max: min_max_type + shape: IntTupleStrSerialized + ndim: int + dims: StrTupleStrSerialized + name: str step: step_type | None = None units: UnitQuantity | None = None @@ -703,7 +709,7 @@ def get_attrs_dict(self, name): out[f"{name}_units"] = self.units return out - def to_summary(self, dims=()) -> CoordSummary: + def to_summary(self, dims, name) -> CoordSummary: """Get the summary info about the coord.""" return CoordSummary( min=self.min(), @@ -711,6 +717,10 @@ def to_summary(self, dims=()) -> CoordSummary: step=self.step, dtype=self.dtype, units=self.units, + shape=self.shape, + ndim=self.ndim, + dims=dims, + name=name, ) def update(self, **kwargs): @@ -969,7 +979,7 @@ def change_length(self, length: int) -> Self: assert self.ndim == 1, "change_length only works on 1D coords." return get_coord(shape=(length,)) - def to_summary(self, dims=()) -> CoordSummary: + def to_summary(self, dims, name) -> CoordSummary: """Get the summary info about the coord.""" return CoordSummary( min=np.nan, @@ -977,6 +987,10 @@ def to_summary(self, dims=()) -> CoordSummary: step=np.nan, dtype=self.dtype, units=None, + dims=dims, + name=name, + shape=self.shape, + ndim=self.ndim, ) diff --git a/dascore/core/patch.py b/dascore/core/patch.py index 8573868a..c2e49d74 100644 --- a/dascore/core/patch.py +++ b/dascore/core/patch.py @@ -4,8 +4,10 @@ import warnings from collections.abc import Mapping, Sequence +from typing import Annotated import numpy as np +from pydantic import PlainValidator from rich.text import Text from typing_extensions import Self @@ -15,10 +17,14 @@ from dascore import transform from dascore.compat import DataArray, array from dascore.core.attrs import PatchAttrs -from dascore.core.coordmanager import CoordManager, get_coord_manager +from dascore.core.coordmanager import ( + CoordManager, + CoordManagerSummary, + get_coord_manager, +) from dascore.core.coords import BaseCoord from dascore.utils.display import array_to_text, attrs_to_text, get_dascore_text -from dascore.utils.models import ArrayLike +from dascore.utils.models import ArrayLike, ArraySummary, DascoreBaseModel from dascore.utils.patch import check_patch_attrs, check_patch_coords, get_patch_names from dascore.utils.time import to_float from dascore.viz import VizPatchNameSpace @@ -77,11 +83,6 @@ def __init__( data, attrs, coords = data.data, data.attrs, data.coords if dims is None and isinstance(coords, CoordManager): dims = coords.dims - # Try to generate coords from ranges in attrs - if coords is None and attrs is not None: - attrs = dc.PatchAttrs.from_dict(attrs) - coords = attrs.coords_from_dims() - dims = dims if dims is not None else attrs.dim_tuple # Ensure required info is here non_attrs = [x is None for x in [data, coords, dims]] if any(non_attrs) and not all(non_attrs): @@ -97,7 +98,6 @@ def __init__( else: # ensure attrs conforms to coords attrs = dc.PatchAttrs.from_dict(attrs).update(coords=coords) - assert coords.dims == attrs.dim_tuple, "dim mismatch on coords and attrs" self._coords = coords self._attrs = attrs self._data = array(self.coords.validate_data(data)) @@ -377,3 +377,51 @@ def tran(self) -> Self: def io(self) -> dc.io.PatchIO: """Return a patch IO object for saving patches to various formats.""" return dc.io.PatchIO(self) + + def to_summary( + self, + uri=None, + resource_format=None, + resource_version=None, + ) -> PatchSummary: + """ + Summarize the contents of the Patch. + """ + uri = uri if uri is not None else self.get_patch_name() + psum = PatchSummary( + uri=uri, + coords=self.coords.to_summary(), + attrs=self.attrs, + data=ArraySummary.from_array(self.data), + resource_format=resource_format, + resource_version=resource_version, + ) + return psum + + +class PatchSummary(DascoreBaseModel): + """ + A class for summarizing the metadata of the Patch. + """ + + uri: str + resource_format: str = "" + resource_version: str = "" + + data: Annotated[ArraySummary, PlainValidator(ArraySummary.from_array)] + + attrs: PatchAttrs + coords: CoordManagerSummary + + def to_summary( + self, + uri=None, + resource_format=None, + resource_version=None, + ): + """ + Return Patch Summary. + + This is here to be compatible with Patch.to_summary. + """ + return self diff --git a/dascore/examples.py b/dascore/examples.py index 61690aaf..f941119a 100644 --- a/dascore/examples.py +++ b/dascore/examples.py @@ -4,7 +4,6 @@ import tempfile from collections.abc import Sequence -from contextlib import suppress from pathlib import Path import numpy as np @@ -69,42 +68,38 @@ def random_patch( """ # get input data rand = np.random.RandomState(13) - array = rand.random(shape) + array: np.ndarray = rand.random(shape) # create attrs t1 = np.atleast_1d(np.datetime64(time_min))[0] d1 = np.atleast_1d(distance_min) attrs = dict( - distance_step=distance_step, - time_step=to_timedelta64(time_step), category="DAS", - time_min=t1, network=network, station=station, tag=tag, - time_units="s", - distance_units="m", ) # need to pop out dim attrs if coordinates provided. if time_array is not None: - attrs.pop("time_min") - # need to keep time_step if time_array is len 1 to get coord range - if len(time_array) > 1: - attrs.pop("time_step") + time_coord = dc.get_coord( + data=time_array, + step=time_step if time_array.size <= 1 else None, + units="s", + ) else: - time_array = dascore.core.get_coord( - data=t1 + np.arange(array.shape[1]) * attrs["time_step"], - step=attrs["time_step"], - units=attrs["time_units"], + time_coord = dascore.core.get_coord( + data=t1 + np.arange(array.shape[1]) * to_timedelta64(time_step), + step=to_timedelta64(time_step), + units="s", ) if dist_array is not None: - attrs.pop("distance_step") + dist_coord = dc.get_coord(data=dist_array) else: - dist_array = dascore.core.get_coord( - data=d1 + np.arange(array.shape[0]) * attrs["distance_step"], - step=attrs["distance_step"], - units=attrs["distance_units"], + dist_coord = dascore.core.get_coord( + data=d1 + np.arange(array.shape[0]) * distance_step, + step=distance_step, + units="m", ) - coords = dict(distance=dist_array, time=time_array) + coords = dict(distance=dist_coord, time=time_coord) # assemble and output. out = dict(data=array, coords=coords, attrs=attrs, dims=("distance", "time")) patch = dc.Patch(**out) @@ -672,13 +667,15 @@ def get_example_patch(example_name="random_das", **kwargs) -> dc.Patch: """ if example_name not in EXAMPLE_PATCHES: # Allow the example name to be a data registry entry. - with suppress(ValueError): - return dc.spool(fetch(example_name))[0] - msg = ( - f"No example patch registered with name {example_name} " - f"Registered example patches are {list(EXAMPLE_PATCHES)}" - ) - raise UnknownExampleError(msg) + try: + path = fetch(example_name) + except ValueError: + msg = ( + f"No example patch registered with name {example_name} " + f"Registered example patches are {list(EXAMPLE_PATCHES)}" + ) + raise UnknownExampleError(msg) + return dc.spool(path)[0] return EXAMPLE_PATCHES[example_name](**kwargs) @@ -722,11 +719,13 @@ def get_example_spool(example_name="random_das", **kwargs) -> dc.BaseSpool: """ if example_name not in EXAMPLE_SPOOLS: # Allow the example spool to be a data registry file. - with suppress(ValueError): - return dc.spool(fetch(example_name)) - msg = ( - f"No example spool registered with name {example_name} " - f"Registered example spools are {list(EXAMPLE_SPOOLS)}" - ) - raise UnknownExampleError(msg) + try: + path = fetch(example_name) + except ValueError: + msg = ( + f"No example spool registered with name {example_name} " + f"Registered example spools are {list(EXAMPLE_SPOOLS)}" + ) + raise UnknownExampleError(msg) + return dc.spool(path) return EXAMPLE_SPOOLS[example_name](**kwargs) diff --git a/dascore/io/ap_sensing/core.py b/dascore/io/ap_sensing/core.py index 6ccfd25a..4fdf58ee 100644 --- a/dascore/io/ap_sensing/core.py +++ b/dascore/io/ap_sensing/core.py @@ -11,7 +11,7 @@ from dascore.io import FiberIO from dascore.utils.hdf5 import H5Reader -from .utils import _get_attrs_dict, _get_patch, _get_version_string +from .utils import _get_patch, _get_version_string class APSensingPatchAttrs(dc.PatchAttrs): @@ -41,17 +41,15 @@ def get_format(self, resource: H5Reader, **kwargs) -> tuple[str, str] | bool: if version_str: return self.name, version_str - def scan(self, resource: H5Reader, **kwargs) -> list[dc.PatchAttrs]: + def scan(self, resource: H5Reader, **kwargs) -> list[dc.PatchSummary]: """Scan an AP sensing file, return summary info about the contents.""" - file_version = _get_version_string(resource) - extras = { - "path": resource.filename, - "file_format": self.name, - "file_version": str(file_version), + info = { + "uri": resource.filename, + "resource_format": self.name, + "resource_version": _get_version_string(resource), } - attrs = _get_attrs_dict(resource) - attrs.update(extras) - return [APSensingPatchAttrs(**attrs)] + patch = _get_patch(resource, load_data=False) + return [patch.to_summary(**info)] def read( self, diff --git a/dascore/io/ap_sensing/utils.py b/dascore/io/ap_sensing/utils.py index a30857e3..90740808 100644 --- a/dascore/io/ap_sensing/utils.py +++ b/dascore/io/ap_sensing/utils.py @@ -78,7 +78,6 @@ def _get_attrs_dict(resource): daq = resource["DAQ"] pserver = resource["ProcessingServer"] out = dict( - coords=_get_coords(resource), data_category="DAS", instrumet_id=unbyte(_maybe_unpack(daq["SerialNumber"])), gauge_length=_maybe_unpack(pserver["GaugeLength"]), @@ -87,13 +86,21 @@ def _get_attrs_dict(resource): return out -def _get_patch(resource, time=None, distance=None, attr_cls=dc.PatchAttrs): +def _get_patch( + resource, + time=None, + distance=None, + attr_cls=dc.PatchAttrs, + load_data=True, + **kwargs, +): """Get a patch from ap_sensing file.""" attrs = _get_attrs_dict(resource) - coords = attrs["coords"] + coords = _get_coords(resource) data = resource["DAS"] if time is not None or distance is not None: coords, data = coords.select(array=data, time=time, distance=distance) attrs["coords"] = coords attrs = attr_cls.model_validate(attrs) - return dc.Patch(data=data[:], coords=coords, attrs=attrs) + data = data[:] if load_data else data + return dc.Patch(data=data, coords=coords, attrs=attrs, **kwargs) diff --git a/dascore/io/core.py b/dascore/io/core.py index bb13944f..ebcf01a4 100644 --- a/dascore/io/core.py +++ b/dascore/io/core.py @@ -477,8 +477,8 @@ def read(self, resource, **kwargs) -> SpoolType: msg = f"FiberIO: {self.name} has no read method" raise NotImplementedError(msg) - def scan(self, resource, **kwargs) -> list[dc.PatchAttrs]: - """Returns a list of summary info for patches contained in file.""" + def scan(self, resource, **kwargs) -> list[dc.PatchSummary]: + """Returns a list of summary info for patches contained in source.""" # default scan method reads in the file and returns required attributes # however, this can be very slow, so each parser should implement scan # when possible. @@ -489,9 +489,10 @@ def scan(self, resource, **kwargs) -> list[dc.PatchAttrs]: raise NotImplementedError(msg) out = [] for pa in spool: - new = pa.attrs.update( - file_format=self.name, - path=str(resource), + new = pa.to_summary( + resource_format=self.name, + resource_version=self.version, + uri=str(resource), ) out.append(new) return out diff --git a/dascore/io/dasdae/core.py b/dascore/io/dasdae/core.py index 2f084a85..ea57910e 100644 --- a/dascore/io/dasdae/core.py +++ b/dascore/io/dasdae/core.py @@ -11,10 +11,9 @@ from dascore.io import FiberIO from dascore.utils.hdf5 import ( H5Reader, + H5Writer, HDFPatchIndexManager, NodeError, - PyTablesReader, - PyTablesWriter, ) from dascore.utils.misc import unbyte from dascore.utils.patch import get_patch_names @@ -52,7 +51,7 @@ class DASDAEV1(FiberIO): preferred_extensions = ("h5", "hdf5") version = "1" - def write(self, spool: SpoolType, resource: PyTablesWriter, index=False, **kwargs): + def write(self, spool: SpoolType, resource: H5Writer, index=False, **kwargs): """ Write a collection of patches to a DASDAE file. @@ -68,6 +67,7 @@ def write(self, spool: SpoolType, resource: PyTablesWriter, index=False, **kwarg This is recommended for files with many patches and not recommended for files with few patches. """ + breakpoint() # write out patches _write_meta(resource, self.version) # get an iterable of patches and save them @@ -109,8 +109,8 @@ def get_format(self, resource: H5Reader, **kwargs) -> tuple[str, str] | bool: version = unbyte(attrs.get("__DASDAE_version__", "")) return file_format, version - def read(self, resource: PyTablesReader, **kwargs) -> SpoolType: - """Read a dascore file.""" + def read(self, resource: H5Reader, **kwargs) -> SpoolType: + """Read a DASDAE file.""" patches = [] try: waveform_group = resource.root["/waveforms"] @@ -120,7 +120,7 @@ def read(self, resource: PyTablesReader, **kwargs) -> SpoolType: patches.append(_read_patch(patch_group, **kwargs)) return dc.spool(patches) - def scan(self, resource: PyTablesReader, **kwargs): + def scan(self, resource: H5Reader, **kwargs): """ Get the patch info from the file. @@ -133,20 +133,6 @@ def scan(self, resource: PyTablesReader, **kwargs): resource A path to the file. """ - indexer = HDFPatchIndexManager(resource.filename) - if indexer.has_index: - # We need to change the path back to the file rather than internal - # HDF5 path so it works with FileSpool and such. - records = indexer.get_index().assign(path=str(resource)).to_dict("records") - return [dc.PatchAttrs(**x) for x in records] - else: - file_format = self.name - version = resource.root._v_attrs.__DASDAE_version__ - return _get_contents_from_patch_groups(resource, version, file_format) - - def index(self, path): - """Index the dasdae file.""" - indexer = HDFPatchIndexManager(path) - if not indexer.has_index: - df = dc.scan_to_df(path) - indexer.write_update(df) + file_format = self.name + version = resource.attrs["__DASDAE_version__"] + return _get_contents_from_patch_groups(resource, version, file_format) diff --git a/dascore/io/dasdae/utils.py b/dascore/io/dasdae/utils.py index f4dcbd9a..98598a65 100644 --- a/dascore/io/dasdae/utils.py +++ b/dascore/io/dasdae/utils.py @@ -9,12 +9,31 @@ from dascore.core.attrs import PatchAttrs from dascore.core.coordmanager import get_coord_manager from dascore.core.coords import get_coord -from dascore.utils.misc import suppress_warnings +from dascore.utils.hdf5 import Empty +from dascore.utils.misc import suppress_warnings, unbyte from dascore.utils.time import to_int # --- Functions for writing DASDAE format +def _santize_pytables(some_dict): + """Remove pytables names from a dict, remove any pickle-able things.""" + pytables_names = {"CLASS", "FLAVOR", "TITLE", "VERSION"} + out = {} + for i, v in some_dict.items(): + if i in pytables_names: + continue + try: + val = unbyte(v) + except ValueError: + continue + # Get rid of empty enum. + if isinstance(val, Empty): + val = "" + out[i] = val + return out + + def _create_or_get_group(h5, group, name): """Create a new group or get existing.""" try: @@ -26,6 +45,7 @@ def _create_or_get_group(h5, group, name): def _create_or_squash_array(h5, group, name, data): """Create a new array, if it exists delete and re-create.""" + breakpoint() try: array = h5.create_array(group, name, data) except NodeError: @@ -37,7 +57,7 @@ def _create_or_squash_array(h5, group, name, data): def _write_meta(hfile, file_version): """Write metadata to hdf5 file.""" - attrs = hfile.root._v_attrs + attrs = hfile.attrs attrs["__format__"] = "DASDAE" attrs["__DASDAE_version__"] = file_version attrs["__dascore__version__"] = dc.__version__ @@ -49,8 +69,8 @@ def _save_attrs_and_dims(patch, patch_group): # TODO will need to test if objects are serializable attr_dict = patch.attrs.model_dump(exclude_unset=True) for i, v in attr_dict.items(): - patch_group._v_attrs[f"_attrs_{i}"] = v - patch_group._v_attrs["_dims"] = ",".join(patch.dims) + patch_group.attrs[f"_attrs_{i}"] = v + patch_group.attrs["_dims"] = ",".join(patch.dims) def _save_array(data, name, group, h5): @@ -61,22 +81,24 @@ def _save_array(data, name, group, h5): if is_dt or is_td: data = to_int(data) array_node = _create_or_squash_array(h5, group, name, data) - array_node._v_attrs["is_datetime64"] = is_dt - array_node._v_attrs["is_timedelta64"] = is_td + array_node.attrs["is_datetime64"] = is_dt + array_node.attrs["is_timedelta64"] = is_td + return array_node def _save_coords(patch, patch_group, h5): """Save coordinates.""" cm = patch.coords for name, coord in cm.coord_map.items(): - dims = cm.dim_map[name] + summary = coord.to_summary(name=name, dims=cm.dims[name]).model_dump( + exclude_defaults=True + ) + breakpoint() # First save coordinate arrays data = coord.values save_name = f"_coord_{name}" - _save_array(data, save_name, patch_group, h5) - # then save dimensions of coordinates - save_name = f"_cdims_{name}" - patch_group._v_attrs[save_name] = ",".join(dims) + dataset = _save_array(data, save_name, patch_group, h5) + dataset.attrs.update(summary) def _save_patch(patch, wave_group, h5, name): @@ -95,7 +117,7 @@ def _save_patch(patch, wave_group, h5, name): def _get_attrs(patch_group): """Get the saved attributes form the group attrs.""" out = {} - attrs = [x for x in patch_group._v_attrs._f_list() if x.startswith("_attrs_")] + attrs = [x for x in patch_group.attrs if x.startswith("_attrs_")] for attr_name in attrs: key = attr_name.replace("_attrs_", "") val = patch_group._v_attrs[attr_name] @@ -110,9 +132,9 @@ def _get_attrs(patch_group): def _read_array(table_array): """Read an array into numpy.""" data = table_array[:] - if table_array._v_attrs["is_datetime64"]: + if table_array.attrs["is_datetime64"]: data = data.view("datetime64[ns]") - if table_array._v_attrs["is_timedelta64"]: + if table_array.attrs["is_timedelta64"]: data = data.view("timedelta64[ns]") return data @@ -131,10 +153,10 @@ def _get_coords(patch_group, dims, attrs2): ) coord_dict[name] = coord # associates coordinates with dimensions - c_dims = [x for x in patch_group._v_attrs._f_list() if x.startswith("_cdims")] + c_dims = [x for x in patch_group.attrs if x.startswith("_cdims")] for coord_name in c_dims: name = coord_name.replace("_cdims_", "") - value = patch_group._v_attrs[coord_name] + value = patch_group.attrs[coord_name] assert name in coord_dict, "Should already have loaded coordinate array" coord_dim_dict[name] = (tuple(value.split(".")), coord_dict[name]) # add dimensions to coordinates that have them. @@ -144,7 +166,7 @@ def _get_coords(patch_group, dims, attrs2): def _get_dims(patch_group): """Get the dims tuple from the patch group.""" - dims = patch_group._v_attrs["_dims"] + dims = patch_group.attrs["_dims"] if not dims: out = () else: @@ -152,7 +174,7 @@ def _get_dims(patch_group): return out -def _read_patch(patch_group, **kwargs): +def _read_patch(patch_group, load_data=True, **kwargs): """Read a patch group, return Patch.""" attrs = _get_attrs(patch_group) dims = _get_dims(patch_group) @@ -163,14 +185,16 @@ def _read_patch(patch_group, **kwargs): if kwargs: coords, data = coords.select(array=patch_group["data"], **kwargs) else: - data = patch_group["data"][:] + data = patch_group["data"] + if load_data: + data = data[:] return dc.Patch(data=data, coords=coords, dims=dims, attrs=attrs) def _get_contents_from_patch_groups(h5, file_version, file_format="DASDAE"): """Get the contents from each patch group.""" out = [] - for group in h5.iter_nodes("/waveforms"): + for name, group in h5[("/waveforms")].items(): contents = _get_patch_content_from_group(group) # populate file info contents["file_version"] = file_version @@ -179,21 +203,61 @@ def _get_contents_from_patch_groups(h5, file_version, file_format="DASDAE"): # suppressing warnings because old dasdae files will issue warning # due to d_dim rather than dim_step. TODO fix test files in the future with suppress_warnings(DeprecationWarning): - out.append(dc.PatchAttrs(**contents)) + try: + out.append(dc.PatchAttrs(**contents)) + except: + breakpoint() + return out +def _get_coord_info(info, group): + """Get the coord dictionary.""" + coords = {} + coord_ds_names = tuple(x for x in group if x.startswith("_coord_")) + for ds_name in coord_ds_names: + name = ds_name.replace("_coord_", "") + ds = group[ds_name] + attrs = _santize_pytables(dict(ds.attrs)) + # Need to get old dimensions from c_dims in attrs. + if "dims" not in attrs: + attrs["dims"] = info.get(f"_cdims_{name}", name) + # The summary info is not stored in attrs; need to read coord array. + c_info = {} + if "min" not in attrs: + c_summary = ( + dc.core.get_coord(data=ds[:]) + .to_summary() + .model_dump(exclude_unset=True, exclude_defaults=True) + ) + c_info.update(c_summary) + + c_info.update( + { + "dtype": ds.dtype.str, + "shape": ds.shape, + "name": name, + } + ) + coords[name] = c_info + return coords + + def _get_patch_content_from_group(group): """Get patch content from a single node.""" - attrs = group._v_attrs out = {} - for key in attrs._f_list(): - value = getattr(attrs, key) + attrs = _santize_pytables(dict(group.attrs)) + for key, value in attrs.items(): new_key = key.replace("_attrs_", "") # need to unpack 0 dim arrays. if isinstance(value, np.ndarray) and not value.shape: value = np.atleast_1d(value)[0] out[new_key] = value + # Add coord info. + out["coords"] = _get_coord_info(out, group) + # Add data info. + out["shape"] = group["data"].shape + out["dtype"] = group["data"].dtype.str # rename dims out["dims"] = out.pop("_dims") return out diff --git a/dascore/io/h5simple/core.py b/dascore/io/h5simple/core.py index e772f72b..5b0c0098 100644 --- a/dascore/io/h5simple/core.py +++ b/dascore/io/h5simple/core.py @@ -5,7 +5,7 @@ import dascore as dc from dascore.constants import SpoolType from dascore.io import FiberIO -from dascore.utils.hdf5 import H5Reader, PyTablesReader +from dascore.utils.hdf5 import H5Reader from .utils import _get_attrs_coords_and_data, _is_h5simple, _maybe_trim_data @@ -23,7 +23,7 @@ def get_format(self, resource: H5Reader, **kwargs) -> tuple[str, str] | bool: return self.name, self.version return False - def read(self, resource: PyTablesReader, snap=True, **kwargs) -> SpoolType: + def read(self, resource: H5Reader, snap=True, **kwargs) -> SpoolType: """ Read a simple h5 file. @@ -41,9 +41,7 @@ def read(self, resource: PyTablesReader, snap=True, **kwargs) -> SpoolType: patch = dc.Patch(coords=new_cm, data=new_data[:], attrs=attrs) return dc.spool([patch]) - def scan( - self, resource: PyTablesReader, snap=True, **kwargs - ) -> list[dc.PatchAttrs]: + def scan(self, resource: H5Reader, snap=True, **kwargs) -> list[dc.PatchAttrs]: """Get the attributes of a h5simple file.""" attrs, cm, data = _get_attrs_coords_and_data(resource, snap, self) attrs["coords"] = cm.to_summary_dict() diff --git a/dascore/io/h5simple/utils.py b/dascore/io/h5simple/utils.py index 177c5d9b..cb6cf215 100644 --- a/dascore/io/h5simple/utils.py +++ b/dascore/io/h5simple/utils.py @@ -1,4 +1,4 @@ -"""Utilities for terra15.""" +"""Utilities simple h5""" from __future__ import annotations diff --git a/dascore/proc/basic.py b/dascore/proc/basic.py index c68ceba8..0a4540df 100644 --- a/dascore/proc/basic.py +++ b/dascore/proc/basic.py @@ -100,7 +100,6 @@ def _fast_attr_update(self, attrs): # since we update history so often, we make a fast track for it. new_attrs = self.attrs.model_dump(exclude_unset=True) new_attrs.update(attrs) - # pop out coords so new coords has priority. if len(attrs) == 1 and "history" in attrs: return _fast_attr_update(self, PatchAttrs(**new_attrs)) new_coords, new_attrs = self.coords.update_from_attrs(new_attrs) @@ -187,7 +186,8 @@ def update( if attrs: coords, attrs = coords.update_from_attrs(attrs) else: - _attrs = dc.PatchAttrs.from_dict(attrs or self.attrs) + attrs = attrs if attrs else self.attrs.model_dump(exclude_unset=True) + _attrs = dc.PatchAttrs(**(attrs or self.attrs)) attrs = _attrs.update(coords=coords, dims=coords.dims) return self.__class__(data=data, coords=coords, attrs=attrs, dims=coords.dims) diff --git a/dascore/proc/coords.py b/dascore/proc/coords.py index e1fbc691..2356de6e 100644 --- a/dascore/proc/coords.py +++ b/dascore/proc/coords.py @@ -210,8 +210,7 @@ def rename_coords(self: PatchType, **kwargs) -> PatchType: >>> assert 'fragrance' in pa2.dims """ new_coord = self.coords.rename_coord(**kwargs) - attrs = self.attrs.rename_dimension(**kwargs) - return self.new(coords=new_coord, dims=new_coord.dims, attrs=attrs) + return self.new(coords=new_coord, dims=new_coord.dims) @patch_function() diff --git a/dascore/utils/attrs.py b/dascore/utils/attrs.py index 34741d80..fea1a972 100644 --- a/dascore/utils/attrs.py +++ b/dascore/utils/attrs.py @@ -360,7 +360,6 @@ def _pop_keys(obj, out): # Check if dims need to be updated. new_dims = _get_dims(obj) if new_dims and new_dims != dims: - obj["dims"] = new_dims dims = new_dims # this is already a dict of coord info. if dims and set(dims).issubset(set(obj)): @@ -368,6 +367,4 @@ def _pop_keys(obj, out): _get_coords_from_coord_level(obj, coord_dict) _get_coords_from_top_level(obj, coord_dict, dims) _pop_keys(obj, coord_dict) - if "dims" not in obj and dims is not None: - obj["dims"] = dims return coord_dict, obj diff --git a/dascore/utils/display.py b/dascore/utils/display.py index fbe93580..fb45065e 100644 --- a/dascore/utils/display.py +++ b/dascore/utils/display.py @@ -132,7 +132,7 @@ def array_to_text(data, units=None) -> Text: def attrs_to_text(attrs) -> Text: """Convert pydantic model to text.""" - attrs = dc.PatchAttrs.from_dict(attrs).model_dump(exclude_defaults=True) + attrs = dc.PatchAttrs(**attrs).model_dump(exclude_defaults=True) # pop coords and dims since they show up in other places. attrs.pop("coords", None), attrs.pop("dims", None) txt = Text("➤ ") + Text("Attributes", style=dascore_styles["dc_yellow"]) diff --git a/dascore/utils/hdf5.py b/dascore/utils/hdf5.py index 1f92dc63..ed6d8d75 100644 --- a/dascore/utils/hdf5.py +++ b/dascore/utils/hdf5.py @@ -18,6 +18,7 @@ import numpy as np import pandas as pd import tables +from h5py import Empty # noqa (we purposely re-import this other places) from h5py import File as H5pyFile from packaging.version import parse as get_version from pandas.io.common import stringify_path diff --git a/dascore/utils/models.py b/dascore/utils/models.py index 5192a598..2e6cbf14 100644 --- a/dascore/utils/models.py +++ b/dascore/utils/models.py @@ -23,6 +23,23 @@ frozen_dict_validator = PlainValidator(lambda x: FrozenDict(x)) frozen_dict_serializer = PlainSerializer(lambda x: dict(x)) + +def _str_to_int_tuple(value): + """Convert a string of ints to a tuple.""" + if isinstance(value, str): + return tuple(int(x) for x in value.split(",")) + return value + + +def _str_to_str_tuple(value: str) -> tuple[str, ...]: + """Ensure a tuple of strings is returned.""" + if not value: + return () + elif isinstance(value, str): + return tuple(value.split(",")) + return tuple(value) + + # A datetime64 DateTime64 = Annotated[ np.datetime64, @@ -56,6 +73,19 @@ str, PlainValidator(lambda x: x if isinstance(x, str) else ",".join(x)) ] +StrTupleStrSerialized = Annotated[ + tuple[str, ...], + PlainValidator(_str_to_str_tuple), + PlainSerializer(lambda x: ",".join(x)), +] + +# A tuple of ints that should serialize to CSVs. +IntTupleStrSerialized = Annotated[ + tuple[int, ...], + PlainValidator(_str_to_int_tuple), + PlainSerializer(lambda x: ",".join(str(y) for y in x)), +] + FrozenDictType = Annotated[ FrozenDict, frozen_dict_validator, @@ -119,3 +149,18 @@ def get_summary_df(cls): return out __eq__ = sensible_model_equals + + +class ArraySummary(DascoreBaseModel): + """ + A class for summarizing arrays. + """ + + dtype: str + shape: tuple[int, ...] + ndim: int + + @classmethod + def from_array(cls, array): + """Init the summary from an array.""" + return cls(dtype=array.dtype, shape=array.shape, ndim=array.ndim) diff --git a/tests/test_core/test_attrs.py b/tests/test_core/test_attrs.py index e3b23040..2d5c19c7 100644 --- a/tests/test_core/test_attrs.py +++ b/tests/test_core/test_attrs.py @@ -2,7 +2,6 @@ from __future__ import annotations -import numpy as np import pandas as pd import pytest from pydantic import ValidationError @@ -11,15 +10,14 @@ from dascore.core.attrs import ( PatchAttrs, ) -from dascore.core.coords import CoordSummary, get_coord from dascore.utils.misc import register_func MORE_COORDS_ATTRS = [] @pytest.fixture(scope="class") -def random_summary(random_patch) -> PatchAttrs: - """Return the summary of the random patch.""" +def random_attrs_loaded(random_patch) -> PatchAttrs: + """Return the random patch attrs dumped then loaded.""" return PatchAttrs.model_validate(random_patch.attrs.model_dump()) @@ -62,13 +60,8 @@ class TestPatchAttrs: def test_get(self, random_attrs): """Ensure get returns existing values.""" - out = random_attrs.get("time_min") - assert out == random_attrs.time_min - - def test_get_existing_key(self, random_attrs): - """Ensure get returns existing values.""" - out = random_attrs.get("time_min") - assert out == random_attrs.time_min + out = random_attrs.get("category") + assert out == random_attrs.category def test_get_no_key(self, random_attrs): """Ensure missing keys return default value.""" @@ -82,169 +75,40 @@ def test_immutable(self, random_attrs): with pytest.raises(ValidationError, match="Instance is frozen"): random_attrs["bob"] = 1 - def test_coords_with_coord_keys(self): - """Ensure coords with base keys work.""" - coords = {"distance": get_coord(data=np.arange(100))} - out = PatchAttrs(**{"coords": coords}) - assert out.coords - assert "distance" in out.coords - for _name, val in out.coords.items(): - assert isinstance(val, CoordSummary) - - def test_coords_with_coord_manager(self, random_patch): - """Ensure coords with a coord manager works.""" - cm = random_patch.coords - out = PatchAttrs(**{"coords": cm}) - assert out.coords - assert set(cm.coord_map) == set(out.coords) - assert out.dims == ",".join(cm.dims) - - def test_coords_are_coord_summary(self, more_coords_attrs): - """All the coordinates should be Coordinate Summarys not dict.""" - for _, coord_sum in more_coords_attrs.coords.items(): - assert isinstance(coord_sum, CoordSummary) - - def test_access_min_max_step_etc(self, more_coords_attrs): - """Ensure min, max, step etc can be accessed for new coordinate.""" - expected_attrs = ["_min", "_max", "_step", "_units", "_dtype"] - for eat in expected_attrs: - attr_name = "depth" + eat - assert hasattr(more_coords_attrs, attr_name) - val = getattr(more_coords_attrs, attr_name) - if eat == "min": - assert val == 10.0 - if eat == "max": - assert val == 12.0 - - def test_deprecated_d_set(self): - """Ensure setting attributes d_whatever is deprecated.""" - with pytest.warns(DeprecationWarning): - PatchAttrs(time_min=1, time_max=10, d_time=10) - - def test_deprecated_d_get(self, random_attrs): - """Access attr.d_{whatever} is deprecated.""" - with pytest.warns(DeprecationWarning): - _ = random_attrs.d_time - - def test_access_coords(self, more_coords_attrs): - """Ensure coordinates can be accessed as well.""" - assert "depth" in more_coords_attrs.coords - assert more_coords_attrs.coords["depth"].min == more_coords_attrs.depth_min - - def test_extra_attrs_not_in_dump(self, more_coords_attrs): - """When using the extra attrs, they shouldn't show up in the dump.""" - dump = more_coords_attrs.model_dump() - not_expected = {"depth_min", "depth_max", "depth_step"} - assert not_expected.isdisjoint(set(dump)) - - def test_extra_attrs_not_in_dump_random_attrs(self, random_attrs): - """When using the extra attrs, they shouldn't show up in the dump.""" - dump = random_attrs.model_dump() - not_expected = {"time_min", "time_max", "time_step"} - assert not_expected.isdisjoint(set(dump)) - def test_supports_extra_attrs(self): """The attr dict should allow extra attributes.""" out = PatchAttrs(bob="doesnt", bill_min=12, bob_max="2012-01-12") assert out.bob == "doesnt" assert out.bill_min == 12 - def test_flat_dump(self, more_coords_attrs): - """Ensure flat dump flattens out the coords.""" - out = more_coords_attrs.flat_dump() - expected = { - "depth_min", - "depth_max", - "depth_step", - "depth_units", - "depth_dtype", - } - assert set(out).issuperset(expected) - - def test_flat_dump_coords(self, more_coords_attrs): - """Ensure flat dim with dim_tuple works.""" - attrs = more_coords_attrs - out = attrs.flat_dump(dim_tuple=True) - assert "depth" in out - depth = attrs.coords["depth"] - dep_min, dep_max = depth.min, depth.max - assert out["depth"] == (dep_min, dep_max) - - def test_coords_to_coord_summary(self): - """Coordinates included in coords should be converted to coord summary.""" - out = { - "station": "01", - "coords": { - "time": get_coord(start=0, stop=10, step=1), - "distance": get_coord(start=10, stop=100, step=1, units="m"), - }, - } - attr = dc.PatchAttrs(**out) - assert attr.dims == ",".join(("time", "distance")) - for name, coord in attr.coords.items(): - assert isinstance(coord, CoordSummary) - def test_items(self, random_patch): """Ensure items works like a dict.""" attrs = random_patch.attrs out = dict(attrs.items()) assert out == attrs.model_dump() - def test_dims_match_attrs(self, random_patch): - """Ensure the dims from patch attrs matches patch dims.""" - pat = random_patch.rename_coords(distance="channel") - assert pat.dims == pat.attrs.dim_tuple - class TestSummaryAttrs: """Tests for summarizing a schema.""" - def test_attrs_reconstructed(self, random_patch, random_summary): + def test_attrs_reconstructed(self, random_patch, random_attrs_loaded): """Ensure all the expected attrs are extracted.""" - summary1 = dict(random_summary) + summary1 = dict(random_attrs_loaded) attrs = dict(random_patch.attrs) common_keys = set(summary1) & set(attrs) for key in common_keys: assert summary1[key] == attrs[key] - def test_can_jsonize(self, random_summary): + def test_can_jsonize(self, random_attrs_loaded): """Ensure the summary can be converted to json.""" - json = random_summary.model_dump_json() + json = random_attrs_loaded.model_dump_json() assert isinstance(json, str) - def test_can_roundrip(self, random_summary): + def test_can_roundrip(self, random_attrs_loaded): """Ensure json can be round-tripped.""" - json = random_summary.model_dump_json() + json = random_attrs_loaded.model_dump_json() random_summary2 = PatchAttrs.model_validate_json(json) - assert random_summary2 == random_summary - - def test_from_dict(self, random_attrs): - """Test new method for more intuitive init.""" - out = PatchAttrs.from_dict(random_attrs) - assert out == random_attrs - new_dict = dict(random_attrs) - new_dict["data_units"] = "m/s" - out = PatchAttrs.from_dict(new_dict) - assert isinstance(out, PatchAttrs) - - -class TestRenameDimension: - """Ensure rename dimension works.""" - - def test_simple_rename(self, random_attrs): - """Ensure renaming a dimension works.""" - attrs = random_attrs - new_name = "money" - time_ind = attrs.dim_tuple.index("time") - out = attrs.rename_dimension(time=new_name) - assert new_name in out.dims - assert out.dim_tuple[time_ind] == new_name - assert len(out.dim_tuple) == len(attrs.dim_tuple) - - def test_empty_rename(self, random_attrs): - """Passing no kwargs should return same attrs.""" - attrs = random_attrs.rename_dimension() - assert attrs == random_attrs + assert random_summary2 == random_attrs_loaded class TestDropPrivate: @@ -258,40 +122,15 @@ def test_simple_drop(self): assert "extra_attr" in attr_dict -class TestMisc: - """Misc small tests.""" - - def test_schema_deprecated(self): - """Ensure schema module emits deprecation warning.""" - with pytest.warns(DeprecationWarning): - from dascore.core.schema import PatchAttrs # noqa - - def test_get_attrs_non_dim_coordinates(self, random_patch_with_lat_lon): - """ - Ensure only dims show up in dims even when coord manager has many - coordinates. - """ - patch = random_patch_with_lat_lon - cm = patch.coords - attrs = dc.PatchAttrs(coords=cm) - assert attrs.dim_tuple == cm.dims - - class TestUpdateAttrs: """Tests for updating attributes.""" def test_attrs_can_update(self, random_attrs): """Ensure attributes can update coordinates.""" - attrs = random_attrs.update(distance_units="miles") - expected = dc.get_quantity("miles") - assert dc.get_quantity(attrs.coords["distance"].units) == expected - - def test_update_from_coords(self, random_patch): - """Ensure attrs.update updates dims from coords.""" - attrs = random_patch.attrs - new_patch = random_patch.rename_coords(distance="channel") - new = attrs.update(coords=new_patch.coords) - assert new.dim_tuple == new_patch.dims + # We also test that str get converted to data units correctly. + attrs = random_attrs.update(data_units="m/s") + expected = dc.get_quantity("m/s") + assert attrs.data_units == expected class TestGetAttrSummary: diff --git a/tests/test_core/test_coords.py b/tests/test_core/test_coords.py index 0f5f3f11..4bb41aa6 100644 --- a/tests/test_core/test_coords.py +++ b/tests/test_core/test_coords.py @@ -385,7 +385,7 @@ class TestCoordSummary: @pytest.fixture(scope="session") def summary(self, coord) -> CoordSummary: """Convert each coord to a summary.""" - return coord.to_summary() + return coord.to_summary(dims="distance", name="distance") def test_dtype_consistent(self, summary, coord): """Ensure the datatype is preserved.""" @@ -397,32 +397,33 @@ def test_dtype_consistent(self, summary, coord): def test_to_summary(self, coord): """Ensure all coords can be converted to a summary.""" - out = coord.to_summary() + out = coord.to_summary(name="time", dims="time") assert isinstance(out, CoordSummary) def test_coord_range_round_trip(self, coord): """Coord ranges should round-trip to summaries and back.""" if not coord.evenly_sampled: return - summary = coord.to_summary() + summary = coord.to_summary(name="time", dims="time") back = summary.to_coord() assert back == coord - if not back.to_summary() == summary: - coord.to_summary() + if not back.to_summary(name="time", dims="time") == summary: + coord.to_summary(name="time", dims="time") - assert back.to_summary() == summary + assert back.to_summary(name="time", dims="time") == summary def test_to_summary_raises(self, random_coord): """Ensure to_summary raises if not evenly sampled.""" match = "Cannot convert summary which is not evenly sampled" with pytest.raises(CoordError, match=match): - random_coord.to_summary().to_coord() + random_coord.to_summary(name="time", dims="time").to_coord() @pytest.mark.parametrize("data", cast_data_list) def test_dtype_inferred(self, data): """Ensure the dtypes are correctly determined if not specified.""" - out = CoordSummary(**data) + augments = {"name": "time", "dims": "time", "shape": (21,), "ndim": 1} + out = CoordSummary(**(data | augments)) assert np.dtype(out.dtype) == np.dtype(type(data["min"])) @@ -1440,7 +1441,7 @@ def test_order_by_samples(self, basic_non_coord): def test_to_summary(self, basic_non_coord): """Ensure we can convert non coord to summary.""" - summary = basic_non_coord.to_summary() + summary = basic_non_coord.to_summary(name="time", dims="time") assert isinstance(summary, CoordSummary) def test_equals_to_other_coord(self, basic_non_coord): diff --git a/tests/test_utils/test_hdf_utils.py b/tests/test_utils/test_hdf_utils.py index 4540249c..434eb369 100644 --- a/tests/test_utils/test_hdf_utils.py +++ b/tests/test_utils/test_hdf_utils.py @@ -15,7 +15,6 @@ from dascore.utils.downloader import fetch from dascore.utils.hdf5 import ( HDFPatchIndexManager, - PyTablesWriter, extract_h5_attrs, h5_matches_structure, open_hdf5_file, @@ -164,18 +163,6 @@ def test_metadata_created(self, tmp_path_factory): assert meta is not None -class TestHDFReaders: - """Tests for HDF5 readers.""" - - def test_get_handle(self, tmp_path_factory): - """Ensure we can get a handle with the class.""" - path = tmp_path_factory.mktemp("hdf_handle_test") / "test_file.h5" - handle = PyTablesWriter.get_handle(path) - assert isinstance(handle, tables.File) - handle_2 = PyTablesWriter.get_handle(handle) - assert isinstance(handle_2, tables.File) - - class TestH5MatchesStructure: """Tests for the h5 matches structure function."""