diff --git a/aioalice/__init__.py b/aioalice/__init__.py index 6551fc3..5e74c07 100644 --- a/aioalice/__init__.py +++ b/aioalice/__init__.py @@ -11,4 +11,4 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) -__version__ = '1.1.7' +__version__ = '1.2.0' diff --git a/aioalice/types/__init__.py b/aioalice/types/__init__.py index ca6ae89..86bc52a 100644 --- a/aioalice/types/__init__.py +++ b/aioalice/types/__init__.py @@ -1,6 +1,10 @@ from .base import AliceObject from .meta import Meta from .markup import Markup +from .entity_tokens import EntityTokens +from .entity_value import EntityValue +from .entity import Entity, EntityType +from .natural_language_understanding import NaturalLanguageUnderstanding from .request import Request, RequestType from .session import BaseSession, Session diff --git a/aioalice/types/alice_request.py b/aioalice/types/alice_request.py index bbcbcdd..7811e67 100644 --- a/aioalice/types/alice_request.py +++ b/aioalice/types/alice_request.py @@ -1,7 +1,7 @@ from attr import attrs, attrib +from aioalice.utils import ensure_cls from . import AliceObject, Meta, Session, \ Card, Request, Response, AliceResponse -from aioalice.utils import ensure_cls @attrs diff --git a/aioalice/types/alice_response.py b/aioalice/types/alice_response.py index 5cfb2cf..e4e28aa 100644 --- a/aioalice/types/alice_response.py +++ b/aioalice/types/alice_response.py @@ -1,6 +1,6 @@ from attr import attrs, attrib -from . import AliceObject, BaseSession, Response from aioalice.utils import ensure_cls +from . import AliceObject, BaseSession, Response @attrs diff --git a/aioalice/types/card.py b/aioalice/types/card.py index 888d041..e086699 100644 --- a/aioalice/types/card.py +++ b/aioalice/types/card.py @@ -1,8 +1,8 @@ from attr import attrs, attrib -from . import AliceObject, MediaButton, Image, CardHeader, CardFooter from aioalice.utils import ensure_cls from aioalice.utils.helper import Helper, HelperMode, Item +from . import AliceObject, MediaButton, Image, CardHeader, CardFooter @attrs diff --git a/aioalice/types/card_footer.py b/aioalice/types/card_footer.py index a0947d4..8d16066 100644 --- a/aioalice/types/card_footer.py +++ b/aioalice/types/card_footer.py @@ -1,7 +1,7 @@ from attr import attrs, attrib -from . import AliceObject, MediaButton from aioalice.utils import ensure_cls +from . import AliceObject, MediaButton @attrs diff --git a/aioalice/types/entity.py b/aioalice/types/entity.py new file mode 100644 index 0000000..925d847 --- /dev/null +++ b/aioalice/types/entity.py @@ -0,0 +1,36 @@ +import logging +from attr import attrs, attrib + +from aioalice.utils import ensure_cls +from aioalice.utils.helper import Helper, HelperMode, Item +from . import AliceObject, EntityTokens, EntityValue + +log = logging.getLogger(__name__) + + +@attrs +class Entity(AliceObject): + """Entity object""" + type = attrib(type=str) + tokens = attrib(convert=ensure_cls(EntityTokens)) + value = attrib(factory=dict) + + @type.validator + def check(self, attribute, value): + """Report unknown type""" + if value not in EntityType.all(): + log.error('Unknown Entity type! `%r`', value) + + def __attrs_post_init__(self): + """If entity type not number, convert to EntityValue""" + if self.value and self.type != EntityType.YANDEX_NUMBER: + self.value = EntityValue(**self.value) + + +class EntityType(Helper): + mode = HelperMode.UPPER_DOT_SEPARATED + + YANDEX_GEO = Item() + YANDEX_FIO = Item() + YANDEX_NUMBER = Item() + YANDEX_DATETIME = Item() diff --git a/aioalice/types/entity_tokens.py b/aioalice/types/entity_tokens.py new file mode 100644 index 0000000..edf9184 --- /dev/null +++ b/aioalice/types/entity_tokens.py @@ -0,0 +1,9 @@ +from attr import attrs, attrib +from . import AliceObject + + +@attrs +class EntityTokens(AliceObject): + """EntityTokens object""" + start = attrib(type=int) + end = attrib(type=int) diff --git a/aioalice/types/entity_value.py b/aioalice/types/entity_value.py new file mode 100644 index 0000000..1509ef5 --- /dev/null +++ b/aioalice/types/entity_value.py @@ -0,0 +1,31 @@ +from attr import attrs, attrib +from . import AliceObject + + +@attrs +class EntityValue(AliceObject): + """EntityValue object""" + + # YANDEX.FIO + first_name = attrib(default=None, type=str) + patronymic_name = attrib(default=None, type=str) + last_name = attrib(default=None, type=str) + + # YANDEX.GEO + country = attrib(default=None, type=str) + city = attrib(default=None, type=str) + street = attrib(default=None, type=str) + house_number = attrib(default=None, type=str) + airport = attrib(default=None, type=str) + + # YANDEX.DATETIME + year = attrib(default=None, type=str) + year_is_relative = attrib(default=False, type=bool) + month = attrib(default=None, type=str) + month_is_relative = attrib(default=False, type=bool) + day = attrib(default=None, type=str) + day_is_relative = attrib(default=False, type=bool) + hour = attrib(default=None, type=str) + hour_is_relative = attrib(default=False, type=bool) + minute = attrib(default=None, type=str) + minute_is_relative = attrib(default=False, type=bool) diff --git a/aioalice/types/image.py b/aioalice/types/image.py index ece7daa..fa4a0da 100644 --- a/aioalice/types/image.py +++ b/aioalice/types/image.py @@ -1,7 +1,7 @@ from attr import attrs, attrib -from . import AliceObject, MediaButton from aioalice.utils import ensure_cls +from . import AliceObject, MediaButton @attrs diff --git a/aioalice/types/meta.py b/aioalice/types/meta.py index b7f02db..19d9ab7 100644 --- a/aioalice/types/meta.py +++ b/aioalice/types/meta.py @@ -1,7 +1,9 @@ from attr import attrs, attrib +from aioalice.utils import safe_kwargs from . import AliceObject +@safe_kwargs @attrs class Meta(AliceObject): """Meta object""" diff --git a/aioalice/types/natural_language_understanding.py b/aioalice/types/natural_language_understanding.py new file mode 100644 index 0000000..274d972 --- /dev/null +++ b/aioalice/types/natural_language_understanding.py @@ -0,0 +1,11 @@ +# Natural Language Understanding: https://medium.com/@lola.com/nlp-vs-nlu-whats-the-difference-d91c06780992 +from attr import attrs, attrib +from aioalice.utils import ensure_cls +from . import AliceObject, Entity + + +@attrs +class NaturalLanguageUnderstanding(AliceObject): + """Natural Language Understanding object""" + tokens = attrib(factory=list) + entities = attrib(factory=list, convert=ensure_cls(Entity)) diff --git a/aioalice/types/request.py b/aioalice/types/request.py index a332839..dd092fd 100644 --- a/aioalice/types/request.py +++ b/aioalice/types/request.py @@ -1,9 +1,11 @@ from attr import attrs, attrib +from aioalice.utils import safe_kwargs, ensure_cls from aioalice.utils.helper import Helper, HelperMode, Item -from . import AliceObject, Markup +from . import AliceObject, Markup, NaturalLanguageUnderstanding +@safe_kwargs @attrs class Request(AliceObject): """Request object""" @@ -12,6 +14,7 @@ class Request(AliceObject): original_utterance = attrib(default='', type=str) # Can be none if payload passed markup = attrib(default=None) payload = attrib(default=None) + nlu = attrib(default=None, convert=ensure_cls(NaturalLanguageUnderstanding)) @type.validator def check(self, attribute, value): diff --git a/aioalice/types/response.py b/aioalice/types/response.py index e3cf260..ff59ba1 100644 --- a/aioalice/types/response.py +++ b/aioalice/types/response.py @@ -1,6 +1,6 @@ from attr import attrs, attrib -from . import AliceObject, Card, Button from aioalice.utils import ensure_cls +from . import AliceObject, Card, Button @attrs diff --git a/aioalice/utils/__init__.py b/aioalice/utils/__init__.py index 9a51ecb..3e612d0 100644 --- a/aioalice/utils/__init__.py +++ b/aioalice/utils/__init__.py @@ -1,6 +1,7 @@ from . import exceptions from .json import json from .payload import generate_json_payload +from .safe_kwargs import safe_kwargs def ensure_cls(klass): diff --git a/aioalice/utils/helper.py b/aioalice/utils/helper.py index 53a6cd3..098ace0 100644 --- a/aioalice/utils/helper.py +++ b/aioalice/utils/helper.py @@ -46,6 +46,7 @@ class HelperMode(Helper): CamelCase = 'CamelCase' snake_case = 'snake_case' lowercase = 'lowercase' + UPPER_DOT_SEPARATED = 'UPPER.DOT.SEPARATED' @classmethod def all(cls): @@ -120,16 +121,18 @@ def apply(cls, text, mode): :param mode: :return: """ - if mode == cls.SCREAMING_SNAKE_CASE: + if mode == cls.UPPER_DOT_SEPARATED: + return cls._screaming_snake_case(text).replace('_', '.') + elif mode == cls.SCREAMING_SNAKE_CASE: return cls._screaming_snake_case(text) elif mode == cls.snake_case: return cls._snake_case(text) - elif mode == cls.lowercase: - return cls._snake_case(text).replace('_', '') elif mode == cls.lowerCamelCase: return cls._camel_case(text) elif mode == cls.CamelCase: return cls._camel_case(text, True) + elif mode == cls.lowercase: + return cls._snake_case(text).replace('_', '') elif callable(mode): return mode(text) return text diff --git a/aioalice/utils/safe_kwargs.py b/aioalice/utils/safe_kwargs.py new file mode 100644 index 0000000..42d46c4 --- /dev/null +++ b/aioalice/utils/safe_kwargs.py @@ -0,0 +1,16 @@ +# https://gist.github.com/surik00/a6c2804a2d18a2ab75630bb5d93693c8 + +import inspect +import functools + + +def safe_kwargs(func_or_class): + spec = inspect.getfullargspec(func_or_class) + all_args = spec.args + + @functools.wraps(func_or_class) + def wrap(*args, **kwargs): + accepted_kwargs = {k: v for k, v in kwargs.items() if k in all_args} + return func_or_class(*args, **accepted_kwargs) + + return wrap diff --git a/setup.py b/setup.py index e1993a5..76c6876 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ if sys.version_info < MINIMAL_PY_VERSION: raise RuntimeError('aioAlice works only with Python {}+'.format('.'.join(map(str, MINIMAL_PY_VERSION)))) -__version__ = '1.1.7' +__version__ = '1.2.0' def get_description(): diff --git a/tests/_dataset.py b/tests/_dataset.py index efe9265..14eb49c 100644 --- a/tests/_dataset.py +++ b/tests/_dataset.py @@ -226,3 +226,93 @@ }, 'version': '1.0' } + +ENTITY_TOKEN = { + "start": 2, + "end": 6 +} + +ENTITY_VALUE = { + "house_number": "16", + "street": "льва толстого" +} + +ENTITY = { + "tokens": ENTITY_TOKEN, + "type": "YANDEX.GEO", + "value": ENTITY_VALUE +} + +ENTITY_INTEGER = { + "tokens": { + "start": 5, + "end": 6 + }, + "type": "YANDEX.NUMBER", + "value": 16 +} + +NLU = { + "tokens": [ + "закажи", + "пиццу", + "на", + "льва", + "толстого", + "16", + "на", + "завтра" + ], + "entities": [ + ENTITY, + { + "tokens": { + "start": 3, + "end": 5 + }, + "type": "YANDEX.FIO", + "value": { + "first_name": "лев", + "last_name": "толстой" + } + }, + ENTITY_INTEGER, + { + "tokens": { + "start": 6, + "end": 8 + }, + "type": "YANDEX.DATETIME", + "value": { + "day": 1, + "day_is_relative": True + } + } + ] +} + +REQUEST_WITH_NLU = { + "meta": { + "locale": "ru-RU", + "timezone": "Europe/Moscow", + "client_id": "ru.yandex.searchplugin/5.80 (Samsung Galaxy; Android 4.4)" + }, + "request": { + "command": "закажи пиццу на улицу льва толстого, 16 на завтра", + "original_utterance": "закажи пиццу на улицу льва толстого, 16 на завтра", + "type": "SimpleUtterance", + "markup": { + "dangerous_context": True + }, + "payload": {}, + "nlu": NLU, + }, + "session": { + "new": True, + "message_id": 4, + "session_id": "2eac4854-fce721f3-b845abba-20d60", + "skill_id": "3ad36498-f5rd-4079-a14b-788652932056", + "user_id": "AC9WC3DF6FCE052E45A4566A48E6B7193774B84814CE49A922E163B8B29881DC" + }, + "version": "1.0" +} diff --git a/tests/test_types.py b/tests/test_types.py index 9efe745..34c806e 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -16,7 +16,8 @@ RESPONSE_BUTTON, \ EXPECTED_ALICE_RESPONSE_BIG_IMAGE_WITH_BUTTON, \ EXPECTED_ALICE_RESPONSE_ITEMS_LIST_WITH_BUTTON, \ - DATA_FROM_STATION + DATA_FROM_STATION, REQUEST_WITH_NLU, ENTITY_TOKEN, \ + ENTITY_VALUE, ENTITY, ENTITY_INTEGER, NLU class TestAliceTypes(unittest.TestCase): @@ -43,6 +44,64 @@ def test_markup(self): markup = types.Markup(**MARKUP) self._test_markup(markup, MARKUP) + def _test_entity_tokens(self, et, dct): + self.assertEqual(et.start, dct['start']) + self.assertEqual(et.end, dct['end']) + + def test_entity_tokens(self): + et = types.EntityTokens(**ENTITY_TOKEN) + self._test_entity_tokens(et, ENTITY_TOKEN) + + def _test_entity_value(self, ev, dct): + for key in ( + 'first_name', + 'patronymic_name', + 'last_name', + 'country', + 'city', + 'street', + 'house_number', + 'airport', + 'year', + 'year_is_relative', + 'month', + 'month_is_relative', + 'day', + 'day_is_relative', + 'hour', + 'hour_is_relative', + 'minute', + 'minute_is_relative', + ): + if key in dct: + self.assertEqual(getattr(ev, key), dct[key]) + + def test_entity_value(self): + ev = types.EntityValue(**ENTITY_VALUE) + self._test_entity_value(ev, ENTITY_VALUE) + + def _test_entity(self, entity, dct): + self._test_entity_tokens(entity.tokens, dct['tokens']) + if entity.type == types.EntityType.YANDEX_NUMBER: + entity.value == dct['value'] + else: + self._test_entity_value(entity.value, dct['value']) + + def test_entity(self): + entity = types.Entity(**ENTITY) + self._test_entity(entity, ENTITY) + entity_int = types.Entity(**ENTITY_INTEGER) + self._test_entity(entity_int, ENTITY_INTEGER) + + def _test_nlu(self, nlu, dct): + self.assertEqual(nlu.tokens, dct['tokens']) + for entity, _dct in zip(nlu.entities, dct['entities']): + self._test_entity(entity, _dct) + + def test_nlu(self): + nlu = types.NaturalLanguageUnderstanding(**NLU) + self._test_nlu(nlu, NLU) + def _test_request(self, req, dct): self.assertEqual(req.command, dct['command']) self.assertEqual(req.original_utterance, dct['original_utterance']) @@ -51,6 +110,8 @@ def _test_request(self, req, dct): self.assertEqual(req.payload, dct['payload']) if 'markup' in dct: self._test_markup(req.markup, dct['markup']) + if 'nlu' in dct: + self._test_nlu(req.nlu, dct['nlu']) def test_request(self): request = types.Request(**REQUEST) @@ -260,6 +321,9 @@ def test_card_items_list_card_method(self): ) self._assert_payload(card_items_list, EXPECTED_CARD_ITEMS_LIST_JSON) + def test_request_with_nlu(self): + self._test_alice_request_from_dct(REQUEST_WITH_NLU) + if __name__ == '__main__': unittest.main()