diff --git a/hvcc/core/hv2ir/types/IR.py b/hvcc/core/hv2ir/types/IR.py index d0c1dd8f..dc2b9d2b 100644 --- a/hvcc/core/hv2ir/types/IR.py +++ b/hvcc/core/hv2ir/types/IR.py @@ -2,16 +2,10 @@ from typing import List, Optional, Union, Literal -ConnectionType = Literal["-->", "~i>", "~f>", "signal"] +IRConnectionType = Literal["-->", "~i>", "~f>", "signal"] -class IndexableBaseModel(BaseModel): - """Allows a BaseModel to return its fields by string variable indexing""" - def __getitem__(self, item): - return getattr(self, item) - - -class Arg(IndexableBaseModel): +class IRArg(BaseModel): name: str value_type: str description: str = "" @@ -31,11 +25,11 @@ class Perf(BaseModel): neon: float = 0 -class IRNode(IndexableBaseModel): - inlets: List[ConnectionType] +class IRNode(BaseModel): + inlets: List[IRConnectionType] ir: IR - outlets: List[ConnectionType] - args: List[Arg] = [] + outlets: List[IRConnectionType] + args: List[IRArg] = [] perf: Optional[Perf] = Perf() # perf: Perf description: Optional[str] = None @@ -48,8 +42,9 @@ class HeavyIRType(RootModel): root: dict[str, IRNode] -# import json -# with open('heavy.ir.json') as f: -# data = json.load(f) -# heavy_ir = HeavyIR(root=data) -# print(heavy_ir.root.keys()) +if __name__ == "__main__": + import json + with open('../../json/heavy.ir.json') as f: + data = json.load(f) + heavy_ir = HeavyIRType(root=data) + print(heavy_ir.root.keys()) diff --git a/hvcc/core/hv2ir/types/Lang.py b/hvcc/core/hv2ir/types/Lang.py index c8fe5ca2..0f790cbc 100644 --- a/hvcc/core/hv2ir/types/Lang.py +++ b/hvcc/core/hv2ir/types/Lang.py @@ -1,18 +1,11 @@ -from ast import Index from pydantic import BaseModel, RootModel from typing import List, Optional, Dict, Literal, Union -ConnectionType = Literal["-->", "-~>", "~f>"] +LangConnectionType = Literal["-->", "-~>", "~f>"] -class IndexableBaseModel(BaseModel): - """Allows a BaseModel to return its fields by string variable indexing""" - def __getitem__(self, item): - return getattr(self, item) - - -class Arg(IndexableBaseModel): +class LangArg(BaseModel): name: str value_type: Optional[str] description: str @@ -22,21 +15,21 @@ class Arg(IndexableBaseModel): class Inlet(BaseModel): name: str - connectionType: ConnectionType + connectionType: LangConnectionType description: str class Outlet(BaseModel): name: str - connectionType: ConnectionType + connectionType: LangConnectionType description: str -class LangNode(IndexableBaseModel): +class LangNode(BaseModel): description: str inlets: List[Inlet] outlets: List[Outlet] - args: List[Arg] + args: List[LangArg] alias: List[str] tags: List[str] @@ -45,8 +38,9 @@ class HeavyLangType(RootModel): root: dict[str, LangNode] -# import json -# with open('heavy.lang.json') as f: -# data = json.load(f) -# heavy_lang = HeavyLang(root=data) -# print(heavy_lang.root.keys()) +if __name__ == "__main__": + import json + with open('../../json/heavy.lang.json') as f: + data = json.load(f) + heavy_lang = HeavyLangType(root=data) + print(heavy_lang.root.keys()) diff --git a/hvcc/core/hv2ir/types/__init__.py b/hvcc/core/hv2ir/types/__init__.py index 7b656a93..70d71a1f 100644 --- a/hvcc/core/hv2ir/types/__init__.py +++ b/hvcc/core/hv2ir/types/__init__.py @@ -1,2 +1,2 @@ -from .IR import HeavyIRType, IRNode # noqa -from .Lang import HeavyLangType, LangNode # noqa +from .IR import HeavyIRType, IRNode, IRArg # noqa +from .Lang import HeavyLangType, LangNode, LangArg # noqa diff --git a/hvcc/interpreters/pd2hv/HeavyObject.py b/hvcc/interpreters/pd2hv/HeavyObject.py index abb10b2c..bfbd0b7f 100644 --- a/hvcc/interpreters/pd2hv/HeavyObject.py +++ b/hvcc/interpreters/pd2hv/HeavyObject.py @@ -17,9 +17,9 @@ import decimal import json import importlib_resources -from typing import Optional, List, Dict, Any, Union +from typing import Optional, List, Dict, Any, Union, cast -from hvcc.core.hv2ir.types import HeavyIRType, HeavyLangType, IRNode, LangNode +from hvcc.core.hv2ir.types import HeavyIRType, HeavyLangType, IRNode, LangNode, IRArg, LangArg from .Connection import Connection from .NotificationEnum import NotificationEnum from .PdObject import PdObject @@ -57,24 +57,26 @@ def __init__( # resolve arguments obj_args = obj_args or [] self.obj_dict = {} + for i, a in enumerate(self.__obj_dict.args): + arg = cast(Union[IRArg, LangArg], a) # if the argument exists (and has been correctly resolved) if i < len(obj_args) and obj_args[i] is not None: # force the Heavy argument type # Catch type errors as early as possible try: - self.obj_dict[a["name"]] = self.force_arg_type( + self.obj_dict[arg.name] = self.force_arg_type( obj_args[i], - a["value_type"]) + arg.value_type) except Exception as e: self.add_error( - f"Heavy {obj_type} cannot convert argument \"{a['name']}\"" - f" with value \"{obj_args[i]}\" to type {a['value_type']}: {e}") + f"Heavy {obj_type} cannot convert argument \"{arg.name}\"" + f" with value \"{obj_args[i]}\" to type {arg.value_type}: {e}") else: # the default argument is required - if a["required"]: + if arg.required: self.add_error( - f"Required argument \"{a['name']}\" to object {obj_type} not present: {obj_args}") + f"Required argument \"{arg.name}\" to object {obj_type} not present: {obj_args}") else: # don't worry about supplying a default, # let hv2ir take care of it. pd2hv only passes on the @@ -90,7 +92,7 @@ def __init__( self.__annotations["scope"] = "public" @classmethod - def force_arg_type(cls, value: str, value_type: str) -> Any: + def force_arg_type(cls, value: str, value_type: Optional[str] = None) -> Any: # TODO(mhroth): add support for mixedarray? if value_type == "auto": try: @@ -148,14 +150,14 @@ def get_inlet_connection_type(self, inlet_index: int) -> Optional[str]: """ Returns the inlet connection type, None if the inlet does not exist. """ # TODO(mhroth): it's stupid that hvlang and hvir json have different data formats here - if self.is_hvlang: + if self.is_hvlang and isinstance(self.__obj_dict, LangNode): if len(self.__obj_dict.inlets) > inlet_index: return self.__obj_dict.inlets[inlet_index].connectionType else: return None - elif self.is_hvir: - if len(self.__obj_dict["inlets"]) > inlet_index: - return self.__obj_dict["inlets"][inlet_index] + elif self.is_hvir and isinstance(self.__obj_dict, IRNode): + if len(self.__obj_dict.inlets) > inlet_index: + return self.__obj_dict.inlets[inlet_index] else: return None else: @@ -165,14 +167,14 @@ def get_outlet_connection_type(self, outlet_index: int) -> Optional[str]: """ Returns the outlet connection type, None if the inlet does not exist. """ # TODO(mhroth): it's stupid that hvlang and hvir json have different data formats here - if self.is_hvlang: - if len(self.__obj_dict["outlets"]) > outlet_index: - return self.__obj_dict["outlets"][outlet_index]["connectionType"] + if self.is_hvlang and isinstance(self.__obj_dict, LangNode): + if len(self.__obj_dict.outlets) > outlet_index: + return self.__obj_dict.outlets[outlet_index].connectionType else: return None - elif self.is_hvir: - if len(self.__obj_dict["outlets"]) > outlet_index: - return self.__obj_dict["outlets"][outlet_index] + elif self.is_hvir and isinstance(self.__obj_dict, IRNode): + if len(self.__obj_dict.outlets) > outlet_index: + return self.__obj_dict.outlets[outlet_index] else: return None else: