Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make join attr types more type-safe #1353

Merged
merged 1 commit into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion entry/api_v2/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from django.db.models import Prefetch
from drf_spectacular.utils import extend_schema_field, extend_schema_serializer
from pydantic import BaseModel
from pydantic import BaseModel, RootModel
from rest_framework import serializers
from rest_framework.exceptions import PermissionDenied, ValidationError
from typing_extensions import TypedDict
Expand Down Expand Up @@ -140,6 +140,21 @@ class EntryAttributeType(TypedDict):
schema: EntityAttributeType


class AdvancedSearchJoinAttrAttrInfo(BaseModel):
name: str
keyword: str | None = None
filter_key: FilterKey | None = None


class AdvancedSearchJoinAttrInfo(BaseModel):
name: str
offset: int = 0
attrinfo: list[AdvancedSearchJoinAttrAttrInfo] = []


AdvancedSearchJoinAttrInfoList = RootModel[list[AdvancedSearchJoinAttrInfo]]


class EntityAttributeTypeSerializer(serializers.Serializer):
id = serializers.IntegerField()
name = serializers.CharField()
Expand Down
41 changes: 20 additions & 21 deletions entry/api_v2/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from entity.models import Entity, EntityAttr
from entry.api_v2.pagination import EntryReferralPagination
from entry.api_v2.serializers import (
AdvancedSearchJoinAttrInfo,
AdvancedSearchJoinAttrInfoList,
AdvancedSearchResultExportSerializer,
AdvancedSearchResultSerializer,
AdvancedSearchSerializer,
Expand Down Expand Up @@ -265,10 +267,12 @@ def post(self, request: Request) -> Response:
is_all_entities = serializer.validated_data["is_all_entities"]
entry_limit = serializer.validated_data["entry_limit"]
entry_offset = serializer.validated_data["entry_offset"]
join_attrs = serializer.validated_data.get("join_attrs", [])
join_attrs = AdvancedSearchJoinAttrInfoList.model_validate(
serializer.validated_data.get("join_attrs", [])
).root

def _get_joined_resp(
prev_results: list[AdvancedSearchResultRecord], join_attr: dict
prev_results: list[AdvancedSearchResultRecord], join_attr: AdvancedSearchJoinAttrInfo
) -> tuple[bool, dict]:
"""
This is a helper method for join_attrs that will get specified attr values
Expand All @@ -285,7 +289,7 @@ def _get_joined_resp(
Prefetch(
"attrs",
queryset=EntityAttr.objects.filter(
name=join_attr["name"], is_active=True
name=join_attr.name, is_active=True
).prefetch_related(
Prefetch(
"referral", queryset=Entity.objects.filter(is_active=True).only("id")
Expand All @@ -300,7 +304,7 @@ def _get_joined_resp(
if entity is None:
continue

attr = next((a for a in entity.attrs.all() if a.name == join_attr["name"]), None)
attr = next((a for a in entity.attrs.all() if a.name == join_attr.name), None)
if attr is None:
continue

Expand All @@ -309,7 +313,7 @@ def _get_joined_resp(
hint_entity_ids.extend([x.id for x in attr.referral.all()])

# set Item name
attrinfo = result.attrs[join_attr["name"]]
attrinfo = result.attrs[join_attr.name]

if attr.type == AttrType.OBJECT and attrinfo["value"]["name"] not in item_names:
item_names.append(attrinfo["value"]["name"])
Expand All @@ -332,24 +336,19 @@ def _get_joined_resp(

# set parameters to filter joining search results
hint_attrs: list[AttrHint] = []
for info in join_attr.get("attrinfo", []):
for info in join_attr.attrinfo:
hint_attrs.append(
AttrHint(
name=info["name"],
keyword=info.get("keyword"),
filter_key=info.get("filter_key"),
name=info.name,
keyword=info.keyword,
filter_key=info.filter_key,
)
)

# search Items from elasticsearch to join
return (
# This represents whether user want to narrow down results by keyword of joined attr
any(
[
x.get("keyword") or x.get("filter_key", 0) > 0
for x in join_attr.get("attrinfo", [])
]
),
any([x.keyword or (x.filter_key or 0) > 0 for x in join_attr.attrinfo]),
AdvancedSearchService.search_entries(
request.user,
hint_entity_ids=list(set(hint_entity_ids)), # this removes depulicated IDs
Expand All @@ -359,7 +358,7 @@ def _get_joined_resp(
hint_referral=None,
is_output_all=is_output_all,
hint_referral_entity_id=None,
offset=join_attr.get("offset", 0),
offset=join_attr.offset,
).dict(),
)

Expand Down Expand Up @@ -447,20 +446,20 @@ def _get_ref_id_from_es_result(attrinfo):
(will_filter_by_joined_attr, joined_resp) = _get_joined_resp(resp.ret_values, join_attr)
# This is needed to set result as blank value
blank_joining_info = {
"%s.%s" % (join_attr["name"], k["name"]): {
"%s.%s" % (join_attr.name, k.name): {
"is_readable": True,
"type": AttrType.STRING,
"value": "",
}
for k in join_attr["attrinfo"]
for k in join_attr.attrinfo
}

# convert search result to dict to be able to handle it without loop
joined_resp_info = {
x["entry"]["id"]: {
"%s.%s" % (join_attr["name"], k): v
"%s.%s" % (join_attr.name, k): v
for k, v in x["attrs"].items()
if any(_x["name"] == k for _x in join_attr["attrinfo"])
if any(_x.name == k for _x in join_attr.attrinfo)
}
for x in joined_resp["ret_values"]
}
Expand All @@ -470,7 +469,7 @@ def _get_ref_id_from_es_result(attrinfo):
joined_ret_values = []
for resp_result in resp.ret_values:
# joining search result to original one
ref_info = resp_result.attrs.get(join_attr["name"])
ref_info = resp_result.attrs.get(join_attr.name)

# This get referral Item-ID from joined search result
ref_list = _get_ref_id_from_es_result(ref_info)
Expand Down