Skip to content

feat: support namedtuple for struct types #462

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 31 additions & 9 deletions docs/docs/core/data_types.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -53,23 +53,41 @@ The native Python type is always more permissive and can represent a superset of
you can choose whatever to use.
The native Python type is usually simpler.

### Struct Type
### Struct Types

A Struct has a bunch of fields, each with a name and a type.

In Python, a Struct type is represented by a [dataclass](https://docs.python.org/3/library/dataclasses.html),
and all fields must be annotated with a specific type. For example:
In Python, a Struct type is represented by either a [dataclass](https://docs.python.org/3/library/dataclasses.html)
or a [NamedTuple](https://docs.python.org/3/library/typing.html#typing.NamedTuple), with all fields annotated with a specific type.
Both options define a structured type with named fields, but they differ slightly:

- **Dataclass**: A flexible class-based structure, mutable by default, defined using the `@dataclass` decorator.
- **NamedTuple**: An immutable tuple-based structure, defined using `typing.NamedTuple`.

For example:

```python
from dataclasses import dataclass
from typing import NamedTuple
import datetime

# Using dataclass
@dataclass
class Person:
first_name: str
last_name
last_name: str
dob: datetime.date

# Using NamedTuple
class PersonTuple(NamedTuple):
first_name: str
last_name: str
dob: datetime.date
```

Both `Person` and `PersonTuple` are valid Struct types in CocoIndex, with identical schemas (three fields: `first_name` (Str), `last_name` (Str), `dob` (Date)).
Choose `dataclass` for mutable objects or when you need additional methods, and `NamedTuple` for immutable, lightweight structures.

### Table Types

A Table type models a collection of rows, each with multiple columns.
Expand All @@ -84,20 +102,24 @@ The row order of a KTable is not preserved.
Type of the first column (key column) must be a [key type](#key-types).

In Python, a KTable type is represented by `dict[K, V]`.
The `V` should be a dataclass, representing the value fields of each row.
For example, you can use `dict[str, Person]` to represent a KTable, with 4 columns: key (Str), `first_name` (Str), `last_name` (Str), `dob` (Date).
The `V` should be a struct type, either a `dataclass` or `NamedTuple`, representing the value fields of each row.
For example, you can use `dict[str, Person]` or `dict[str, PersonTuple]` to represent a KTable, with 4 columns: key (Str), `first_name` (Str), `last_name` (Str), `dob` (Date).

Note that if you want to use a struct as the key, you need to annotate the struct with `@dataclass(frozen=True)`, so the values are immutable.
Note that if you want to use a struct as the key, you need to ensure the struct is immutable. For `dataclass`, annotate it with `@dataclass(frozen=True)`. For `NamedTuple`, immutability is built-in.
For example:

```python
@dataclass(frozen=True)
class PersonKey:
id_kind: str
id: str

class PersonKeyTuple(NamedTuple):
id_kind: str
id: str
```

Then you can use `dict[PersonKey, Person]` to represent a KTable keyed by `PersonKey`.
Then you can use `dict[PersonKey, Person]` or `dict[PersonKeyTuple, PersonTuple]` to represent a KTable keyed by `PersonKey` or `PersonKeyTuple`.


#### LTable
Expand All @@ -118,4 +140,4 @@ Currently, the following types are key types
- Range
- Uuid
- Date
- Struct with all fields being key types
- Struct with all fields being key types (using `@dataclass(frozen=True)` or `NamedTuple`)
44 changes: 33 additions & 11 deletions python/cocoindex/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@

from enum import Enum
from typing import Any, Callable, get_origin
from .typing import analyze_type_info, encode_enriched_type, TABLE_TYPES, KEY_FIELD_NAME
from .typing import analyze_type_info, encode_enriched_type, is_namedtuple_type, TABLE_TYPES, KEY_FIELD_NAME


def encode_engine_value(value: Any) -> Any:
"""Encode a Python value to an engine value."""
if dataclasses.is_dataclass(value):
return [encode_engine_value(getattr(value, f.name)) for f in dataclasses.fields(value)]
if is_namedtuple_type(type(value)):
return [encode_engine_value(getattr(value, name)) for name in value._fields]
if isinstance(value, (list, tuple)):
return [encode_engine_value(v) for v in value]
if isinstance(value, dict):
Expand Down Expand Up @@ -55,16 +57,16 @@ def make_engine_value_decoder(
f"Type mismatch for `{''.join(field_path)}`: "
f"passed in {src_type_kind}, declared {dst_annotation} ({dst_type_info.kind})")

if dst_type_info.dataclass_type is not None:
if dst_type_info.struct_type is not None:
return _make_engine_struct_value_decoder(
field_path, src_type['fields'], dst_type_info.dataclass_type)
field_path, src_type['fields'], dst_type_info.struct_type)

if src_type_kind in TABLE_TYPES:
field_path.append('[*]')
elem_type_info = analyze_type_info(dst_type_info.elem_type)
if elem_type_info.dataclass_type is None:
if elem_type_info.struct_type is None:
raise ValueError(f"Type mismatch for `{''.join(field_path)}`: "
f"declared `{dst_type_info.kind}`, a dataclass type expected")
f"declared `{dst_type_info.kind}`, a dataclass or NamedTuple type expected")
engine_fields_schema = src_type['row']['fields']
if elem_type_info.key_type is not None:
key_field_schema = engine_fields_schema[0]
Expand All @@ -73,14 +75,14 @@ def make_engine_value_decoder(
field_path, key_field_schema['type'], elem_type_info.key_type)
field_path.pop()
value_decoder = _make_engine_struct_value_decoder(
field_path, engine_fields_schema[1:], elem_type_info.dataclass_type)
field_path, engine_fields_schema[1:], elem_type_info.struct_type)
def decode(value):
if value is None:
return None
return {key_decoder(v[0]): value_decoder(v[1:]) for v in value}
else:
elem_decoder = _make_engine_struct_value_decoder(
field_path, engine_fields_schema, elem_type_info.dataclass_type)
field_path, engine_fields_schema, elem_type_info.struct_type)
def decode(value):
if value is None:
return None
Expand All @@ -96,19 +98,39 @@ def decode(value):
def _make_engine_struct_value_decoder(
field_path: list[str],
src_fields: list[dict[str, Any]],
dst_dataclass_type: type,
dst_struct_type: type,
) -> Callable[[list], Any]:
"""Make a decoder from an engine field values to a Python value."""

