Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support struct and table types in Python->Rust type encoding. #19 #51

Merged
merged 2 commits into from
Mar 6, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 56 additions & 36 deletions python/cocoindex/typing.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import typing
import collections
import dataclasses
from typing import Annotated, NamedTuple, Any

class Vector(NamedTuple):
dim: int | None

class NumBits(NamedTuple):
bits: int

class TypeKind(NamedTuple):
kind: str
class TypeAttr:
key: str
value: Any
Expand All @@ -16,10 +16,10 @@ def __init__(self, key: str, value: Any):
self.key = key
self.value = value

Float32 = Annotated[float, NumBits(32)]
Float64 = Annotated[float, NumBits(64)]
Range = Annotated[tuple[int, int], 'range']
Json = Annotated[Any, 'json']
Float32 = Annotated[float, TypeKind('Float32')]
Float64 = Annotated[float, TypeKind('Float64')]
Range = Annotated[tuple[int, int], TypeKind('Range')]
Json = Annotated[Any, TypeKind('Json')]

def _find_annotation(metadata, cls):
for m in iter(metadata):
Expand All @@ -32,45 +32,65 @@ def _get_origin_type_and_metadata(t):
return (t.__origin__, t.__metadata__)
return (t, ())

def _basic_type_to_json_value(t, metadata):
def _dump_fields_schema(cls: type) -> list[dict[str, Any]]:
return [
{
'name': field.name,
'value_type': _dump_enriched_type(field.type),
}
for field in dataclasses.fields(cls)
]

def _dump_type(t, metadata):
origin_type = typing.get_origin(t)
if origin_type is collections.abc.Sequence or origin_type is list:
dim = _find_annotation(metadata, Vector)
if dim is None:
raise ValueError(f"Vector dimension not found for {t}")
args = typing.get_args(t)
type_json = {
'kind': 'Vector',
'element_type': _basic_type_to_json_value(*_get_origin_type_and_metadata(args[0])),
'dimension': dim.dim,
elem_type, elem_type_metadata = _get_origin_type_and_metadata(args[0])
vector_annot = _find_annotation(metadata, Vector)
if vector_annot is not None:
encoded_type = {
'kind': 'Vector',
'element_type': _dump_type(elem_type, elem_type_metadata),
'dimension': vector_annot.dim,
}
elif dataclasses.is_dataclass(elem_type):
encoded_type = {
'kind': 'Table',
'row': _dump_fields_schema(elem_type),
}
else:
raise ValueError(f"Unsupported type: {t}")
elif dataclasses.is_dataclass(t):
encoded_type = {
'kind': 'Struct',
'fields': _dump_fields_schema(t),
}
else:
if t is bytes:
kind = 'Bytes'
elif t is str:
kind = 'Str'
elif t is bool:
kind = 'Bool'
elif t is int:
kind = 'Int64'
elif t is float:
num_bits = _find_annotation(metadata, NumBits)
kind = 'Float32' if num_bits is not None and num_bits.bits <= 32 else 'Float64'
elif t is Range:
kind = 'Range'
elif t is Json:
kind = 'Json'
type_kind = _find_annotation(metadata, TypeKind)
if type_kind is not None:
kind = type_kind.kind
else:
raise ValueError(f"type unsupported yet: {t}")
type_json = { 'kind': kind }
if t is bytes:
kind = 'Bytes'
elif t is str:
kind = 'Str'
elif t is bool:
kind = 'Bool'
elif t is int:
kind = 'Int64'
elif t is float:
kind = 'Float64'
else:
raise ValueError(f"type unsupported yet: {t}")
encoded_type = { 'kind': kind }

return type_json
return encoded_type

def _enriched_type_to_json_value(t) -> dict[str, Any] | None:
def _dump_enriched_type(t) -> dict[str, Any] | None:
if t is None:
return None
t, metadata = _get_origin_type_and_metadata(t)
enriched_type_json = {'type': _basic_type_to_json_value(t, metadata)}
enriched_type_json = {'type': _dump_type(t, metadata)}
attrs = None
for attr in metadata:
if isinstance(attr, TypeAttr):
Expand All @@ -86,4 +106,4 @@ def dump_type(t) -> dict[str, Any] | None:
"""
Convert a Python type to a CocoIndex's type in JSON.
"""
return _enriched_type_to_json_value(t)
return _dump_enriched_type(t)