Skip to content

Commit

Permalink
try and make use of the pydantic models
Browse files Browse the repository at this point in the history
  • Loading branch information
dromer committed Oct 19, 2024
1 parent 8e4d700 commit 75a31ba
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 16 deletions.
12 changes: 9 additions & 3 deletions hvcc/core/json/heavy_ir.py → hvcc/core/hv2ir/types/IR.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@
ConnectionType = Literal["-->", "~i>", "~f>", "signal"]


class Arg(BaseModel):
class IndexableBaseModel(BaseModel):
"""Allows a BaseModel to return its fields by string variable indexing"""
def __getitem__(self, item):
return getattr(self, item)


class Arg(IndexableBaseModel):
name: str
value_type: str
description: str = ""
Expand All @@ -25,7 +31,7 @@ class Perf(BaseModel):
neon: float = 0


class IRNode(BaseModel):
class IRNode(IndexableBaseModel):
inlets: List[ConnectionType]
ir: IR
outlets: List[ConnectionType]
Expand All @@ -38,7 +44,7 @@ class IRNode(BaseModel):
keywords: List[str] = []


class HeavyIR(RootModel):
class HeavyIRType(RootModel):
root: dict[str, IRNode]


Expand Down
13 changes: 10 additions & 3 deletions hvcc/core/json/heavy_lang.py → hvcc/core/hv2ir/types/Lang.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
from ast import Index
from pydantic import BaseModel, RootModel
from typing import List, Optional, Dict, Literal, Union


ConnectionType = Literal["-->", "-~>", "~f>"]


class Arg(BaseModel):
class IndexableBaseModel(BaseModel):
"""Allows a BaseModel to return its fields by string variable indexing"""
def __getitem__(self, item):
return getattr(self, item)


class Arg(IndexableBaseModel):
name: str
value_type: Optional[str]
description: str
Expand All @@ -25,7 +32,7 @@ class Outlet(BaseModel):
description: str


class LangNode(BaseModel):
class LangNode(IndexableBaseModel):
description: str
inlets: List[Inlet]
outlets: List[Outlet]
Expand All @@ -34,7 +41,7 @@ class LangNode(BaseModel):
tags: List[str]


class HeavyLang(RootModel):
class HeavyLangType(RootModel):
root: dict[str, LangNode]


Expand Down
2 changes: 2 additions & 0 deletions hvcc/core/hv2ir/types/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .IR import HeavyIRType, IRNode # noqa
from .Lang import HeavyLangType, LangNode # noqa
23 changes: 13 additions & 10 deletions hvcc/interpreters/pd2hv/HeavyObject.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
import decimal
import json
import importlib_resources
from typing import Optional, List, Dict, Any
from typing import Optional, List, Dict, Any, Union

from hvcc.core.hv2ir.types import HeavyIRType, HeavyLangType, IRNode, LangNode
from .Connection import Connection
from .NotificationEnum import NotificationEnum
from .PdObject import PdObject
Expand All @@ -28,11 +29,11 @@ class HeavyObject(PdObject):

heavy_lang_json = importlib_resources.files('hvcc') / 'core/json/heavy.lang.json'
with open(heavy_lang_json, "r") as f:
__HEAVY_LANG_OBJS = json.load(f)
__HEAVY_LANG_OBJS = HeavyLangType(json.load(f))

heavy_ir_json = importlib_resources.files('hvcc') / 'core/json/heavy.ir.json'
with open(heavy_ir_json, "r") as f:
__HEAVY_IR_OBJS = json.load(f)
__HEAVY_IR_OBJS = HeavyIRType(json.load(f))

def __init__(
self,
Expand All @@ -43,18 +44,20 @@ def __init__(
) -> None:
super().__init__(obj_type, obj_args, pos_x, pos_y)

self.__obj_dict: Union[IRNode, LangNode]

# get the object dictionary (note that it is NOT a copy)
if self.is_hvlang:
self.__obj_dict = self.__HEAVY_LANG_OBJS[obj_type]
self.__obj_dict = self.__HEAVY_LANG_OBJS.root[obj_type]
elif self.is_hvir:
self.__obj_dict = self.__HEAVY_IR_OBJS[obj_type]
self.__obj_dict = self.__HEAVY_IR_OBJS.root[obj_type]
else:
assert False, f"{obj_type} is not a Heavy Lang or IR object."

# resolve arguments
obj_args = obj_args or []
self.obj_dict = {}
for i, a in enumerate(self.__obj_dict["args"]):
for i, a in enumerate(self.__obj_dict.args):
# if the argument exists (and has been correctly resolved)
if i < len(obj_args) and obj_args[i] is not None:
# force the Heavy argument type
Expand Down Expand Up @@ -135,19 +138,19 @@ def force_arg_type(cls, value: str, value_type: str) -> Any:

@property
def is_hvlang(self) -> bool:
return self.obj_type in self.__HEAVY_LANG_OBJS
return self.obj_type in self.__HEAVY_LANG_OBJS.root.keys()

@property
def is_hvir(self) -> bool:
return self.obj_type in self.__HEAVY_IR_OBJS
return self.obj_type in self.__HEAVY_IR_OBJS.root.keys()

def get_inlet_connection_type(self, inlet_index: int) -> Optional[str]:
""" Returns the inlet connection type, None if the inlet does not exist.
"""
# TODO(mhroth): it's stupid that hvlang and hvir json have different data formats here
if self.is_hvlang:
if len(self.__obj_dict["inlets"]) > inlet_index:
return self.__obj_dict["inlets"][inlet_index]["connectionType"]
if len(self.__obj_dict.inlets) > inlet_index:
return self.__obj_dict.inlets[inlet_index].connectionType
else:
return None
elif self.is_hvir:
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,8 @@ Heavy = { source = "hvcc/__init__.py", type = "onefile", bundle = true, arch = "
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

[tool.mypy]
plugins = [
"pydantic.mypy"
]

0 comments on commit 75a31ba

Please sign in to comment.