From 07ecb06a8b43948cca345522abd81d748353b0e3 Mon Sep 17 00:00:00 2001 From: ryche Date: Wed, 20 Mar 2024 10:53:48 +0100 Subject: [PATCH] fix text expansion issues --- blockkit/components.py | 86 ++++++++++++++++++++-------------------- tests/test_blocks.py | 4 +- tests/test_components.py | 31 ++++++++++++++- tests/test_surfaces.py | 2 +- 4 files changed, 76 insertions(+), 47 deletions(-) diff --git a/blockkit/components.py b/blockkit/components.py index 6c7158b..1813d06 100644 --- a/blockkit/components.py +++ b/blockkit/components.py @@ -1,51 +1,53 @@ import json -from typing import TYPE_CHECKING, Any, List, Type, Union +from typing import Any, List, Type +from typing import get_origin, get_args, get_type_hints -from pydantic import BaseModel - -if TYPE_CHECKING: - from blockkit.objects import MarkdownText, PlainText +from pydantic import BaseModel, model_validator class Component(BaseModel): - def __init__(self, *args, **kwargs): - for name, field in self.model_fields.items(): - value = kwargs.get(name) - origin = getattr(field.annotation, "__origin__", None) - if ( - value - and type(value) in (str, list) - and origin is Union - and self.__class__.__name__ != "Message" - ): - types = field.annotation.__args__ - # types = field.type_.__args__ - if type(value) is str: - value = self._expand_str(value, types) - elif type(value) is list: - items = [] - for v in value: - if type(v) is str: - v = self._expand_str(v, types) - items.append(v) - value = items - - kwargs[name] = value - super().__init__(*args, **kwargs) - - def _expand_str( - self, value: str, types: List[Type[Any]] - ) -> Union["PlainText", "MarkdownText", str]: - literal_types = [getattr(t, "__name__", None) for t in types] - - if "MarkdownText" in literal_types: - idx = literal_types.index("MarkdownText") - return types[idx](text=value) - elif "PlainText" in literal_types: - idx = literal_types.index("PlainText") - return types[idx](text=value, emoji=True) - + @model_validator(mode="after") + def expand_strings(self) -> Any: + hints = get_type_hints(self) + for field_name, types in hints.items(): + inner_types = self._get_inner_types(types) + if not inner_types: + continue + + value = getattr(self, field_name) + type_names = [t.__name__ for t in inner_types] + + expandable = "MarkdownText" in type_names or "PlainText" in type_names + if not expandable: + continue + + if isinstance(value, str): + value = self._expand_str(value, inner_types, type_names) + if isinstance(value, list): + value = [ + self._expand_str(v, inner_types, type_names) + if isinstance(v, str) + else v + for v in value + ] + setattr(self, field_name, value) + return self + + @classmethod + def _expand_str(cls, value: str, types: List[Type[Any]], type_names: List[str]): + if "MarkdownText" in type_names: + return types[type_names.index("MarkdownText")](text=value) + elif "PlainText" in type_names: + return types[type_names.index("PlainText")](text=value, emoji=True) return value + @classmethod + def _get_inner_types(cls, types, parent_types=None): + origin = get_origin(types) + if not origin: + return parent_types + args = get_args(types) + return cls._get_inner_types(args[0], parent_types=args) + def build(self) -> dict: return json.loads(self.model_dump_json(by_alias=True, exclude_none=True)) diff --git a/tests/test_blocks.py b/tests/test_blocks.py index 1e0ce70..cb235fc 100644 --- a/tests/test_blocks.py +++ b/tests/test_blocks.py @@ -20,7 +20,7 @@ RichTextQuote, RichTextSection, ) -from blockkit.objects import Emoji, MarkdownText, Style, Text +from blockkit.objects import MarkdownText, Text from pydantic import ValidationError @@ -411,7 +411,7 @@ def test_builds_section(): assert Section( text=MarkdownText(text="*markdown* text"), block_id="block_id", - fields=[MarkdownText(text="field 1"), MarkdownText(text="field 2")], + fields=["field 1", MarkdownText(text="field 2")], accessory=Button(text=PlainText(text="button"), action_id="action_id"), ).build() == { "type": "section", diff --git a/tests/test_components.py b/tests/test_components.py index a02dd2a..f14868a 100644 --- a/tests/test_components.py +++ b/tests/test_components.py @@ -1,7 +1,7 @@ -from blockkit.objects import Confirm +from blockkit import Confirm, Context, Image -def test_converts_str(): +def test_expands_str(): assert Confirm( title="title", text="*markdown* text", @@ -15,3 +15,30 @@ def test_converts_str(): "deny": {"type": "plain_text", "text": "deny", "emoji": True}, "style": "primary", } + + +def test_expands_str_list(): + assert Context( + elements=[ + Image(image_url="http://placekitten.com/300/200", alt_text="kitten"), + "element 1", + "element 2", + ] + ).build() == { + "type": "context", + "elements": [ + { + "type": "image", + "image_url": "http://placekitten.com/300/200", + "alt_text": "kitten", + }, + { + "type": "mrkdwn", + "text": "element 1", + }, + { + "type": "mrkdwn", + "text": "element 2", + }, + ], + } diff --git a/tests/test_surfaces.py b/tests/test_surfaces.py index dd08b04..faf3c86 100644 --- a/tests/test_surfaces.py +++ b/tests/test_surfaces.py @@ -197,7 +197,7 @@ def test_builds_message(): text="message text", blocks=[ Section( - text=MarkdownText(text="*markdown* text"), + text="*markdown* text", accessory=Button(text=PlainText(text="button"), action_id="action_id"), ), RichText(