Skip to content

Commit

Permalink
start using the types
Browse files Browse the repository at this point in the history
  • Loading branch information
dromer committed Oct 20, 2024
1 parent 75a31ba commit e69286d
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 56 deletions.
29 changes: 12 additions & 17 deletions hvcc/core/hv2ir/types/IR.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,10 @@
from typing import List, Optional, Union, Literal


ConnectionType = Literal["-->", "~i>", "~f>", "signal"]
IRConnectionType = Literal["-->", "~i>", "~f>", "signal"]


class IndexableBaseModel(BaseModel):
"""Allows a BaseModel to return its fields by string variable indexing"""
def __getitem__(self, item):
return getattr(self, item)


class Arg(IndexableBaseModel):
class IRArg(BaseModel):
name: str
value_type: str
description: str = ""
Expand All @@ -31,11 +25,11 @@ class Perf(BaseModel):
neon: float = 0


class IRNode(IndexableBaseModel):
inlets: List[ConnectionType]
class IRNode(BaseModel):
inlets: List[IRConnectionType]
ir: IR
outlets: List[ConnectionType]
args: List[Arg] = []
outlets: List[IRConnectionType]
args: List[IRArg] = []
perf: Optional[Perf] = Perf()
# perf: Perf
description: Optional[str] = None
Expand All @@ -48,8 +42,9 @@ class HeavyIRType(RootModel):
root: dict[str, IRNode]


# import json
# with open('heavy.ir.json') as f:
# data = json.load(f)
# heavy_ir = HeavyIR(root=data)
# print(heavy_ir.root.keys())
if __name__ == "__main__":
import json
with open('../../json/heavy.ir.json') as f:
data = json.load(f)
heavy_ir = HeavyIRType(root=data)
print(heavy_ir.root.keys())
30 changes: 12 additions & 18 deletions hvcc/core/hv2ir/types/Lang.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,11 @@
from ast import Index
from pydantic import BaseModel, RootModel
from typing import List, Optional, Dict, Literal, Union


ConnectionType = Literal["-->", "-~>", "~f>"]
LangConnectionType = Literal["-->", "-~>", "~f>"]


class IndexableBaseModel(BaseModel):
"""Allows a BaseModel to return its fields by string variable indexing"""
def __getitem__(self, item):
return getattr(self, item)


class Arg(IndexableBaseModel):
class LangArg(BaseModel):
name: str
value_type: Optional[str]
description: str
Expand All @@ -22,21 +15,21 @@ class Arg(IndexableBaseModel):

class Inlet(BaseModel):
name: str
connectionType: ConnectionType
connectionType: LangConnectionType
description: str


class Outlet(BaseModel):
name: str
connectionType: ConnectionType
connectionType: LangConnectionType
description: str


class LangNode(IndexableBaseModel):
class LangNode(BaseModel):
description: str
inlets: List[Inlet]
outlets: List[Outlet]
args: List[Arg]
args: List[LangArg]
alias: List[str]
tags: List[str]

Expand All @@ -45,8 +38,9 @@ class HeavyLangType(RootModel):
root: dict[str, LangNode]


# import json
# with open('heavy.lang.json') as f:
# data = json.load(f)
# heavy_lang = HeavyLang(root=data)
# print(heavy_lang.root.keys())
if __name__ == "__main__":
import json
with open('../../json/heavy.lang.json') as f:
data = json.load(f)
heavy_lang = HeavyLangType(root=data)
print(heavy_lang.root.keys())
4 changes: 2 additions & 2 deletions hvcc/core/hv2ir/types/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .IR import HeavyIRType, IRNode # noqa
from .Lang import HeavyLangType, LangNode # noqa
from .IR import HeavyIRType, IRNode, IRArg # noqa
from .Lang import HeavyLangType, LangNode, LangArg # noqa
40 changes: 21 additions & 19 deletions hvcc/interpreters/pd2hv/HeavyObject.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
import decimal
import json
import importlib_resources
from typing import Optional, List, Dict, Any, Union
from typing import Optional, List, Dict, Any, Union, cast

