Skip to content

Commit

Permalink
Refactor JsonResponder and JsonRequestResponder middlewares
Browse files Browse the repository at this point in the history
using generics.
  • Loading branch information
dsh2dsh committed Sep 11, 2024
1 parent e4ac5cf commit 7f2589d
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 81 deletions.
7 changes: 3 additions & 4 deletions client/version.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,23 +40,22 @@ func runVersionCmd() error {
return errors.New("show flag must be 'client' or 'server' or be left empty")
}

var clientVersion, daemonVersion *version.ZreplVersionInformation
var clientVersion, daemonVersion version.ZreplVersionInformation
if args.Show == "client" || args.Show == "" {
clientVersion = version.NewZreplVersionInformation()
fmt.Printf("client: %s\n", clientVersion.String())
}

if args.Show == "daemon" || args.Show == "" {
if args.ConfigErr != nil {
return fmt.Errorf("config parsing error: %s", args.ConfigErr)
}

var info version.ZreplVersionInformation
err := jsonRequestResponse(args.Config.Global.Control.SockPath,
daemon.ControlJobEndpointVersion, nil, &info)
daemon.ControlJobEndpointVersion, nil, &daemonVersion)
if err != nil {
return fmt.Errorf("server: error: %s\n", err)
}
daemonVersion = &info
fmt.Printf("server: %s\n", daemonVersion.String())
}

Expand Down
44 changes: 17 additions & 27 deletions daemon/control.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package daemon

import (
"context"
"errors"
"fmt"
"net"
"net/http"
Expand Down Expand Up @@ -150,19 +149,17 @@ func (j *controlJob) Run(ctx context.Context, cron *cron.Cron) {
}

func (j *controlJob) mux() *http.ServeMux {
mux := http.NewServeMux()
logRequest := middleware.RequestLogger(j.log,
middleware.WithPrometheusMetrics(j.requestBegin, j.requestFinished))

mux := http.NewServeMux()
mux.Handle(ControlJobEndpointPProf, middleware.New(
logRequest,
middleware.JsonRequestResponder(j.log, j.pprof)))

mux.Handle(ControlJobEndpointVersion, middleware.New(
logRequest,
middleware.JsonResponder(j.log, func() (any, error) {
return version.NewZreplVersionInformation(), nil
})))
middleware.JsonResponder(j.log, j.version)))

