Skip to content

Commit

Permalink
GODRIVER-3470 Correct BSON unmarshaling logic for null values (#1924) (
Browse files Browse the repository at this point in the history
…#1955)

Co-authored-by: Preston Vasquez <prestonvasquez@icloud.com>
  • Loading branch information
matthewdale and prestonvasquez authored Feb 21, 2025
1 parent 4aeff60 commit d19dc26
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 0 deletions.
12 changes: 12 additions & 0 deletions bson/bsoncodec/default_value_decoders.go
Original file line number Diff line number Diff line change
Expand Up @@ -1521,6 +1521,18 @@ func (dvd DefaultValueDecoders) ValueUnmarshalerDecodeValue(_ DecodeContext, vr
return ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tValueUnmarshaler}, Received: val}
}

// If BSON value is null and the go value is a pointer, then don't call
// UnmarshalBSONValue. Even if the Go pointer is already initialized (i.e.,
// non-nil), encountering null in BSON will result in the pointer being
// directly set to nil here. Since the pointer is being replaced with nil,
// there is no opportunity (or reason) for the custom UnmarshalBSONValue logic
// to be called.
if vr.Type() == bsontype.Null && val.Kind() == reflect.Ptr {
val.Set(reflect.Zero(val.Type()))

return vr.ReadNull()
}

if val.Kind() == reflect.Ptr && val.IsNil() {
if !val.CanSet() {
return ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tValueUnmarshaler}, Received: val}
Expand Down
3 changes: 3 additions & 0 deletions bson/unmarshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ type ValueUnmarshaler interface {
// Unmarshal parses the BSON-encoded data and stores the result in the value
// pointed to by val. If val is nil or not a pointer, Unmarshal returns
// InvalidUnmarshalError.
//
// When unmarshaling BSON, if the BSON value is null and the Go value is a
// pointer, the pointer is set to nil without calling UnmarshalBSONValue.
func Unmarshal(data []byte, val interface{}) error {
return UnmarshalWithRegistry(DefaultRegistry, data, val)
}
Expand Down
24 changes: 24 additions & 0 deletions bson/unmarshal_value_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"go.mongodb.org/mongo-driver/bson/bsoncodec"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/internal/assert"
"go.mongodb.org/mongo-driver/internal/require"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)

Expand Down Expand Up @@ -93,6 +94,29 @@ func TestUnmarshalValue(t *testing.T) {
})
}

func TestInitializedPointerDataWithBSONNull(t *testing.T) {
// Set up the test case with initialized pointers.
tc := unmarshalBehaviorTestCase{
BSONValuePtrTracker: &unmarshalBSONValueCallTracker{},
BSONPtrTracker: &unmarshalBSONCallTracker{},
}

// Create BSON data where the '*_ptr_tracker' fields are explicitly set to
// null.
bytes := docToBytes(D{
{Key: "bv_ptr_tracker", Value: nil},
{Key: "b_ptr_tracker", Value: nil},
})

// Unmarshal the BSON data into the test case struct. This should set the
// pointer fields to nil due to the BSON null value.
err := Unmarshal(bytes, &tc)
require.NoError(t, err)

assert.Nil(t, tc.BSONValuePtrTracker)
assert.Nil(t, tc.BSONPtrTracker)
}

