Skip to content

Commit a5ee014

Browse files
committed
Added tests for new default scorer
1 parent 64018bd commit a5ee014

File tree

3 files changed

+165
-2
lines changed

3 files changed

+165
-2
lines changed

redis/commands/search/query.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,8 @@ def scorer(self, scorer: str) -> "Query":
179179
Use a different scoring function to evaluate document relevance.
180180
Default is `TFIDF`.
181181
182+
Since Redis 8.0 default was changed to BM25STD.
183+
182184
:param scorer: The scoring function to use
183185
(e.g. `TFIDF.DOCNORM` or `BM25`)
184186
"""

tests/test_asyncio/test_search.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,7 @@ async def test_client(decoded_r: redis.Redis):
341341

342342
@pytest.mark.redismod
343343
@pytest.mark.onlynoncluster
344+
@skip_if_server_version_gte("7.9.0")
344345
async def test_scores(decoded_r: redis.Redis):
345346
await decoded_r.ft().create_index((TextField("txt"),))
346347

@@ -361,6 +362,29 @@ async def test_scores(decoded_r: redis.Redis):
361362
assert "doc1" == res["results"][1]["id"]
362363

363364

365+
@pytest.mark.redismod
366+
@pytest.mark.onlynoncluster
367+
@skip_if_server_version_lt("7.9.0")
368+
async def test_scores_with_new_default_scorer(decoded_r: redis.Redis):
369+
await decoded_r.ft().create_index((TextField("txt"),))
370+
371+
await decoded_r.hset("doc1", mapping={"txt": "foo baz"})
372+
await decoded_r.hset("doc2", mapping={"txt": "foo bar"})
373+
374+
q = Query("foo ~bar").with_scores()
375+
res = await decoded_r.ft().search(q)
376+
if is_resp2_connection(decoded_r):
377+
assert 2 == res.total
378+
assert "doc2" == res.docs[0].id
379+
assert 0.87 == pytest.approx(res.docs[0].score, 0.01)
380+
assert "doc1" == res.docs[1].id
381+
else:
382+
assert 2 == res["total_results"]
383+
assert "doc2" == res["results"][0]["id"]
384+
assert 0.87 == pytest.approx(res["results"][0]["score"], 0.01)
385+
assert "doc1" == res["results"][1]["id"]
386+
387+
364388
@pytest.mark.redismod
365389
async def test_stopwords(decoded_r: redis.Redis):
366390
stopwords = ["foo", "bar", "baz"]
@@ -1029,6 +1053,7 @@ async def test_phonetic_matcher(decoded_r: redis.Redis):
10291053
@pytest.mark.onlynoncluster
10301054
# NOTE(imalinovskyi): This test contains hardcoded scores valid only for RediSearch 2.8+
10311055
@skip_ifmodversion_lt("2.8.0", "search")
1056+
@skip_if_server_version_gte("7.9.0")
10321057
async def test_scorer(decoded_r: redis.Redis):
10331058
await decoded_r.ft().create_index((TextField("description"),))
10341059

@@ -1087,6 +1112,69 @@ async def test_scorer(decoded_r: redis.Redis):
10871112
assert 0.0 == res["results"][0]["score"]
10881113

10891114

1115+
@pytest.mark.redismod
1116+
@pytest.mark.onlynoncluster
1117+
# NOTE(imalinovskyi): This test contains hardcoded scores valid only for RediSearch 2.8+
1118+
@skip_ifmodversion_lt("2.8.0", "search")
1119+
@skip_if_server_version_lt("7.9.0")
1120+
async def test_scorer_with_new_default_scorer(decoded_r: redis.Redis):
1121+
await decoded_r.ft().create_index((TextField("description"),))
1122+
1123+
await decoded_r.hset(
1124+
"doc1", mapping={"description": "The quick brown fox jumps over the lazy dog"}
1125+
)
1126+
await decoded_r.hset(
1127+
"doc2",
1128+
mapping={
1129+
"description": "Quick alice was beginning to get very tired of sitting by her quick sister on the bank, and of having nothing to do." # noqa
1130+
},
1131+
)
1132+
1133+
if is_resp2_connection(decoded_r):
1134+
# default scorer is BM25STD
1135+
res = await decoded_r.ft().search(Query("quick").with_scores())
1136+
assert 0.23 == pytest.approx(res.docs[0].score, 0.05)
1137+
res = await decoded_r.ft().search(Query("quick").scorer("TFIDF").with_scores())
1138+
assert 1.0 == res.docs[0].score
1139+
res = await decoded_r.ft().search(
1140+
Query("quick").scorer("TFIDF.DOCNORM").with_scores()
1141+
)
1142+
assert 0.14285714285714285 == res.docs[0].score
1143+
res = await decoded_r.ft().search(Query("quick").scorer("BM25").with_scores())
1144+
assert 0.22471909420069797 == res.docs[0].score
1145+
res = await decoded_r.ft().search(Query("quick").scorer("DISMAX").with_scores())
1146+
assert 2.0 == res.docs[0].score
1147+
res = await decoded_r.ft().search(
1148+
Query("quick").scorer("DOCSCORE").with_scores()
1149+
)
1150+
assert 1.0 == res.docs[0].score
1151+
res = await decoded_r.ft().search(
1152+
Query("quick").scorer("HAMMING").with_scores()
1153+
)
1154+
assert 0.0 == res.docs[0].score
1155+
else:
1156+
res = await decoded_r.ft().search(Query("quick").with_scores())
1157+
assert 0.23 == pytest.approx(res["results"][0]["score"], 0.05)
1158+
res = await decoded_r.ft().search(Query("quick").scorer("TFIDF").with_scores())
1159+
assert 1.0 == res["results"][0]["score"]
1160+
res = await decoded_r.ft().search(
1161+
Query("quick").scorer("TFIDF.DOCNORM").with_scores()
1162+
)
1163+
assert 0.14285714285714285 == res["results"][0]["score"]
1164+
res = await decoded_r.ft().search(Query("quick").scorer("BM25").with_scores())
1165+
assert 0.22471909420069797 == res["results"][0]["score"]
1166+
res = await decoded_r.ft().search(Query("quick").scorer("DISMAX").with_scores())
1167+
assert 2.0 == res["results"][0]["score"]
1168+
res = await decoded_r.ft().search(
1169+
Query("quick").scorer("DOCSCORE").with_scores()
1170+
)
1171+
assert 1.0 == res["results"][0]["score"]
1172+
res = await decoded_r.ft().search(
1173+
Query("quick").scorer("HAMMING").with_scores()
1174+
)
1175+
assert 0.0 == res["results"][0]["score"]
1176+
1177+
10901178
@pytest.mark.redismod
10911179
async def test_get(decoded_r: redis.Redis):
10921180
await decoded_r.ft().create_index((TextField("f1"), TextField("f2")))

tests/test_search.py

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,7 @@ def test_client(client):
314314

315315
@pytest.mark.redismod
316316
@pytest.mark.onlynoncluster
317+
@skip_if_server_version_gte("7.9.0")
317318
def test_scores(client):
318319
client.ft().create_index((TextField("txt"),))
319320

@@ -334,6 +335,29 @@ def test_scores(client):
334335
assert "doc1" == res["results"][1]["id"]
335336

336337

338+
@pytest.mark.redismod
339+
@pytest.mark.onlynoncluster
340+
@skip_if_server_version_lt("7.9.0")
341+
def test_scores_with_new_default_scorer(client):
342+
client.ft().create_index((TextField("txt"),))
343+
344+
client.hset("doc1", mapping={"txt": "foo baz"})
345+
client.hset("doc2", mapping={"txt": "foo bar"})
346+
347+
q = Query("foo ~bar").with_scores()
348+
res = client.ft().search(q)
349+
if is_resp2_connection(client):
350+
assert 2 == res.total
351+
assert "doc2" == res.docs[0].id
352+
assert 0.87 == pytest.approx(res.docs[0].score, 0.01)
353+
assert "doc1" == res.docs[1].id
354+
else:
355+
assert 2 == res["total_results"]
356+
assert "doc2" == res["results"][0]["id"]
357+
assert 0.87 == pytest.approx(res["results"][0]["score"], 0.01)
358+
assert "doc1" == res["results"][1]["id"]
359+
360+
337361
@pytest.mark.redismod
338362
def test_stopwords(client):
339363
client.ft().create_index((TextField("txt"),), stopwords=["foo", "bar", "baz"])
@@ -936,6 +960,7 @@ def test_phonetic_matcher(client):
936960
@pytest.mark.onlynoncluster
937961
# NOTE(imalinovskyi): This test contains hardcoded scores valid only for RediSearch 2.8+
938962
@skip_ifmodversion_lt("2.8.0", "search")
963+
@skip_if_server_version_gte("7.9.0")
939964
def test_scorer(client):
940965
client.ft().create_index((TextField("description"),))
941966

@@ -982,6 +1007,55 @@ def test_scorer(client):
9821007
assert 0.0 == res["results"][0]["score"]
9831008

9841009

1010+
@pytest.mark.redismod
1011+
@pytest.mark.onlynoncluster
1012+
@skip_if_server_version_lt("7.9.0")
1013+
def test_scorer_with_new_default_scorer(client):
1014+
client.ft().create_index((TextField("description"),))
1015+
1016+
client.hset(
1017+
"doc1", mapping={"description": "The quick brown fox jumps over the lazy dog"}
1018+
)
1019+
client.hset(
1020+
"doc2",
1021+
mapping={
1022+
"description": "Quick alice was beginning to get very tired of sitting by her quick sister on the bank, and of having nothing to do." # noqa
1023+
},
1024+
)
1025+
1026+
# default scorer is BM25STD
1027+
if is_resp2_connection(client):
1028+
res = client.ft().search(Query("quick").with_scores())
1029+
assert 0.23 == pytest.approx(res.docs[0].score, 0.05)
1030+
res = client.ft().search(Query("quick").scorer("TFIDF").with_scores())
1031+
assert 1.0 == res.docs[0].score
1032+
res = client.ft().search(Query("quick").scorer("TFIDF.DOCNORM").with_scores())
1033+
assert 0.14285714285714285 == res.docs[0].score
1034+
res = client.ft().search(Query("quick").scorer("BM25").with_scores())
1035+
assert 0.22471909420069797 == res.docs[0].score
1036+
res = client.ft().search(Query("quick").scorer("DISMAX").with_scores())
1037+
assert 2.0 == res.docs[0].score
1038+
res = client.ft().search(Query("quick").scorer("DOCSCORE").with_scores())
1039+
assert 1.0 == res.docs[0].score
1040+
res = client.ft().search(Query("quick").scorer("HAMMING").with_scores())
1041+
assert 0.0 == res.docs[0].score
1042+
else:
1043+
res = client.ft().search(Query("quick").with_scores())
1044+
assert 0.23 == pytest.approx(res["results"][0]["score"], 0.05)
1045+
res = client.ft().search(Query("quick").scorer("TFIDF").with_scores())
1046+
assert 1.0 == res["results"][0]["score"]
1047+
res = client.ft().search(Query("quick").scorer("TFIDF.DOCNORM").with_scores())
1048+
assert 0.14285714285714285 == res["results"][0]["score"]
1049+
res = client.ft().search(Query("quick").scorer("BM25").with_scores())
1050+
assert 0.22471909420069797 == res["results"][0]["score"]
1051+
res = client.ft().search(Query("quick").scorer("DISMAX").with_scores())
1052+
assert 2.0 == res["results"][0]["score"]
1053+
res = client.ft().search(Query("quick").scorer("DOCSCORE").with_scores())
1054+
assert 1.0 == res["results"][0]["score"]
1055+
res = client.ft().search(Query("quick").scorer("HAMMING").with_scores())
1056+
assert 0.0 == res["results"][0]["score"]
1057+
1058+
9851059
@pytest.mark.redismod
9861060
def test_get(client):
9871061
client.ft().create_index((TextField("f1"), TextField("f2")))
@@ -2605,9 +2679,8 @@ def test_search_missing_fields(client):
26052679
},
26062680
)
26072681

2608-
with pytest.raises(redis.exceptions.ResponseError) as e:
2682+
with pytest.raises(redis.exceptions.ResponseError):
26092683
client.ft().search(Query("ismissing(@title)").return_field("id").no_content())
2610-
assert "to be defined with 'INDEXMISSING'" in e.value.args[0]
26112684

26122685
res = client.ft().search(
26132686
Query("ismissing(@features)").return_field("id").no_content()

0 commit comments

Comments
 (0)