Skip to content

Commit

Permalink
Remove unneeded log param from handlers and middlewares
Browse files Browse the repository at this point in the history
We can extract logger.Logger from context.Context.
  • Loading branch information
dsh2dsh committed Oct 31, 2024
1 parent 864b66c commit 29ddcf8
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 89 deletions.
25 changes: 14 additions & 11 deletions internal/daemon/control.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
package daemon

import (
"context"
"fmt"
"net/http"
"os"

"github.com/dsh2dsh/zrepl/internal/daemon/logging"
"github.com/dsh2dsh/zrepl/internal/daemon/middleware"
"github.com/dsh2dsh/zrepl/internal/logger"
"github.com/dsh2dsh/zrepl/internal/util/envconst"
"github.com/dsh2dsh/zrepl/internal/version"
"github.com/dsh2dsh/zrepl/internal/zfs/zfscmd"
Expand All @@ -18,32 +19,33 @@ const (
ControlJobEndpointVersion = "/version"
)

func newControlJob(jobs *jobs, log logger.Logger) *controlJob {
return &controlJob{jobs: jobs, log: log}
func newControlJob(jobs *jobs) *controlJob {
return &controlJob{jobs: jobs}
}

type controlJob struct {
jobs *jobs
log logger.Logger
}

func (j *controlJob) Endpoints(mux *http.ServeMux, m ...middleware.Middleware,
) {
mux.Handle(ControlJobEndpointVersion, middleware.Append(m,
middleware.JsonResponder(j.log, j.version)))
middleware.JsonResponder(j.version)))

mux.Handle(ControlJobEndpointStatus, middleware.Append(m,
middleware.JsonResponder(j.log, j.status)))
middleware.JsonResponder(j.status)))

mux.Handle(ControlJobEndpointSignal, middleware.Append(m,
middleware.JsonRequestResponder(j.log, j.signal)))
middleware.JsonRequestResponder(j.signal)))
}

