From 29ddcf8f4e5a32387d066e4afbe815aaf066464d Mon Sep 17 00:00:00 2001 From: Denis Shaposhnikov <993498+dsh2dsh@users.noreply.github.com> Date: Thu, 31 Oct 2024 19:40:39 +0100 Subject: [PATCH] Remove unneeded log param from handlers and middlewares We can extract logger.Logger from context.Context. --- internal/daemon/control.go | 25 +++--- internal/daemon/daemon.go | 2 +- internal/daemon/middleware/json.go | 107 +++++++++-------------- internal/daemon/middleware/log.go | 12 +-- internal/daemon/middleware/middleware.go | 7 ++ internal/daemon/middleware/request_id.go | 5 +- internal/daemon/server.go | 2 +- 7 files changed, 71 insertions(+), 89 deletions(-) diff --git a/internal/daemon/control.go b/internal/daemon/control.go index bb76758c..db78733c 100644 --- a/internal/daemon/control.go +++ b/internal/daemon/control.go @@ -1,12 +1,13 @@ package daemon import ( + "context" "fmt" "net/http" "os" + "github.com/dsh2dsh/zrepl/internal/daemon/logging" "github.com/dsh2dsh/zrepl/internal/daemon/middleware" - "github.com/dsh2dsh/zrepl/internal/logger" "github.com/dsh2dsh/zrepl/internal/util/envconst" "github.com/dsh2dsh/zrepl/internal/version" "github.com/dsh2dsh/zrepl/internal/zfs/zfscmd" @@ -18,32 +19,33 @@ const ( ControlJobEndpointVersion = "/version" ) -func newControlJob(jobs *jobs, log logger.Logger) *controlJob { - return &controlJob{jobs: jobs, log: log} +func newControlJob(jobs *jobs) *controlJob { + return &controlJob{jobs: jobs} } type controlJob struct { jobs *jobs - log logger.Logger } func (j *controlJob) Endpoints(mux *http.ServeMux, m ...middleware.Middleware, ) { mux.Handle(ControlJobEndpointVersion, middleware.Append(m, - middleware.JsonResponder(j.log, j.version))) + middleware.JsonResponder(j.version))) mux.Handle(ControlJobEndpointStatus, middleware.Append(m, - middleware.JsonResponder(j.log, j.status))) + middleware.JsonResponder(j.status))) mux.Handle(ControlJobEndpointSignal, middleware.Append(m, - middleware.JsonRequestResponder(j.log, j.signal))) + middleware.JsonRequestResponder(j.signal))) } -func (j *controlJob) version() (version.ZreplVersionInformation, error) { +func (j *controlJob) version(_ context.Context) ( + version.ZreplVersionInformation, error, +) { return version.NewZreplVersionInformation(), nil } -func (j *controlJob) status() (Status, error) { +func (j *controlJob) status(_ context.Context) (Status, error) { s := Status{ Jobs: j.jobs.status(), Global: GlobalStatus{ @@ -60,8 +62,9 @@ type signalRequest struct { Name string } -func (j *controlJob) signal(req *signalRequest) (struct{}, error) { - log := j.log.WithField("op", req.Op) +func (j *controlJob) signal(ctx context.Context, req *signalRequest, +) (struct{}, error) { + log := logging.FromContext(ctx).WithField("op", req.Op) if req.Name != "" { log = log.WithField("name", req.Name) } diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index 7d70d8ac..0a31b0ed 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -82,7 +82,7 @@ func registerTraceCallbacks() { func startServer(log logger.Logger, conf *config.Config, jobs *jobs, logOutlets *logger.Outlets, ) error { - server := newServerJob(log, newControlJob(jobs, log)) + server := newServerJob(log, newControlJob(jobs)) var hasControl, hasMetrics bool for i := range conf.Listen { diff --git a/internal/daemon/middleware/json.go b/internal/daemon/middleware/json.go index 123eeeab..652c091d 100644 --- a/internal/daemon/middleware/json.go +++ b/internal/daemon/middleware/json.go @@ -1,87 +1,62 @@ package middleware import ( + "context" "encoding/json" "net/http" - - "github.com/dsh2dsh/zrepl/internal/logger" ) -func JsonResponder[T any](log logger.Logger, h func() (T, error)) Middleware { - return func(next http.Handler) http.Handler { - return &jsonResponder[T]{log: log, handler: h} +func JsonResponder[T any](h func(context.Context) (T, error)) Middleware { + fn := func(w http.ResponseWriter, r *http.Request) { + res, err := h(r.Context()) + if err != nil { + writeError(w, r, err, "control handler error") + return + } + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(&res); err != nil { + writeError(w, r, err, "control handler json marshal error") + } } + return func(next http.Handler) http.Handler { return http.HandlerFunc(fn) } } -type jsonResponder[T any] struct { - log logger.Logger - handler func() (T, error) -} - -func (self *jsonResponder[T]) ServeHTTP(w http.ResponseWriter, - r *http.Request, +func writeError(w http.ResponseWriter, r *http.Request, err error, msg string, ) { - res, err := self.handler() - if err != nil { - self.writeError(err, w, "control handler error") - return - } - - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(&res); err != nil { - self.writeError(err, w, "control handler json marshal error") - } + writeErrorCode(w, r, http.StatusInternalServerError, err, msg) } -func (self *jsonResponder[T]) writeError(err error, w http.ResponseWriter, - msg string, +func writeErrorCode(w http.ResponseWriter, r *http.Request, statusCode int, + err error, msg string, ) { - self.log.WithError(err).Error(msg) - http.Error(w, err.Error(), http.StatusInternalServerError) + getLogger(r).WithError(err).Error(msg) + http.Error(w, err.Error(), statusCode) } // -------------------------------------------------- -func JsonRequestResponder[T1 any, T2 any](log logger.Logger, - h func(req *T1) (T2, error), +func JsonRequestResponder[T1 any, T2 any](h func(ctx context.Context, req *T1, +) (T2, error), ) Middleware { - return func(next http.Handler) http.Handler { - return &jsonRequestResponder[T1, T2]{log: log, handler: h} + fn := func(w http.ResponseWriter, r *http.Request) { + var req T1 + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeErrorCode(w, r, http.StatusBadRequest, err, + "control handler json unmarshal error", + ) + return + } + + resp, err := h(r.Context(), &req) + if err != nil { + writeError(w, r, err, "control handler error") + return + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(&resp); err != nil { + writeError(w, r, err, "control handler json marshal error") + } } -} - -type jsonRequestResponder[T1 any, T2 any] struct { - log logger.Logger - handler func(req *T1) (T2, error) -} - -func (self *jsonRequestResponder[T1, T2]) ServeHTTP(w http.ResponseWriter, - r *http.Request, -) { - var req T1 - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - self.writeError(err, w, "control handler json unmarshal error", - http.StatusBadRequest) - return - } - - resp, err := self.handler(&req) - if err != nil { - self.writeError(err, w, "control handler error", - http.StatusInternalServerError) - return - } - - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(&resp); err != nil { - self.writeError(err, w, "control handler json marshal error", - http.StatusInternalServerError) - } -} - -func (self *jsonRequestResponder[T1, T2]) writeError(err error, - w http.ResponseWriter, msg string, statusCode int, -) { - self.log.WithError(err).Error(msg) - http.Error(w, err.Error(), statusCode) + return func(next http.Handler) http.Handler { return http.HandlerFunc(fn) } } diff --git a/internal/daemon/middleware/log.go b/internal/daemon/middleware/log.go index 741de4cd..b25bdcfe 100644 --- a/internal/daemon/middleware/log.go +++ b/internal/daemon/middleware/log.go @@ -9,11 +9,9 @@ import ( "github.com/dsh2dsh/zrepl/internal/logger" ) -func RequestLogger(log logger.Logger, opts ...LoggerOption) Middleware { +func RequestLogger(opts ...LoggerOption) Middleware { l := &LogReq{ - log: log, - levels: make(map[string]logger.Level, 1), - + levels: make(map[string]logger.Level, 1), completedLevel: logger.Debug, } @@ -34,7 +32,6 @@ func WithCustomLevel(url string, level logger.Level) LoggerOption { } type LogReq struct { - log logger.Logger levels map[string]logger.Level completedLevel logger.Level @@ -47,11 +44,8 @@ func (self *LogReq) WithCustomLevel(url string, level logger.Level) *LogReq { func (self *LogReq) middleware(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { - log := self.log + log := getLogger(r) logLevel := self.requestLevel(r) - if requestId := GetRequestId(r.Context()); requestId != "" { - log = log.WithField("rid", requestId) - } methodURL := r.Method + " " + r.URL.String() log.Log(logLevel, "\""+methodURL+"\"") diff --git a/internal/daemon/middleware/middleware.go b/internal/daemon/middleware/middleware.go index a95245ae..a4c1caf0 100644 --- a/internal/daemon/middleware/middleware.go +++ b/internal/daemon/middleware/middleware.go @@ -3,6 +3,9 @@ package middleware import ( "net/http" "slices" + + "github.com/dsh2dsh/zrepl/internal/daemon/logging" + "github.com/dsh2dsh/zrepl/internal/logger" ) type Middleware func(next http.Handler) http.Handler @@ -24,3 +27,7 @@ func Append(m1 []Middleware, m2 ...Middleware) http.Handler { func AppendHandler(m []Middleware, h http.Handler) http.Handler { return Append(m, func(http.Handler) http.Handler { return h }) } + +func getLogger(r *http.Request) logger.Logger { + return logging.FromContext(r.Context()) +} diff --git a/internal/daemon/middleware/request_id.go b/internal/daemon/middleware/request_id.go index 98f34c66..bd737414 100644 --- a/internal/daemon/middleware/request_id.go +++ b/internal/daemon/middleware/request_id.go @@ -5,6 +5,8 @@ import ( "net/http" "strconv" "sync/atomic" + + "github.com/dsh2dsh/zrepl/internal/daemon/logging" ) type ctxKeyRequestId struct{} @@ -22,12 +24,13 @@ func RequestId(next http.Handler) http.Handler { requestId := genRequestId() ctx := context.WithValue(r.Context(), RequestIdKey, strconv.FormatUint(requestId, 10)) + ctx = logging.WithLogger(ctx, getLogger(r).WithField("rid", requestId)) next.ServeHTTP(w, r.WithContext(ctx)) } return http.HandlerFunc(fn) } -func GetRequestId(ctx context.Context) string { +func RequestIdFrom(ctx context.Context) string { if ctx == nil { return "" } diff --git a/internal/daemon/server.go b/internal/daemon/server.go index 35b5439e..d248e645 100644 --- a/internal/daemon/server.go +++ b/internal/daemon/server.go @@ -60,7 +60,7 @@ type serverJob struct { func (self *serverJob) init() *serverJob { self.defaultMiddles = []middleware.Middleware{ - middleware.RequestLogger(self.log, + middleware.RequestLogger( // don't log requests to status endpoint, too spammy middleware.WithCustomLevel(ControlJobEndpointStatus, logger.Debug)), middleware.PrometheusMetrics(self.reqBegin, self.reqFinished),