src_name_to_idx = {f['name']: i for i, f in enumerate(src_fields)}

is_dataclass = dataclasses.is_dataclass(dst_struct_type)
is_namedtuple = is_namedtuple_type(dst_struct_type)

if is_dataclass:
parameters = inspect.signature(dst_struct_type).parameters
elif is_namedtuple:
defaults = getattr(dst_struct_type, '_field_defaults', {})
parameters = {
name: inspect.Parameter(
name=name,
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=defaults.get(name, inspect.Parameter.empty),
annotation=dst_struct_type.__annotations__.get(name, inspect.Parameter.empty)
)
for name in dst_struct_type._fields
}
else:
raise ValueError(f"Unsupported struct type: {dst_struct_type}")

def make_closure_for_value(name: str, param: inspect.Parameter) -> Callable[[list], Any]:
src_idx = src_name_to_idx.get(name)
if src_idx is not None:
field_path.append(f'.{name}')
field_decoder = make_engine_value_decoder(
field_path, src_fields[src_idx]['type'], param.annotation)
field_path.pop()
return lambda values: field_decoder(values[src_idx])
return lambda values: field_decoder(values[src_idx]) if len(values) > src_idx else param.default

default_value = param.default
if default_value is inspect.Parameter.empty:
Expand All @@ -119,9 +141,9 @@ def make_closure_for_value(name: str, param: inspect.Parameter) -> Callable[[lis

field_value_decoder = [
make_closure_for_value(name, param)
for (name, param) in inspect.signature(dst_dataclass_type).parameters.items()]
for (name, param) in parameters.items()]

return lambda values: dst_dataclass_type(
return lambda values: dst_struct_type(
*(decoder(values) for decoder in field_value_decoder))

def dump_engine_object(v: Any) -> Any:
Expand Down
73 changes: 70 additions & 3 deletions python/cocoindex/tests/test_convert.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import uuid
import datetime
from dataclasses import dataclass, make_dataclass
from typing import NamedTuple, Literal
import pytest
import cocoindex
from cocoindex.typing import encode_enriched_type
from cocoindex.convert import encode_engine_value, make_engine_value_decoder
from typing import Literal

@dataclass
class Order:
order_id: str
Expand Down Expand Up @@ -33,6 +34,17 @@ class NestedStruct:
orders: list[Order]
count: int = 0

class OrderNamedTuple(NamedTuple):
order_id: str
name: str
price: float
extra_field: str = "default_extra"

class CustomerNamedTuple(NamedTuple):
name: str
order: OrderNamedTuple
tags: list[Tag] | None = None

def build_engine_value_decoder(engine_type_in_py, python_type=None):
"""
Helper to build a converter for the given engine-side type (as represented in Python).
Expand Down Expand Up @@ -62,10 +74,16 @@ def test_encode_engine_value_date_time_types():
def test_encode_engine_value_struct():
order = Order(order_id="O123", name="mixed nuts", price=25.0)
assert encode_engine_value(order) == ["O123", "mixed nuts", 25.0, "default_extra"]

order_nt = OrderNamedTuple(order_id="O123", name="mixed nuts", price=25.0)
assert encode_engine_value(order_nt) == ["O123", "mixed nuts", 25.0, "default_extra"]

def test_encode_engine_value_list_of_structs():
orders = [Order("O1", "item1", 10.0), Order("O2", "item2", 20.0)]
assert encode_engine_value(orders) == [["O1", "item1", 10.0, "default_extra"], ["O2", "item2", 20.0, "default_extra"]]

orders_nt = [OrderNamedTuple("O1", "item1", 10.0), OrderNamedTuple("O2", "item2", 20.0)]
assert encode_engine_value(orders_nt) == [["O1", "item1", 10.0, "default_extra"], ["O2", "item2", 20.0, "default_extra"]]

def test_encode_engine_value_struct_with_list():
basket = Basket(items=["apple", "banana"])
Expand All @@ -74,6 +92,9 @@ def test_encode_engine_value_struct_with_list():
def test_encode_engine_value_nested_struct():
customer = Customer(name="Alice", order=Order("O1", "item1", 10.0))
assert encode_engine_value(customer) == ["Alice", ["O1", "item1", 10.0, "default_extra"], None]

customer_nt = CustomerNamedTuple(name="Alice", order=OrderNamedTuple("O1", "item1", 10.0))
assert encode_engine_value(customer_nt) == ["Alice", ["O1", "item1", 10.0, "default_extra"], None]

def test_encode_engine_value_empty_list():
assert encode_engine_value([]) == []
Expand Down Expand Up @@ -103,38 +124,62 @@ def test_make_engine_value_decoder_basic_types():
@pytest.mark.parametrize(
"data_type, engine_val, expected",
[
# All fields match
# All fields match (dataclass)
(Order, ["O123", "mixed nuts", 25.0, "default_extra"], Order("O123", "mixed nuts", 25.0, "default_extra")),
# All fields match (NamedTuple)
(OrderNamedTuple, ["O123", "mixed nuts", 25.0, "default_extra"], OrderNamedTuple("O123", "mixed nuts", 25.0, "default_extra")),
# Extra field in engine value (should ignore extra)
(Order, ["O123", "mixed nuts", 25.0, "default_extra", "unexpected"], Order("O123", "mixed nuts", 25.0, "default_extra")),
(OrderNamedTuple, ["O123", "mixed nuts", 25.0, "default_extra", "unexpected"], OrderNamedTuple("O123", "mixed nuts", 25.0, "default_extra")),
# Fewer fields in engine value (should fill with default)
(Order, ["O123", "mixed nuts", 0.0, "default_extra"], Order("O123", "mixed nuts", 0.0, "default_extra")),
(OrderNamedTuple, ["O123", "mixed nuts", 0.0, "default_extra"], OrderNamedTuple("O123", "mixed nuts", 0.0, "default_extra")),
# More fields in engine value (should ignore extra)
(Order, ["O123", "mixed nuts", 25.0, "unexpected"], Order("O123", "mixed nuts", 25.0, "unexpected")),
(OrderNamedTuple, ["O123", "mixed nuts", 25.0, "unexpected"], OrderNamedTuple("O123", "mixed nuts", 25.0, "unexpected")),
# Truly extra field (should ignore the fifth field)
(Order, ["O123", "mixed nuts", 25.0, "default_extra", "ignored"], Order("O123", "mixed nuts", 25.0, "default_extra")),
(OrderNamedTuple, ["O123", "mixed nuts", 25.0, "default_extra", "ignored"], OrderNamedTuple("O123", "mixed nuts", 25.0, "default_extra")),
# Missing optional field in engine value (tags=None)
(Customer, ["Alice", ["O1", "item1", 10.0, "default_extra"], None], Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), None)),
(CustomerNamedTuple, ["Alice", ["O1", "item1", 10.0, "default_extra"], None], CustomerNamedTuple("Alice", OrderNamedTuple("O1", "item1", 10.0, "default_extra"), None)),
# Extra field in engine value for Customer (should ignore)
(Customer, ["Alice", ["O1", "item1", 10.0, "default_extra"], [["vip"]], "extra"], Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), [Tag("vip")])),
(CustomerNamedTuple, ["Alice", ["O1", "item1", 10.0, "default_extra"], [["vip"]], "extra"], CustomerNamedTuple("Alice", OrderNamedTuple("O1", "item1", 10.0, "default_extra"), [Tag("vip")])),
# Missing optional field with default
(Order, ["O123", "mixed nuts", 25.0], Order("O123", "mixed nuts", 25.0, "default_extra")),
(OrderNamedTuple, ["O123", "mixed nuts", 25.0], OrderNamedTuple("O123", "mixed nuts", 25.0, "default_extra")),
# Partial optional fields
(Customer, ["Alice", ["O1", "item1", 10.0]], Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), None)),
(CustomerNamedTuple, ["Alice", ["O1", "item1", 10.0]], CustomerNamedTuple("Alice", OrderNamedTuple("O1", "item1", 10.0, "default_extra"), None)),
]
)
def test_struct_decoder_cases(data_type, engine_val, expected):
decoder = build_engine_value_decoder(data_type)
assert decoder(engine_val) == expected

def test_make_engine_value_decoder_collections():
# List of structs
# List of structs (dataclass)
decoder = build_engine_value_decoder(list[Order])
engine_val = [
["O1", "item1", 10.0, "default_extra"],
["O2", "item2", 20.0, "default_extra"]
]
assert decoder(engine_val) == [Order("O1", "item1", 10.0, "default_extra"), Order("O2", "item2", 20.0, "default_extra")]

# List of structs (NamedTuple)
decoder = build_engine_value_decoder(list[OrderNamedTuple])
assert decoder(engine_val) == [OrderNamedTuple("O1", "item1", 10.0, "default_extra"), OrderNamedTuple("O2", "item2", 20.0, "default_extra")]

# Struct with list field
decoder = build_engine_value_decoder(Customer)
engine_val = ["Alice", ["O1", "item1", 10.0, "default_extra"], [["vip"], ["premium"]]]
assert decoder(engine_val) == Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), [Tag("vip"), Tag("premium")])

# NamedTuple with list field
decoder = build_engine_value_decoder(CustomerNamedTuple)
assert decoder(engine_val) == CustomerNamedTuple("Alice", OrderNamedTuple("O1", "item1", 10.0, "default_extra"), [Tag("vip"), Tag("premium")])

# Struct with struct field
decoder = build_engine_value_decoder(NestedStruct)
engine_val = [
Expand Down Expand Up @@ -239,6 +284,13 @@ def test_roundtrip_ltable():
assert encoded == [["O1", "item1", 10.0, "default_extra"], ["O2", "item2", 20.0, "default_extra"]]
decoded = build_engine_value_decoder(t)(encoded)
assert decoded == value

t_nt = list[OrderNamedTuple]
value_nt = [OrderNamedTuple("O1", "item1", 10.0), OrderNamedTuple("O2", "item2", 20.0)]
encoded = encode_engine_value(value_nt)
assert encoded == [["O1", "item1", 10.0, "default_extra"], ["O2", "item2", 20.0, "default_extra"]]
decoded = build_engine_value_decoder(t_nt)(encoded)
assert decoded == value_nt

def test_roundtrip_ktable_str_key():
t = dict[str, Order]
Expand All @@ -247,6 +299,13 @@ def test_roundtrip_ktable_str_key():
assert encoded == [["K1", "O1", "item1", 10.0, "default_extra"], ["K2", "O2", "item2", 20.0, "default_extra"]]
decoded = build_engine_value_decoder(t)(encoded)
assert decoded == value

t_nt = dict[str, OrderNamedTuple]
value_nt = {"K1": OrderNamedTuple("O1", "item1", 10.0), "K2": OrderNamedTuple("O2", "item2", 20.0)}
encoded = encode_engine_value(value_nt)
assert encoded == [["K1", "O1", "item1", 10.0, "default_extra"], ["K2", "O2", "item2", 20.0, "default_extra"]]
decoded = build_engine_value_decoder(t_nt)(encoded)
assert decoded == value_nt

def test_roundtrip_ktable_struct_key():
@dataclass(frozen=True)
Expand All @@ -261,6 +320,14 @@ class OrderKey:
[["B", 4], "O2", "item2", 20.0, "default_extra"]]
decoded = build_engine_value_decoder(t)(encoded)
assert decoded == value

t_nt = dict[OrderKey, OrderNamedTuple]
value_nt = {OrderKey("A", 3): OrderNamedTuple("O1", "item1", 10.0), OrderKey("B", 4): OrderNamedTuple("O2", "item2", 20.0)}
encoded = encode_engine_value(value_nt)
assert encoded == [[["A", 3], "O1", "item1", 10.0, "default_extra"],
[["B", 4], "O2", "item2", 20.0, "default_extra"]]
decoded = build_engine_value_decoder(t_nt)(encoded)
assert decoded == value_nt

IntVectorType = cocoindex.Vector[int, Literal[5]]
def test_vector_as_vector() -> None:
Expand Down
Loading
Loading