From d771f3022ae00f04934cfdb3d41434335ee55831 Mon Sep 17 00:00:00 2001 From: LJ Date: Thu, 6 Mar 2025 13:00:58 -0800 Subject: [PATCH] Support struct and table types in Python->Rust type encoding. #19 --- python/cocoindex/typing.py | 49 +++++++++++++++++++++++++++----------- 1 file changed, 35 insertions(+), 14 deletions(-) diff --git a/python/cocoindex/typing.py b/python/cocoindex/typing.py index 2e34f58..c11ec84 100644 --- a/python/cocoindex/typing.py +++ b/python/cocoindex/typing.py @@ -1,5 +1,6 @@ import typing import collections +import dataclasses from typing import Annotated, NamedTuple, Any class Vector(NamedTuple): @@ -31,18 +32,38 @@ def _get_origin_type_and_metadata(t): return (t.__origin__, t.__metadata__) return (t, ()) -def _type_to_json_value(t, metadata): +def _dump_fields_schema(cls: type) -> list[dict[str, Any]]: + return [ + { + 'name': field.name, + 'value_type': _dump_enriched_type(field.type), + } + for field in dataclasses.fields(cls) + ] + +def _dump_type(t, metadata): origin_type = typing.get_origin(t) if origin_type is collections.abc.Sequence or origin_type is list: - dim = _find_annotation(metadata, Vector) - if dim is None: - raise ValueError(f"Vector dimension not found for {t}") args = typing.get_args(t) - origin_type, metadata = _get_origin_type_and_metadata(args[0]) - type_json = { - 'kind': 'Vector', - 'element_type': _type_to_json_value(origin_type, metadata), - 'dimension': dim.dim, + elem_type, elem_type_metadata = _get_origin_type_and_metadata(args[0]) + vector_annot = _find_annotation(metadata, Vector) + if vector_annot is not None: + encoded_type = { + 'kind': 'Vector', + 'element_type': _dump_type(elem_type, elem_type_metadata), + 'dimension': vector_annot.dim, + } + elif dataclasses.is_dataclass(elem_type): + encoded_type = { + 'kind': 'Table', + 'row': _dump_fields_schema(elem_type), + } + else: + raise ValueError(f"Unsupported type: {t}") + elif dataclasses.is_dataclass(t): + encoded_type = { + 'kind': 'Struct', + 'fields': _dump_fields_schema(t), } else: type_kind = _find_annotation(metadata, TypeKind) @@ -61,15 +82,15 @@ def _type_to_json_value(t, metadata): kind = 'Float64' else: raise ValueError(f"type unsupported yet: {t}") - type_json = { 'kind': kind } + encoded_type = { 'kind': kind } - return type_json + return encoded_type -def _enriched_type_to_json_value(t) -> dict[str, Any] | None: +def _dump_enriched_type(t) -> dict[str, Any] | None: if t is None: return None t, metadata = _get_origin_type_and_metadata(t) - enriched_type_json = {'type': _type_to_json_value(t, metadata)} + enriched_type_json = {'type': _dump_type(t, metadata)} attrs = None for attr in metadata: if isinstance(attr, TypeAttr): @@ -85,4 +106,4 @@ def dump_type(t) -> dict[str, Any] | None: """ Convert a Python type to a CocoIndex's type in JSON. """ - return _enriched_type_to_json_value(t) + return _dump_enriched_type(t)