Skip to content

Commit 470b41a

Browse files
Merge pull request #1294 from syucream/fix/attr-hint-type
Refine AttrHint type
2 parents 150349a + 75a76c9 commit 470b41a

19 files changed

+293
-255
lines changed

airone/lib/elasticsearch.py

+26-28
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,12 @@ class FilterKey(enum.IntEnum):
5151
DUPLICATED = 5
5252

5353

54-
class AttrHint(TypedDict):
54+
class AttrHint(BaseModel):
5555
name: str
56-
is_readable: NotRequired[bool]
57-
filter_key: NotRequired[FilterKey]
58-
keyword: NotRequired[str]
59-
exact_match: NotRequired[bool]
56+
is_readable: bool | None = None
57+
filter_key: FilterKey | None = None
58+
keyword: str | None = None
59+
exact_match: bool | None = None
6060

6161

6262
class AttributeDocument(TypedDict):
@@ -267,23 +267,23 @@ def make_query(
267267

268268
# Conversion processing from "filter_key" to "keyword" for each hint_attrs
269269
for hint_attr in hint_attrs:
270-
match hint_attr.get("filter_key", None):
270+
match hint_attr.filter_key:
271271
case FilterKey.CLEARED:
272272
# remove "keyword" parameter
273-
hint_attr.pop("keyword", None)
273+
hint_attr.keyword = None
274274
case FilterKey.EMPTY:
275-
hint_attr["keyword"] = "\\"
275+
hint_attr.keyword = "\\"
276276
case FilterKey.NON_EMPTY:
277-
hint_attr["keyword"] = "*"
277+
hint_attr.keyword = "*"
278278
case FilterKey.DUPLICATED:
279-
aggs_query = _make_aggs_query(hint_attr["name"])
279+
aggs_query = _make_aggs_query(hint_attr.name)
280280
# TODO Set to 1 for convenience
281281
resp = execute_query(aggs_query, 1)
282282
keyword_infos = resp["aggregations"]["attr_aggs"]["attr_name_aggs"][
283283
"attr_value_aggs"
284284
]["buckets"]
285285
keyword_list = [x["key"] for x in keyword_infos]
286-
hint_attr["keyword"] = CONFIG.OR_SEARCH_CHARACTER.join(
286+
hint_attr.keyword = CONFIG.OR_SEARCH_CHARACTER.join(
287287
["^" + x + "$" for x in keyword_list]
288288
)
289289

@@ -324,9 +324,7 @@ def make_query(
324324
"query": {
325325
"bool": {
326326
"should": [
327-
{"term": {"attr.name": x["name"]}}
328-
for x in hint_attrs
329-
if "name" in x
327+
{"term": {"attr.name": x.name}} for x in hint_attrs if x.name
330328
]
331329
}
332330
},
@@ -337,7 +335,7 @@ def make_query(
337335
attr_query: dict[str, dict] = {}
338336

339337
# filter attribute by keywords
340-
for hint in [hint for hint in hint_attrs if "name" in hint and hint.get("keyword")]:
338+
for hint in [hint for hint in hint_attrs if hint.name and hint.keyword]:
341339
attr_query.update(_parse_or_search(hint))
342340

343341
# Build queries along keywords
@@ -649,7 +647,7 @@ def _parse_or_search(hint: AttrHint) -> dict[str, dict]:
649647
duplicate_keys: list = []
650648

651649
# Split and process keywords with 'or'
652-
for keyword_divided_or in hint["keyword"].split(CONFIG.OR_SEARCH_CHARACTER):
650+
for keyword_divided_or in (hint.keyword or "").split(CONFIG.OR_SEARCH_CHARACTER):
653651
parsed_query = _parse_and_search(hint, keyword_divided_or, duplicate_keys)
654652
attr_query.update(parsed_query)
655653

@@ -696,7 +694,7 @@ def _parse_and_search(
696694

697695
# Keyword divided by 'or' is processed by dividing by 'and'
698696
for keyword in keyword_divided_or.split(CONFIG.AND_SEARCH_CHARACTER):
699-
key = f"{keyword}_{hint['name']}"
697+
key = f"{keyword}_{hint.name}"
700698

701699
# Skip if keywords overlap
702700
if key in duplicate_keys:
@@ -742,29 +740,29 @@ def _build_queries_along_keywords(
742740
"""
743741

744742
# Get the keyword.
745-
hints = [x for x in hint_attrs if "keyword" in x and x["keyword"]]
743+
hints = [x for x in hint_attrs if x.keyword]
746744
res_query: dict[str, Any] = {}
747745

748746
for hint in hints:
749747
and_query: dict[str, Any] = {}
750748
or_query: dict[str, Any] = {}
751749

752750
# Split keyword by 'or'
753-
for keyword_divided_or in hint["keyword"].split(CONFIG.OR_SEARCH_CHARACTER):
751+
for keyword_divided_or in (hint.keyword or "").split(CONFIG.OR_SEARCH_CHARACTER):
754752
if CONFIG.AND_SEARCH_CHARACTER in keyword_divided_or:
755753
# If 'AND' is included in the keyword divided by 'OR', add it to 'filter'
756754
for keyword in keyword_divided_or.split(CONFIG.AND_SEARCH_CHARACTER):
757755
if keyword_divided_or not in and_query:
758756
and_query[keyword_divided_or] = {"bool": {"filter": []}}
759757

760758
and_query[keyword_divided_or]["bool"]["filter"].append(
761-
attr_query[keyword + "_" + hint["name"]]
759+
attr_query[keyword + "_" + hint.name]
762760
)
763761

764762
else:
765-
and_query[keyword_divided_or] = attr_query[keyword_divided_or + "_" + hint["name"]]
763+
and_query[keyword_divided_or] = attr_query[keyword_divided_or + "_" + hint.name]
766764

767-
if CONFIG.OR_SEARCH_CHARACTER in hint["keyword"]:
765+
if CONFIG.OR_SEARCH_CHARACTER in (hint.keyword or ""):
768766
# If the keyword contains 'or', concatenate with 'should'
769767
if not or_query:
770768
or_query = {"bool": {"should": []}}
@@ -817,7 +815,7 @@ def _make_an_attribute_filter(hint: AttrHint, keyword: str) -> dict[str, dict]:
817815
dict[str, str]: Created attribute filter
818816
819817
"""
820-
cond_attr: list[dict] = [{"term": {"attr.name": hint["name"]}}]
818+
cond_attr: list[dict] = [{"term": {"attr.name": hint.name}}]
821819

822820
date_results = _is_date(keyword)
823821
if date_results:
@@ -840,7 +838,7 @@ def _make_an_attribute_filter(hint: AttrHint, keyword: str) -> dict[str, dict]:
840838

841839
str_cond = {"regexp": {"attr.value": _get_regex_pattern(keyword)}}
842840

843-
if hint.get("filter_key") == FilterKey.TEXT_NOT_CONTAINED:
841+
if hint.filter_key == FilterKey.TEXT_NOT_CONTAINED:
844842
cond_attr.append({"bool": {"must_not": [date_cond, str_cond]}})
845843
else:
846844
cond_attr.append({"bool": {"should": [date_cond, str_cond]}})
@@ -866,10 +864,10 @@ def _make_an_attribute_filter(hint: AttrHint, keyword: str) -> dict[str, dict]:
866864
)
867865

868866
elif hint_keyword_val:
869-
if "exact_match" not in hint:
867+
if hint.exact_match is None:
870868
cond_val.append({"regexp": {"attr.value": _get_regex_pattern(hint_keyword_val)}})
871869

872-
if hint.get("filter_key") == FilterKey.TEXT_NOT_CONTAINED:
870+
if hint.filter_key == FilterKey.TEXT_NOT_CONTAINED:
873871
cond_attr.append({"bool": {"must_not": cond_val}})
874872
else:
875873
cond_attr.append({"bool": {"should": cond_val}})
@@ -1004,7 +1002,7 @@ def make_search_results(
10041002
# formalize attribute values according to the type
10051003
for attrinfo in entry_info["attr"]:
10061004
# Skip other than the target Attribute
1007-
if attrinfo["name"] not in [x["name"] for x in hint_attrs]:
1005+
if attrinfo["name"] not in [x.name for x in hint_attrs]:
10081006
continue
10091007

10101008
ret_attrinfo: AdvancedSearchResultRecordAttr = {}
@@ -1024,7 +1022,7 @@ def make_search_results(
10241022
record.attrs[attrinfo["name"]] = ret_attrinfo
10251023

10261024
# Check for has permission to EntityAttr
1027-
if attrinfo["name"] not in [x["name"] for x in hint_attrs if x["is_readable"]]:
1025+
if attrinfo["name"] not in [x.name for x in hint_attrs if x.is_readable]:
10281026
ret_attrinfo["is_readable"] = False
10291027
continue
10301028

airone/tests/test_elasticsearch.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from django.test import TestCase
22

33
from airone.lib import elasticsearch
4-
from airone.lib.elasticsearch import AdvancedSearchResultRecord
4+
from airone.lib.elasticsearch import AdvancedSearchResultRecord, AttrHint
55
from airone.lib.types import AttrType
66
from entity.models import Entity, EntityAttr
77
from entry.models import Attribute, AttributeValue, Entry
@@ -42,8 +42,8 @@ def test_make_query(self):
4242
query = elasticsearch.make_query(
4343
hint_entity=self._entity,
4444
hint_attrs=[
45-
{"name": "a1", "keyword": "hoge|fu&ga"},
46-
{"name": "a2", "keyword": ""},
45+
AttrHint(name="a1", keyword="hoge|fu&ga"),
46+
AttrHint(name="a2", keyword=""),
4747
],
4848
entry_name="entry1",
4949
)
@@ -359,7 +359,9 @@ def test_make_search_results(self):
359359
}
360360
}
361361

362-
hint_attrs = [{"name": "test_attr", "keyword": "", "is_readable": True}]
362+
hint_attrs = [
363+
AttrHint(name="test_attr", keyword="", is_readable=True),
364+
]
363365
hint_referral = ""
364366
results = elasticsearch.make_search_results(self._user, res, hint_attrs, hint_referral, 100)
365367

api_v1/entry/serializer.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from rest_framework.exceptions import ValidationError
33

44
from airone.exceptions import ElasticsearchException
5+
from airone.lib.elasticsearch import AttrHint
56
from airone.lib.log import Logger
67
from entity.models import Entity, EntityAttr
78
from entry.models import Entry
@@ -294,10 +295,10 @@ def _do_forward_search(sub_query, sub_query_result):
294295

295296
# Query for forward search
296297
hint_attrs = [
297-
{
298-
"name": sub_query["name"],
299-
"keyword": search_keyword,
300-
}
298+
AttrHint(
299+
name=sub_query["name"],
300+
keyword=search_keyword,
301+
)
301302
]
302303

303304
# get Entry informations from result

api_v1/entry/views.py

+22-10
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
import pytz
44
from django.conf import settings
55
from django.db.models import Q
6+
from pydantic import ValidationError
67
from rest_framework import status
78
from rest_framework.response import Response
89
from rest_framework.views import APIView
910

1011
from airone.exceptions import ElasticsearchException
1112
from airone.lib.acl import ACLType
13+
from airone.lib.elasticsearch import AttrHint
1214
from api_v1.entry.serializer import EntrySearchChainSerializer
1315
from entity.models import Entity
1416
from entry.models import Entry
@@ -62,15 +64,15 @@ def post(self, request, format=None):
6264

6365
hint_entities = request.data.get("entities")
6466
hint_entry_name = request.data.get("entry_name", "")
65-
hint_attrs = request.data.get("attrinfo")
6667
hint_referral = request.data.get("referral")
68+
attrinfo = request.data.get("attrinfo", [])
6769
is_output_all = request.data.get("is_output_all", True)
6870
entry_limit = request.data.get("entry_limit", CONFIG_ENTRY.MAX_LIST_ENTRIES)
6971

7072
if (
7173
not isinstance(hint_entities, list)
7274
or not isinstance(hint_entry_name, str)
73-
or not isinstance(hint_attrs, list)
75+
or not isinstance(attrinfo, list)
7476
or not isinstance(is_output_all, bool)
7577
or (hint_referral and not isinstance(hint_referral, str))
7678
or not isinstance(entry_limit, int)
@@ -79,21 +81,31 @@ def post(self, request, format=None):
7981
"The type of parameter is incorrect", status=status.HTTP_400_BAD_REQUEST
8082
)
8183

84+
try:
85+
hint_attrs = [
86+
AttrHint(
87+
name=x.get("name"),
88+
keyword=x.get("keyword"),
89+
filter_key=x.get("filter_key"),
90+
exact_match=x.get("exact_match"),
91+
)
92+
for x in attrinfo
93+
]
94+
except (TypeError, ValidationError):
95+
return Response(
96+
"The type of parameter 'attrinfo' is incorrect",
97+
status=status.HTTP_400_BAD_REQUEST,
98+
)
99+
82100
# forbid to input large size request
83101
if len(hint_entry_name) > CONFIG_ENTRY.MAX_QUERY_SIZE:
84102
return Response("Sending parameter is too large", status=400)
85103

86104
# check attribute params
87105
for hint_attr in hint_attrs:
88-
if "name" not in hint_attr:
89-
return Response("The name key is required for attrinfo parameter", status=400)
90-
if not isinstance(hint_attr["name"], str):
91-
return Response("Invalid value for attrinfo parameter", status=400)
92-
if hint_attr.get("keyword"):
93-
if not isinstance(hint_attr["keyword"], str):
94-
return Response("Invalid value for attrinfo parameter", status=400)
106+
if hint_attr.keyword:
95107
# forbid to input large size request
96-
if len(hint_attr["keyword"]) > CONFIG_ENTRY.MAX_QUERY_SIZE:
108+
if len(hint_attr.keyword) > CONFIG_ENTRY.MAX_QUERY_SIZE:
97109
return Response("Sending parameter is too large", status=400)
98110

99111
# check entities params

api_v1/tests/entry/test_api.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -29,25 +29,24 @@ def test_search_invalid_param(self):
2929
params = {**valid_params, **invalid_param}
3030
resp = self.client.post("/api/v1/entry/search", json.dumps(params), "application/json")
3131
self.assertEqual(resp.status_code, 400)
32-
self.assertEqual(resp.content, b'"The type of parameter is incorrect"')
3332

3433
params = {**valid_params, **{"attrinfo": [{"hoge": "value"}]}}
3534
resp = self.client.post("/api/v1/entry/search", json.dumps(params), "application/json")
3635
self.assertEqual(resp.status_code, 400)
37-
self.assertEqual(resp.content, b'"The name key is required for attrinfo parameter"')
36+
self.assertEqual(resp.content, b"\"The type of parameter 'attrinfo' is incorrect\"")
3837

3938
params = {**valid_params, **{"attrinfo": [{"name": ["hoge"]}]}}
4039
resp = self.client.post("/api/v1/entry/search", json.dumps(params), "application/json")
4140
self.assertEqual(resp.status_code, 400)
42-
self.assertEqual(resp.content, b'"Invalid value for attrinfo parameter"')
41+
self.assertEqual(resp.content, b"\"The type of parameter 'attrinfo' is incorrect\"")
4342

4443
params = {
4544
**valid_params,
4645
**{"attrinfo": [{"name": "value", "keyword": ["hoge"]}]},
4746
}
4847
resp = self.client.post("/api/v1/entry/search", json.dumps(params), "application/json")
4948
self.assertEqual(resp.status_code, 400)
50-
self.assertEqual(resp.content, b'"Invalid value for attrinfo parameter"')
49+
self.assertEqual(resp.content, b"\"The type of parameter 'attrinfo' is incorrect\"")
5150

5251
def test_narrow_down_advanced_search_results(self):
5352
user = self.admin_login()

dashboard/tasks.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from natsort import natsorted
99

1010
from airone.celery import app
11-
from airone.lib.elasticsearch import AdvancedSearchResultRecord
11+
from airone.lib.elasticsearch import AdvancedSearchResultRecord, AttrHint
1212
from airone.lib.job import may_schedule_until_job_is_ready
1313
from airone.lib.types import AttrType
1414
from entry.services import AdvancedSearchService
@@ -174,13 +174,14 @@ def export_search_result(self, job: Job):
174174
has_referral: bool = recv_data.get("has_referral", False)
175175
referral_name: str | None = recv_data.get("referral_name")
176176
entry_name: str | None = recv_data.get("entry_name")
177+
hint_attrs = [AttrHint.model_validate(attr) for attr in recv_data["attrinfo"]]
177178
if has_referral and referral_name is None:
178179
referral_name = ""
179180

180181
resp = AdvancedSearchService.search_entries(
181182
user,
182183
recv_data["entities"],
183-
recv_data["attrinfo"],
184+
hint_attrs,
184185
settings.ES_CONFIG["MAXIMUM_RESULTS_NUM"],
185186
entry_name,
186187
referral_name,

dashboard/tests/test_view.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,16 @@ def test_show_advanced_search_results(self):
213213
attr = entry.attrs.first()
214214
self.assertEqual(resp.status_code, 200)
215215
self.assertEqual(
216-
resp.context["hint_attrs"], [{"name": "attr", "is_readable": attr.is_public}]
216+
resp.context["hint_attrs"],
217+
[
218+
{
219+
"name": "attr",
220+
"is_readable": attr.is_public,
221+
"filter_key": None,
222+
"keyword": None,
223+
"exact_match": None,
224+
}
225+
],
217226
)
218227
self.assertEqual(resp.context["results"]["ret_count"], 20)
219228
self.assertEqual(len(resp.context["results"]["ret_values"]), 20)
@@ -320,7 +329,7 @@ def test_show_advanced_search_results(self):
320329
self.assertEqual(resp.status_code, 400)
321330
self.assertEqual(
322331
resp.content.decode("utf-8"),
323-
"The name key is required for attrinfo parameter",
332+
"Invalid value for attrinfo parameter",
324333
)
325334

326335
resp = self.client.get(

0 commit comments

Comments
 (0)