Skip to content

Commit

Permalink
MB-61889: support search with params
Browse files Browse the repository at this point in the history
- update knn search request syntax, to enable users to supply search-time parameters,
  using which they can control the latency v/s recall tradeoff.
- Supported parameters are
  + ivf_nprobe_pct
  + ivf_max_codes_pct
- These parameters will be applied to those segments which have faissIVFIndex

- In future, to support other faiss Index classes (like hnsw), the list of
  supported knn search params can be extended.
  • Loading branch information
moshaad7 committed Jul 22, 2024
1 parent 64ab008 commit 359a033
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 10 deletions.
2 changes: 1 addition & 1 deletion index/scorch/optimize_knn.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func (o *OptimizeVR) Finish() error {
for _, vr := range vrs {
// for each VR, populate postings list and iterators
// by passing the obtained vector index and getting similar vectors.
pl, err := vecIndex.Search(vr.vector, vr.k)
pl, err := vecIndex.Search(vr.vector, vr.k, vr.searchParams)
if err != nil {
errorsM.Lock()
errors = append(errors, err)
Expand Down
5 changes: 4 additions & 1 deletion index/scorch/snapshot_index_vr.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package scorch
import (
"bytes"
"context"
"encoding/json"
"fmt"
"reflect"

Expand Down Expand Up @@ -48,6 +49,8 @@ type IndexSnapshotVectorReader struct {
currPosting segment_api.VecPosting
currID index.IndexInternalID
ctx context.Context

searchParams json.RawMessage
}

func (i *IndexSnapshotVectorReader) Size() int {
Expand Down Expand Up @@ -103,7 +106,7 @@ func (i *IndexSnapshotVectorReader) Advance(ID index.IndexInternalID,
preAlloced *index.VectorDoc) (*index.VectorDoc, error) {

if i.currPosting != nil && bytes.Compare(i.currID, ID) >= 0 {
i2, err := i.snapshot.VectorReader(i.ctx, i.vector, i.field, i.k)
i2, err := i.snapshot.VectorReader(i.ctx, i.vector, i.field, i.k, i.searchParams)
if err != nil {
return nil, err
}
Expand Down
12 changes: 7 additions & 5 deletions index/scorch/snapshot_vector_index.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,22 @@ package scorch

import (
"context"
"encoding/json"

index "github.com/blevesearch/bleve_index_api"
segment_api "github.com/blevesearch/scorch_segment_api/v2"
)

func (is *IndexSnapshot) VectorReader(ctx context.Context, vector []float32,
field string, k int64) (
field string, k int64, searchParams json.RawMessage) (
index.VectorReader, error) {

rv := &IndexSnapshotVectorReader{
vector: vector,
field: field,
k: k,
snapshot: is,
vector: vector,
field: field,
k: k,
snapshot: is,
searchParams: searchParams,
}

if rv.postings == nil {
Expand Down
10 changes: 9 additions & 1 deletion search/query/knn.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package query

import (
"context"
"encoding/json"
"fmt"

"github.com/blevesearch/bleve/v2/mapping"
Expand All @@ -32,6 +33,9 @@ type KNNQuery struct {
Vector []float32 `json:"vector"`
K int64 `json:"k"`
BoostVal *Boost `json:"boost,omitempty"`

// see KNNRequest.Params for description
Params json.RawMessage `json:"params"`
}

func NewKNNQuery(vector []float32) *KNNQuery {
Expand Down Expand Up @@ -59,6 +63,10 @@ func (q *KNNQuery) Boost() float64 {
return q.BoostVal.Value()
}

func (q *KNNQuery) SetParams(params json.RawMessage) {
q.Params = params
}

func (q *KNNQuery) Searcher(ctx context.Context, i index.IndexReader,
m mapping.IndexMapping, options search.SearcherOptions) (search.Searcher, error) {
fieldMapping := m.FieldMappingForPath(q.VectorField)
Expand All @@ -70,5 +78,5 @@ func (q *KNNQuery) Searcher(ctx context.Context, i index.IndexReader,
return nil, fmt.Errorf("k must be greater than 0 and vector must be non-empty")
}
return searcher.NewKNNSearcher(ctx, i, m, options, q.VectorField,
q.Vector, q.K, q.BoostVal.Value(), similarityMetric)
q.Vector, q.K, q.BoostVal.Value(), similarityMetric, q.Params)
}
7 changes: 5 additions & 2 deletions search/searcher/search_knn.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package searcher

import (
"context"
"encoding/json"
"reflect"

"github.com/blevesearch/bleve/v2/mapping"
Expand Down Expand Up @@ -48,9 +49,11 @@ type KNNSearcher struct {

func NewKNNSearcher(ctx context.Context, i index.IndexReader, m mapping.IndexMapping,
options search.SearcherOptions, field string, vector []float32, k int64,
boost float64, similarityMetric string) (search.Searcher, error) {
boost float64, similarityMetric string, searchParams json.RawMessage) (
search.Searcher, error) {

if vr, ok := i.(index.VectorIndexReader); ok {
vectorReader, err := vr.VectorReader(ctx, vector, field, k)
vectorReader, err := vr.VectorReader(ctx, vector, field, k, searchParams)
if err != nil {
return nil, err
}
Expand Down
11 changes: 11 additions & 0 deletions search_knn.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,16 @@ type KNNRequest struct {
VectorBase64 string `json:"vector_base64"`
K int64 `json:"k"`
Boost *query.Boost `json:"boost,omitempty"`

// Search parameters for the field's vector index part of the segment.
// Value of it depends on the field's backing vector index implementation.
//
// For Faiss IVF index, supported search params are:
// - ivf_nprobe_pct : int // percentage of total clusters to search
// - ivf_max_codes_pct : float // percentage of total vectors to visit to do a query (across all clusters)
//
// Consult go-faiss to know all supported search params
Params json.RawMessage `json:"params"`
}

func (r *SearchRequest) AddKNN(field string, vector []float32, k int64, boost float64) {
Expand Down Expand Up @@ -214,6 +224,7 @@ func createKNNQuery(req *SearchRequest) (query.Query, []int64, int64, error) {
knnQuery.SetFieldVal(knn.Field)
knnQuery.SetK(knn.K)
knnQuery.SetBoost(knn.Boost.Value())
knnQuery.SetParams(knn.Params)
subQueries = append(subQueries, knnQuery)
kArray = append(kArray, knn.K)
sumOfK += knn.K
Expand Down

0 comments on commit 359a033

Please sign in to comment.