diff --git a/.helper_test.go.swp b/.helper_test.go.swp new file mode 100644 index 0000000..5048f30 Binary files /dev/null and b/.helper_test.go.swp differ diff --git a/.models_test.go.swp b/.models_test.go.swp new file mode 100644 index 0000000..de5f4e7 Binary files /dev/null and b/.models_test.go.swp differ diff --git a/.request_test.go.swp b/.request_test.go.swp new file mode 100644 index 0000000..d43f40f Binary files /dev/null and b/.request_test.go.swp differ diff --git a/constants.go b/constants.go index 23288d3..9940c33 100644 --- a/constants.go +++ b/constants.go @@ -10,6 +10,7 @@ const ( annotationOmitEmpty = "omitempty" annotationISO8601 = "iso8601" annotationSeperator = "," + annotationExtend = "extend" iso8601TimeFormat = "2006-01-02T15:04:05Z" diff --git a/helper.go b/helper.go new file mode 100644 index 0000000..5d92f2b --- /dev/null +++ b/helper.go @@ -0,0 +1,65 @@ +package jsonapi + +import ( + "reflect" + "strings" +) + +type structExtractedField struct { + reflect.Value + reflect.Kind + Annotation string + Args []string + IsPtr bool +} + +func extractFields(model reflect.Value) ([]structExtractedField, error) { + modelValue := model.Elem() + modelType := model.Type().Elem() + + var fields [](structExtractedField) + + for i := 0; i < modelValue.NumField(); i++ { + structField := modelValue.Type().Field(i) + fieldValue := modelValue.Field(i) + fieldType := modelType.Field(i) + + tag := structField.Tag.Get(annotationJSONAPI) + if tag == "" { + continue + } + if tag == annotationExtend && fieldType.Anonymous { + extendedFields, er := extractFields(modelValue.Field(i).Addr()) + if er != nil { + return nil, er + } + fields = append(fields, extendedFields...) + continue + } + + args := strings.Split(tag, annotationSeperator) + + if len(args) < 1 { + return nil, ErrBadJSONAPIStructTag + } + + annotation, args := args[0], args[1:] + + if (annotation == annotationClientID && len(args) != 0) || + (annotation != annotationClientID && len(args) < 1) { + return nil, ErrBadJSONAPIStructTag + } + + // Deal with PTRS + kind := fieldValue.Kind() + isPtr := fieldValue.Kind() == reflect.Ptr + if isPtr { + kind = fieldType.Type.Elem().Kind() + } + + field := structExtractedField{Value: fieldValue, Kind: kind, Annotation: annotation, Args: args, IsPtr: isPtr} + fields = append(fields, field) + } + + return fields, nil +} diff --git a/helper_test.go b/helper_test.go new file mode 100644 index 0000000..1095223 --- /dev/null +++ b/helper_test.go @@ -0,0 +1,187 @@ +package jsonapi + +import ( + "reflect" + "testing" + "time" + "unsafe" +) + +func TestHelper_BadPrimaryAnnotation(t *testing.T) { + fields, err := extractFields(reflect.ValueOf(new(BadModel))) + + if fields != nil { + t.Fatalf("Was expecting results to be nil") + } + + if expected, actual := ErrBadJSONAPIStructTag, err; expected != actual { + t.Fatalf("Was expecting error to be `%s`, got `%s`", expected, actual) + } +} + +func TestHelper_BadExtendedAnonymousField(t *testing.T) { + fields, err := extractFields(reflect.ValueOf(new(WithBadExtendedAnonymousField))) + + if fields != nil { + t.Fatalf("Was expecting results to be nil") + } + + if expected, actual := ErrBadJSONAPIStructTag, err; expected != actual { + t.Fatalf("Was expecting error to be `%s`, got `%s`", expected, actual) + } +} + +func TestHelper_returnsProperValue(t *testing.T) { + comment := &Comment{} + fields, err := extractFields(reflect.ValueOf(comment)) + + if err != nil { + t.Fatalf("Was expecting error to be nil, got `%s`", err) + } + + if expected, actual := 4, len(fields); expected != actual { + t.Fatalf("Was expecting fields to have `%d` items, got `%d`", expected, actual) + } + + // Check Annotation value + if expected, actual := "primary", fields[0].Annotation; expected != actual { + t.Fatalf("Was expecting fields[0].Annotation to be `%s`, got `%s`", expected, actual) + } + + if expected, actual := "client-id", fields[1].Annotation; expected != actual { + t.Fatalf("Was expecting fields[1].Annotation to be `%s`, got `%s`", expected, actual) + } + + if expected, actual := "attr", fields[2].Annotation; expected != actual { + t.Fatalf("Was expecting fields[2].Annotation to be `%s`, got `%s`", expected, actual) + } + + if expected, actual := "attr", fields[3].Annotation; expected != actual { + t.Fatalf("Was expecting fields[3].Annotation to be `%s`, got `%s`", expected, actual) + } + + // Check Args value + if expected, actual := []string{"comments"}, fields[0].Args; !reflect.DeepEqual(expected, actual) { + t.Fatalf("Was expecting fields[0].Args to be `%s`, got `%s`", expected, actual) + } + + if expected, actual := []string{}, fields[1].Args; !reflect.DeepEqual(expected, actual) { + t.Fatalf("Was expecting fields[1].Args to be `%s`, got `%s`", expected, actual) + } + + if expected, actual := []string{"post_id"}, fields[2].Args; !reflect.DeepEqual(expected, actual) { + t.Fatalf("Was expecting fields[2].Args to be `%s`, got `%s`", expected, actual) + } + + if expected, actual := []string{"body"}, fields[3].Args; !reflect.DeepEqual(expected, actual) { + t.Fatalf("Was expecting fields[3].Args to be `%s`, got `%s`", expected, actual) + } + + // Check IsPtr + if expected, actual := false, fields[0].IsPtr; !reflect.DeepEqual(expected, actual) { + t.Fatalf("Was expecting fields[0].IsPtr to be `%t`, got `%t`", expected, actual) + } + + if expected, actual := false, fields[1].IsPtr; !reflect.DeepEqual(expected, actual) { + t.Fatalf("Was expecting fields[1].IsPtr to be `%t`, got `%t`", expected, actual) + } + + if expected, actual := false, fields[2].IsPtr; !reflect.DeepEqual(expected, actual) { + t.Fatalf("Was expecting fields[2].IsPtr to be `%t`, got `%t`", expected, actual) + } + + if expected, actual := false, fields[3].IsPtr; !reflect.DeepEqual(expected, actual) { + t.Fatalf("Was expecting fields[3].IsPtr to be `%t`, got `%t`", expected, actual) + } + + // Check Value value + if uintptr(unsafe.Pointer(&comment.ID)) != fields[0].Value.UnsafeAddr() { + t.Fatalf("Was expecting fields[0].Value to point to comment.ID") + } + + if uintptr(unsafe.Pointer(&comment.ClientID)) != fields[1].Value.UnsafeAddr() { + t.Fatalf("Was expecting fields[1].Value to point to comment.ClientID") + } + + if uintptr(unsafe.Pointer(&comment.PostID)) != fields[2].Value.UnsafeAddr() { + t.Fatalf("Was expecting fields[2].Value to point to comment.PostID") + } + + if uintptr(unsafe.Pointer(&comment.Body)) != fields[3].Value.UnsafeAddr() { + t.Fatalf("Was expecting fields[3].Value to point to comment.Body") + } + + // Check Kind value + if expected, actual := reflect.Int, fields[0].Kind; expected != actual { + t.Fatalf("Was expecting fields[0].Kind to be `%s`, got `%s`", expected, actual) + } + + if expected, actual := reflect.String, fields[1].Kind; expected != actual { + t.Fatalf("Was expecting fields[1].Kind to be `%s`, got `%s`", expected, actual) + } + + if expected, actual := reflect.Int, fields[2].Kind; expected != actual { + t.Fatalf("Was expecting fields[2].Kind to be `%s`, got `%s`", expected, actual) + } + + if expected, actual := reflect.String, fields[3].Kind; expected != actual { + t.Fatalf("Was expecting fields[3].Kind to be `%s`, got `%s`", expected, actual) + } + +} + +func TestHelper_ignoreFieldWithoutAnnotation(t *testing.T) { + book := &Book{ + ID: 0, + Author: "aren55555", + PublishedAt: time.Now().AddDate(0, -1, 0), + } + fields, err := extractFields(reflect.ValueOf(book)) + if err != nil { + t.Fatalf("Was expecting error to be nil, got `%s`", err) + } + + if expected, actual := 7, len(fields); expected != actual { + t.Fatalf("Was expecting fields to have `%d` items, got `%d`", expected, actual) + } +} + +func TestHelper_WithExtendedAnonymousField(t *testing.T) { + model := &WithExtendedAnonymousField{} + fields, err := extractFields(reflect.ValueOf(model)) + if err != nil { + t.Fatalf("Was expecting error to be nil, got `%s`", err) + } + + if expected, actual := 2, len(fields); expected != actual { + t.Fatalf("Was expecting fields to have `%d` items, got `%d`", expected, actual) + } + + if uintptr(unsafe.Pointer(&model.CommonField)) != fields[0].Value.UnsafeAddr() { + t.Fatalf("Was expecting fields[0].Value to point to comment.CommonField") + } + + if uintptr(unsafe.Pointer(&model.ID)) != fields[1].Value.UnsafeAddr() { + t.Fatalf("Was expecting fields[1].Value to point to comment.ID") + } +} + +func TestHelper_WithPointer(t *testing.T) { + model := &WithPointer{} + fields, err := extractFields(reflect.ValueOf(model)) + if err != nil { + t.Fatalf("Was expecting error to be nil, got `%s`", err) + } + + if expected, actual := 5, len(fields); expected != actual { + t.Fatalf("Was expecting fields to have `%d` items, got `%d`", expected, actual) + } + + if expected, actual := true, fields[0].IsPtr; !reflect.DeepEqual(expected, actual) { + t.Fatalf("Was expecting fields[0].IsPtr to be `%t`, got `%t`", expected, actual) + } + + if uintptr(unsafe.Pointer(&model.ID)) != fields[0].Value.UnsafeAddr() { + t.Fatalf("Was expecting fields[0].Value to point to comment.ID") + } +} diff --git a/models_test.go b/models_test.go index a53dd61..7a728e0 100644 --- a/models_test.go +++ b/models_test.go @@ -25,6 +25,24 @@ type WithPointer struct { FloatVal *float32 `jsonapi:"attr,float-val"` } +type base struct { + CommonField string `jsonapi:"attr,common_field"` +} + +type WithExtendedAnonymousField struct { + base `jsonapi:"extend"` + ID int `jsonapi:"primary,with-extended-anonymous-fields"` +} + +type badBase struct { + CommonField string `jsonapi:"attr"` +} + +type WithBadExtendedAnonymousField struct { + badBase `jsonapi:"extend"` + ID int `jsonapi:"primary,with-bad-extended-anonymous-fields"` +} + type Timestamp struct { ID int `jsonapi:"primary,timestamps"` Time time.Time `jsonapi:"attr,timestamp,iso8601"` diff --git a/request.go b/request.go index ea6ae45..657abf8 100644 --- a/request.go +++ b/request.go @@ -8,7 +8,6 @@ import ( "io" "reflect" "strconv" - "strings" "time" ) @@ -124,34 +123,16 @@ func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node) } }() - modelValue := model.Elem() - modelType := model.Type().Elem() - var er error - for i := 0; i < modelValue.NumField(); i++ { - fieldType := modelType.Field(i) - tag := fieldType.Tag.Get("jsonapi") - if tag == "" { - continue - } - - fieldValue := modelValue.Field(i) - - args := strings.Split(tag, ",") + fields, er := extractFields(model) - if len(args) < 1 { - er = ErrBadJSONAPIStructTag - break - } - - annotation := args[0] + if er != nil { + return er + } - if (annotation == annotationClientID && len(args) != 1) || - (annotation != annotationClientID && len(args) < 2) { - er = ErrBadJSONAPIStructTag - break - } + for _, field := range fields { + fieldValue, annotation, kind, args := field.Value, field.Annotation, field.Kind, field.Args if annotation == annotationPrimary { if data.ID == "" { @@ -159,11 +140,11 @@ func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node) } // Check the JSON API Type - if data.Type != args[1] { + if data.Type != args[0] { er = fmt.Errorf( "Trying to Unmarshal an object of type %#v, but %#v does not match", data.Type, - args[1], + args[0], ) break } @@ -171,14 +152,6 @@ func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node) // ID will have to be transmitted as astring per the JSON API spec v := reflect.ValueOf(data.ID) - // Deal with PTRS - var kind reflect.Kind - if fieldValue.Kind() == reflect.Ptr { - kind = fieldType.Type.Elem().Kind() - } else { - kind = fieldType.Type.Kind() - } - // Handle String case if kind == reflect.String { assign(fieldValue, v) @@ -250,15 +223,15 @@ func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node) var iso8601 bool - if len(args) > 2 { - for _, arg := range args[2:] { + if len(args) > 1 { + for _, arg := range args[1:] { if arg == annotationISO8601 { iso8601 = true } } } - val := attributes[args[1]] + val := attributes[args[0]] // continue if the attribute was not included in the request if val == nil { @@ -362,15 +335,6 @@ func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node) if v.Kind() == reflect.Float64 { floatValue := v.Interface().(float64) - // The field may or may not be a pointer to a numeric; the kind var - // will not contain a pointer type - var kind reflect.Kind - if fieldValue.Kind() == reflect.Ptr { - kind = fieldType.Type.Elem().Kind() - } else { - kind = fieldType.Type.Kind() - } - var numericValue reflect.Value switch kind { @@ -454,7 +418,7 @@ func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node) } else if annotation == annotationRelation { isSlice := fieldValue.Type().Kind() == reflect.Slice - if data.Relationships == nil || data.Relationships[args[1]] == nil { + if data.Relationships == nil || data.Relationships[args[0]] == nil { continue } @@ -464,7 +428,7 @@ func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node) buf := bytes.NewBuffer(nil) - json.NewEncoder(buf).Encode(data.Relationships[args[1]]) + json.NewEncoder(buf).Encode(data.Relationships[args[0]]) json.NewDecoder(buf).Decode(relationship) data := relationship.Data @@ -493,7 +457,7 @@ func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node) buf := bytes.NewBuffer(nil) json.NewEncoder(buf).Encode( - data.Relationships[args[1]], + data.Relationships[args[0]], ) json.NewDecoder(buf).Decode(relationship) diff --git a/request_test.go b/request_test.go index 1066733..bb31eca 100644 --- a/request_test.go +++ b/request_test.go @@ -70,6 +70,32 @@ func TestUnmarshalToStructWithPointerAttr(t *testing.T) { } } +func TestUnmarshall_attrFromExtendedAnonymousField(t *testing.T) { + out := new(WithExtendedAnonymousField) + commonField := "Common value" + data := map[string]interface{}{ + "data": map[string]interface{}{ + "type": "with-extended-anonymous-fields", + "id": "1", + "attributes": map[string]interface{}{ + "common_field": commonField, + }, + }, + } + b, err := json.Marshal(data) + if err != nil { + t.Fatal(err) + } + + if err := UnmarshalPayload(bytes.NewReader(b), out); err != nil { + t.Fatal(err) + } + + if expected, actual := commonField, out.CommonField; expected != actual { + t.Fatalf("Was expecting CommonField to be `%s`, got `%s`", expected, actual) + } +} + func TestUnmarshalPayload_ptrsAllNil(t *testing.T) { out := new(WithPointer) if err := UnmarshalPayload( diff --git a/response.go b/response.go index a8ac71f..9f2bad3 100644 --- a/response.go +++ b/response.go @@ -7,7 +7,6 @@ import ( "io" "reflect" "strconv" - "strings" "time" ) @@ -206,44 +205,15 @@ func visitModelNode(model interface{}, included *map[string]*Node, var er error - modelValue := reflect.ValueOf(model).Elem() - modelType := reflect.ValueOf(model).Type().Elem() + fields, er := extractFields(reflect.ValueOf(model)) - for i := 0; i < modelValue.NumField(); i++ { - structField := modelValue.Type().Field(i) - tag := structField.Tag.Get(annotationJSONAPI) - if tag == "" { - continue - } - - fieldValue := modelValue.Field(i) - fieldType := modelType.Field(i) - - args := strings.Split(tag, annotationSeperator) - - if len(args) < 1 { - er = ErrBadJSONAPIStructTag - break - } - - annotation := args[0] - - if (annotation == annotationClientID && len(args) != 1) || - (annotation != annotationClientID && len(args) < 2) { - er = ErrBadJSONAPIStructTag - break - } + for _, field := range fields { + fieldValue, annotation, kind, args := field.Value, field.Annotation, field.Kind, field.Args if annotation == annotationPrimary { v := fieldValue - - // Deal with PTRS - var kind reflect.Kind - if fieldValue.Kind() == reflect.Ptr { - kind = fieldType.Type.Elem().Kind() - v = reflect.Indirect(fieldValue) - } else { - kind = fieldType.Type.Kind() + if field.IsPtr { + v = reflect.Indirect(v) } // Handle allowed types @@ -277,7 +247,7 @@ func visitModelNode(model interface{}, included *map[string]*Node, break } - node.Type = args[1] + node.Type = args[0] } else if annotation == annotationClientID { clientID := fieldValue.String() if clientID != "" { @@ -286,8 +256,8 @@ func visitModelNode(model interface{}, included *map[string]*Node, } else if annotation == annotationAttribute { var omitEmpty, iso8601 bool - if len(args) > 2 { - for _, arg := range args[2:] { + if len(args) > 1 { + for _, arg := range args[1:] { switch arg { case annotationOmitEmpty: omitEmpty = true @@ -309,9 +279,9 @@ func visitModelNode(model interface{}, included *map[string]*Node, } if iso8601 { - node.Attributes[args[1]] = t.UTC().Format(iso8601TimeFormat) + node.Attributes[args[0]] = t.UTC().Format(iso8601TimeFormat) } else { - node.Attributes[args[1]] = t.Unix() + node.Attributes[args[0]] = t.Unix() } } else if fieldValue.Type() == reflect.TypeOf(new(time.Time)) { // A time pointer may be nil @@ -320,7 +290,7 @@ func visitModelNode(model interface{}, included *map[string]*Node, continue } - node.Attributes[args[1]] = nil + node.Attributes[args[0]] = nil } else { tm := fieldValue.Interface().(*time.Time) @@ -329,9 +299,9 @@ func visitModelNode(model interface{}, included *map[string]*Node, } if iso8601 { - node.Attributes[args[1]] = tm.UTC().Format(iso8601TimeFormat) + node.Attributes[args[0]] = tm.UTC().Format(iso8601TimeFormat) } else { - node.Attributes[args[1]] = tm.Unix() + node.Attributes[args[0]] = tm.Unix() } } } else { @@ -345,17 +315,17 @@ func visitModelNode(model interface{}, included *map[string]*Node, strAttr, ok := fieldValue.Interface().(string) if ok { - node.Attributes[args[1]] = strAttr + node.Attributes[args[0]] = strAttr } else { - node.Attributes[args[1]] = fieldValue.Interface() + node.Attributes[args[0]] = fieldValue.Interface() } } } else if annotation == annotationRelation { var omitEmpty bool //add support for 'omitempty' struct tag for marshaling as absent - if len(args) > 2 { - omitEmpty = args[2] == annotationOmitEmpty + if len(args) > 1 { + omitEmpty = args[1] == annotationOmitEmpty } isSlice := fieldValue.Type().Kind() == reflect.Slice @@ -371,12 +341,12 @@ func visitModelNode(model interface{}, included *map[string]*Node, var relLinks *Links if linkableModel, ok := model.(RelationshipLinkable); ok { - relLinks = linkableModel.JSONAPIRelationshipLinks(args[1]) + relLinks = linkableModel.JSONAPIRelationshipLinks(args[0]) } var relMeta *Meta if metableModel, ok := model.(RelationshipMetable); ok { - relMeta = metableModel.JSONAPIRelationshipMeta(args[1]) + relMeta = metableModel.JSONAPIRelationshipMeta(args[0]) } if isSlice { @@ -400,20 +370,20 @@ func visitModelNode(model interface{}, included *map[string]*Node, shallowNodes = append(shallowNodes, toShallowNode(n)) } - node.Relationships[args[1]] = &RelationshipManyNode{ + node.Relationships[args[0]] = &RelationshipManyNode{ Data: shallowNodes, Links: relationship.Links, Meta: relationship.Meta, } } else { - node.Relationships[args[1]] = relationship + node.Relationships[args[0]] = relationship } } else { // to-one relationships // Handle null relationship case if fieldValue.IsNil() { - node.Relationships[args[1]] = &RelationshipOneNode{Data: nil} + node.Relationships[args[0]] = &RelationshipOneNode{Data: nil} continue } @@ -429,13 +399,13 @@ func visitModelNode(model interface{}, included *map[string]*Node, if sideload { appendIncluded(included, relationship) - node.Relationships[args[1]] = &RelationshipOneNode{ + node.Relationships[args[0]] = &RelationshipOneNode{ Data: toShallowNode(relationship), Links: relLinks, Meta: relMeta, } } else { - node.Relationships[args[1]] = &RelationshipOneNode{ + node.Relationships[args[0]] = &RelationshipOneNode{ Data: relationship, Links: relLinks, Meta: relMeta, diff --git a/response_test.go b/response_test.go index 331023e..99a135f 100644 --- a/response_test.go +++ b/response_test.go @@ -211,6 +211,33 @@ func TestMarshall_invalidIDType(t *testing.T) { } } +func TestMarshall_attrFromExtendedAnonymousField(t *testing.T) { + id, commonField := 1, "Common value" + model := &WithExtendedAnonymousField{} + model.ID = id + model.CommonField = commonField + + out := bytes.NewBuffer(nil) + + if err := MarshalOnePayload(out, model); err != nil { + t.Fatal(err) + } + + var jsonData map[string]interface{} + if err := json.Unmarshal(out.Bytes(), &jsonData); err != nil { + t.Fatal(err) + } + + attributes := jsonData["data"].(map[string]interface{})["attributes"].(map[string]interface{}) + val, exists := attributes["common_field"] + if !exists { + t.Fatal("Was expecting the data.attributes.common_field member to exist") + } + if val != commonField { + t.Fatalf("Was expecting the data.attributes.common_field member to be `%s`, got `%s`", commonField, val) + } +} + func TestOmitsEmptyAnnotation(t *testing.T) { book := &Book{ Author: "aren55555",