Skip to content

Commit

Permalink
simpler SQLite querying with user-defined functions, etc.
Browse files Browse the repository at this point in the history
  • Loading branch information
xrotwang committed Jan 20, 2025
1 parent 6df05df commit 222da07
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 1 deletion.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
The `pycldf` package adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).


## Unreleased

- Added a utility function to query SQLite DBs using user-defined functions, aggregates or collations.


## [1.40.4] - 2025-01-15

- Fixed issue where validator reports invalid data sets as valid if a logger is supplied
Expand Down
32 changes: 32 additions & 0 deletions src/pycldf/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
);
"""
import typing
import inspect
import pathlib
import functools
import collections
Expand Down Expand Up @@ -369,3 +370,34 @@ def to_cldf(self, dest, mdname='cldf-metadata.json', coordinate_precision=4) ->
except KeyError:
assert table_type == self.source_table_name, table_type
return self.dataset.write_metadata(dest / mdname)


def query(conn,
sql: str,
params=None,
functions=None,
aggregates=None,
collations=None) -> typing.Generator[typing.Any, None, None]:
for func in functions or []:
if isinstance(func, tuple):
name, func = func
else:
name = func.__name__
conn.create_function(name, len(inspect.signature(func).parameters), func)

for cls in aggregates or []:
if isinstance(cls, tuple):
name, cls = cls
else:
name = cls.__name__
conn.create_aggregate(name, len(inspect.signature(cls.step).parameters) - 1, cls)

for func in collations or []:
if isinstance(func, tuple):
name, func = func
else:
name = func.__name__
assert len(inspect.signature(func).parameters) == 2
conn.create_collation(name, func)

return conn.execute(sql, params or ()).fetchall()
62 changes: 61 additions & 1 deletion tests/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest

from pycldf.dataset import Dataset, Generic, StructureDataset
from pycldf.db import Database, translate, TableTranslation
from pycldf.db import Database, translate, TableTranslation, query


@pytest.fixture
Expand All @@ -18,6 +18,66 @@ def ds_sd(tmp_path):
return StructureDataset.in_dir(tmp_path)


@pytest.fixture
def conn():
con = sqlite3.connect(":memory:")
cur = con.execute("CREATE TABLE test(x, y)")
values = [
("a", 4),
("b", 5),
("c", 3),
("d", 8),
("e", 1),
]
cur.executemany("INSERT INTO test VALUES(?, ?)", values)
return con


def test_query(conn):
def double(x):
return x + x

res = list(query(conn, "SELECT double(x) FROM test", functions=[double]))
assert res[0][0] == 'aa'

res = list(query(conn, "SELECT doppel(y) FROM test", functions=[('doppel', double)]))
assert res[0][0] == 8

class strsum:
def __init__(self):
self.res = ''

def step(self, value):
self.res += value

def finalize(self):
return self.res

res = list(query(conn, "SELECT strsum(x) FROM test", aggregates=[strsum]))
assert res[0][0] == 'abcde'

res = list(query(conn, "SELECT concat(x) FROM test", aggregates=[('concat', strsum)]))
assert res[0][0] == 'abcde'

def mysort(x, y):
order = 'cdbae'
ox = order.index(x)
oy = order.index(y)
if ox == oy:
return 0 # pragma: no cover
if ox < oy:
return -1
return 1

res = list(query(conn, "SELECT x FROM test ORDER BY x COLLATE mysort", collations=[mysort]))
assert res[0][0] == 'c'

res = list(query(conn,
"SELECT x FROM test ORDER BY x COLLATE my",
collations=[('my', mysort)]))
assert res[0][0] == 'c'


def test_db_geocoords():
item = dict(cldf_latitude=decimal.Decimal(3.123456))
assert pytest.approx(
Expand Down

0 comments on commit 222da07

Please sign in to comment.