Skip to content

Commit aa9ab60

Browse files
authored
feat: support namedtuple for struct types (#462)
* feat: add support for defining struct types via namedtuples * test: add unit tests for checking namedtuples' support * docs: add corresponding instructions for using NamedTuple to define a struct * docs: make immutability requirements for struct keys clearer * refactor: merge dataclass and namedtuple as one uniform `struct_type` * feat: handle default values for namedtuple params in decoder
1 parent 61c759c commit aa9ab60

File tree

4 files changed

+169
-40
lines changed

4 files changed

+169
-40
lines changed

docs/docs/core/data_types.mdx

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,23 +53,41 @@ The native Python type is always more permissive and can represent a superset of
5353
you can choose whatever to use.
5454
The native Python type is usually simpler.
5555

56-
### Struct Type
56+
### Struct Types
5757

5858
A Struct has a bunch of fields, each with a name and a type.
5959

60-
In Python, a Struct type is represented by a [dataclass](https://docs.python.org/3/library/dataclasses.html),
61-
and all fields must be annotated with a specific type. For example:
60+
In Python, a Struct type is represented by either a [dataclass](https://docs.python.org/3/library/dataclasses.html)
61+
or a [NamedTuple](https://docs.python.org/3/library/typing.html#typing.NamedTuple), with all fields annotated with a specific type.
62+
Both options define a structured type with named fields, but they differ slightly:
63+
64+
- **Dataclass**: A flexible class-based structure, mutable by default, defined using the `@dataclass` decorator.
65+
- **NamedTuple**: An immutable tuple-based structure, defined using `typing.NamedTuple`.
66+
67+
For example:
6268

6369
```python
6470
from dataclasses import dataclass
71+
from typing import NamedTuple
72+
import datetime
6573

74+
# Using dataclass
6675
@dataclass
6776
class Person:
6877
first_name: str
69-
last_name
78+
last_name: str
79+
dob: datetime.date
80+
81+
# Using NamedTuple
82+
class PersonTuple(NamedTuple):
83+
first_name: str
84+
last_name: str
7085
dob: datetime.date
7186
```
7287

88+
Both `Person` and `PersonTuple` are valid Struct types in CocoIndex, with identical schemas (three fields: `first_name` (Str), `last_name` (Str), `dob` (Date)).
89+
Choose `dataclass` for mutable objects or when you need additional methods, and `NamedTuple` for immutable, lightweight structures.
90+
7391
### Table Types
7492

7593
A Table type models a collection of rows, each with multiple columns.
@@ -84,20 +102,24 @@ The row order of a KTable is not preserved.
84102
Type of the first column (key column) must be a [key type](#key-types).
85103

86104
In Python, a KTable type is represented by `dict[K, V]`.
87-
The `V` should be a dataclass, representing the value fields of each row.
88-
For example, you can use `dict[str, Person]` to represent a KTable, with 4 columns: key (Str), `first_name` (Str), `last_name` (Str), `dob` (Date).
105+
The `V` should be a struct type, either a `dataclass` or `NamedTuple`, representing the value fields of each row.
106+
For example, you can use `dict[str, Person]` or `dict[str, PersonTuple]` to represent a KTable, with 4 columns: key (Str), `first_name` (Str), `last_name` (Str), `dob` (Date).
89107

90-
Note that if you want to use a struct as the key, you need to annotate the struct with `@dataclass(frozen=True)`, so the values are immutable.
108+
Note that if you want to use a struct as the key, you need to ensure the struct is immutable. For `dataclass`, annotate it with `@dataclass(frozen=True)`. For `NamedTuple`, immutability is built-in.
91109
For example:
92110

93111
```python
94112
@dataclass(frozen=True)
95113
class PersonKey:
96114
id_kind: str
97115
id: str
116+
117+
class PersonKeyTuple(NamedTuple):
118+
id_kind: str
119+
id: str
98120
```
99121

100-
Then you can use `dict[PersonKey, Person]` to represent a KTable keyed by `PersonKey`.
122+
Then you can use `dict[PersonKey, Person]` or `dict[PersonKeyTuple, PersonTuple]` to represent a KTable keyed by `PersonKey` or `PersonKeyTuple`.
101123

102124

103125
#### LTable
@@ -118,4 +140,4 @@ Currently, the following types are key types
118140
- Range
119141
- Uuid
120142
- Date
121-
- Struct with all fields being key types
143+
- Struct with all fields being key types (using `@dataclass(frozen=True)` or `NamedTuple`)

python/cocoindex/convert.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@
88

99
from enum import Enum
1010
from typing import Any, Callable, get_origin
11-
from .typing import analyze_type_info, encode_enriched_type, TABLE_TYPES, KEY_FIELD_NAME
11+
from .typing import analyze_type_info, encode_enriched_type, is_namedtuple_type, TABLE_TYPES, KEY_FIELD_NAME
1212

1313

1414
def encode_engine_value(value: Any) -> Any:
1515
"""Encode a Python value to an engine value."""
1616
if dataclasses.is_dataclass(value):
1717
return [encode_engine_value(getattr(value, f.name)) for f in dataclasses.fields(value)]
18+
if is_namedtuple_type(type(value)):
19+
return [encode_engine_value(getattr(value, name)) for name in value._fields]
1820
if isinstance(value, (list, tuple)):
1921
return [encode_engine_value(v) for v in value]
2022
if isinstance(value, dict):
@@ -55,16 +57,16 @@ def make_engine_value_decoder(
5557
f"Type mismatch for `{''.join(field_path)}`: "
5658
f"passed in {src_type_kind}, declared {dst_annotation} ({dst_type_info.kind})")
5759

58-
if dst_type_info.dataclass_type is not None:
60+
if dst_type_info.struct_type is not None:
5961
return _make_engine_struct_value_decoder(
60-
field_path, src_type['fields'], dst_type_info.dataclass_type)
62+
field_path, src_type['fields'], dst_type_info.struct_type)
6163

6264
if src_type_kind in TABLE_TYPES:
6365
field_path.append('[*]')
6466
elem_type_info = analyze_type_info(dst_type_info.elem_type)
65-
if elem_type_info.dataclass_type is None:
67+
if elem_type_info.struct_type is None:
6668
raise ValueError(f"Type mismatch for `{''.join(field_path)}`: "
67-
f"declared `{dst_type_info.kind}`, a dataclass type expected")
69+
f"declared `{dst_type_info.kind}`, a dataclass or NamedTuple type expected")
6870
engine_fields_schema = src_type['row']['fields']
6971
if elem_type_info.key_type is not None:
7072
key_field_schema = engine_fields_schema[0]
@@ -73,14 +75,14 @@ def make_engine_value_decoder(
7375
field_path, key_field_schema['type'], elem_type_info.key_type)
7476
field_path.pop()
7577
value_decoder = _make_engine_struct_value_decoder(
76-
field_path, engine_fields_schema[1:], elem_type_info.dataclass_type)
78+
field_path, engine_fields_schema[1:], elem_type_info.struct_type)
7779
def decode(value):
7880
if value is None:
7981
return None
8082
return {key_decoder(v[0]): value_decoder(v[1:]) for v in value}
8183
else:
8284
elem_decoder = _make_engine_struct_value_decoder(
83-
field_path, engine_fields_schema, elem_type_info.dataclass_type)
85+
field_path, engine_fields_schema, elem_type_info.struct_type)
8486
def decode(value):
8587
if value is None:
8688
return None
@@ -96,19 +98,39 @@ def decode(value):
9698
def _make_engine_struct_value_decoder(
9799
field_path: list[str],
98100
src_fields: list[dict[str, Any]],
99-
dst_dataclass_type: type,
101+
dst_struct_type: type,
100102
) -> Callable[[list], Any]:
101103
"""Make a decoder from an engine field values to a Python value."""
102104

103105
src_name_to_idx = {f['name']: i for i, f in enumerate(src_fields)}
106+
107+
is_dataclass = dataclasses.is_dataclass(dst_struct_type)
108+
is_namedtuple = is_namedtuple_type(dst_struct_type)
109+
110+
if is_dataclass:
111+
parameters = inspect.signature(dst_struct_type).parameters
112+
elif is_namedtuple:
113+
defaults = getattr(dst_struct_type, '_field_defaults', {})
114+
parameters = {
115+
name: inspect.Parameter(
116+
name=name,
117+
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
118+
default=defaults.get(name, inspect.Parameter.empty),
119+
annotation=dst_struct_type.__annotations__.get(name, inspect.Parameter.empty)
120+
)
121+
for name in dst_struct_type._fields
122+
}
123+
else:
124+
raise ValueError(f"Unsupported struct type: {dst_struct_type}")
125+
104126
def make_closure_for_value(name: str, param: inspect.Parameter) -> Callable[[list], Any]:
105127
src_idx = src_name_to_idx.get(name)
106128
if src_idx is not None:
107129
field_path.append(f'.{name}')
108130
field_decoder = make_engine_value_decoder(
109131
field_path, src_fields[src_idx]['type'], param.annotation)
110132
field_path.pop()
111-
return lambda values: field_decoder(values[src_idx])
133+
return lambda values: field_decoder(values[src_idx]) if len(values) > src_idx else param.default
112134

113135
default_value = param.default
114136
if default_value is inspect.Parameter.empty:
@@ -119,9 +141,9 @@ def make_closure_for_value(name: str, param: inspect.Parameter) -> Callable[[lis
119141

120142
field_value_decoder = [
121143
make_closure_for_value(name, param)
122-
for (name, param) in inspect.signature(dst_dataclass_type).parameters.items()]
144+
for (name, param) in parameters.items()]
123145

124-
return lambda values: dst_dataclass_type(
146+
return lambda values: dst_struct_type(
125147
*(decoder(values) for decoder in field_value_decoder))
126148

127149
def dump_engine_object(v: Any) -> Any:

python/cocoindex/tests/test_convert.py

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import uuid
22
import datetime
33
from dataclasses import dataclass, make_dataclass
4+
from typing import NamedTuple, Literal
45
import pytest
56
import cocoindex
67
from cocoindex.typing import encode_enriched_type
78
from cocoindex.convert import encode_engine_value, make_engine_value_decoder
8-
from typing import Literal
9+
910
@dataclass
1011
class Order:
1112
order_id: str
@@ -33,6 +34,17 @@ class NestedStruct:
3334
orders: list[Order]
3435
count: int = 0
3536

37+
class OrderNamedTuple(NamedTuple):
38+
order_id: str
39+
name: str
40+
price: float
41+
extra_field: str = "default_extra"
42+
43+
class CustomerNamedTuple(NamedTuple):
44+
name: str
45+
order: OrderNamedTuple
46+
tags: list[Tag] | None = None
47+
3648
def build_engine_value_decoder(engine_type_in_py, python_type=None):
3749
"""
3850
Helper to build a converter for the given engine-side type (as represented in Python).
@@ -62,10 +74,16 @@ def test_encode_engine_value_date_time_types():
6274
def test_encode_engine_value_struct():
6375
order = Order(order_id="O123", name="mixed nuts", price=25.0)
6476
assert encode_engine_value(order) == ["O123", "mixed nuts", 25.0, "default_extra"]
77+
78+
order_nt = OrderNamedTuple(order_id="O123", name="mixed nuts", price=25.0)
79+
assert encode_engine_value(order_nt) == ["O123", "mixed nuts", 25.0, "default_extra"]
6580

6681
def test_encode_engine_value_list_of_structs():
6782
orders = [Order("O1", "item1", 10.0), Order("O2", "item2", 20.0)]
6883
assert encode_engine_value(orders) == [["O1", "item1", 10.0, "default_extra"], ["O2", "item2", 20.0, "default_extra"]]
84+
85+
orders_nt = [OrderNamedTuple("O1", "item1", 10.0), OrderNamedTuple("O2", "item2", 20.0)]
86+
assert encode_engine_value(orders_nt) == [["O1", "item1", 10.0, "default_extra"], ["O2", "item2", 20.0, "default_extra"]]
6987

7088
def test_encode_engine_value_struct_with_list():
7189
basket = Basket(items=["apple", "banana"])
@@ -74,6 +92,9 @@ def test_encode_engine_value_struct_with_list():
7492
def test_encode_engine_value_nested_struct():
7593
customer = Customer(name="Alice", order=Order("O1", "item1", 10.0))
7694
assert encode_engine_value(customer) == ["Alice", ["O1", "item1", 10.0, "default_extra"], None]
95+
96+
customer_nt = CustomerNamedTuple(name="Alice", order=OrderNamedTuple("O1", "item1", 10.0))
97+
assert encode_engine_value(customer_nt) == ["Alice", ["O1", "item1", 10.0, "default_extra"], None]
7798

7899
def test_encode_engine_value_empty_list():
79100
assert encode_engine_value([]) == []
@@ -103,38 +124,62 @@ def test_make_engine_value_decoder_basic_types():
103124
@pytest.mark.parametrize(
104125
"data_type, engine_val, expected",
105126
[
106-
# All fields match
127+
# All fields match (dataclass)
107128
(Order, ["O123", "mixed nuts", 25.0, "default_extra"], Order("O123", "mixed nuts", 25.0, "default_extra")),
129+
# All fields match (NamedTuple)
130+
(OrderNamedTuple, ["O123", "mixed nuts", 25.0, "default_extra"], OrderNamedTuple("O123", "mixed nuts", 25.0, "default_extra")),
108131
# Extra field in engine value (should ignore extra)
109132
(Order, ["O123", "mixed nuts", 25.0, "default_extra", "unexpected"], Order("O123", "mixed nuts", 25.0, "default_extra")),
133+
(OrderNamedTuple, ["O123", "mixed nuts", 25.0, "default_extra", "unexpected"], OrderNamedTuple("O123", "mixed nuts", 25.0, "default_extra")),
110134
# Fewer fields in engine value (should fill with default)
111135
(Order, ["O123", "mixed nuts", 0.0, "default_extra"], Order("O123", "mixed nuts", 0.0, "default_extra")),
136+
(OrderNamedTuple, ["O123", "mixed nuts", 0.0, "default_extra"], OrderNamedTuple("O123", "mixed nuts", 0.0, "default_extra")),
112137
# More fields in engine value (should ignore extra)
113138
(Order, ["O123", "mixed nuts", 25.0, "unexpected"], Order("O123", "mixed nuts", 25.0, "unexpected")),
139+
(OrderNamedTuple, ["O123", "mixed nuts", 25.0, "unexpected"], OrderNamedTuple("O123", "mixed nuts", 25.0, "unexpected")),
114140
# Truly extra field (should ignore the fifth field)
115141
(Order, ["O123", "mixed nuts", 25.0, "default_extra", "ignored"], Order("O123", "mixed nuts", 25.0, "default_extra")),
142+
(OrderNamedTuple, ["O123", "mixed nuts", 25.0, "default_extra", "ignored"], OrderNamedTuple("O123", "mixed nuts", 25.0, "default_extra")),
116143
# Missing optional field in engine value (tags=None)
117144
(Customer, ["Alice", ["O1", "item1", 10.0, "default_extra"], None], Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), None)),
145+
(CustomerNamedTuple, ["Alice", ["O1", "item1", 10.0, "default_extra"], None], CustomerNamedTuple("Alice", OrderNamedTuple("O1", "item1", 10.0, "default_extra"), None)),
118146
# Extra field in engine value for Customer (should ignore)
119147
(Customer, ["Alice", ["O1", "item1", 10.0, "default_extra"], [["vip"]], "extra"], Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), [Tag("vip")])),
148+
(CustomerNamedTuple, ["Alice", ["O1", "item1", 10.0, "default_extra"], [["vip"]], "extra"], CustomerNamedTuple("Alice", OrderNamedTuple("O1", "item1", 10.0, "default_extra"), [Tag("vip")])),
149+
# Missing optional field with default
150+
(Order, ["O123", "mixed nuts", 25.0], Order("O123", "mixed nuts", 25.0, "default_extra")),
151+
(OrderNamedTuple, ["O123", "mixed nuts", 25.0], OrderNamedTuple("O123", "mixed nuts", 25.0, "default_extra")),
152+
# Partial optional fields
153+
(Customer, ["Alice", ["O1", "item1", 10.0]], Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), None)),
154+
(CustomerNamedTuple, ["Alice", ["O1", "item1", 10.0]], CustomerNamedTuple("Alice", OrderNamedTuple("O1", "item1", 10.0, "default_extra"), None)),
120155
]
121156
)
122157
def test_struct_decoder_cases(data_type, engine_val, expected):
123158
decoder = build_engine_value_decoder(data_type)
124159
assert decoder(engine_val) == expected
125160

126161
def test_make_engine_value_decoder_collections():
127-
# List of structs
162+
# List of structs (dataclass)
128163
decoder = build_engine_value_decoder(list[Order])
129164
engine_val = [
130165
["O1", "item1", 10.0, "default_extra"],
131166
["O2", "item2", 20.0, "default_extra"]
132167
]
133168
assert decoder(engine_val) == [Order("O1", "item1", 10.0, "default_extra"), Order("O2", "item2", 20.0, "default_extra")]
169+
170+
# List of structs (NamedTuple)
171+
decoder = build_engine_value_decoder(list[OrderNamedTuple])
172+
assert decoder(engine_val) == [OrderNamedTuple("O1", "item1", 10.0, "default_extra"), OrderNamedTuple("O2", "item2", 20.0, "default_extra")]
173+
134174
# Struct with list field
135175
decoder = build_engine_value_decoder(Customer)
136176
engine_val = ["Alice", ["O1", "item1", 10.0, "default_extra"], [["vip"], ["premium"]]]
137177
assert decoder(engine_val) == Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), [Tag("vip"), Tag("premium")])
178+
179+
# NamedTuple with list field
180+
decoder = build_engine_value_decoder(CustomerNamedTuple)
181+
assert decoder(engine_val) == CustomerNamedTuple("Alice", OrderNamedTuple("O1", "item1", 10.0, "default_extra"), [Tag("vip"), Tag("premium")])
182+
138183
# Struct with struct field
139184
decoder = build_engine_value_decoder(NestedStruct)
140185
engine_val = [
@@ -239,6 +284,13 @@ def test_roundtrip_ltable():
239284
assert encoded == [["O1", "item1", 10.0, "default_extra"], ["O2", "item2", 20.0, "default_extra"]]
240285
decoded = build_engine_value_decoder(t)(encoded)
241286
assert decoded == value
287+
288+
t_nt = list[OrderNamedTuple]
289+
value_nt = [OrderNamedTuple("O1", "item1", 10.0), OrderNamedTuple("O2", "item2", 20.0)]
290+
encoded = encode_engine_value(value_nt)
291+
assert encoded == [["O1", "item1", 10.0, "default_extra"], ["O2", "item2", 20.0, "default_extra"]]
292+
decoded = build_engine_value_decoder(t_nt)(encoded)
293+
assert decoded == value_nt
242294

243295
def test_roundtrip_ktable_str_key():
244296
t = dict[str, Order]
@@ -247,6 +299,13 @@ def test_roundtrip_ktable_str_key():
247299
assert encoded == [["K1", "O1", "item1", 10.0, "default_extra"], ["K2", "O2", "item2", 20.0, "default_extra"]]
248300
decoded = build_engine_value_decoder(t)(encoded)
249301
assert decoded == value
302+
303+
t_nt = dict[str, OrderNamedTuple]
304+
value_nt = {"K1": OrderNamedTuple("O1", "item1", 10.0), "K2": OrderNamedTuple("O2", "item2", 20.0)}
305+
encoded = encode_engine_value(value_nt)
306+
assert encoded == [["K1", "O1", "item1", 10.0, "default_extra"], ["K2", "O2", "item2", 20.0, "default_extra"]]
307+
decoded = build_engine_value_decoder(t_nt)(encoded)
308+
assert decoded == value_nt
250309

251310
def test_roundtrip_ktable_struct_key():
252311
@dataclass(frozen=True)
@@ -261,6 +320,14 @@ class OrderKey:
261320
[["B", 4], "O2", "item2", 20.0, "default_extra"]]
262321
decoded = build_engine_value_decoder(t)(encoded)
263322
assert decoded == value
323+
324+
t_nt = dict[OrderKey, OrderNamedTuple]
325+
value_nt = {OrderKey("A", 3): OrderNamedTuple("O1", "item1", 10.0), OrderKey("B", 4): OrderNamedTuple("O2", "item2", 20.0)}
326+
encoded = encode_engine_value(value_nt)
327+
assert encoded == [[["A", 3], "O1", "item1", 10.0, "default_extra"],
328+
[["B", 4], "O2", "item2", 20.0, "default_extra"]]
329+
decoded = build_engine_value_decoder(t_nt)(encoded)
330+
assert decoded == value_nt
264331

265332
IntVectorType = cocoindex.Vector[int, Literal[5]]
266333
def test_vector_as_vector() -> None:

0 commit comments

Comments
 (0)