Skip to content

Commit 7ad5ef2

Browse files
Support rerank
Signed-off-by: junjiejiangjjj <junjie.jiang@zilliz.com>
1 parent 7a5e3f4 commit 7ad5ef2

File tree

9 files changed

+163
-12
lines changed

9 files changed

+163
-12
lines changed

examples/simple_rerank.py

+92
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import time
2+
import numpy as np
3+
from pymilvus import (
4+
MilvusClient,
5+
DataType,
6+
Function,
7+
FunctionType,
8+
AnnSearchRequest,
9+
)
10+
11+
fmt = "\n=== {:30} ===\n"
12+
dim = 8
13+
collection_name = "hello_milvus"
14+
milvus_client = MilvusClient("http://localhost:19530")
15+
16+
has_collection = milvus_client.has_collection(collection_name, timeout=5)
17+
if has_collection:
18+
milvus_client.drop_collection(collection_name)
19+
20+
schema = milvus_client.create_schema(enable_dynamic_field=False, auto_id=True)
21+
schema.add_field("id", DataType.INT64, is_primary=True)
22+
schema.add_field("embeddings", DataType.FLOAT_VECTOR, dim=dim)
23+
schema.add_field("ts", DataType.INT64)
24+
25+
26+
index_params = milvus_client.prepare_index_params()
27+
index_params.add_index(field_name = "embeddings", index_type="FLAT", metric_type="L2")
28+
milvus_client.create_collection(collection_name, schema=schema, index_params=index_params, consistency_level="Strong")
29+
30+
print(fmt.format(" all collections "))
31+
print(milvus_client.list_collections())
32+
33+
print(fmt.format(f"schema of collection {collection_name}"))
34+
print(milvus_client.describe_collection(collection_name))
35+
36+
rng = np.random.default_rng(seed=19530)
37+
rows = [
38+
{"embeddings": rng.random((1, dim))[0], "ts": 100},
39+
{"embeddings": rng.random((1, dim))[0], "ts": 200},
40+
{"embeddings": rng.random((1, dim))[0], "ts": 300},
41+
{"embeddings": rng.random((1, dim))[0], "ts": 400},
42+
{"embeddings": rng.random((1, dim))[0], "ts": 500},
43+
{"embeddings": rng.random((1, dim))[0], "ts": 600},
44+
]
45+
46+
print(fmt.format("Start inserting entities"))
47+
insert_result = milvus_client.insert(collection_name, rows)
48+
print(fmt.format("Inserting entities done"))
49+
print(insert_result)
50+
51+
52+
print(fmt.format("Start load collection "))
53+
milvus_client.load_collection(collection_name)
54+
55+
rng = np.random.default_rng(seed=19530)
56+
vectors_to_search = rng.random((1, dim))
57+
58+
ranker = Function(
59+
name="rerank_fn",
60+
input_field_names=["ts"],
61+
function_type=FunctionType.RERANK,
62+
params={
63+
"reranker": "decay",
64+
"function": "exp",
65+
"origin": 0,
66+
"offset": 200,
67+
"decay": 0.9,
68+
"scale": 100
69+
}
70+
)
71+
72+
print(fmt.format(f"Start search with retrieve serveral fields."))
73+
result = milvus_client.search(collection_name, vectors_to_search, limit=3, output_fields=["*"], ranker=ranker)
74+
for hits in result:
75+
for hit in hits:
76+
print(f"hit: {hit}")
77+
78+
vectors_to_search = rng.random((1, dim))
79+
search_param = {
80+
"data": vectors_to_search,
81+
"anns_field": "embeddings",
82+
"param": {"metric_type": "L2"},
83+
"limit": 3,
84+
}
85+
req = AnnSearchRequest(**search_param)
86+
87+
hybrid_res = milvus_client.hybrid_search(collection_name, [req, req], ranker=ranker, limit=3, output_fields=["ts"])
88+
for hits in hybrid_res:
89+
for hit in hits:
90+
print(f" hybrid search hit: {hit}")
91+
92+
milvus_client.drop_collection(collection_name)

pymilvus/client/async_grpc_handler.py

