Skip to content

Commit

Permalink
MB-59616: Adding vector_base64 field (#2012)
Browse files Browse the repository at this point in the history
- Added a new field type called vector_base64.
 - Acts similar to vector in most cases.
- When a new document arrives in the bleve layer, during the parsing of
all its fields in processProperty, if the field mapping type is
vector-base64, then its value is decoded into a vector field and
processed like a vector.
 - The standard golang base64 library is used for the decode operation.

---------

Co-authored-by: Abhinav Dangeti <abhinav@couchbase.com>
  • Loading branch information
Likith101 and abhinavdangeti authored Apr 18, 2024
1 parent 6d02ec6 commit 757705e
Show file tree
Hide file tree
Showing 8 changed files with 475 additions and 5 deletions.
140 changes: 140 additions & 0 deletions document/field_vector_base64.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
// Copyright (c) 2024 Couchbase, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//go:build vectors
// +build vectors

package document

import (
"encoding/base64"
"encoding/binary"
"fmt"
"math"

"github.com/blevesearch/bleve/v2/size"
index "github.com/blevesearch/bleve_index_api"
)

type VectorBase64Field struct {
vectorField *VectorField
base64Encoding string
}

func (n *VectorBase64Field) Size() int {
return n.vectorField.Size()
}

func (n *VectorBase64Field) Name() string {
return n.vectorField.Name()
}

func (n *VectorBase64Field) ArrayPositions() []uint64 {
return n.vectorField.ArrayPositions()
}

func (n *VectorBase64Field) Options() index.FieldIndexingOptions {
return n.vectorField.Options()
}

func (n *VectorBase64Field) NumPlainTextBytes() uint64 {
return n.vectorField.NumPlainTextBytes()
}

func (n *VectorBase64Field) AnalyzedLength() int {
return n.vectorField.AnalyzedLength()
}

func (n *VectorBase64Field) EncodedFieldType() byte {
return 'e'
}

func (n *VectorBase64Field) AnalyzedTokenFrequencies() index.TokenFrequencies {
return n.vectorField.AnalyzedTokenFrequencies()
}

func (n *VectorBase64Field) Analyze() {
}

func (n *VectorBase64Field) Value() []byte {
return n.vectorField.Value()
}

func (n *VectorBase64Field) GoString() string {
return fmt.Sprintf("&document.vectorFieldBase64Field{Name:%s, Options: %s, "+
"Value: %+v}", n.vectorField.Name(), n.vectorField.Options(), n.vectorField.Value())
}

// For the sake of not polluting the API, we are keeping arrayPositions as a
// parameter, but it is not used.
func NewVectorBase64Field(name string, arrayPositions []uint64, vectorBase64 string,
dims int, similarity, vectorIndexOptimizedFor string) (*VectorBase64Field, error) {

vector, err := DecodeVector(vectorBase64)
if err != nil {
return nil, err
}

return &VectorBase64Field{
vectorField: NewVectorFieldWithIndexingOptions(name, arrayPositions,
vector, dims, similarity,
vectorIndexOptimizedFor, DefaultVectorIndexingOptions),

base64Encoding: vectorBase64,
}, nil
}

// This function takes a base64 encoded string and decodes it into
// a vector.
func DecodeVector(encodedValue string) ([]float32, error) {

// We first decode the encoded string into a byte array.
decodedString, err := base64.StdEncoding.DecodeString(encodedValue)
if err != nil {
return nil, err
}

// The array is expected to be divisible by 4 because each float32
// should occupy 4 bytes
if len(decodedString)%size.SizeOfFloat32 != 0 {
return nil, fmt.Errorf("Decoded byte array not divisible by %d", size.SizeOfFloat32)
}
dims := int(len(decodedString) / size.SizeOfFloat32)
decodedVector := make([]float32, dims)

// We iterate through the array 4 bytes at a time and convert each of
// them to a float32 value by reading them in a little endian notation
for i := 0; i < dims; i++ {
bytes := decodedString[i*size.SizeOfFloat32 : (i+1)*size.SizeOfFloat32]
decodedVector[i] = math.Float32frombits(binary.LittleEndian.Uint32(bytes))
}

return decodedVector, nil
}

func (n *VectorBase64Field) Vector() []float32 {
return n.vectorField.Vector()
}

func (n *VectorBase64Field) Dims() int {
return n.vectorField.Dims()
}

func (n *VectorBase64Field) Similarity() string {
return n.vectorField.Similarity()
}

func (n *VectorBase64Field) IndexOptimizedFor() string {
return n.vectorField.IndexOptimizedFor()
}
113 changes: 113 additions & 0 deletions document/field_vector_base64_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// Copyright (c) 2024 Couchbase, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//go:build vectors
// +build vectors

package document

import (
"bytes"
"encoding/base64"
"encoding/binary"
"fmt"
"math/rand"
"testing"
)

func TestDecodeVector(t *testing.T) {
vec := make([]float32, 2048)
for i := range vec {
vec[i] = rand.Float32()
}

vecBytes := bytifyVec(vec)
encodedVec := base64.StdEncoding.EncodeToString(vecBytes)

decodedVec, err := DecodeVector(encodedVec)
if err != nil {
t.Error(err)
}
if len(decodedVec) != len(vec) {
t.Errorf("Decoded vector dimensions not same as original vector dimensions")
}

for i := range vec {
if vec[i] != decodedVec[i] {
t.Errorf("Decoded vector not the same as original vector")
}
}
}

func BenchmarkDecodeVector128(b *testing.B) {
vec := make([]float32, 128)
for i := range vec {
vec[i] = rand.Float32()
}

vecBytes := bytifyVec(vec)
encodedVec := base64.StdEncoding.EncodeToString(vecBytes)

b.ResetTimer()

for i := 0; i < b.N; i++ {
_, _ = DecodeVector(encodedVec)
}
}

func BenchmarkDecodeVector784(b *testing.B) {
vec := make([]float32, 784)
for i := range vec {
vec[i] = rand.Float32()
}

vecBytes := bytifyVec(vec)
encodedVec := base64.StdEncoding.EncodeToString(vecBytes)

b.ResetTimer()

for i := 0; i < b.N; i++ {
_, _ = DecodeVector(encodedVec)
}
}

func BenchmarkDecodeVector1536(b *testing.B) {
vec := make([]float32, 1536)
for i := range vec {
vec[i] = rand.Float32()
}

vecBytes := bytifyVec(vec)
encodedVec := base64.StdEncoding.EncodeToString(vecBytes)

b.ResetTimer()

for i := 0; i < b.N; i++ {
_, _ = DecodeVector(encodedVec)
}
}

func bytifyVec(vec []float32) []byte {

buf := new(bytes.Buffer)

for _, v := range vec {
err := binary.Write(buf, binary.LittleEndian, v)
if err != nil {
fmt.Println(err)
}
}

return buf.Bytes()
}
2 changes: 2 additions & 0 deletions mapping/document.go
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,8 @@ func (dm *DocumentMapping) processProperty(property interface{}, path []string,
fieldMapping.processGeoShape(property, pathString, path, indexes, context)
} else if fieldMapping.Type == "geopoint" {
fieldMapping.processGeoPoint(property, pathString, path, indexes, context)
} else if fieldMapping.Type == "vector_base64" {
fieldMapping.processVectorBase64(property, pathString, path, indexes, context)
} else {
fieldMapping.processString(propertyValueString, pathString, path, indexes, context)
}
Expand Down
9 changes: 9 additions & 0 deletions mapping/mapping_no_vectors.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,20 @@ func NewVectorFieldMapping() *FieldMapping {
return nil
}

func NewVectorBase64FieldMapping() *FieldMapping {
return nil
}

func (fm *FieldMapping) processVector(propertyMightBeVector interface{},
pathString string, path []string, indexes []uint64, context *walkContext) bool {
return false
}

func (fm *FieldMapping) processVectorBase64(propertyMightBeVector interface{},
pathString string, path []string, indexes []uint64, context *walkContext) {

}

// -----------------------------------------------------------------------------
// document validation functions

Expand Down
28 changes: 27 additions & 1 deletion mapping/mapping_vectors.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,17 @@ func NewVectorFieldMapping() *FieldMapping {
}
}

