diff --git a/index/scorch/empty_vec.go b/index/scorch/empty_vec.go new file mode 100644 index 000000000..d9d2199d9 --- /dev/null +++ b/index/scorch/empty_vec.go @@ -0,0 +1,30 @@ +//go:build vectors +// +build vectors + +package scorch + +import segment "github.com/blevesearch/scorch_segment_api/v2" + +type emptyVecPostingsIterator struct{} + +func (e *emptyVecPostingsIterator) Next() (segment.VecPosting, error) { + return nil, nil +} + +func (e *emptyVecPostingsIterator) Advance(uint64) (segment.VecPosting, error) { + return nil, nil +} + +func (e *emptyVecPostingsIterator) Size() int { + return 0 +} + +func (e *emptyVecPostingsIterator) BytesRead() uint64 { + return 0 +} + +func (e *emptyVecPostingsIterator) ResetBytesRead(uint64) {} + +func (e *emptyVecPostingsIterator) BytesWritten() uint64 { return 0 } + +var anemptyVecPostingsIterator = &emptyVecPostingsIterator{} diff --git a/index/scorch/optimize_knn.go b/index/scorch/optimize_knn.go index ca179574c..2774a9bd2 100644 --- a/index/scorch/optimize_knn.go +++ b/index/scorch/optimize_knn.go @@ -114,17 +114,22 @@ func (o *OptimizeVR) Finish() error { eligibleVectorInternalIDs.And(snapshotGlobalDocNums[index]) } - eligibleLocalDocNums := make([]uint64, - eligibleVectorInternalIDs.GetCardinality()) - // get the (segment-)local document numbers - for i, docNum := range eligibleVectorInternalIDs.ToArray() { - localDocNum := o.snapshot.localDocNumFromGlobal(index, - uint64(docNum)) - eligibleLocalDocNums[i] = localDocNum + if eligibleVectorInternalIDs.GetCardinality() > 0 { + eligibleLocalDocNums := make([]uint64, 0, + eligibleVectorInternalIDs.GetCardinality()) + // get the (segment-)local document numbers + for _, docNum := range eligibleVectorInternalIDs.ToArray() { + localDocNum := o.snapshot.localDocNumFromGlobal(index, + uint64(docNum)) + eligibleLocalDocNums = append(eligibleLocalDocNums, + localDocNum) + } + + if len(eligibleLocalDocNums) > 0 { + pl, err = vecIndex.SearchWithFilter(vr.vector, vr.k, + eligibleLocalDocNums, vr.searchParams) + } } - - pl, err = vecIndex.SearchWithFilter(vr.vector, vr.k, - eligibleLocalDocNums, vr.searchParams) } else { pl, err = vecIndex.Search(vr.vector, vr.k, vr.searchParams) } @@ -139,10 +144,14 @@ func (o *OptimizeVR) Finish() error { atomic.AddUint64(&o.snapshot.parent.stats.TotKNNSearches, uint64(1)) - // postings and iterators are already alloc'ed when - // IndexSnapshotVectorReader is created - vr.postings[index] = pl - vr.iterators[index] = pl.Iterator(vr.iterators[index]) + if pl != nil && pl.Count() > 0 { + // postings and iterators are already alloc'ed when + // IndexSnapshotVectorReader is created + vr.postings[index] = pl + vr.iterators[index] = pl.Iterator(vr.iterators[index]) + } else { + vr.iterators[index] = &emptyVecPostingsIterator{} + } } go vecIndex.Close() } diff --git a/index/scorch/snapshot_vector_index.go b/index/scorch/snapshot_vector_index.go index bcb05024d..51f5fb5b5 100644 --- a/index/scorch/snapshot_vector_index.go +++ b/index/scorch/snapshot_vector_index.go @@ -65,9 +65,13 @@ func (is *IndexSnapshot) VectorReaderWithFilter(ctx context.Context, vector []fl if rv.postings == nil { rv.postings = make([]segment_api.VecPostingsList, len(is.segment)) + } if rv.iterators == nil { rv.iterators = make([]segment_api.VecPostingsIterator, len(is.segment)) + for index := 0; index < len(is.segment); index++ { + rv.iterators[index] = anemptyVecPostingsIterator + } } // initialize postings and iterators within the OptimizeVR's Finish()