+6
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,7 @@ async def search(
572572
output_fields: Optional[List[str]] = None,
573573
round_decimal: int = -1,
574574
timeout: Optional[float] = None,
575+
ranker: Optional["Function"] = None,
575576
**kwargs,
576577
):
577578
await self.ensure_channel_ready()
@@ -584,6 +585,7 @@ async def search(
584585
output_fields=output_fields,
585586
guarantee_timestamp=kwargs.get("guarantee_timestamp"),
586587
timeout=timeout,
588+
ranker=ranker,
587589
)
588590
request = Prepare.search_requests_with_expr(
589591
collection_name,
@@ -595,6 +597,7 @@ async def search(
595597
partition_names,
596598
output_fields,
597599
round_decimal,
600+
ranker=ranker,
598601
**kwargs,
599602
)
600603
return await self._execute_search(request, timeout, round_decimal=round_decimal, **kwargs)
@@ -610,6 +613,7 @@ async def hybrid_search(
610613
output_fields: Optional[List[str]] = None,
611614
round_decimal: int = -1,
612615
timeout: Optional[float] = None,
616+
ranker: Optional["Function"] = None,
613617
**kwargs,
614618
):
615619
await self.ensure_channel_ready()
@@ -620,6 +624,7 @@ async def hybrid_search(
620624
output_fields=output_fields,
621625
guarantee_timestamp=kwargs.get("guarantee_timestamp"),
622626
timeout=timeout,
627+
ranker=ranker,
623628
)
624629

625630
requests = []
@@ -645,6 +650,7 @@ async def hybrid_search(
645650
partition_names,
646651
output_fields,
647652
round_decimal,
653+
ranker=ranker,
648654
**kwargs,
649655
)
650656
return await self._execute_hybrid_search(

pymilvus/client/check.py

+8
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from pymilvus.exceptions import ParamError
66
from pymilvus.grpc_gen import milvus_pb2 as milvus_types
7+
from pymilvus.orm.schema import Function
78

89
from . import entity_helper
910
from .singleton_utils import Singleton
@@ -315,6 +316,12 @@ def is_legal_operate_privilege_group_type(operate_privilege_group_type: Any) ->
315316
milvus_types.OperatePrivilegeGroupType.RemovePrivilegesFromGroup,
316317
)
317318

319+
def is_legal_ranker(ranker: Any) -> bool:
320+
if ranker is None:
321+
return True
322+
if not isinstance(ranker, Function):
323+
return False
324+
return True
318325

319326
class ParamChecker(metaclass=Singleton):
320327
def __init__(self) -> None:
@@ -363,6 +370,7 @@ def __init__(self) -> None:
363370
"privilege_group": is_legal_privilege_group,
364371
"privileges": is_legal_privileges,
365372
"operate_privilege_group_type": is_legal_operate_privilege_group_type,
373+
"ranker": is_legal_ranker,
366374
}
367375

368376
def check(self, key: str, value: Callable):