func NewVectorBase64FieldMapping() *FieldMapping {
return &FieldMapping{
Type: "vector_base64",
Store: false,
Index: true,
IncludeInAll: false,
DocValues: false,
SkipFreqNorm: true,
}
}

// validate and process a flat vector
func processFlatVector(vecV reflect.Value, dims int) ([]float32, bool) {
if vecV.Len() != dims {
Expand Down Expand Up @@ -140,13 +151,28 @@ func (fm *FieldMapping) processVector(propertyMightBeVector interface{},
return true
}

func (fm *FieldMapping) processVectorBase64(propertyMightBeVectorBase64 interface{},
pathString string, path []string, indexes []uint64, context *walkContext) {
encodedString, ok := propertyMightBeVectorBase64.(string)
if !ok {
return
}

propertyMightBeVector, err := document.DecodeVector(encodedString)
if err != nil {
return
}

fm.processVector(propertyMightBeVector, pathString, path, indexes, context)
}

// -----------------------------------------------------------------------------
// document validation functions

func validateFieldMapping(field *FieldMapping, parentName string,
fieldAliasCtx map[string]*FieldMapping) error {
switch field.Type {
case "vector":
case "vector", "vector_base64":
return validateVectorFieldAlias(field, parentName, fieldAliasCtx)
default: // non-vector field
return validateFieldType(field)
Expand Down
4 changes: 4 additions & 0 deletions mapping_vector.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,7 @@ import "github.com/blevesearch/bleve/v2/mapping"
func NewVectorFieldMapping() *mapping.FieldMapping {
return mapping.NewVectorFieldMapping()
}

func NewVectorBase64FieldMapping() *mapping.FieldMapping {
return mapping.NewVectorBase64FieldMapping()
}
21 changes: 17 additions & 4 deletions search_knn.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"fmt"
"sort"

"github.com/blevesearch/bleve/v2/document"
"github.com/blevesearch/bleve/v2/search"
"github.com/blevesearch/bleve/v2/search/collector"
"github.com/blevesearch/bleve/v2/search/query"
Expand Down Expand Up @@ -67,11 +68,13 @@ type SearchRequest struct {
sortFunc func(sort.Interface)
}

// Vector takes precedence over vectorBase64 in case both fields are given
type KNNRequest struct {
Field string `json:"field"`
Vector []float32 `json:"vector"`
K int64 `json:"k"`
Boost *query.Boost `json:"boost,omitempty"`
Field string `json:"field"`
Vector []float32 `json:"vector"`
VectorBase64 string `json:"vector_base64"`
K int64 `json:"k"`
Boost *query.Boost `json:"boost,omitempty"`
}

func (r *SearchRequest) AddKNN(field string, vector []float32, k int64, boost float64) {
Expand Down Expand Up @@ -231,6 +234,16 @@ func validateKNN(req *SearchRequest) error {
if q == nil {
return fmt.Errorf("knn query cannot be nil")
}
if q.VectorBase64 != "" {
if q.Vector == nil {
vec, err := document.DecodeVector(q.VectorBase64)
if err != nil {
return err
}

q.Vector = vec
}
}
if q.K <= 0 || len(q.Vector) == 0 {
return fmt.Errorf("k must be greater than 0 and vector must be non-empty")
}
Expand Down
Loading

0 comments on commit 757705e

Please sign in to comment.