from hvcc.core.hv2ir.types import HeavyIRType, HeavyLangType, IRNode, LangNode
from hvcc.core.hv2ir.types import HeavyIRType, HeavyLangType, IRNode, LangNode, IRArg, LangArg
from .Connection import Connection
from .NotificationEnum import NotificationEnum
from .PdObject import PdObject
Expand Down Expand Up @@ -57,24 +57,26 @@ def __init__(
# resolve arguments
obj_args = obj_args or []
self.obj_dict = {}

for i, a in enumerate(self.__obj_dict.args):
arg = cast(Union[IRArg, LangArg], a)
# if the argument exists (and has been correctly resolved)
if i < len(obj_args) and obj_args[i] is not None:
# force the Heavy argument type
# Catch type errors as early as possible
try:
self.obj_dict[a["name"]] = self.force_arg_type(
self.obj_dict[arg.name] = self.force_arg_type(
obj_args[i],
a["value_type"])
arg.value_type)
except Exception as e:
self.add_error(
f"Heavy {obj_type} cannot convert argument \"{a['name']}\""
f" with value \"{obj_args[i]}\" to type {a['value_type']}: {e}")
f"Heavy {obj_type} cannot convert argument \"{arg.name}\""
f" with value \"{obj_args[i]}\" to type {arg.value_type}: {e}")
else:
# the default argument is required
if a["required"]:
if arg.required:
self.add_error(
f"Required argument \"{a['name']}\" to object {obj_type} not present: {obj_args}")
f"Required argument \"{arg.name}\" to object {obj_type} not present: {obj_args}")
else:
# don't worry about supplying a default,
# let hv2ir take care of it. pd2hv only passes on the
Expand All @@ -90,7 +92,7 @@ def __init__(
self.__annotations["scope"] = "public"

@classmethod
def force_arg_type(cls, value: str, value_type: str) -> Any:
def force_arg_type(cls, value: str, value_type: Optional[str] = None) -> Any:
# TODO(mhroth): add support for mixedarray?
if value_type == "auto":
try:
Expand Down Expand Up @@ -148,14 +150,14 @@ def get_inlet_connection_type(self, inlet_index: int) -> Optional[str]:
""" Returns the inlet connection type, None if the inlet does not exist.
"""
# TODO(mhroth): it's stupid that hvlang and hvir json have different data formats here
if self.is_hvlang:
if self.is_hvlang and isinstance(self.__obj_dict, LangNode):
if len(self.__obj_dict.inlets) > inlet_index:
return self.__obj_dict.inlets[inlet_index].connectionType
else:
return None
elif self.is_hvir:
if len(self.__obj_dict["inlets"]) > inlet_index:
return self.__obj_dict["inlets"][inlet_index]
elif self.is_hvir and isinstance(self.__obj_dict, IRNode):
if len(self.__obj_dict.inlets) > inlet_index:
return self.__obj_dict.inlets[inlet_index]
else:
return None
else:
Expand All @@ -165,14 +167,14 @@ def get_outlet_connection_type(self, outlet_index: int) -> Optional[str]:
""" Returns the outlet connection type, None if the inlet does not exist.
"""
# TODO(mhroth): it's stupid that hvlang and hvir json have different data formats here
if self.is_hvlang:
if len(self.__obj_dict["outlets"]) > outlet_index:
return self.__obj_dict["outlets"][outlet_index]["connectionType"]
if self.is_hvlang and isinstance(self.__obj_dict, LangNode):
if len(self.__obj_dict.outlets) > outlet_index:
return self.__obj_dict.outlets[outlet_index].connectionType
else:
return None
elif self.is_hvir:
if len(self.__obj_dict["outlets"]) > outlet_index:
return self.__obj_dict["outlets"][outlet_index]
elif self.is_hvir and isinstance(self.__obj_dict, IRNode):
if len(self.__obj_dict.outlets) > outlet_index:
return self.__obj_dict.outlets[outlet_index]
else:
return None
else:
Expand Down

0 comments on commit e69286d

Please sign in to comment.