Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rpc: test cors, and respect option on rest endpoints #1054

Merged
merged 2 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions internal/services/jsonrpc/methods.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,13 @@ func (s *Server) handleMethod(ctx context.Context, method jsonrpc.Method, params

argsPtr, handler := maker(ctx, s)

// Treat omitted params as null, which may or may not be acceptable
// depending on the handler's parameters type. Otherwise json.Unmarshal
// always errors with "unexpected end of JSON input".
if params == nil {
params = []byte(`null`)
}

err := json.Unmarshal(params, argsPtr)
if err != nil {
return nil, jsonrpc.NewError(jsonrpc.ErrorInvalidParams, err.Error(), nil)
Expand Down
12 changes: 9 additions & 3 deletions internal/services/jsonrpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,21 +305,27 @@ func NewServer(addr string, log log.Logger, opts ...Opt) (*Server, error) {
w.Header().Set("content-type", "application/json; charset=utf-8")
http.ServeContent(w, r, "openrpc.json", time.Time{}, bytes.NewReader(s.spec))
})
specHandler = corsHandler(specHandler)
if cfg.enableCORS {
specHandler = corsHandler(specHandler)
}
specHandler = recoverer(specHandler, log)
mux.Handle(pathSpecV1, specHandler)

// aggregate health endpoint handler
var healthHandler http.Handler
healthHandler = http.HandlerFunc(s.healthMethodHandler)
healthHandler = corsHandler(healthHandler)
if cfg.enableCORS {
healthHandler = corsHandler(healthHandler)
}
healthHandler = recoverer(healthHandler, log)
mux.Handle(pathHealthV1, healthHandler)

// service specific health endpoint handler with wild card for service
var userHealthHandler http.Handler
userHealthHandler = http.HandlerFunc(s.handleSvcHealth)
userHealthHandler = corsHandler(userHealthHandler)
if cfg.enableCORS {
userHealthHandler = corsHandler(userHealthHandler)
}
userHealthHandler = recoverer(userHealthHandler, log)
mux.Handle(pathSvcHealthV1, userHealthHandler)

Expand Down
149 changes: 149 additions & 0 deletions internal/services/jsonrpc/server_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
package rpcserver

import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"slices"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -69,3 +73,148 @@ func Test_timeout(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, resp.Error.Code, jsonrpc.ErrorTimeout)
}

func Test_options(t *testing.T) {
logger := log.NewStdOut(log.WarnLevel)

const testOrigin = "whoever"

wantCorsHeaders := http.Header{
"Access-Control-Allow-Credentials": {"true"},
"Access-Control-Allow-Headers": {strings.Join([]string{"Accept", "Content-Type", "Content-Length", "Accept-Encoding", "Authorization", "ResponseType", "Range"}, ", ")},
"Access-Control-Allow-Methods": {strings.Join([]string{http.MethodGet, http.MethodPost, http.MethodOptions}, ", ")},
"Access-Control-Allow-Origin": {testOrigin},
}

for _, tt := range []struct {
name string
path string
withcors bool
reqMeth string
expectStatus int
reqBody io.Reader
}{
// JSON-RPC endpoint
{
name: "no cors, options req",
path: pathRPCV1,
withcors: false,
reqMeth: http.MethodOptions,
expectStatus: http.StatusMethodNotAllowed,
},
{
name: "with cors, options req",
path: pathRPCV1,
withcors: true,
reqMeth: http.MethodOptions,
expectStatus: http.StatusOK,
},
{
name: "no cors, get req",
path: pathRPCV1,
withcors: false,
reqMeth: http.MethodGet,
expectStatus: http.StatusMethodNotAllowed,
},
{
name: "with cors, post empty req",
path: pathRPCV1,
withcors: true,
reqMeth: http.MethodPost,
expectStatus: http.StatusBadRequest, // not a jsonrpc req => 400 status code
reqBody: nil,
},
{
name: "with cors, post json req no method",
path: pathRPCV1,
withcors: true,
reqMeth: http.MethodPost,
expectStatus: http.StatusNotFound, // method not found => 404 status code
reqBody: strings.NewReader(`{"jsonrpc":"2.0","id":2,"method":"rpc.nope"}`),
},
{
name: "with cors, post json req valid method",
path: pathRPCV1,
withcors: true,
reqMeth: http.MethodPost,
expectStatus: http.StatusOK, // method not found => 404 status code
reqBody: strings.NewReader(`{"jsonrpc":"2.0","id":2,"method":"rpc.dummy","params":null}`),
},
{
name: "with cors, post json req valid method (no params)",
path: pathRPCV1,
withcors: true,
reqMeth: http.MethodPost,
expectStatus: http.StatusOK, // method not found => 404 status code
reqBody: strings.NewReader(`{"jsonrpc":"2.0","id":2,"method":"rpc.dummy"}`),
},
// REST endpoints
{
name: "no cors, rest options req",
path: pathSpecV1,
withcors: false,
reqMeth: http.MethodOptions,
expectStatus: http.StatusMethodNotAllowed,
},
{
name: "with cors, rest options req",
path: pathSpecV1,
withcors: true,
reqMeth: http.MethodOptions,
expectStatus: http.StatusOK,
},
{
name: "with cors, rest get req",
path: pathSpecV1,
withcors: true,
reqMeth: http.MethodGet,
expectStatus: http.StatusOK,
},
{
name: "with cors, rest health options req",
path: pathHealthV1,
withcors: true,
reqMeth: http.MethodOptions,
expectStatus: http.StatusOK,
},
} {
t.Run(tt.name, func(t *testing.T) {
opts := []Opt{}
if tt.withcors {
opts = append(opts, WithCORS())
}
srv, err := NewServer("127.0.0.1:", logger, opts...)
require.NoError(t, err)

srv.RegisterMethodHandler(
"rpc.dummy",
MakeMethodHandler(func(context.Context, *any) (*json.RawMessage, *jsonrpc.Error) {
respjson := []byte(`"hi"`)
return (*json.RawMessage)(&respjson), nil
}),
)

r := httptest.NewRequest(tt.reqMeth, tt.path, tt.reqBody)
r.Header.Set("origin", testOrigin)
w := httptest.NewRecorder()
srv.srv.Handler.ServeHTTP(w, r)

assert.Equal(t, tt.expectStatus, w.Code)

if tt.withcors && tt.expectStatus == http.StatusOK {
// expect the cors headers fields
rhdr := w.Result().Header
for hk, hvs := range wantCorsHeaders {
vs, have := rhdr[hk]
if !have {
t.Fatalf("missing cors header %v", hk)
}
if !slices.Equal(vs, hvs) {
t.Errorf("different cors headers: got %v, want %v", vs, hvs)
}
}

}
})
}
}
2 changes: 1 addition & 1 deletion internal/version/version.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
// - 0.6.0+release
// - 0.6.1
// - 0.6.2-alpha0+go1.21.nocgo
const kwilVersion = "0.9.0-pre"
const kwilVersion = "0.9.2-pre" // remove "-pre" for the tagged commit

// KwildVersion may be set at compile time by:
//
Expand Down
Loading