Skip to content

Commit d6a4c25

Browse files
authored
Support struct and table types in Python->Rust type encoding. #19 (#51)
* Code cleanups for basic type encoding in Python SDK. * Support struct and table types in Python->Rust type encoding. #19
1 parent 4ddb948 commit d6a4c25

File tree

1 file changed

+56
-36
lines changed

1 file changed

+56
-36
lines changed

python/cocoindex/typing.py

Lines changed: 56 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import typing
22
import collections
3+
import dataclasses
34
from typing import Annotated, NamedTuple, Any
45

56
class Vector(NamedTuple):
67
dim: int | None
78

8-
class NumBits(NamedTuple):
9-
bits: int
10-
9+
class TypeKind(NamedTuple):
10+
kind: str
1111
class TypeAttr:
1212
key: str
1313
value: Any
@@ -16,10 +16,10 @@ def __init__(self, key: str, value: Any):
1616
self.key = key
1717
self.value = value
1818

19-
Float32 = Annotated[float, NumBits(32)]
20-
Float64 = Annotated[float, NumBits(64)]
21-
Range = Annotated[tuple[int, int], 'range']
22-
Json = Annotated[Any, 'json']
19+
Float32 = Annotated[float, TypeKind('Float32')]
20+
Float64 = Annotated[float, TypeKind('Float64')]
21+
Range = Annotated[tuple[int, int], TypeKind('Range')]
22+
Json = Annotated[Any, TypeKind('Json')]
2323

2424
def _find_annotation(metadata, cls):
2525
for m in iter(metadata):
@@ -32,45 +32,65 @@ def _get_origin_type_and_metadata(t):
3232
return (t.__origin__, t.__metadata__)
3333
return (t, ())
3434

35-
def _basic_type_to_json_value(t, metadata):
35+
def _dump_fields_schema(cls: type) -> list[dict[str, Any]]:
36+
return [
37+
{
38+
'name': field.name,
39+
'value_type': _dump_enriched_type(field.type),
40+
}
41+
for field in dataclasses.fields(cls)
42+
]
43+
44+
def _dump_type(t, metadata):
3645
origin_type = typing.get_origin(t)
3746
if origin_type is collections.abc.Sequence or origin_type is list:
38-
dim = _find_annotation(metadata, Vector)
39-
if dim is None:
40-
raise ValueError(f"Vector dimension not found for {t}")
4147
args = typing.get_args(t)
42-
type_json = {
43-
'kind': 'Vector',
44-
'element_type': _basic_type_to_json_value(*_get_origin_type_and_metadata(args[0])),
45-
'dimension': dim.dim,
48+
elem_type, elem_type_metadata = _get_origin_type_and_metadata(args[0])
49+
vector_annot = _find_annotation(metadata, Vector)
50+
if vector_annot is not None:
51+
encoded_type = {
52+
'kind': 'Vector',
53+
'element_type': _dump_type(elem_type, elem_type_metadata),
54+
'dimension': vector_annot.dim,
55+
}
56+
elif dataclasses.is_dataclass(elem_type):
57+
encoded_type = {
58+
'kind': 'Table',
59+
'row': _dump_fields_schema(elem_type),
60+
}
61+
else:
62+
raise ValueError(f"Unsupported type: {t}")
63+
elif dataclasses.is_dataclass(t):
64+
encoded_type = {
65+
'kind': 'Struct',
66+
'fields': _dump_fields_schema(t),
4667
}
4768
else:
48-
if t is bytes:
49-
kind = 'Bytes'
50-
elif t is str:
51-
kind = 'Str'
52-
elif t is bool:
53-
kind = 'Bool'
54-
elif t is int:
55-
kind = 'Int64'
56-
elif t is float:
57-
num_bits = _find_annotation(metadata, NumBits)
58-
kind = 'Float32' if num_bits is not None and num_bits.bits <= 32 else 'Float64'
59-
elif t is Range:
60-
kind = 'Range'
61-
elif t is Json:
62-
kind = 'Json'
69+
type_kind = _find_annotation(metadata, TypeKind)
70+
if type_kind is not None:
71+
kind = type_kind.kind
6372
else:
64-
raise ValueError(f"type unsupported yet: {t}")
65-
type_json = { 'kind': kind }
73+
if t is bytes:
74+
kind = 'Bytes'
75+
elif t is str:
76+
kind = 'Str'
77+
elif t is bool:
78+
kind = 'Bool'
79+
elif t is int:
80+
kind = 'Int64'
81+
elif t is float:
82+
kind = 'Float64'
83+
else:
84+
raise ValueError(f"type unsupported yet: {t}")
85+
encoded_type = { 'kind': kind }
6686

67-
return type_json
87+
return encoded_type
6888

69-
def _enriched_type_to_json_value(t) -> dict[str, Any] | None:
89+
def _dump_enriched_type(t) -> dict[str, Any] | None:
7090
if t is None:
7191
return None
7292
t, metadata = _get_origin_type_and_metadata(t)
73-
enriched_type_json = {'type': _basic_type_to_json_value(t, metadata)}
93+
enriched_type_json = {'type': _dump_type(t, metadata)}
7494
attrs = None
7595
for attr in metadata:
7696
if isinstance(attr, TypeAttr):
@@ -86,4 +106,4 @@ def dump_type(t) -> dict[str, Any] | None:
86106
"""
87107
Convert a Python type to a CocoIndex's type in JSON.
88108
"""
89-
return _enriched_type_to_json_value(t)
109+
return _dump_enriched_type(t)

0 commit comments

Comments
 (0)