Skip to content

Commit 57f689e

Browse files
committed
refactor: merge dataclass and namedtuple as one uniform struct_type
1 parent 06091cb commit 57f689e

File tree

2 files changed

+16
-29
lines changed

2 files changed

+16
-29
lines changed

python/cocoindex/convert.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -57,17 +57,16 @@ def make_engine_value_decoder(
5757
f"Type mismatch for `{''.join(field_path)}`: "
5858
f"passed in {src_type_kind}, declared {dst_annotation} ({dst_type_info.kind})")
5959

60-
if dst_type_info.dataclass_type is not None or dst_type_info.namedtuple_type is not None:
61-
struct_type = dst_type_info.dataclass_type or dst_type_info.namedtuple_type
60+
if dst_type_info.struct_type is not None:
6261
return _make_engine_struct_value_decoder(
63-
field_path, src_type['fields'], struct_type)
62+
field_path, src_type['fields'], dst_type_info.struct_type)
6463

6564
if src_type_kind in TABLE_TYPES:
6665
field_path.append('[*]')
6766
elem_type_info = analyze_type_info(dst_type_info.elem_type)
68-
if elem_type_info.dataclass_type is None and elem_type_info.namedtuple_type is None:
67+
if elem_type_info.struct_type is None:
6968
raise ValueError(f"Type mismatch for `{''.join(field_path)}`: "
70-
f"declared `{dst_type_info.kind}`, a dataclass or namedtuple type expected")
69+
f"declared `{dst_type_info.kind}`, a dataclass or NamedTuple type expected")
7170
engine_fields_schema = src_type['row']['fields']
7271
if elem_type_info.key_type is not None:
7372
key_field_schema = engine_fields_schema[0]
@@ -76,16 +75,14 @@ def make_engine_value_decoder(
7675
field_path, key_field_schema['type'], elem_type_info.key_type)
7776
field_path.pop()
7877
value_decoder = _make_engine_struct_value_decoder(
79-
field_path, engine_fields_schema[1:],
80-
elem_type_info.dataclass_type or elem_type_info.namedtuple_type)
78+
field_path, engine_fields_schema[1:], elem_type_info.struct_type)
8179
def decode(value):
8280
if value is None:
8381
return None
8482
return {key_decoder(v[0]): value_decoder(v[1:]) for v in value}
8583
else:
8684
elem_decoder = _make_engine_struct_value_decoder(
87-
field_path, engine_fields_schema,
88-
elem_type_info.dataclass_type or elem_type_info.namedtuple_type)
85+
field_path, engine_fields_schema, elem_type_info.struct_type)
8986
def decode(value):
9087
if value is None:
9188
return None
@@ -144,12 +141,8 @@ def make_closure_for_value(name: str, param: inspect.Parameter) -> Callable[[lis
144141
make_closure_for_value(name, param)
145142
for (name, param) in parameters.items()]
146143

147-
if is_dataclass:
148-
return lambda values: dst_struct_type(
149-
*(decoder(values) for decoder in field_value_decoder))
150-
else: # namedtuple
151-
return lambda values: dst_struct_type(
152-
*(decoder(values) for decoder in field_value_decoder))
144+
return lambda values: dst_struct_type(
145+
*(decoder(values) for decoder in field_value_decoder))
153146

154147
def dump_engine_object(v: Any) -> Any:
155148
"""Recursively dump an object for engine. Engine side uses `Pythonized` to catch."""

python/cocoindex/typing.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,7 @@ class AnalyzedTypeInfo:
7272
elem_type: ElementType | None # For Vector and Table
7373

7474
key_type: type | None # For element of KTable
75-
dataclass_type: type | None # For Struct
76-
namedtuple_type: type | None # For Struct
75+
struct_type: type | None # For Struct, a dataclass or namedtuple
7776

7877
attrs: dict[str, Any] | None
7978
nullable: bool = False
@@ -121,19 +120,16 @@ def analyze_type_info(t) -> AnalyzedTypeInfo:
121120
elif isinstance(attr, TypeKind):
122121
kind = attr.kind
123122

124-
dataclass_type = None
125-
namedtuple_type = None
123+
struct_type = None
126124
elem_type = None
127125
key_type = None
128126
if _is_struct_type(t):
127+
struct_type = t
128+
129129
if kind is None:
130130
kind = 'Struct'
131131
elif kind != 'Struct':
132132
raise ValueError(f"Unexpected type kind for struct: {kind}")
133-
if dataclasses.is_dataclass(t):
134-
dataclass_type = t
135-
elif is_namedtuple_type(t):
136-
namedtuple_type = t
137133
elif base_type is collections.abc.Sequence or base_type is list:
138134
args = typing.get_args(t)
139135
elem_type = args[0]
@@ -180,8 +176,7 @@ def analyze_type_info(t) -> AnalyzedTypeInfo:
180176
vector_info=vector_info,
181177
elem_type=elem_type,
182178
key_type=key_type,
183-
dataclass_type=dataclass_type,
184-
namedtuple_type=namedtuple_type,
179+
struct_type=struct_type,
185180
attrs=attrs,
186181
nullable=nullable,
187182
)
@@ -216,11 +211,10 @@ def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]:
216211
encoded_type: dict[str, Any] = { 'kind': type_info.kind }
217212

218213
if type_info.kind == 'Struct':
219-
struct_type = type_info.dataclass_type or type_info.namedtuple_type
220-
if struct_type is None:
214+
if type_info.struct_type is None:
221215
raise ValueError("Struct type must have a dataclass or namedtuple type")
222-
encoded_type['fields'] = _encode_fields_schema(struct_type, type_info.key_type)
223-
if doc := inspect.getdoc(struct_type):
216+
encoded_type['fields'] = _encode_fields_schema(type_info.struct_type, type_info.key_type)
217+
if doc := inspect.getdoc(type_info.struct_type):
224218
encoded_type['description'] = doc
225219

226220
elif type_info.kind == 'Vector':

0 commit comments

Comments
 (0)