Skip to content

Commit

Permalink
create duck utils
Browse files Browse the repository at this point in the history
  • Loading branch information
d-chambers committed Jan 18, 2025
1 parent b3847db commit b7ef2bb
Show file tree
Hide file tree
Showing 7 changed files with 195 additions and 34 deletions.
19 changes: 11 additions & 8 deletions dascore/core/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ class PatchSummary(DascoreBaseModel):
"path",
"format_version",
"format_name",
"acquistion_id",
"acquisition_id",
"tag",
)

Expand All @@ -425,28 +425,32 @@ def to_summary(
"""
return self

def _attrs_to_patch_info(self, attr_info, patch_info, patch_key):
def _attrs_to_patch_info(self, attr_info, patch_info, patch_key, spool_key):
"""Transfer some attrs to the patch info."""
out = []
for key in self._attrs_to_patch_keys:
if value := attr_info.pop(key, None):
patch_info[key] = value
# flatten remaining attrs
for item, value in attr_info.items():
out.append(dict(name=item, value=value, patch_key=patch_key))
out.append(
dict(name=item, value=value, patch_key=patch_key, spool_key=spool_key)
)
return out

def _reshape_coords(self, patch_info, coord_info, patch_key):
def _reshape_coords(self, patch_info, coord_info, patch_key, spool_key):
"""Move some coord info over to patch info."""
patch_info["dims"] = coord_info.pop("dims")
coord_list = list(coord_info["coord_map"].values())
for coord in coord_list:
coord["patch_key"] = patch_key # ensure patch key is in coord.
coord["spool_key"] = spool_key # ensure spool key is in coord.
return coord_list

def to_patch_coords_attrs_info(
self,
patch_key,
spool_key=0,
) -> tuple[list[dict], list[dict], list[dict]]:
"""
Convert the PatchSummary to three lists of dicts.
Expand All @@ -456,8 +460,7 @@ def to_patch_coords_attrs_info(
attrs = self.attrs.model_dump(exclude_unset=True)
coords = self.coords.model_dump(exclude_unset=True)
patch_info = self.data.model_dump(exclude_unset=True)

patch_info["patch_key"] = patch_key
attrs = self._attrs_to_patch_info(attrs, patch_info, patch_key)
coords = self._reshape_coords(patch_info, coords, patch_key)
patch_info["patch_key"], patch_info["spool_key"] = patch_key, spool_key
attrs = self._attrs_to_patch_info(attrs, patch_info, patch_key, spool_key)
coords = self._reshape_coords(patch_info, coords, patch_key, spool_key)
return patch_info, coords, attrs
6 changes: 2 additions & 4 deletions dascore/core/spool.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,9 @@ def __rich__(self):
text += Text(" Patches)") if patch_len != 1 else Text(" Patch)")
return text

def __str__(self):
def __repr__(self):
return str(self.__rich__())

__repr__ = __str__

def __eq__(self, other) -> bool:
"""Simple equality checks on spools."""

Expand Down Expand Up @@ -338,7 +336,7 @@ def viz(self):
"'Spool' has no 'viz' namespace. "
"Apply 'viz' on a Patch object. "
"(you can merge a subset of the spool into a single patch using "
"the Chunk function. i.e., spool.chunk(time=None)[0].viz.waterfall())"
"the Chunk function. e.g., spool.chunk(time=None)[0].viz.waterfall())"
)
raise AttributeError(msg)

Expand Down
21 changes: 2 additions & 19 deletions dascore/io/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
import os.path
import warnings
from collections import defaultdict
from collections.abc import Generator, Iterable
from collections.abc import Generator
from functools import cache, cached_property, wraps
from importlib.metadata import entry_points
from itertools import chain
from pathlib import Path
from typing import Literal, get_type_hints

Expand All @@ -36,6 +35,7 @@
from dascore.utils.io import IOResourceManager, get_handle_from_resource
from dascore.utils.mapping import FrozenDict
from dascore.utils.misc import cached_method, iterate, warn_or_raise
from dascore.utils.pd import _patch_summary_to_dataframes
from dascore.utils.progress import track


Expand Down Expand Up @@ -662,23 +662,6 @@ def _assemble_summary_df(patch_df, coord_df, attr_df):
return out.assign(**defaults)


def _patch_summary_to_dataframes(
patch_summaries: Iterable[dc.PatchSummary],
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""Convert a sequence of Patch Summaries to dataframes."""
patch_list, coord_list, attr_list = [], [], []
for num, summary in enumerate(iterate(patch_summaries)):
patch_in, coord_in, attr_in = summary.to_patch_coords_attrs_info(num)
patch_list.append(patch_in)
coord_list.append(coord_in)
attr_list.append(attr_in)
patch_df = pd.DataFrame(patch_list).set_index("patch_key")
# The coords and attrs are nested lists so we need to flatten them.
coord_df = pd.DataFrame(list(chain.from_iterable(coord_list)))
attr_df = pd.DataFrame(list(chain.from_iterable(attr_list)))
return patch_df, coord_df, attr_df


def _iterate_scan_inputs(patch_source, ext, mtime, include_directories=True, **kwargs):
"""Yield scan candidates."""
for el in iterate(patch_source):
Expand Down
127 changes: 127 additions & 0 deletions dascore/utils/duck.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
"""
DuckDB utils.
"""

from __future__ import annotations

from collections.abc import Sequence
from contextlib import contextmanager

import duckdb
import pandas as pd

import dascore as dc
from dascore.utils.pd import _patch_summary_to_dataframes


def make_schema_str(schema_list):
"""Make the string of schema."""
commas = (" ".join(x) for x in schema_list)
return f"({', '.join(commas)})"


class DuckIndexer:
"""
A class to encapsulate DuckDB interactions for spool indexing.
"""

# If the tables have been tried to be created.
_tried_create_tables = False

# The schema for most data values; optimally flexible to avoid making
# multiple tables per dtype.
# Note: as of duckdb v1.1.3 intervals with ns precision are not supported.
_flexible_vals = "UNION(str VARCHAR, int LONG, float DOUBLE, dt TIMESTAMP_NS)"

# Primary keys for patch table, secondary for others.
_patch_keys = ("patch_key", "spool_key")

# The names of the source tables and their keys/schema.
_schema = {
"patch_source": (
("patch_key", "INTEGER"),
("spool_key", "INTEGER"),
("data_units", "VARCHAR"),
("ndims", "INTEGER"),
("data_shape", "INTEGER[]"),
("data_dtype", "VARCHAR"),
("dims", "VARCHAR[]"),
("coords", "VARCHAR[]"),
("format_name", "VARCHAR"),
("format_version", "VARCHAR"),
("path", "VARCHAR"),
("acquisition_id", "VARCHAR"),
("tag", "VARCHAR"),
),
"coord_source": (
("patch_key", "INTEGER"),
("spool_key", "INTEGER"),
("name", "VARCHAR"),
("shape", "INTEGER[]"),
("dtype", "VARCHAR"),
("ndims", "INTEGER"),
("units", "VARCHAR"),
("dims", "VARCHAR[]"),
("start", _flexible_vals),
("stop", _flexible_vals),
("step", _flexible_vals),
),
"attr_source": (
("patch_key", "INTEGER"),
("spool_key", "INTEGER"),
("name", "VARCHAR"),
("value", _flexible_vals),
),
}

# SQL type Schema for patch table.
_patch_schema = ()

# SQL type Schema for attribute table.
_attr_schema = ()

# SQL type Schema for coordinate table.
_coord_schema = ()

def __init__(self, connection="", **kwargs):
self._connection = connection
self._kwargs = kwargs

def __repr__(self):
out = f"DuckDB indexer ({self._connection}, {self._kwargs})) "
return out

@contextmanager
def connection(self):
"""A context manager to create (and close) the connection."""
with duckdb.connect(self._connection, **self._kwargs) as conn:
# Ensure tables have been created.
if not self._tried_create_tables:
self._add_spool_tables(conn)
yield conn
conn.close()

def _add_spool_tables(self, conn):
"""Add the tables (schema) to database if they aren't defined."""
for table_name, schema in self._schema.items():
schema_str = make_schema_str(schema)
conn.sql(f"CREATE TABLE IF NOT EXISTS {table_name} {schema_str};")

def get_table(self, name) -> pd.DataFrame:
"""Retrieve a table name as a dataframe."""
with self.connection() as conn:
out = conn.sql(f"SELECT * FROM {name}").df()
return out

def upsert_table(self, df, table):
"""Insert or update a dataframe."""
with self.connection() as conn:
cmd_str = (
f"INSERT INTO {table} " "SELECT * FROM ? " "ON CONFLICT DO UPDATE SET "
)
conn.sql(cmd_str, [df])

def insert_summaries(self, summaries: Sequence[dc.PatchSummary]):
"""Insert the Patch Summaries into the duck index."""
patch, coord, attr = _patch_summary_to_dataframes(summaries)
breakpoint()
2 changes: 1 addition & 1 deletion dascore/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def iterate(obj):


class CacheDescriptor:
"""A descriptor for storing infor in an instance cache (mapping)."""
"""A descriptor for storing info in an instance cache (mapping)."""

def __init__(self, cache_name, func_name, args=None, kwargs=None):
self._cache_name = cache_name
Expand Down
23 changes: 21 additions & 2 deletions dascore/utils/pd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import fnmatch
import os
from collections import defaultdict
from collections.abc import Collection, Mapping, Sequence
from collections.abc import Collection, Iterable, Mapping, Sequence
from functools import cache
from itertools import chain

import numpy as np
import pandas as pd
Expand All @@ -16,7 +17,7 @@
from dascore.constants import PatchType
from dascore.core.attrs import PatchAttrs
from dascore.exceptions import ParameterError
from dascore.utils.misc import sanitize_range_param
from dascore.utils.misc import iterate, sanitize_range_param
from dascore.utils.time import to_datetime64, to_timedelta64


Expand Down Expand Up @@ -554,3 +555,21 @@ def rolling_df(df, window, step=None, axis=0, center=False):
"""
df = df if not axis else df.T # silly deprecated axis argument.
return df.rolling(window=window, step=step, center=center)


def _patch_summary_to_dataframes(
patch_summaries: Iterable[dc.PatchSummary],
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""Convert a sequence of Patch Summaries to dataframes."""
patch_list, coord_list, attr_list = [], [], []
for num, summary in enumerate(iterate(patch_summaries)):
summary = summary.to_summary() # Ensure we have a summary
patch_in, coord_in, attr_in = summary.to_patch_coords_attrs_info(num)
patch_list.append(patch_in)
coord_list.append(coord_in)
attr_list.append(attr_in)
patch_df = pd.DataFrame(patch_list).set_index("patch_key")
# The coords and attrs are nested lists so we need to flatten them.
coord_df = pd.DataFrame(list(chain.from_iterable(coord_list)))
attr_df = pd.DataFrame(list(chain.from_iterable(attr_list)))
return patch_df, coord_df, attr_df
31 changes: 31 additions & 0 deletions tests/test_utils/test_duck.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""
Tests for duckdb.
"""

import pandas as pd
import pytest

from dascore.utils.duck import DuckIndexer


@pytest.fixture(scope="class")
def duck_indexer():
"""The default duck indexer fixture."""
return DuckIndexer()


class TestDuckIndexer:
"""Basic tests for DuckIndexer."""

def test_schema_created(self, duck_indexer):
"""Iterate the expected tables and ensure they exist."""
for table_name, schema in duck_indexer._schema.items():
cols = [x[0] for x in schema]
df = duck_indexer.get_table(table_name)
assert isinstance(df, pd.DataFrame)
assert set(df.columns) == set(cols)

def test_insert_summary(self, duck_indexer, random_patch):
"""Ensure we can insert a summary"""
duck_indexer.insert_summaries(random_patch)
breakpoint()

0 comments on commit b7ef2bb

Please sign in to comment.