|
1 | 1 | import abc
|
2 |
| -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union |
| 2 | +from typing import Any, Dict, List, Optional, Tuple, Union |
3 | 3 |
|
4 | 4 | import ujson
|
5 | 5 |
|
@@ -351,19 +351,6 @@ def _pack(self, raw: Any):
|
351 | 351 | )
|
352 | 352 |
|
353 | 353 |
|
354 |
| -class SequenceIterator: |
355 |
| - def __init__(self, seq: Sequence[Any]): |
356 |
| - self._seq = seq |
357 |
| - self._idx = 0 |
358 |
| - |
359 |
| - def __next__(self) -> Any: |
360 |
| - if self._idx < len(self._seq): |
361 |
| - res = self._seq[self._idx] |
362 |
| - self._idx += 1 |
363 |
| - return res |
364 |
| - raise StopIteration |
365 |
| - |
366 |
| - |
367 | 354 | class BaseRanker:
|
368 | 355 | def __int__(self):
|
369 | 356 | return
|
@@ -394,16 +381,22 @@ def dict(self):
|
394 | 381 |
|
395 | 382 |
|
396 | 383 | class WeightedRanker(BaseRanker):
|
397 |
| - def __init__(self, *nums): |
| 384 | + def __init__(self, *nums, norm_score: bool = True): |
398 | 385 | self._strategy = RANKER_TYPE_WEIGHTED
|
399 | 386 | weights = []
|
400 | 387 | for num in nums:
|
| 388 | + # isinstance(True, int) is True, thus we need to check bool first |
| 389 | + if isinstance(num, bool) or not isinstance(num, (int, float)): |
| 390 | + error_msg = f"Weight must be a number, got {type(num)}" |
| 391 | + raise TypeError(error_msg) |
401 | 392 | weights.append(num)
|
402 | 393 | self._weights = weights
|
| 394 | + self._norm_score = norm_score |
403 | 395 |
|
404 | 396 | def dict(self):
|
405 | 397 | params = {
|
406 | 398 | "weights": self._weights,
|
| 399 | + "norm_score": self._norm_score, |
407 | 400 | }
|
408 | 401 | return {
|
409 | 402 | "strategy": self._strategy,
|
@@ -660,9 +653,6 @@ def get_fields_by_range(
|
660 | 653 | continue
|
661 | 654 | return field2data
|
662 | 655 |
|
663 |
| - def __iter__(self) -> SequenceIterator: |
664 |
| - return SequenceIterator(self) |
665 |
| - |
666 | 656 | def __str__(self) -> str:
|
667 | 657 | """Only print at most 10 query results"""
|
668 | 658 | reminder = f" ... and {len(self) - 10} results remaining" if len(self) > 10 else ""
|
@@ -743,9 +733,6 @@ def __init__(
|
743 | 733 |
|
744 | 734 | super().__init__(hits)
|
745 | 735 |
|
746 |
| - def __iter__(self) -> SequenceIterator: |
747 |
| - return SequenceIterator(self) |
748 |
| - |
749 | 736 | def __str__(self) -> str:
|
750 | 737 | """Only print at most 10 query results"""
|
751 | 738 | reminder = f" ... and {len(self) - 10} entities remaining" if len(self) > 10 else ""
|
|
0 commit comments