diff --git a/dascore/core/patch.py b/dascore/core/patch.py index b5215ede..ad18f337 100644 --- a/dascore/core/patch.py +++ b/dascore/core/patch.py @@ -408,7 +408,7 @@ class PatchSummary(DascoreBaseModel): "path", "format_version", "format_name", - "acquistion_id", + "acquisition_id", "tag", ) @@ -425,7 +425,7 @@ def to_summary( """ return self - def _attrs_to_patch_info(self, attr_info, patch_info, patch_key): + def _attrs_to_patch_info(self, attr_info, patch_info, patch_key, spool_key): """Transfer some attrs to the patch info.""" out = [] for key in self._attrs_to_patch_keys: @@ -433,20 +433,24 @@ def _attrs_to_patch_info(self, attr_info, patch_info, patch_key): patch_info[key] = value # flatten remaining attrs for item, value in attr_info.items(): - out.append(dict(name=item, value=value, patch_key=patch_key)) + out.append( + dict(name=item, value=value, patch_key=patch_key, spool_key=spool_key) + ) return out - def _reshape_coords(self, patch_info, coord_info, patch_key): + def _reshape_coords(self, patch_info, coord_info, patch_key, spool_key): """Move some coord info over to patch info.""" patch_info["dims"] = coord_info.pop("dims") coord_list = list(coord_info["coord_map"].values()) for coord in coord_list: coord["patch_key"] = patch_key # ensure patch key is in coord. + coord["spool_key"] = spool_key # ensure spool key is in coord. return coord_list def to_patch_coords_attrs_info( self, patch_key, + spool_key=0, ) -> tuple[list[dict], list[dict], list[dict]]: """ Convert the PatchSummary to three lists of dicts. @@ -456,8 +460,7 @@ def to_patch_coords_attrs_info( attrs = self.attrs.model_dump(exclude_unset=True) coords = self.coords.model_dump(exclude_unset=True) patch_info = self.data.model_dump(exclude_unset=True) - - patch_info["patch_key"] = patch_key - attrs = self._attrs_to_patch_info(attrs, patch_info, patch_key) - coords = self._reshape_coords(patch_info, coords, patch_key) + patch_info["patch_key"], patch_info["spool_key"] = patch_key, spool_key + attrs = self._attrs_to_patch_info(attrs, patch_info, patch_key, spool_key) + coords = self._reshape_coords(patch_info, coords, patch_key, spool_key) return patch_info, coords, attrs diff --git a/dascore/core/spool.py b/dascore/core/spool.py index 796938a9..00fff262 100644 --- a/dascore/core/spool.py +++ b/dascore/core/spool.py @@ -77,11 +77,9 @@ def __rich__(self): text += Text(" Patches)") if patch_len != 1 else Text(" Patch)") return text - def __str__(self): + def __repr__(self): return str(self.__rich__()) - __repr__ = __str__ - def __eq__(self, other) -> bool: """Simple equality checks on spools.""" @@ -338,7 +336,7 @@ def viz(self): "'Spool' has no 'viz' namespace. " "Apply 'viz' on a Patch object. " "(you can merge a subset of the spool into a single patch using " - "the Chunk function. i.e., spool.chunk(time=None)[0].viz.waterfall())" + "the Chunk function. e.g., spool.chunk(time=None)[0].viz.waterfall())" ) raise AttributeError(msg) diff --git a/dascore/io/core.py b/dascore/io/core.py index 3ab3a36b..0ade3584 100644 --- a/dascore/io/core.py +++ b/dascore/io/core.py @@ -9,10 +9,9 @@ import os.path import warnings from collections import defaultdict -from collections.abc import Generator, Iterable +from collections.abc import Generator from functools import cache, cached_property, wraps from importlib.metadata import entry_points -from itertools import chain from pathlib import Path from typing import Literal, get_type_hints @@ -36,6 +35,7 @@ from dascore.utils.io import IOResourceManager, get_handle_from_resource from dascore.utils.mapping import FrozenDict from dascore.utils.misc import cached_method, iterate, warn_or_raise +from dascore.utils.pd import _patch_summary_to_dataframes from dascore.utils.progress import track @@ -662,23 +662,6 @@ def _assemble_summary_df(patch_df, coord_df, attr_df): return out.assign(**defaults) -def _patch_summary_to_dataframes( - patch_summaries: Iterable[dc.PatchSummary], -) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: - """Convert a sequence of Patch Summaries to dataframes.""" - patch_list, coord_list, attr_list = [], [], [] - for num, summary in enumerate(iterate(patch_summaries)): - patch_in, coord_in, attr_in = summary.to_patch_coords_attrs_info(num) - patch_list.append(patch_in) - coord_list.append(coord_in) - attr_list.append(attr_in) - patch_df = pd.DataFrame(patch_list).set_index("patch_key") - # The coords and attrs are nested lists so we need to flatten them. - coord_df = pd.DataFrame(list(chain.from_iterable(coord_list))) - attr_df = pd.DataFrame(list(chain.from_iterable(attr_list))) - return patch_df, coord_df, attr_df - - def _iterate_scan_inputs(patch_source, ext, mtime, include_directories=True, **kwargs): """Yield scan candidates.""" for el in iterate(patch_source): diff --git a/dascore/utils/duck.py b/dascore/utils/duck.py new file mode 100644 index 00000000..0f11a1cc --- /dev/null +++ b/dascore/utils/duck.py @@ -0,0 +1,127 @@ +""" +DuckDB utils. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from contextlib import contextmanager + +import duckdb +import pandas as pd + +import dascore as dc +from dascore.utils.pd import _patch_summary_to_dataframes + + +def make_schema_str(schema_list): + """Make the string of schema.""" + commas = (" ".join(x) for x in schema_list) + return f"({', '.join(commas)})" + + +class DuckIndexer: + """ + A class to encapsulate DuckDB interactions for spool indexing. + """ + + # If the tables have been tried to be created. + _tried_create_tables = False + + # The schema for most data values; optimally flexible to avoid making + # multiple tables per dtype. + # Note: as of duckdb v1.1.3 intervals with ns precision are not supported. + _flexible_vals = "UNION(str VARCHAR, int LONG, float DOUBLE, dt TIMESTAMP_NS)" + + # Primary keys for patch table, secondary for others. + _patch_keys = ("patch_key", "spool_key") + + # The names of the source tables and their keys/schema. + _schema = { + "patch_source": ( + ("patch_key", "INTEGER"), + ("spool_key", "INTEGER"), + ("data_units", "VARCHAR"), + ("ndims", "INTEGER"), + ("data_shape", "INTEGER[]"), + ("data_dtype", "VARCHAR"), + ("dims", "VARCHAR[]"), + ("coords", "VARCHAR[]"), + ("format_name", "VARCHAR"), + ("format_version", "VARCHAR"), + ("path", "VARCHAR"), + ("acquisition_id", "VARCHAR"), + ("tag", "VARCHAR"), + ), + "coord_source": ( + ("patch_key", "INTEGER"), + ("spool_key", "INTEGER"), + ("name", "VARCHAR"), + ("shape", "INTEGER[]"), + ("dtype", "VARCHAR"), + ("ndims", "INTEGER"), + ("units", "VARCHAR"), + ("dims", "VARCHAR[]"), + ("start", _flexible_vals), + ("stop", _flexible_vals), + ("step", _flexible_vals), + ), + "attr_source": ( + ("patch_key", "INTEGER"), + ("spool_key", "INTEGER"), + ("name", "VARCHAR"), + ("value", _flexible_vals), + ), + } + + # SQL type Schema for patch table. + _patch_schema = () + + # SQL type Schema for attribute table. + _attr_schema = () + + # SQL type Schema for coordinate table. + _coord_schema = () + + def __init__(self, connection="", **kwargs): + self._connection = connection + self._kwargs = kwargs + + def __repr__(self): + out = f"DuckDB indexer ({self._connection}, {self._kwargs})) " + return out + + @contextmanager + def connection(self): + """A context manager to create (and close) the connection.""" + with duckdb.connect(self._connection, **self._kwargs) as conn: + # Ensure tables have been created. + if not self._tried_create_tables: + self._add_spool_tables(conn) + yield conn + conn.close() + + def _add_spool_tables(self, conn): + """Add the tables (schema) to database if they aren't defined.""" + for table_name, schema in self._schema.items(): + schema_str = make_schema_str(schema) + conn.sql(f"CREATE TABLE IF NOT EXISTS {table_name} {schema_str};") + + def get_table(self, name) -> pd.DataFrame: + """Retrieve a table name as a dataframe.""" + with self.connection() as conn: + out = conn.sql(f"SELECT * FROM {name}").df() + return out + + def upsert_table(self, df, table): + """Insert or update a dataframe.""" + with self.connection() as conn: + cmd_str = ( + f"INSERT INTO {table} " "SELECT * FROM ? " "ON CONFLICT DO UPDATE SET " + ) + conn.sql(cmd_str, [df]) + + def insert_summaries(self, summaries: Sequence[dc.PatchSummary]): + """Insert the Patch Summaries into the duck index.""" + patch, coord, attr = _patch_summary_to_dataframes(summaries) + breakpoint() diff --git a/dascore/utils/misc.py b/dascore/utils/misc.py index 126fa2e8..903f3cd6 100644 --- a/dascore/utils/misc.py +++ b/dascore/utils/misc.py @@ -208,7 +208,7 @@ def iterate(obj): class CacheDescriptor: - """A descriptor for storing infor in an instance cache (mapping).""" + """A descriptor for storing info in an instance cache (mapping).""" def __init__(self, cache_name, func_name, args=None, kwargs=None): self._cache_name = cache_name diff --git a/dascore/utils/pd.py b/dascore/utils/pd.py index 1bef0757..ad1015b6 100644 --- a/dascore/utils/pd.py +++ b/dascore/utils/pd.py @@ -5,8 +5,9 @@ import fnmatch import os from collections import defaultdict -from collections.abc import Collection, Mapping, Sequence +from collections.abc import Collection, Iterable, Mapping, Sequence from functools import cache +from itertools import chain import numpy as np import pandas as pd @@ -16,7 +17,7 @@ from dascore.constants import PatchType from dascore.core.attrs import PatchAttrs from dascore.exceptions import ParameterError -from dascore.utils.misc import sanitize_range_param +from dascore.utils.misc import iterate, sanitize_range_param from dascore.utils.time import to_datetime64, to_timedelta64 @@ -554,3 +555,21 @@ def rolling_df(df, window, step=None, axis=0, center=False): """ df = df if not axis else df.T # silly deprecated axis argument. return df.rolling(window=window, step=step, center=center) + + +def _patch_summary_to_dataframes( + patch_summaries: Iterable[dc.PatchSummary], +) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: + """Convert a sequence of Patch Summaries to dataframes.""" + patch_list, coord_list, attr_list = [], [], [] + for num, summary in enumerate(iterate(patch_summaries)): + summary = summary.to_summary() # Ensure we have a summary + patch_in, coord_in, attr_in = summary.to_patch_coords_attrs_info(num) + patch_list.append(patch_in) + coord_list.append(coord_in) + attr_list.append(attr_in) + patch_df = pd.DataFrame(patch_list).set_index("patch_key") + # The coords and attrs are nested lists so we need to flatten them. + coord_df = pd.DataFrame(list(chain.from_iterable(coord_list))) + attr_df = pd.DataFrame(list(chain.from_iterable(attr_list))) + return patch_df, coord_df, attr_df diff --git a/tests/test_utils/test_duck.py b/tests/test_utils/test_duck.py new file mode 100644 index 00000000..c8f0d7ef --- /dev/null +++ b/tests/test_utils/test_duck.py @@ -0,0 +1,31 @@ +""" +Tests for duckdb. +""" + +import pandas as pd +import pytest + +from dascore.utils.duck import DuckIndexer + + +@pytest.fixture(scope="class") +def duck_indexer(): + """The default duck indexer fixture.""" + return DuckIndexer() + + +class TestDuckIndexer: + """Basic tests for DuckIndexer.""" + + def test_schema_created(self, duck_indexer): + """Iterate the expected tables and ensure they exist.""" + for table_name, schema in duck_indexer._schema.items(): + cols = [x[0] for x in schema] + df = duck_indexer.get_table(table_name) + assert isinstance(df, pd.DataFrame) + assert set(df.columns) == set(cols) + + def test_insert_summary(self, duck_indexer, random_patch): + """Ensure we can insert a summary""" + duck_indexer.insert_summaries(random_patch) + breakpoint()