// tests covering GODRIVER-2779
func BenchmarkSliceCodecUnmarshal(b *testing.B) {
benchmarks := []struct {
Expand Down
101 changes: 101 additions & 0 deletions bson/unmarshaling_cases_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
)

type unmarshalingTestCase struct {
Expand Down Expand Up @@ -114,6 +115,26 @@ func unmarshalingTestCases() []unmarshalingTestCase {
},
data: docToBytes(D{{"fooBar", int32(10)}}),
},
{
name: "nil pointer and non-pointer type with literal null BSON",
sType: reflect.TypeOf(unmarshalBehaviorTestCase{}),
want: &unmarshalBehaviorTestCase{
BSONValueTracker: unmarshalBSONValueCallTracker{
called: true,
},
BSONValuePtrTracker: nil,
BSONTracker: unmarshalBSONCallTracker{
called: true,
},
BSONPtrTracker: nil,
},
data: docToBytes(D{
{Key: "bv_tracker", Value: nil},
{Key: "bv_ptr_tracker", Value: nil},
{Key: "b_tracker", Value: nil},
{Key: "b_ptr_tracker", Value: nil},
}),
},
// GODRIVER-2252
// Test that a struct of pointer types with UnmarshalBSON functions defined marshal and
// unmarshal to the same Go values when the pointer values are "nil".
Expand Down Expand Up @@ -174,6 +195,50 @@ func unmarshalingTestCases() []unmarshalingTestCase {
want: &valNonPtrStruct,
data: docToBytes(valNonPtrStruct),
},
{
name: "nil pointer and non-pointer type with BSON minkey",
sType: reflect.TypeOf(unmarshalBehaviorTestCase{}),
want: &unmarshalBehaviorTestCase{
BSONValueTracker: unmarshalBSONValueCallTracker{
called: true,
},
BSONValuePtrTracker: &unmarshalBSONValueCallTracker{
called: true,
},
BSONTracker: unmarshalBSONCallTracker{
called: true,
},
BSONPtrTracker: nil,
},
data: docToBytes(D{
{Key: "bv_tracker", Value: primitive.MinKey{}},
{Key: "bv_ptr_tracker", Value: primitive.MinKey{}},
{Key: "b_tracker", Value: primitive.MinKey{}},
{Key: "b_ptr_tracker", Value: primitive.MinKey{}},
}),
},
{
name: "nil pointer and non-pointer type with BSON maxkey",
sType: reflect.TypeOf(unmarshalBehaviorTestCase{}),
want: &unmarshalBehaviorTestCase{
BSONValueTracker: unmarshalBSONValueCallTracker{
called: true,
},
BSONValuePtrTracker: &unmarshalBSONValueCallTracker{
called: true,
},
BSONTracker: unmarshalBSONCallTracker{
called: true,
},
BSONPtrTracker: nil,
},
data: docToBytes(D{
{Key: "bv_tracker", Value: primitive.MaxKey{}},
{Key: "bv_ptr_tracker", Value: primitive.MaxKey{}},
{Key: "b_tracker", Value: primitive.MaxKey{}},
{Key: "b_ptr_tracker", Value: primitive.MaxKey{}},
}),
},
}
}

Expand Down Expand Up @@ -250,3 +315,39 @@ func (ms *myString) UnmarshalBSON(bytes []byte) error {
*ms = myString(s)
return nil
}

// unmarshalBSONValueCallTracker is a test struct that tracks whether the
// UnmarshalBSONValue method has been called.
type unmarshalBSONValueCallTracker struct {
called bool // called is set to true when UnmarshalBSONValue is invoked.
}

var _ ValueUnmarshaler = &unmarshalBSONValueCallTracker{}

// unmarshalBSONCallTracker is a test struct that tracks whether the
// UnmarshalBSON method has been called.
type unmarshalBSONCallTracker struct {
called bool // called is set to true when UnmarshalBSON is invoked.
}

// Ensure unmarshalBSONCallTracker implements the Unmarshaler interface.
var _ Unmarshaler = &unmarshalBSONCallTracker{}

// unmarshalBehaviorTestCase holds instances of call trackers for testing BSON
// unmarshaling behavior.
type unmarshalBehaviorTestCase struct {
BSONValueTracker unmarshalBSONValueCallTracker `bson:"bv_tracker"` // BSON value unmarshaling by value.
BSONValuePtrTracker *unmarshalBSONValueCallTracker `bson:"bv_ptr_tracker"` // BSON value unmarshaling by pointer.
BSONTracker unmarshalBSONCallTracker `bson:"b_tracker"` // BSON unmarshaling by value.
BSONPtrTracker *unmarshalBSONCallTracker `bson:"b_ptr_tracker"` // BSON unmarshaling by pointer.
}

func (tracker *unmarshalBSONValueCallTracker) UnmarshalBSONValue(bsontype.Type, []byte) error {
tracker.called = true
return nil
}

func (tracker *unmarshalBSONCallTracker) UnmarshalBSON([]byte) error {
tracker.called = true
return nil
}

0 comments on commit d19dc26

Please sign in to comment.