mux.Handle(ControlJobEndpointStatus, middleware.New(
// don't log requests to status endpoint, too spammy
Expand All @@ -174,40 +171,33 @@ func (j *controlJob) mux() *http.ServeMux {
return mux
}

func (j *controlJob) pprof(decoder middleware.JsonDecoder) (any, error) {
var msg PprofServerControlMsg
err := decoder(&msg)
if err != nil {
return nil, errors.New("decode failed")
}
j.pprofServer.Control(msg)
func (j *controlJob) pprof(msg *PprofServerControlMsg) (struct{}, error) {
j.pprofServer.Control(*msg)
return struct{}{}, nil
}

func (j *controlJob) status() (any, error) {
jobs := j.jobs.status()
globalZFS := zfscmd.GetReport()
envconstReport := envconst.GetReport()
func (j *controlJob) version() (version.ZreplVersionInformation, error) {
return version.NewZreplVersionInformation(), nil
}

func (j *controlJob) status() (Status, error) {
s := Status{
Jobs: jobs,
Jobs: j.jobs.status(),
Global: GlobalStatus{
ZFSCmds: globalZFS,
Envconst: envconstReport,
ZFSCmds: zfscmd.GetReport(),
Envconst: envconst.GetReport(),
OsEnviron: os.Environ(),
},
}
return s, nil
}

func (j *controlJob) signal(decoder middleware.JsonDecoder) (any, error) {
req := struct {
Op string
Name string
}{}
if decoder(&req) != nil {
return nil, errors.New("decode failed")
}
type signalRequest struct {
Op string
Name string
}

func (j *controlJob) signal(req *signalRequest) (struct{}, error) {
log := j.log.WithField("op", req.Op)
if req.Name != "" {
log.WithField("name", req.Name)
Expand Down
73 changes: 30 additions & 43 deletions daemon/middleware/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,37 +8,33 @@ import (
"github.com/dsh2dsh/zrepl/logger"
)

func JsonResponder(log logger.Logger, producer func() (any, error),
) Middleware {
func JsonResponder[T any](log logger.Logger, h func() (T, error)) Middleware {
return func(next http.Handler) http.Handler {
return NewJsonResponder(log, producer)
return &jsonResponder[T]{log: log, handler: h}
}
}

func NewJsonResponder(log logger.Logger, producer func() (any, error),
) *jsonResponder {
return &jsonResponder{log: log, producer: producer}
}

type jsonResponder struct {
log logger.Logger
producer func() (any, error)
type jsonResponder[T any] struct {
log logger.Logger
handler func() (T, error)
}

func (self *jsonResponder) ServeHTTP(w http.ResponseWriter, r *http.Request) {
res, err := self.producer()
func (self *jsonResponder[T]) ServeHTTP(w http.ResponseWriter,
r *http.Request,
) {
res, err := self.handler()
if err != nil {
self.writeError(err, w, "control handler error")
return
}

w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(res); err != nil {
if err := json.NewEncoder(w).Encode(&res); err != nil {
self.writeError(err, w, "control handler json marshal error")
}
}

func (self *jsonResponder) writeError(err error, w http.ResponseWriter,
func (self *jsonResponder[T]) writeError(err error, w http.ResponseWriter,
msg string,
) {
self.log.WithError(err).Error(msg)
Expand All @@ -48,56 +44,47 @@ func (self *jsonResponder) writeError(err error, w http.ResponseWriter,
}
}

func JsonRequestResponder(log logger.Logger,
producer func(decoder JsonDecoder) (any, error),
// --------------------------------------------------

func JsonRequestResponder[T1 any, T2 any](log logger.Logger,
h func(req *T1) (T2, error),
) Middleware {
return func(next http.Handler) http.Handler {
return NewJsonRequestResponder(log, producer)
return &jsonRequestResponder[T1, T2]{log: log, handler: h}
}
}

func NewJsonRequestResponder(log logger.Logger,
producer func(decoder JsonDecoder) (any, error),
) *jsonRequestResponder {
return &jsonRequestResponder{log: log, producer: producer}
type jsonRequestResponder[T1 any, T2 any] struct {
log logger.Logger
handler func(req *T1) (T2, error)
}

type jsonRequestResponder struct {
log logger.Logger
producer func(decoder JsonDecoder) (any, error)
}

type JsonDecoder = func(any) error

func (self *jsonRequestResponder) ServeHTTP(w http.ResponseWriter,
func (self *jsonRequestResponder[T1, T2]) ServeHTTP(w http.ResponseWriter,
r *http.Request,
) {
var decodeErr error
resp, err := self.producer(func(req any) error {
decodeErr = json.NewDecoder(r.Body).Decode(&req)
return decodeErr
})

// If we had a decode error ignore output of producer and return error
if decodeErr != nil {
self.writeError(decodeErr, w, "control handler json unmarshal error",
var req T1
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
self.writeError(err, w, "control handler json unmarshal error",
http.StatusBadRequest)
return
} else if err != nil {
}

resp, err := self.handler(&req)
if err != nil {
self.writeError(err, w, "control handler error",
http.StatusInternalServerError)
return
}

w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(resp); err != nil {
if err := json.NewEncoder(w).Encode(&resp); err != nil {
self.writeError(err, w, "control handler json marshal error",
http.StatusInternalServerError)
}
}

func (self *jsonRequestResponder) writeError(err error, w http.ResponseWriter,
msg string, statusCode int,
func (self *jsonRequestResponder[T1, T2]) writeError(err error,
w http.ResponseWriter, msg string, statusCode int,
) {
self.log.WithError(err).Error(msg)
w.WriteHeader(statusCode)
Expand Down
13 changes: 6 additions & 7 deletions version/version.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@ import (
"github.com/prometheus/client_golang/prometheus"
)

var (
zreplVersion string // set by build infrastructure
)
var zreplVersion string // set by build infrastructure

type ZreplVersionInformation struct {
Version string
Expand All @@ -19,8 +17,8 @@ type ZreplVersionInformation struct {
RUNTIMECompiler string
}

func NewZreplVersionInformation() *ZreplVersionInformation {
return &ZreplVersionInformation{
func NewZreplVersionInformation() ZreplVersionInformation {
return ZreplVersionInformation{
Version: zreplVersion,
RuntimeGo: runtime.Version(),
RuntimeGOOS: runtime.GOOS,
Expand All @@ -29,9 +27,10 @@ func NewZreplVersionInformation() *ZreplVersionInformation {
}
}

func (i *ZreplVersionInformation) String() string {
func (self ZreplVersionInformation) String() string {
return fmt.Sprintf("zrepl version=%s go=%s GOOS=%s GOARCH=%s Compiler=%s",
i.Version, i.RuntimeGo, i.RuntimeGOOS, i.RuntimeGOARCH, i.RUNTIMECompiler)
self.Version, self.RuntimeGo, self.RuntimeGOOS, self.RuntimeGOARCH,
self.RUNTIMECompiler)
}

var prometheusMetric = prometheus.NewGauge(
Expand Down

0 comments on commit 7f2589d

Please sign in to comment.