pymilvus/client/grpc_handler.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,7 @@ def search(
888888
output_fields: Optional[List[str]] = None,
889889
round_decimal: int = -1,
890890
timeout: Optional[float] = None,
891+
ranker: Optional["Function"] = None,
891892
**kwargs,
892893
):
893894
check_pass_param(
@@ -899,6 +900,7 @@ def search(
899900
output_fields=output_fields,
900901
guarantee_timestamp=kwargs.get("guarantee_timestamp"),
901902
timeout=timeout,
903+
ranker=ranker,
902904
)
903905

904906
request = Prepare.search_requests_with_expr(
@@ -911,6 +913,7 @@ def search(
911913
partition_names,
912914
output_fields,
913915
round_decimal,
916+
ranker=ranker,
914917
**kwargs,
915918
)
916919
return self._execute_search(request, timeout, round_decimal=round_decimal, **kwargs)
@@ -920,7 +923,7 @@ def hybrid_search(
920923
self,
921924
collection_name: str,
922925
reqs: List[AnnSearchRequest],
923-
rerank: BaseRanker,
926+
rerank: Union[BaseRanker, "Function"],
924927
limit: int,
925928
partition_names: Optional[List[str]] = None,
926929
output_fields: Optional[List[str]] = None,
@@ -956,7 +959,7 @@ def hybrid_search(
956959
hybrid_search_request = Prepare.hybrid_search_request_with_ranker(
957960
collection_name,
958961
requests,
959-
rerank.dict(),
962+
rerank,
960963
limit,
961964
partition_names,
962965
output_fields,

pymilvus/client/prepare.py

+32-5
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pymilvus.grpc_gen import common_pb2 as common_types
1010
from pymilvus.grpc_gen import milvus_pb2 as milvus_types
1111
from pymilvus.grpc_gen import schema_pb2 as schema_types
12-
from pymilvus.orm.schema import CollectionSchema, FieldSchema
12+
from pymilvus.orm.schema import CollectionSchema, FieldSchema, Function
1313
from pymilvus.orm.types import infer_dtype_by_scalar_data
1414

1515
from . import __version__, blob, check, entity_helper, ts_utils, utils
@@ -938,6 +938,7 @@ def search_requests_with_expr(
938938
partition_names: Optional[List[str]] = None,
939939
output_fields: Optional[List[str]] = None,
940940
round_decimal: int = -1,
941+
ranker: Optional[Function] = None,
941942
**kwargs,
942943
) -> milvus_types.SearchRequest:
943944
use_default_consistency = ts_utils.construct_guarantee_ts(collection_name, kwargs)
@@ -1049,24 +1050,31 @@ def search_requests_with_expr(
10491050
if expr is not None:
10501051
request.dsl = expr
10511052

1053+
if isinstance(ranker, Function):
1054+
request.function_score.CopyFrom(Prepare.ranker_to_function_score(ranker))
1055+
10521056
return request
10531057

10541058
@classmethod
10551059
def hybrid_search_request_with_ranker(
10561060
cls,
10571061
collection_name: str,
10581062
reqs: List,
1059-
rerank_param: Dict,
1063+
rerank: Union[Dict, Function],
10601064
limit: int,
10611065
partition_names: Optional[List[str]] = None,
10621066
output_fields: Optional[List[str]] = None,
10631067
round_decimal: int = -1,
10641068
**kwargs,
10651069
) -> milvus_types.HybridSearchRequest:
10661070
use_default_consistency = ts_utils.construct_guarantee_ts(collection_name, kwargs)
1067-
rerank_param["limit"] = limit
1068-
rerank_param["round_decimal"] = round_decimal
1069-
rerank_param["offset"] = kwargs.get("offset", 0)
1071+
rerank_param = {}
1072+
if isinstance(rerank, Dict):
1073+
rerank_param = rerank
1074+
else:
1075+
rerank_param["limit"] = limit
1076+
rerank_param["round_decimal"] = round_decimal
1077+
rerank_param["offset"] = kwargs.get("offset", 0)
10701078

10711079
request = milvus_types.HybridSearchRequest(
10721080
collection_name=collection_name,
@@ -1121,7 +1129,26 @@ def hybrid_search_request_with_ranker(
11211129
]
11221130
)
11231131

1132+
if isinstance(rerank, Function):
1133+
request.function_score.CopyFrom(Prepare.ranker_to_function_score(rerank))
11241134
return request
1135+
1136+
@staticmethod
1137+
def ranker_to_function_score(ranker: Function) -> schema_types.FunctionScore:
1138+
function_score = schema_types.FunctionScore(
1139+
functions=[
1140+
schema_types.FunctionSchema(
1141+
name=ranker.name,
1142+
type=ranker.type,
1143+
description=ranker.description,
1144+
input_field_names=ranker.input_field_names,
1145+
)
1146+
],
1147+
)
1148+
for k, v in ranker.params.items():
1149+
kv_pair = common_types.KeyValuePair(key=str(k), value=str(v))
1150+
function_score.functions[0].params.append(kv_pair)
1151+
return function_score
11251152

11261153
@classmethod
11271154
def create_alias_request(cls, collection_name: str, alias: str):

pymilvus/client/types.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ class FunctionType(IntEnum):
121121
UNKNOWN = 0
122122
BM25 = 1
123123
TEXTEMBEDDING = 2
124-
124+
RERANK = 3
125125

126126
class RangeType(IntEnum):
127127
LT = 0 # less than

pymilvus/milvus_client/milvus_client.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
ServerVersionIncompatibleException,
2525
)
2626
from pymilvus.orm import utility
27-
from pymilvus.orm.collection import CollectionSchema, FieldSchema
27+
from pymilvus.orm.collection import CollectionSchema, FieldSchema, Function
2828
from pymilvus.orm.connections import connections
2929
from pymilvus.orm.constants import FIELDS, METRIC_TYPE, TYPE, UNLIMITED
3030
from pymilvus.orm.iterator import QueryIterator, SearchIterator
@@ -297,7 +297,7 @@ def hybrid_search(
297297
self,
298298
collection_name: str,
299299
reqs: List[AnnSearchRequest],
300-
ranker: BaseRanker,
300+
ranker: Union[BaseRanker, Function],
301301
limit: int = 10,
302302
output_fields: Optional[List[str]] = None,
303303
timeout: Optional[float] = None,
@@ -309,7 +309,7 @@ def hybrid_search(
309309
Args:
310310
collection_name(``string``): The name of collection.
311311
reqs (``List[AnnSearchRequest]``): The vector search requests.
312-
ranker (``BaseRanker``): The ranker for rearrange nummer of limit results.
312+
ranker (``Union[BaseRanker, Function]``): The ranker for rearrange nummer of limit results.
313313
limit (``int``): The max number of returned record, also known as `topk`.
314314
315315
partition_names (``List[str]``, optional): The names of partitions to search on.
@@ -375,6 +375,7 @@ def search(
375375
timeout: Optional[float] = None,
376376
partition_names: Optional[List[str]] = None,
377377
anns_field: Optional[str] = None,
378+
ranker: Optional["Function"] = None,
378379
**kwargs,
379380
) -> List[List[dict]]:
380381
"""Search for a query vector/vectors.
@@ -389,6 +390,7 @@ def search(
389390
output_fields (List[str], optional): List of which field values to return. If None
390391
specified, only primary fields including distances will be returned.
391392
search_params (dict, optional): The search params to use for the search.
393+
ranker (Function, optional): The ranker to use for the search.
392394
timeout (float, optional): Timeout to use, overides the client level assigned at init.
393395
Defaults to None.
394396
@@ -412,6 +414,7 @@ def search(
412414
partition_names=partition_names,
413415
expr_params=kwargs.pop("filter_params", {}),
414416
timeout=timeout,
417+
ranker=ranker,
415418
**kwargs,
416419
)
417420
except Exception as ex:

0 commit comments

Comments
 (0)