-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b3847db
commit b7ef2bb
Showing
7 changed files
with
195 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
""" | ||
DuckDB utils. | ||
""" | ||
|
||
from __future__ import annotations | ||
|
||
from collections.abc import Sequence | ||
from contextlib import contextmanager | ||
|
||
import duckdb | ||
import pandas as pd | ||
|
||
import dascore as dc | ||
from dascore.utils.pd import _patch_summary_to_dataframes | ||
|
||
|
||
def make_schema_str(schema_list): | ||
"""Make the string of schema.""" | ||
commas = (" ".join(x) for x in schema_list) | ||
return f"({', '.join(commas)})" | ||
|
||
|
||
class DuckIndexer: | ||
""" | ||
A class to encapsulate DuckDB interactions for spool indexing. | ||
""" | ||
|
||
# If the tables have been tried to be created. | ||
_tried_create_tables = False | ||
|
||
# The schema for most data values; optimally flexible to avoid making | ||
# multiple tables per dtype. | ||
# Note: as of duckdb v1.1.3 intervals with ns precision are not supported. | ||
_flexible_vals = "UNION(str VARCHAR, int LONG, float DOUBLE, dt TIMESTAMP_NS)" | ||
|
||
# Primary keys for patch table, secondary for others. | ||
_patch_keys = ("patch_key", "spool_key") | ||
|
||
# The names of the source tables and their keys/schema. | ||
_schema = { | ||
"patch_source": ( | ||
("patch_key", "INTEGER"), | ||
("spool_key", "INTEGER"), | ||
("data_units", "VARCHAR"), | ||
("ndims", "INTEGER"), | ||
("data_shape", "INTEGER[]"), | ||
("data_dtype", "VARCHAR"), | ||
("dims", "VARCHAR[]"), | ||
("coords", "VARCHAR[]"), | ||
("format_name", "VARCHAR"), | ||
("format_version", "VARCHAR"), | ||
("path", "VARCHAR"), | ||
("acquisition_id", "VARCHAR"), | ||
("tag", "VARCHAR"), | ||
), | ||
"coord_source": ( | ||
("patch_key", "INTEGER"), | ||
("spool_key", "INTEGER"), | ||
("name", "VARCHAR"), | ||
("shape", "INTEGER[]"), | ||
("dtype", "VARCHAR"), | ||
("ndims", "INTEGER"), | ||
("units", "VARCHAR"), | ||
("dims", "VARCHAR[]"), | ||
("start", _flexible_vals), | ||
("stop", _flexible_vals), | ||
("step", _flexible_vals), | ||
), | ||
"attr_source": ( | ||
("patch_key", "INTEGER"), | ||
("spool_key", "INTEGER"), | ||
("name", "VARCHAR"), | ||
("value", _flexible_vals), | ||
), | ||
} | ||
|
||
# SQL type Schema for patch table. | ||
_patch_schema = () | ||
|
||
# SQL type Schema for attribute table. | ||
_attr_schema = () | ||
|
||
# SQL type Schema for coordinate table. | ||
_coord_schema = () | ||
|
||
def __init__(self, connection="", **kwargs): | ||
self._connection = connection | ||
self._kwargs = kwargs | ||
|
||
def __repr__(self): | ||
out = f"DuckDB indexer ({self._connection}, {self._kwargs})) " | ||
return out | ||
|
||
@contextmanager | ||
def connection(self): | ||
"""A context manager to create (and close) the connection.""" | ||
with duckdb.connect(self._connection, **self._kwargs) as conn: | ||
# Ensure tables have been created. | ||
if not self._tried_create_tables: | ||
self._add_spool_tables(conn) | ||
yield conn | ||
conn.close() | ||
|
||
def _add_spool_tables(self, conn): | ||
"""Add the tables (schema) to database if they aren't defined.""" | ||
for table_name, schema in self._schema.items(): | ||
schema_str = make_schema_str(schema) | ||
conn.sql(f"CREATE TABLE IF NOT EXISTS {table_name} {schema_str};") | ||
|
||
def get_table(self, name) -> pd.DataFrame: | ||
"""Retrieve a table name as a dataframe.""" | ||
with self.connection() as conn: | ||
out = conn.sql(f"SELECT * FROM {name}").df() | ||
return out | ||
|
||
def upsert_table(self, df, table): | ||
"""Insert or update a dataframe.""" | ||
with self.connection() as conn: | ||
cmd_str = ( | ||
f"INSERT INTO {table} " "SELECT * FROM ? " "ON CONFLICT DO UPDATE SET " | ||
) | ||
conn.sql(cmd_str, [df]) | ||
|
||
def insert_summaries(self, summaries: Sequence[dc.PatchSummary]): | ||
"""Insert the Patch Summaries into the duck index.""" | ||
patch, coord, attr = _patch_summary_to_dataframes(summaries) | ||
breakpoint() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
""" | ||
Tests for duckdb. | ||
""" | ||
|
||
import pandas as pd | ||
import pytest | ||
|
||
from dascore.utils.duck import DuckIndexer | ||
|
||
|
||
@pytest.fixture(scope="class") | ||
def duck_indexer(): | ||
"""The default duck indexer fixture.""" | ||
return DuckIndexer() | ||
|
||
|
||
class TestDuckIndexer: | ||
"""Basic tests for DuckIndexer.""" | ||
|
||
def test_schema_created(self, duck_indexer): | ||
"""Iterate the expected tables and ensure they exist.""" | ||
for table_name, schema in duck_indexer._schema.items(): | ||
cols = [x[0] for x in schema] | ||
df = duck_indexer.get_table(table_name) | ||
assert isinstance(df, pd.DataFrame) | ||
assert set(df.columns) == set(cols) | ||
|
||
def test_insert_summary(self, duck_indexer, random_patch): | ||
"""Ensure we can insert a summary""" | ||
duck_indexer.insert_summaries(random_patch) | ||
breakpoint() |