diff --git a/document/field_vector_base64.go b/document/field_vector_base64.go new file mode 100644 index 000000000..e62dbe0a2 --- /dev/null +++ b/document/field_vector_base64.go @@ -0,0 +1,140 @@ +// Copyright (c) 2024 Couchbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build vectors +// +build vectors + +package document + +import ( + "encoding/base64" + "encoding/binary" + "fmt" + "math" + + "github.com/blevesearch/bleve/v2/size" + index "github.com/blevesearch/bleve_index_api" +) + +type VectorBase64Field struct { + vectorField *VectorField + base64Encoding string +} + +func (n *VectorBase64Field) Size() int { + return n.vectorField.Size() +} + +func (n *VectorBase64Field) Name() string { + return n.vectorField.Name() +} + +func (n *VectorBase64Field) ArrayPositions() []uint64 { + return n.vectorField.ArrayPositions() +} + +func (n *VectorBase64Field) Options() index.FieldIndexingOptions { + return n.vectorField.Options() +} + +func (n *VectorBase64Field) NumPlainTextBytes() uint64 { + return n.vectorField.NumPlainTextBytes() +} + +func (n *VectorBase64Field) AnalyzedLength() int { + return n.vectorField.AnalyzedLength() +} + +func (n *VectorBase64Field) EncodedFieldType() byte { + return 'e' +} + +func (n *VectorBase64Field) AnalyzedTokenFrequencies() index.TokenFrequencies { + return n.vectorField.AnalyzedTokenFrequencies() +} + +func (n *VectorBase64Field) Analyze() { +} + +func (n *VectorBase64Field) Value() []byte { + return n.vectorField.Value() +} + +func (n *VectorBase64Field) GoString() string { + return fmt.Sprintf("&document.vectorFieldBase64Field{Name:%s, Options: %s, "+ + "Value: %+v}", n.vectorField.Name(), n.vectorField.Options(), n.vectorField.Value()) +} + +// For the sake of not polluting the API, we are keeping arrayPositions as a +// parameter, but it is not used. +func NewVectorBase64Field(name string, arrayPositions []uint64, vectorBase64 string, + dims int, similarity, vectorIndexOptimizedFor string) (*VectorBase64Field, error) { + + vector, err := DecodeVector(vectorBase64) + if err != nil { + return nil, err + } + + return &VectorBase64Field{ + vectorField: NewVectorFieldWithIndexingOptions(name, arrayPositions, + vector, dims, similarity, + vectorIndexOptimizedFor, DefaultVectorIndexingOptions), + + base64Encoding: vectorBase64, + }, nil +} + +// This function takes a base64 encoded string and decodes it into +// a vector. +func DecodeVector(encodedValue string) ([]float32, error) { + + // We first decode the encoded string into a byte array. + decodedString, err := base64.StdEncoding.DecodeString(encodedValue) + if err != nil { + return nil, err + } + + // The array is expected to be divisible by 4 because each float32 + // should occupy 4 bytes + if len(decodedString)%size.SizeOfFloat32 != 0 { + return nil, fmt.Errorf("Decoded byte array not divisible by %d", size.SizeOfFloat32) + } + dims := int(len(decodedString) / size.SizeOfFloat32) + decodedVector := make([]float32, dims) + + // We iterate through the array 4 bytes at a time and convert each of + // them to a float32 value by reading them in a little endian notation + for i := 0; i < dims; i++ { + bytes := decodedString[i*size.SizeOfFloat32 : (i+1)*size.SizeOfFloat32] + decodedVector[i] = math.Float32frombits(binary.LittleEndian.Uint32(bytes)) + } + + return decodedVector, nil +} + +func (n *VectorBase64Field) Vector() []float32 { + return n.vectorField.Vector() +} + +func (n *VectorBase64Field) Dims() int { + return n.vectorField.Dims() +} + +func (n *VectorBase64Field) Similarity() string { + return n.vectorField.Similarity() +} + +func (n *VectorBase64Field) IndexOptimizedFor() string { + return n.vectorField.IndexOptimizedFor() +} diff --git a/document/field_vector_base64_test.go b/document/field_vector_base64_test.go new file mode 100644 index 000000000..ac4bd8d4e --- /dev/null +++ b/document/field_vector_base64_test.go @@ -0,0 +1,113 @@ +// Copyright (c) 2024 Couchbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build vectors +// +build vectors + +package document + +import ( + "bytes" + "encoding/base64" + "encoding/binary" + "fmt" + "math/rand" + "testing" +) + +func TestDecodeVector(t *testing.T) { + vec := make([]float32, 2048) + for i := range vec { + vec[i] = rand.Float32() + } + + vecBytes := bytifyVec(vec) + encodedVec := base64.StdEncoding.EncodeToString(vecBytes) + + decodedVec, err := DecodeVector(encodedVec) + if err != nil { + t.Error(err) + } + if len(decodedVec) != len(vec) { + t.Errorf("Decoded vector dimensions not same as original vector dimensions") + } + + for i := range vec { + if vec[i] != decodedVec[i] { + t.Errorf("Decoded vector not the same as original vector") + } + } +} + +func BenchmarkDecodeVector128(b *testing.B) { + vec := make([]float32, 128) + for i := range vec { + vec[i] = rand.Float32() + } + + vecBytes := bytifyVec(vec) + encodedVec := base64.StdEncoding.EncodeToString(vecBytes) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _, _ = DecodeVector(encodedVec) + } +} + +func BenchmarkDecodeVector784(b *testing.B) { + vec := make([]float32, 784) + for i := range vec { + vec[i] = rand.Float32() + } + + vecBytes := bytifyVec(vec) + encodedVec := base64.StdEncoding.EncodeToString(vecBytes) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _, _ = DecodeVector(encodedVec) + } +} + +func BenchmarkDecodeVector1536(b *testing.B) { + vec := make([]float32, 1536) + for i := range vec { + vec[i] = rand.Float32() + } + + vecBytes := bytifyVec(vec) + encodedVec := base64.StdEncoding.EncodeToString(vecBytes) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _, _ = DecodeVector(encodedVec) + } +} + +func bytifyVec(vec []float32) []byte { + + buf := new(bytes.Buffer) + + for _, v := range vec { + err := binary.Write(buf, binary.LittleEndian, v) + if err != nil { + fmt.Println(err) + } + } + + return buf.Bytes() +} diff --git a/mapping/document.go b/mapping/document.go index 73bb124db..3131f33bf 100644 --- a/mapping/document.go +++ b/mapping/document.go @@ -443,6 +443,8 @@ func (dm *DocumentMapping) processProperty(property interface{}, path []string, fieldMapping.processGeoShape(property, pathString, path, indexes, context) } else if fieldMapping.Type == "geopoint" { fieldMapping.processGeoPoint(property, pathString, path, indexes, context) + } else if fieldMapping.Type == "vector_base64" { + fieldMapping.processVectorBase64(property, pathString, path, indexes, context) } else { fieldMapping.processString(propertyValueString, pathString, path, indexes, context) } diff --git a/mapping/mapping_no_vectors.go b/mapping/mapping_no_vectors.go index f9f35f57c..90cb1e225 100644 --- a/mapping/mapping_no_vectors.go +++ b/mapping/mapping_no_vectors.go @@ -21,11 +21,20 @@ func NewVectorFieldMapping() *FieldMapping { return nil } +func NewVectorBase64FieldMapping() *FieldMapping { + return nil +} + func (fm *FieldMapping) processVector(propertyMightBeVector interface{}, pathString string, path []string, indexes []uint64, context *walkContext) bool { return false } +func (fm *FieldMapping) processVectorBase64(propertyMightBeVector interface{}, + pathString string, path []string, indexes []uint64, context *walkContext) { + +} + // ----------------------------------------------------------------------------- // document validation functions diff --git a/mapping/mapping_vectors.go b/mapping/mapping_vectors.go index a0b712608..0ec7c0f9f 100644 --- a/mapping/mapping_vectors.go +++ b/mapping/mapping_vectors.go @@ -43,6 +43,17 @@ func NewVectorFieldMapping() *FieldMapping { } } +func NewVectorBase64FieldMapping() *FieldMapping { + return &FieldMapping{ + Type: "vector_base64", + Store: false, + Index: true, + IncludeInAll: false, + DocValues: false, + SkipFreqNorm: true, + } +} + // validate and process a flat vector func processFlatVector(vecV reflect.Value, dims int) ([]float32, bool) { if vecV.Len() != dims { @@ -140,13 +151,28 @@ func (fm *FieldMapping) processVector(propertyMightBeVector interface{}, return true } +func (fm *FieldMapping) processVectorBase64(propertyMightBeVectorBase64 interface{}, + pathString string, path []string, indexes []uint64, context *walkContext) { + encodedString, ok := propertyMightBeVectorBase64.(string) + if !ok { + return + } + + propertyMightBeVector, err := document.DecodeVector(encodedString) + if err != nil { + return + } + + fm.processVector(propertyMightBeVector, pathString, path, indexes, context) +} + // ----------------------------------------------------------------------------- // document validation functions func validateFieldMapping(field *FieldMapping, parentName string, fieldAliasCtx map[string]*FieldMapping) error { switch field.Type { - case "vector": + case "vector", "vector_base64": return validateVectorFieldAlias(field, parentName, fieldAliasCtx) default: // non-vector field return validateFieldType(field) diff --git a/mapping_vector.go b/mapping_vector.go index 594313861..c73dac9e5 100644 --- a/mapping_vector.go +++ b/mapping_vector.go @@ -22,3 +22,7 @@ import "github.com/blevesearch/bleve/v2/mapping" func NewVectorFieldMapping() *mapping.FieldMapping { return mapping.NewVectorFieldMapping() } + +func NewVectorBase64FieldMapping() *mapping.FieldMapping { + return mapping.NewVectorBase64FieldMapping() +} diff --git a/search_knn.go b/search_knn.go index ba6b005f7..36f749706 100644 --- a/search_knn.go +++ b/search_knn.go @@ -23,6 +23,7 @@ import ( "fmt" "sort" + "github.com/blevesearch/bleve/v2/document" "github.com/blevesearch/bleve/v2/search" "github.com/blevesearch/bleve/v2/search/collector" "github.com/blevesearch/bleve/v2/search/query" @@ -67,11 +68,13 @@ type SearchRequest struct { sortFunc func(sort.Interface) } +// Vector takes precedence over vectorBase64 in case both fields are given type KNNRequest struct { - Field string `json:"field"` - Vector []float32 `json:"vector"` - K int64 `json:"k"` - Boost *query.Boost `json:"boost,omitempty"` + Field string `json:"field"` + Vector []float32 `json:"vector"` + VectorBase64 string `json:"vector_base64"` + K int64 `json:"k"` + Boost *query.Boost `json:"boost,omitempty"` } func (r *SearchRequest) AddKNN(field string, vector []float32, k int64, boost float64) { @@ -231,6 +234,16 @@ func validateKNN(req *SearchRequest) error { if q == nil { return fmt.Errorf("knn query cannot be nil") } + if q.VectorBase64 != "" { + if q.Vector == nil { + vec, err := document.DecodeVector(q.VectorBase64) + if err != nil { + return err + } + + q.Vector = vec + } + } if q.K <= 0 || len(q.Vector) == 0 { return fmt.Errorf("k must be greater than 0 and vector must be non-empty") } diff --git a/search_knn_test.go b/search_knn_test.go index b54ce5a93..c1629f427 100644 --- a/search_knn_test.go +++ b/search_knn_test.go @@ -19,6 +19,7 @@ package bleve import ( "archive/zip" + "encoding/base64" "encoding/json" "fmt" "math" @@ -397,6 +398,168 @@ func min(a, b int) int { return b } +func TestVectorBase64Index(t *testing.T) { + dataset, searchRequests, err := readDatasetAndQueries(testInputCompressedFile) + if err != nil { + t.Fatal(err) + } + documents := makeDatasetIntoDocuments(dataset) + + _, searchRequestsCopy, err := readDatasetAndQueries(testInputCompressedFile) + if err != nil { + t.Fatal(err) + } + + for _, doc := range documents { + vec, err := json.Marshal(doc["vector"]) + if err != nil { + t.Fatal(err) + } + doc["vectorEncoded"] = base64.StdEncoding.EncodeToString(vec) + } + + for _, sr := range searchRequestsCopy { + for _, kr := range sr.KNN { + kr.Field = "vectorEncoded" + } + } + + contentFM := NewTextFieldMapping() + contentFM.Analyzer = en.AnalyzerName + + vecFML2 := mapping.NewVectorFieldMapping() + vecFML2.Dims = testDatasetDims + vecFML2.Similarity = index.EuclideanDistance + + vecBFML2 := mapping.NewVectorBase64FieldMapping() + vecBFML2.Dims = testDatasetDims + vecBFML2.Similarity = index.EuclideanDistance + + vecFMDot := mapping.NewVectorFieldMapping() + vecFMDot.Dims = testDatasetDims + vecFMDot.Similarity = index.CosineSimilarity + + vecBFMDot := mapping.NewVectorBase64FieldMapping() + vecBFMDot.Dims = testDatasetDims + vecBFMDot.Similarity = index.CosineSimilarity + + indexMappingL2 := NewIndexMapping() + indexMappingL2.DefaultMapping.AddFieldMappingsAt("content", contentFM) + indexMappingL2.DefaultMapping.AddFieldMappingsAt("vector", vecFML2) + indexMappingL2.DefaultMapping.AddFieldMappingsAt("vectorEncoded", vecBFML2) + + indexMappingDot := NewIndexMapping() + indexMappingDot.DefaultMapping.AddFieldMappingsAt("content", contentFM) + indexMappingDot.DefaultMapping.AddFieldMappingsAt("vector", vecFMDot) + indexMappingDot.DefaultMapping.AddFieldMappingsAt("vectorEncoded", vecBFMDot) + + tmpIndexPathL2 := createTmpIndexPath(t) + defer cleanupTmpIndexPath(t, tmpIndexPathL2) + + tmpIndexPathDot := createTmpIndexPath(t) + defer cleanupTmpIndexPath(t, tmpIndexPathDot) + + indexL2, err := New(tmpIndexPathL2, indexMappingL2) + if err != nil { + t.Fatal(err) + } + defer func() { + err := indexL2.Close() + if err != nil { + t.Fatal(err) + } + }() + + indexDot, err := New(tmpIndexPathDot, indexMappingDot) + if err != nil { + t.Fatal(err) + } + defer func() { + err := indexDot.Close() + if err != nil { + t.Fatal(err) + } + }() + + batchL2 := indexL2.NewBatch() + batchDot := indexDot.NewBatch() + + for _, doc := range documents { + err = batchL2.Index(doc["id"].(string), doc) + if err != nil { + t.Fatal(err) + } + err = batchDot.Index(doc["id"].(string), doc) + if err != nil { + t.Fatal(err) + } + } + + err = indexL2.Batch(batchL2) + if err != nil { + t.Fatal(err) + } + + err = indexDot.Batch(batchDot) + if err != nil { + t.Fatal(err) + } + + for i, _ := range searchRequests { + for _, operator := range knnOperators { + controlQuery := searchRequests[i] + testQuery := searchRequestsCopy[i] + + controlQuery.AddKNNOperator(operator) + testQuery.AddKNNOperator(operator) + + controlResultL2, err := indexL2.Search(controlQuery) + if err != nil { + t.Fatal(err) + } + testResultL2, err := indexL2.Search(testQuery) + if err != nil { + t.Fatal(err) + } + + if controlResultL2 != nil && testResultL2 != nil { + if len(controlResultL2.Hits) == len(testResultL2.Hits) { + for j, _ := range controlResultL2.Hits { + if controlResultL2.Hits[j].ID != testResultL2.Hits[j].ID { + t.Fatalf("testcase %d failed: expected hit id %s, got hit id %s", i, controlResultL2.Hits[j].ID, testResultL2.Hits[j].ID) + } + } + } + } else if (controlResultL2 == nil && testResultL2 != nil) || + (controlResultL2 != nil && testResultL2 == nil) { + t.Fatalf("testcase %d failed: expected result %s, got result %s", i, controlResultL2, testResultL2) + } + + controlResultDot, err := indexDot.Search(controlQuery) + if err != nil { + t.Fatal(err) + } + testResultDot, err := indexDot.Search(testQuery) + if err != nil { + t.Fatal(err) + } + + if controlResultDot != nil && testResultDot != nil { + if len(controlResultDot.Hits) == len(testResultDot.Hits) { + for j, _ := range controlResultDot.Hits { + if controlResultDot.Hits[j].ID != testResultDot.Hits[j].ID { + t.Fatalf("testcase %d failed: expected hit id %s, got hit id %s", i, controlResultDot.Hits[j].ID, testResultDot.Hits[j].ID) + } + } + } + } else if (controlResultDot == nil && testResultDot != nil) || + (controlResultDot != nil && testResultDot == nil) { + t.Fatalf("testcase %d failed: expected result %s, got result %s", i, controlResultDot, testResultDot) + } + } + } +} + type testDocument struct { ID string `json:"id"` Content string `json:"content"`