Skip to content

Commit

Permalink
feat: sup cookie and multip val read (#46)
Browse files Browse the repository at this point in the history
Co-authored-by: xuxin.vinci <xuxin.vinci@bytedance.com>
  • Loading branch information
ShiningRush and xuxin.vinci authored Feb 7, 2025
1 parent 411d6f2 commit 6233b61
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 36 deletions.
94 changes: 58 additions & 36 deletions middleware/http_input.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,52 +161,74 @@ func (mw *HttpInputMiddleware) injectFieldFromUrlAndMap(ptr interface{}) error {
}

src, name := getSourceWayAndName(elType.Field(i))
if src == "" && mw.opt.IsReadFromBody {
if mw.searchMap != nil {
if v, ok := mw.searchMap[name]; ok {
if input.Field(i).Kind() == reflect.String {
input.Field(i).Set(reflect.ValueOf(string(v)))
} else if input.Field(i).Kind() == reflect.Slice {
input.Field(i).Set(reflect.ValueOf(v))
}
}
sources := strings.Split(src, "|")
for _, v := range sources {
findVal, err := mw.searchVal(v, name, input.Field(i))
if err != nil {
return fmt.Errorf("source %s[%s] read failed: %w", name, src, err)
}
if name == "@body" {
if input.Field(i).Type().Implements(reflect.TypeOf((*io.ReadCloser)(nil)).Elem()) {
input.Field(i).Set(reflect.ValueOf(mw.req.Body))
continue
}
if findVal {
break
}
}
}

bs, err := data.CopyBody(mw.req)
if err != nil {
return err
return nil
}

func (mw *HttpInputMiddleware) searchVal(src, name string, field reflect.Value) (findVal bool, err error) {
if src == "" && mw.opt.IsReadFromBody {
if mw.searchMap != nil {
if v, ok := mw.searchMap[name]; ok {
if field.Kind() == reflect.String {
field.Set(reflect.ValueOf(string(v)))
return true, nil
} else if field.Kind() == reflect.Slice {
field.Set(reflect.ValueOf(v))
return true, nil
}
input.Field(i).Set(reflect.ValueOf(bs))
}
continue
}
if name == "@body" {
if field.Type().Implements(reflect.TypeOf((*io.ReadCloser)(nil)).Elem()) {
field.Set(reflect.ValueOf(mw.req.Body))
return true, nil
}

val := ""
switch src {
case "path":
val = mw.opt.PathParamsFunc(name)
case "header":
val = mw.req.Header.Get(name)
default:
val = mw.req.FormValue(name)
bs, err := data.CopyBody(mw.req)
if err != nil {
return false, fmt.Errorf("read body failed: %w", err)
}
field.Set(reflect.ValueOf(bs))
}
return false, nil
}

tarVal, err := changeToFieldKind(val, input.Field(i).Type())
if err != nil {
return err
}
if tarVal == nil {
continue
val := ""
switch src {
case "path":
val = mw.opt.PathParamsFunc(name)
case "header":
val = mw.req.Header.Get(name)
case "cookie":
ck, err := mw.req.Cookie(name)
if err != nil && errors.Is(err, http.ErrNoCookie) {
return false, nil
}
input.Field(i).Set(reflect.ValueOf(tarVal))
val = ck.Value
default:
val = mw.req.FormValue(name)
}

return nil
tarVal, err := changeToFieldKind(val, field.Type())
if err != nil {
return false, fmt.Errorf("field[%s] covert failed: %w", name, err)
}
if tarVal == nil {
return false, nil
}
field.Set(reflect.ValueOf(tarVal))
return true, nil
}

func recoverPager(pInput interface{}) (bool, error) {
Expand Down Expand Up @@ -334,5 +356,5 @@ func changeToFieldKind(str string, t reflect.Type) (interface{}, error) {
return i, nil
}

return nil, fmt.Errorf("unsupport type: %s", kind.String())
return nil, fmt.Errorf("unsupport convert type: %s", kind.String())
}
67 changes: 67 additions & 0 deletions middleware/http_input_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ type TestInput struct {
HeaderInt int `auto_read:"header-int,header"`
DefaultIntPtr *int `auto_read:"query_int"`
PathStrPtr *string `auto_read:"path_str,path"`
CookieStrPtr *string `auto_read:"cookie_str,cookie"`
MixedStrPtr *string `auto_read:"mixed, query|header|cookie"`
Body []byte `auto_read:"@body"`
}

Expand All @@ -110,6 +112,68 @@ func TestInputMiddleWare_injectFieldFromUrlAndMap(t *testing.T) {
}{
{
name: "normal",
giveMw: &HttpInputMiddleware{
opt: HttpInputOption{
PathParamsFunc: FixedPathFunc,
IsReadFromBody: true,
},
req: &http.Request{
URL: &url.URL{RawQuery: "query_str=query_string&test=2&mixed=query"},
Method: http.MethodPost,
Header: map[string][]string{
"Header-Int": {"10"},
"Cookie": {"cookie_str=c_str;mixed=cookie"},
"Mixed": {"header"},
},
Body: io.NopCloser(bytes.NewBufferString("all body")),
},
searchMap: nil,
},
givePtr: &TestInput{},
wantPtr: &TestInput{
QueryString: "query_string",
HeaderInt: 10,
DefaultIntPtr: nil,
PathStrPtr: strPtr("path_str"),
CookieStrPtr: strPtr("c_str"),
Body: []byte("all body"),
MixedStrPtr: strPtr("query"),
},
wantErr: require.NoError,
},
{
name: "mixed-header",
giveMw: &HttpInputMiddleware{
opt: HttpInputOption{
PathParamsFunc: FixedPathFunc,
IsReadFromBody: true,
},
req: &http.Request{
URL: &url.URL{RawQuery: "query_str=query_string&test=2"},
Method: http.MethodPost,
Header: map[string][]string{
"Header-Int": {"10"},
"Cookie": {"cookie_str=c_str;mixed=cookie"},
"Mixed": {"header"},
},
Body: io.NopCloser(bytes.NewBufferString("all body")),
},
searchMap: nil,
},
givePtr: &TestInput{},
wantPtr: &TestInput{
QueryString: "query_string",
HeaderInt: 10,
DefaultIntPtr: nil,
PathStrPtr: strPtr("path_str"),
CookieStrPtr: strPtr("c_str"),
Body: []byte("all body"),
MixedStrPtr: strPtr("header"),
},
wantErr: require.NoError,
},
{
name: "mixed-cookie",
giveMw: &HttpInputMiddleware{
opt: HttpInputOption{
PathParamsFunc: FixedPathFunc,
Expand All @@ -120,6 +184,7 @@ func TestInputMiddleWare_injectFieldFromUrlAndMap(t *testing.T) {
Method: http.MethodPost,
Header: map[string][]string{
"Header-Int": {"10"},
"Cookie": {"cookie_str=c_str;mixed=cookie"},
},
Body: io.NopCloser(bytes.NewBufferString("all body")),
},
Expand All @@ -131,7 +196,9 @@ func TestInputMiddleWare_injectFieldFromUrlAndMap(t *testing.T) {
HeaderInt: 10,
DefaultIntPtr: nil,
PathStrPtr: strPtr("path_str"),
CookieStrPtr: strPtr("c_str"),
Body: []byte("all body"),
MixedStrPtr: strPtr("cookie"),
},
wantErr: require.NoError,
},
Expand Down

0 comments on commit 6233b61

Please sign in to comment.