func (j *controlJob) version() (version.ZreplVersionInformation, error) {
func (j *controlJob) version(_ context.Context) (
version.ZreplVersionInformation, error,
) {
return version.NewZreplVersionInformation(), nil
}

func (j *controlJob) status() (Status, error) {
func (j *controlJob) status(_ context.Context) (Status, error) {
s := Status{
Jobs: j.jobs.status(),
Global: GlobalStatus{
Expand All @@ -60,8 +62,9 @@ type signalRequest struct {
Name string
}

func (j *controlJob) signal(req *signalRequest) (struct{}, error) {
log := j.log.WithField("op", req.Op)
func (j *controlJob) signal(ctx context.Context, req *signalRequest,
) (struct{}, error) {
log := logging.FromContext(ctx).WithField("op", req.Op)
if req.Name != "" {
log = log.WithField("name", req.Name)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/daemon/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func registerTraceCallbacks() {
func startServer(log logger.Logger, conf *config.Config, jobs *jobs,
logOutlets *logger.Outlets,
) error {
server := newServerJob(log, newControlJob(jobs, log))
server := newServerJob(log, newControlJob(jobs))

var hasControl, hasMetrics bool
for i := range conf.Listen {
Expand Down
107 changes: 41 additions & 66 deletions internal/daemon/middleware/json.go
Original file line number Diff line number Diff line change
@@ -1,87 +1,62 @@
package middleware

import (
"context"
"encoding/json"
"net/http"

"github.com/dsh2dsh/zrepl/internal/logger"
)

func JsonResponder[T any](log logger.Logger, h func() (T, error)) Middleware {
return func(next http.Handler) http.Handler {
return &jsonResponder[T]{log: log, handler: h}
func JsonResponder[T any](h func(context.Context) (T, error)) Middleware {
fn := func(w http.ResponseWriter, r *http.Request) {
res, err := h(r.Context())
if err != nil {
writeError(w, r, err, "control handler error")
return
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(&res); err != nil {
writeError(w, r, err, "control handler json marshal error")
}
}
return func(next http.Handler) http.Handler { return http.HandlerFunc(fn) }
}

type jsonResponder[T any] struct {
log logger.Logger
handler func() (T, error)
}

func (self *jsonResponder[T]) ServeHTTP(w http.ResponseWriter,
r *http.Request,
func writeError(w http.ResponseWriter, r *http.Request, err error, msg string,
) {
res, err := self.handler()
if err != nil {
self.writeError(err, w, "control handler error")
return
}

w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(&res); err != nil {
self.writeError(err, w, "control handler json marshal error")
}
writeErrorCode(w, r, http.StatusInternalServerError, err, msg)
}

func (self *jsonResponder[T]) writeError(err error, w http.ResponseWriter,
msg string,
func writeErrorCode(w http.ResponseWriter, r *http.Request, statusCode int,
err error, msg string,
) {
self.log.WithError(err).Error(msg)
http.Error(w, err.Error(), http.StatusInternalServerError)
getLogger(r).WithError(err).Error(msg)
http.Error(w, err.Error(), statusCode)
}

// --------------------------------------------------

func JsonRequestResponder[T1 any, T2 any](log logger.Logger,
h func(req *T1) (T2, error),
func JsonRequestResponder[T1 any, T2 any](h func(ctx context.Context, req *T1,
) (T2, error),
) Middleware {
return func(next http.Handler) http.Handler {
return &jsonRequestResponder[T1, T2]{log: log, handler: h}
fn := func(w http.ResponseWriter, r *http.Request) {
var req T1
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeErrorCode(w, r, http.StatusBadRequest, err,
"control handler json unmarshal error",
)
return
}

resp, err := h(r.Context(), &req)
if err != nil {
writeError(w, r, err, "control handler error")
return
}

w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(&resp); err != nil {
writeError(w, r, err, "control handler json marshal error")
}
}
}

type jsonRequestResponder[T1 any, T2 any] struct {
log logger.Logger
handler func(req *T1) (T2, error)
}

func (self *jsonRequestResponder[T1, T2]) ServeHTTP(w http.ResponseWriter,
r *http.Request,
) {
var req T1
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
self.writeError(err, w, "control handler json unmarshal error",
http.StatusBadRequest)
return
}

resp, err := self.handler(&req)
if err != nil {
self.writeError(err, w, "control handler error",
http.StatusInternalServerError)
return
}

w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(&resp); err != nil {
self.writeError(err, w, "control handler json marshal error",
http.StatusInternalServerError)
}
}

func (self *jsonRequestResponder[T1, T2]) writeError(err error,
w http.ResponseWriter, msg string, statusCode int,
) {
self.log.WithError(err).Error(msg)
http.Error(w, err.Error(), statusCode)
return func(next http.Handler) http.Handler { return http.HandlerFunc(fn) }
}
12 changes: 3 additions & 9 deletions internal/daemon/middleware/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@ import (
"github.com/dsh2dsh/zrepl/internal/logger"
)

func RequestLogger(log logger.Logger, opts ...LoggerOption) Middleware {
func RequestLogger(opts ...LoggerOption) Middleware {
l := &LogReq{
log: log,
levels: make(map[string]logger.Level, 1),

levels: make(map[string]logger.Level, 1),
completedLevel: logger.Debug,
}

Expand All @@ -34,7 +32,6 @@ func WithCustomLevel(url string, level logger.Level) LoggerOption {
}

type LogReq struct {
log logger.Logger
levels map[string]logger.Level

completedLevel logger.Level
Expand All @@ -47,11 +44,8 @@ func (self *LogReq) WithCustomLevel(url string, level logger.Level) *LogReq {

func (self *LogReq) middleware(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
log := self.log
log := getLogger(r)
logLevel := self.requestLevel(r)
if requestId := GetRequestId(r.Context()); requestId != "" {
log = log.WithField("rid", requestId)
}

methodURL := r.Method + " " + r.URL.String()
log.Log(logLevel, "\""+methodURL+"\"")
Expand Down
7 changes: 7 additions & 0 deletions internal/daemon/middleware/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package middleware
import (
"net/http"
"slices"

"github.com/dsh2dsh/zrepl/internal/daemon/logging"
"github.com/dsh2dsh/zrepl/internal/logger"
)

type Middleware func(next http.Handler) http.Handler
Expand All @@ -24,3 +27,7 @@ func Append(m1 []Middleware, m2 ...Middleware) http.Handler {
func AppendHandler(m []Middleware, h http.Handler) http.Handler {
return Append(m, func(http.Handler) http.Handler { return h })
}

func getLogger(r *http.Request) logger.Logger {
return logging.FromContext(r.Context())
}
5 changes: 4 additions & 1 deletion internal/daemon/middleware/request_id.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"net/http"
"strconv"
"sync/atomic"

"github.com/dsh2dsh/zrepl/internal/daemon/logging"
)

type ctxKeyRequestId struct{}
Expand All @@ -22,12 +24,13 @@ func RequestId(next http.Handler) http.Handler {
requestId := genRequestId()
ctx := context.WithValue(r.Context(),
RequestIdKey, strconv.FormatUint(requestId, 10))
ctx = logging.WithLogger(ctx, getLogger(r).WithField("rid", requestId))
next.ServeHTTP(w, r.WithContext(ctx))
}
return http.HandlerFunc(fn)
}

func GetRequestId(ctx context.Context) string {
func RequestIdFrom(ctx context.Context) string {
if ctx == nil {
return ""
}
Expand Down
2 changes: 1 addition & 1 deletion internal/daemon/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ type serverJob struct {

func (self *serverJob) init() *serverJob {
self.defaultMiddles = []middleware.Middleware{
middleware.RequestLogger(self.log,
middleware.RequestLogger(
// don't log requests to status endpoint, too spammy
middleware.WithCustomLevel(ControlJobEndpointStatus, logger.Debug)),
middleware.PrometheusMetrics(self.reqBegin, self.reqFinished),
Expand Down

0 comments on commit 29ddcf8

Please sign in to comment.