diff --git a/auth.go b/auth.go index 7518941..1c587ea 100644 --- a/auth.go +++ b/auth.go @@ -7,18 +7,21 @@ Released under MIT license. package authkit import ( + "context" "crypto/tls" "crypto/x509" "fmt" "net/http" "os" + "github.com/acronis/go-appkit/httpserver/middleware" "github.com/acronis/go-appkit/log" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "github.com/acronis/go-authkit/idptoken" "github.com/acronis/go-authkit/internal/idputil" + "github.com/acronis/go-authkit/internal/libinfo" "github.com/acronis/go-authkit/jwks" "github.com/acronis/go-authkit/jwt" ) @@ -26,13 +29,11 @@ import ( // NewJWTParser creates a new JWTParser with the given configuration. // If cfg.JWT.ClaimsCache.Enabled is true, then jwt.CachingParser created, otherwise - jwt.Parser. func NewJWTParser(cfg *Config, opts ...JWTParserOption) (JWTParser, error) { - var options jwtParserOptions + options := jwtParserOptions{loggerProvider: middleware.GetLoggerFromContext} for _, opt := range opts { opt(&options) } - logger := idputil.PrepareLogger(options.logger) - // Make caching JWKS client. jwksCacheUpdateMinInterval := cfg.JWKS.Cache.UpdateMinInterval if jwksCacheUpdateMinInterval == 0 { @@ -40,8 +41,8 @@ func NewJWTParser(cfg *Config, opts ...JWTParserOption) (JWTParser, error) { } jwksClientOpts := jwks.CachingClientOpts{ ClientOpts: jwks.ClientOpts{ - Logger: logger, - HTTPClient: idputil.MakeDefaultHTTPClient(cfg.HTTPClient.RequestTimeout, logger), + LoggerProvider: options.loggerProvider, + HTTPClient: idputil.MakeDefaultHTTPClient(cfg.HTTPClient.RequestTimeout, options.loggerProvider), PrometheusLibInstanceLabel: options.prometheusLibInstanceLabel, }, CacheUpdateMinInterval: jwksCacheUpdateMinInterval, @@ -51,17 +52,19 @@ func NewJWTParser(cfg *Config, opts ...JWTParserOption) (JWTParser, error) { // Make JWT parser. if len(cfg.JWT.TrustedIssuers) == 0 && len(cfg.JWT.TrustedIssuerURLs) == 0 { - logger.Warn("list of trusted issuers is empty, jwt parsing may not work properly") + idputil.GetLoggerFromProvider(context.Background(), options.loggerProvider).Warn( + "list of trusted issuers is empty, jwt parsing may not work properly") } parserOpts := jwt.ParserOpts{ RequireAudience: cfg.JWT.RequireAudience, ExpectedAudience: cfg.JWT.ExpectedAudience, TrustedIssuerNotFoundFallback: options.trustedIssuerNotFoundFallback, + LoggerProvider: options.loggerProvider, } if cfg.JWT.ClaimsCache.Enabled { - cachingJWTParser, err := jwt.NewCachingParserWithOpts(jwksClient, logger, jwt.CachingParserOpts{ + cachingJWTParser, err := jwt.NewCachingParserWithOpts(jwksClient, jwt.CachingParserOpts{ ParserOpts: parserOpts, CacheMaxEntries: cfg.JWT.ClaimsCache.MaxEntries, }) @@ -74,7 +77,7 @@ func NewJWTParser(cfg *Config, opts ...JWTParserOption) (JWTParser, error) { return cachingJWTParser, nil } - jwtParser := jwt.NewParserWithOpts(jwksClient, logger, parserOpts) + jwtParser := jwt.NewParserWithOpts(jwksClient, parserOpts) if err := addTrustedIssuers(jwtParser, cfg.JWT.TrustedIssuers, cfg.JWT.TrustedIssuerURLs); err != nil { return nil, err } @@ -82,7 +85,7 @@ func NewJWTParser(cfg *Config, opts ...JWTParserOption) (JWTParser, error) { } type jwtParserOptions struct { - logger log.FieldLogger + loggerProvider func(ctx context.Context) log.FieldLogger prometheusLibInstanceLabel string trustedIssuerNotFoundFallback jwt.TrustedIssNotFoundFallback } @@ -90,10 +93,10 @@ type jwtParserOptions struct { // JWTParserOption is an option for creating JWTParser. type JWTParserOption func(options *jwtParserOptions) -// WithJWTParserLogger sets the logger for JWTParser. -func WithJWTParserLogger(logger log.FieldLogger) JWTParserOption { +// WithJWTParserLoggerProvider sets the logger provider for JWTParser. +func WithJWTParserLoggerProvider(loggerProvider func(ctx context.Context) log.FieldLogger) JWTParserOption { return func(options *jwtParserOptions) { - options.logger = logger + options.loggerProvider = loggerProvider } } @@ -122,15 +125,14 @@ func NewTokenIntrospector( scopeFilter []idptoken.IntrospectionScopeFilterAccessPolicy, opts ...TokenIntrospectorOption, ) (*idptoken.Introspector, error) { - var options tokenIntrospectorOptions + options := tokenIntrospectorOptions{loggerProvider: middleware.GetLoggerFromContext} for _, opt := range opts { opt(&options) } - logger := idputil.PrepareLogger(options.logger) - if len(cfg.JWT.TrustedIssuers) == 0 && len(cfg.JWT.TrustedIssuerURLs) == 0 { - logger.Warn("list of trusted issuers is empty, jwt introspection may not work properly") + idputil.GetLoggerFromProvider(context.Background(), options.loggerProvider).Warn( + "list of trusted issuers is empty, jwt introspection may not work properly") } var grpcClient *idptoken.GRPCClient @@ -139,9 +141,14 @@ func NewTokenIntrospector( if err != nil { return nil, fmt.Errorf("make grpc transport credentials: %w", err) } - grpcClient, err = idptoken.NewGRPCClientWithOpts(cfg.Introspection.GRPC.Endpoint, transportCreds, - idptoken.GRPCClientOpts{RequestTimeout: cfg.GRPCClient.RequestTimeout, Logger: logger}) - if err != nil { + grpcClientOpts := idptoken.GRPCClientOpts{ + RequestTimeout: cfg.GRPCClient.RequestTimeout, + LoggerProvider: options.loggerProvider, + UserAgent: libinfo.UserAgent(), + } + if grpcClient, err = idptoken.NewGRPCClientWithOpts( + cfg.Introspection.GRPC.Endpoint, transportCreds, grpcClientOpts, + ); err != nil { return nil, fmt.Errorf("new grpc client: %w", err) } } @@ -149,9 +156,9 @@ func NewTokenIntrospector( introspectorOpts := idptoken.IntrospectorOpts{ HTTPEndpoint: cfg.Introspection.Endpoint, GRPCClient: grpcClient, - HTTPClient: idputil.MakeDefaultHTTPClient(cfg.HTTPClient.RequestTimeout, logger), + HTTPClient: idputil.MakeDefaultHTTPClient(cfg.HTTPClient.RequestTimeout, options.loggerProvider), AccessTokenScope: cfg.Introspection.AccessTokenScope, - Logger: logger, + LoggerProvider: options.loggerProvider, ScopeFilter: scopeFilter, TrustedIssuerNotFoundFallback: options.trustedIssuerNotFoundFallback, PrometheusLibInstanceLabel: options.prometheusLibInstanceLabel, @@ -179,7 +186,7 @@ func NewTokenIntrospector( } type tokenIntrospectorOptions struct { - logger log.FieldLogger + loggerProvider func(ctx context.Context) log.FieldLogger prometheusLibInstanceLabel string trustedIssuerNotFoundFallback idptoken.TrustedIssNotFoundFallback } @@ -187,10 +194,10 @@ type tokenIntrospectorOptions struct { // TokenIntrospectorOption is an option for creating TokenIntrospector. type TokenIntrospectorOption func(options *tokenIntrospectorOptions) -// WithTokenIntrospectorLogger sets the logger for TokenIntrospector. -func WithTokenIntrospectorLogger(logger log.FieldLogger) TokenIntrospectorOption { +// WithTokenIntrospectorLoggerProvider sets the logger provider for TokenIntrospector. +func WithTokenIntrospectorLoggerProvider(loggerProvider func(ctx context.Context) log.FieldLogger) TokenIntrospectorOption { return func(options *tokenIntrospectorOptions) { - options.logger = logger + options.loggerProvider = loggerProvider } } @@ -242,6 +249,11 @@ func NewVerifyAccessByRolesInJWTMaker(namespace string) func(roleNames ...string } } +// SetDefaultLogger sets the default logger for the library. +func SetDefaultLogger(logger log.FieldLogger) { + idputil.DefaultLogger = logger +} + type issuerParser interface { AddTrustedIssuer(issName string, issURL string) AddTrustedIssuerURL(issURL string) error diff --git a/examples/authn-middleware/main.go b/examples/authn-middleware/main.go index 90d0320..6fac6c5 100644 --- a/examples/authn-middleware/main.go +++ b/examples/authn-middleware/main.go @@ -39,20 +39,19 @@ func runApp() error { logger, loggerClose := log.NewLogger(cfg.Log) defer loggerClose() - jwtParser, err := authkit.NewJWTParser(cfg.Auth, authkit.WithJWTParserLogger(logger)) + jwtParser, err := authkit.NewJWTParser(cfg.Auth) if err != nil { return fmt.Errorf("create JWT parser: %w", err) } - logMw := middleware.Logging(logger) authNMw := authkit.JWTAuthMiddleware(serviceErrorDomain, jwtParser) srvMux := http.NewServeMux() - srvMux.Handle("/", logMw(authNMw(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + srvMux.Handle("/", authNMw(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { jwtClaims := authkit.GetJWTClaimsFromContext(r.Context()) // get JWT claims from the request context _, _ = rw.Write([]byte(fmt.Sprintf("Hello, %s", jwtClaims.Subject))) - })))) - if err = http.ListenAndServe(":8080", srvMux); err != nil && !errors.Is(err, http.ErrServerClosed) { + }))) + if err = http.ListenAndServe(":8080", middleware.Logging(logger)(srvMux)); err != nil && !errors.Is(err, http.ErrServerClosed) { return fmt.Errorf("listen and HTTP server: %w", err) } diff --git a/examples/idp-test-server/main.go b/examples/idp-test-server/main.go index b119970..80d15ee 100644 --- a/examples/idp-test-server/main.go +++ b/examples/idp-test-server/main.go @@ -9,6 +9,7 @@ package main import ( "context" "errors" + "fmt" golog "log" "net/http" "os" @@ -22,7 +23,6 @@ import ( "github.com/acronis/go-authkit" "github.com/acronis/go-authkit/idptest" "github.com/acronis/go-authkit/idptoken" - "github.com/acronis/go-authkit/jwks" "github.com/acronis/go-authkit/jwt" ) @@ -38,9 +38,11 @@ func runApp() error { logger, loggerClose := log.NewLogger(&log.Config{Output: log.OutputStdout, Level: log.LevelInfo, Format: log.FormatJSON}) defer loggerClose() - jwksClientOpts := jwks.CachingClientOpts{ClientOpts: jwks.ClientOpts{Logger: logger}} - jwtParser := jwt.NewParser(jwks.NewCachingClientWithOpts(jwksClientOpts), logger) - _ = jwtParser.AddTrustedIssuerURL("http://" + idpAddr) + jwtParser, err := authkit.NewJWTParser( + &authkit.Config{JWT: authkit.JWTConfig{TrustedIssuerURLs: []string{"http://" + idpAddr}}}) + if err != nil { + return fmt.Errorf("create JWT parser: %w", err) + } idpSrv := idptest.NewHTTPServer( idptest.WithHTTPAddress(idpAddr), @@ -48,8 +50,8 @@ func runApp() error { idptest.WithHTTPClaimsProvider(&demoClaimsProvider{}), idptest.WithHTTPTokenIntrospector(&demoTokenIntrospector{jwtParser: jwtParser}), ) - if err := idpSrv.StartAndWaitForReady(time.Second * 3); err != nil { - return err + if err = idpSrv.StartAndWaitForReady(time.Second * 3); err != nil { + return fmt.Errorf("start HTTP server: %w", err) } logger.Info("HTTP IDP server is running on " + idpAddr) @@ -61,13 +63,13 @@ func runApp() error { defer shutdownCancel() if stopErr := idpSrv.Shutdown(shutdownCtx); stopErr != nil && !errors.Is(stopErr, http.ErrServerClosed) { - return stopErr + return fmt.Errorf("shutdown HTTP server: %w", stopErr) } return nil } type demoTokenIntrospector struct { - jwtParser *jwt.Parser + jwtParser authkit.JWTParser } func (dti *demoTokenIntrospector) IntrospectToken(r *http.Request, token string) (idptoken.IntrospectionResult, error) { diff --git a/examples/token-introspection/config.yml b/examples/token-introspection/config.yml index 366ec3d..987c95c 100644 --- a/examples/token-introspection/config.yml +++ b/examples/token-introspection/config.yml @@ -20,4 +20,4 @@ auth: maxEntries: 1000 ttl: 5m grpc: - endpoint: 127.0.0.1:50051 + endpoint: "" diff --git a/examples/token-introspection/grpc-server/main.go b/examples/token-introspection/grpc-server/main.go index ee0c106..bad5e40 100644 --- a/examples/token-introspection/grpc-server/main.go +++ b/examples/token-introspection/grpc-server/main.go @@ -8,6 +8,7 @@ package main import ( "context" + "fmt" golog "log" "os" "os/signal" @@ -17,9 +18,9 @@ import ( "github.com/acronis/go-appkit/log" "google.golang.org/grpc/metadata" + "github.com/acronis/go-authkit" "github.com/acronis/go-authkit/idptest" "github.com/acronis/go-authkit/idptoken/pb" - "github.com/acronis/go-authkit/jwks" "github.com/acronis/go-authkit/jwt" ) @@ -38,15 +39,17 @@ func runApp() error { logger, loggerClose := log.NewLogger(&log.Config{Output: log.OutputStdout, Level: log.LevelInfo, Format: log.FormatJSON}) defer loggerClose() - jwksClientOpts := jwks.CachingClientOpts{ClientOpts: jwks.ClientOpts{Logger: logger}} - jwtParser := jwt.NewParser(jwks.NewCachingClientWithOpts(jwksClientOpts), logger) - _ = jwtParser.AddTrustedIssuerURL("http://" + idpAddr) + jwtParser, err := authkit.NewJWTParser( + &authkit.Config{JWT: authkit.JWTConfig{TrustedIssuerURLs: []string{"http://" + idpAddr}}}) + if err != nil { + return fmt.Errorf("create JWT parser: %w", err) + } grpcSrv := idptest.NewGRPCServer( idptest.WithGRPCAddr(grpcAddr), idptest.WithGRPCTokenIntrospector(&demoGRPCTokenIntrospector{jwtParser: jwtParser, logger: logger}), ) - if err := grpcSrv.StartAndWaitForReady(time.Second * 3); err != nil { + if err = grpcSrv.StartAndWaitForReady(time.Second * 3); err != nil { return err } logger.Info("GRPC server for token introspection is running on " + grpcAddr) @@ -62,18 +65,26 @@ func runApp() error { const accessTokenWithIntrospectionPermission = "access-token-with-introspection-permission" type demoGRPCTokenIntrospector struct { - jwtParser *jwt.Parser + jwtParser authkit.JWTParser logger log.FieldLogger } func (dti *demoGRPCTokenIntrospector) IntrospectToken( ctx context.Context, req *pb.IntrospectTokenRequest, ) (*pb.IntrospectTokenResponse, error) { - dti.logger.Info("got IntrospectTokenRequest") + var userAgent string var authMeta string - if mdVal := metadata.ValueFromIncomingContext(ctx, "authorization"); len(mdVal) != 0 { - authMeta = mdVal[0] + if md, ok := metadata.FromIncomingContext(ctx); ok { + if userAgentList := md.Get("user-agent"); len(userAgentList) > 0 { + userAgent = userAgentList[0] + } + if authList := md.Get("authorization"); len(authList) > 0 { + authMeta = authList[0] + } } + + dti.logger.Info("got IntrospectTokenRequest", log.String("user_agent", userAgent)) + if authMeta != "Bearer "+accessTokenWithIntrospectionPermission { return nil, idptest.ErrUnauthorized } diff --git a/examples/token-introspection/main.go b/examples/token-introspection/main.go index b2c9da1..b4eb0c7 100644 --- a/examples/token-introspection/main.go +++ b/examples/token-introspection/main.go @@ -43,16 +43,14 @@ func runApp() error { defer loggerClose() // Create JWT parser. - jwtParser, err := authkit.NewJWTParser(cfg.Auth, authkit.WithJWTParserLogger(logger)) + jwtParser, err := authkit.NewJWTParser(cfg.Auth) if err != nil { return fmt.Errorf("create JWT parser: %w", err) } // Create token introspector. - introspectionScopeFilter := []idptoken.IntrospectionScopeFilterAccessPolicy{ - {ResourceNamespace: serviceAccessPolicy}} - tokenIntrospector, err := authkit.NewTokenIntrospector(cfg.Auth, introspectionTokenProvider{}, - introspectionScopeFilter, authkit.WithTokenIntrospectorLogger(logger)) + introspectionScopeFilter := []idptoken.IntrospectionScopeFilterAccessPolicy{{ResourceNamespace: serviceAccessPolicy}} + tokenIntrospector, err := authkit.NewTokenIntrospector(cfg.Auth, introspectionTokenProvider{}, introspectionScopeFilter) if err != nil { return fmt.Errorf("create token introspector: %w", err) } @@ -65,8 +63,6 @@ func runApp() error { }() } - logMw := middleware.Logging(logger) - // Configure JWTAuthMiddleware that performs only authentication via OAuth2 token introspection endpoint. authNMw := authkit.JWTAuthMiddleware(serviceErrorDomain, jwtParser, authkit.WithJWTAuthMiddlewareTokenIntrospector(tokenIntrospector)) @@ -81,16 +77,16 @@ func runApp() error { // Create HTTP server and start it. srvMux := http.NewServeMux() // "/" endpoint will be available for all authenticated users. - srvMux.Handle("/", logMw(authNMw(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + srvMux.Handle("/", authNMw(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { jwtClaims := authkit.GetJWTClaimsFromContext(r.Context()) // get JWT claims from the request context _, _ = rw.Write([]byte(fmt.Sprintf("Hello, %s", jwtClaims.Subject))) - })))) + }))) // "/admin" endpoint will be available only for users with the "admin" role. - srvMux.Handle("/admin", logMw(authZMw(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + srvMux.Handle("/admin", authZMw(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { jwtClaims := authkit.GetJWTClaimsFromContext(r.Context()) // Get JWT claims from the request context. _, _ = rw.Write([]byte(fmt.Sprintf("Hi, %s", jwtClaims.Subject))) - })))) - if err = http.ListenAndServe(":8080", srvMux); err != nil && !errors.Is(err, http.ErrServerClosed) { + }))) + if err = http.ListenAndServe(":8080", middleware.Logging(logger)(srvMux)); err != nil && !errors.Is(err, http.ErrServerClosed) { return fmt.Errorf("listen and HTTP server: %w", err) } diff --git a/go.mod b/go.mod index d710935..2548c6f 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/acronis/go-authkit go 1.20 require ( - github.com/acronis/go-appkit v1.3.0 + github.com/acronis/go-appkit v1.5.0 github.com/golang-jwt/jwt/v5 v5.2.1 github.com/google/uuid v1.6.0 github.com/mendsley/gojwk v0.0.0-20141217222730-4d5ec6e58103 @@ -21,6 +21,7 @@ require ( github.com/beorn7/perks v1.0.1 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/cloudflare/ahocorasick v0.0.0-20240916140611-054963ec9396 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/hashicorp/golang-lru v1.0.2 // indirect diff --git a/go.sum b/go.sum index 8a2a365..148277c 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,8 @@ code.cloudfoundry.org/bytefmt v0.0.0-20240808182453-a379845013d9 h1:8KlrGCtoaWaa code.cloudfoundry.org/bytefmt v0.0.0-20240808182453-a379845013d9/go.mod h1:eF2ZbltNI7Pv+8Cuyeksu9up5FN5konuH0trDJBuscw= github.com/RussellLuo/slidingwindow v0.0.0-20200528002341-535bb99d338b h1:5/++qT1/z812ZqBvqQt6ToRswSuPZ/B33m6xVHRzADU= github.com/RussellLuo/slidingwindow v0.0.0-20200528002341-535bb99d338b/go.mod h1:4+EPqMRApwwE/6yo6CxiHoSnBzjRr3jsqer7frxP8y4= -github.com/acronis/go-appkit v1.3.0 h1:IaX0DbD7HWp8ykqnK9F+c8757AmP4uBHVBe8J0Wv2sw= -github.com/acronis/go-appkit v1.3.0/go.mod h1:ouqWNe1/69fwjhx+2vV81Y6iqstfDzhmC6HpZ0E/gp4= +github.com/acronis/go-appkit v1.5.0 h1:dhZZyZjKS3WD9zsomznRnDMxdzIW2pG+EI08/AqtH+E= +github.com/acronis/go-appkit v1.5.0/go.mod h1:bDNkQ2ENdEz6vGHId22sE1NhZcScS61HY6GmVOkaS1s= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bsm/ginkgo/v2 v2.7.0/go.mod h1:AiKlXPm7ItEHNc/2+OkrNG4E0ITzojb9/xWzvQ9XZ9w= @@ -14,6 +14,8 @@ github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cloudflare/ahocorasick v0.0.0-20240916140611-054963ec9396 h1:W2HK1IdCnCGuLUeyizSCkwvBjdj0ZL7mxnJYQ3poyzI= +github.com/cloudflare/ahocorasick v0.0.0-20240916140611-054963ec9396/go.mod h1:tGWUZLZp9ajsxUOnHmFFLnqnlKXsCn6GReG4jAD59H0= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= diff --git a/idptest/http_server.go b/idptest/http_server.go index 993257c..a756f33 100644 --- a/idptest/http_server.go +++ b/idptest/http_server.go @@ -14,7 +14,6 @@ import ( "sync/atomic" "time" - "github.com/acronis/go-appkit/log" "github.com/acronis/go-appkit/testutil" "github.com/acronis/go-authkit/idptoken" @@ -258,7 +257,7 @@ func (s *HTTPServer) StartAndWaitForReady(timeout time.Duration) error { } func (s *HTTPServer) makeJWTParser() *jwt.Parser { - p := jwt.NewParser(jwks.NewClient(), log.NewDisabledLogger()) + p := jwt.NewParser(jwks.NewClient()) _ = p.AddTrustedIssuerURL(s.URL()) return p } diff --git a/idptest/jwt_test.go b/idptest/jwt_test.go index b3caa91..d86bfe5 100644 --- a/idptest/jwt_test.go +++ b/idptest/jwt_test.go @@ -12,7 +12,6 @@ import ( "testing" "time" - "github.com/acronis/go-appkit/log" jwtgo "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/require" @@ -29,8 +28,6 @@ func TestMakeTokenStringWithHeader(t *testing.T) { issuerConfigServer := httptest.NewServer(&OpenIDConfigurationHandler{JWKSURL: jwksServer.URL}) defer issuerConfigServer.Close() - logger := log.NewDisabledLogger() - jwtClaims := &jwt.Claims{ RegisteredClaims: jwtgo.RegisteredClaims{ Issuer: testIss, @@ -41,7 +38,7 @@ func TestMakeTokenStringWithHeader(t *testing.T) { }, } - parser := jwt.NewParser(jwks.NewCachingClient(), logger) + parser := jwt.NewParser(jwks.NewCachingClient()) parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) parsedClaims, err := parser.Parse(context.Background(), MustMakeTokenStringSignedWithTestKey(jwtClaims)) require.NoError(t, err) diff --git a/idptoken/grpc_client.go b/idptoken/grpc_client.go index 69cd158..0ee9aae 100644 --- a/idptoken/grpc_client.go +++ b/idptoken/grpc_client.go @@ -34,12 +34,15 @@ const grpcMetaAuthorization = "authorization" // GRPCClientOpts contains options for the GRPCClient. type GRPCClientOpts struct { - // Logger is a logger for the client. - Logger log.FieldLogger + // LoggerProvider is a function that provides a logger for the client. + LoggerProvider func(ctx context.Context) log.FieldLogger // RequestTimeout is a timeout for the gRPC requests. RequestTimeout time.Duration + // UserAgent is a user agent string for the client. + UserAgent string + // PrometheusLibInstanceLabel is a label for Prometheus metrics. // It allows distinguishing metrics from different instances of the same library. PrometheusLibInstanceLabel string @@ -65,14 +68,14 @@ func NewGRPCClient( func NewGRPCClientWithOpts( target string, transportCreds credentials.TransportCredentials, opts GRPCClientOpts, ) (*GRPCClient, error) { - opts.Logger = idputil.PrepareLogger(opts.Logger) if opts.RequestTimeout == 0 { opts.RequestTimeout = DefaultGRPCClientRequestTimeout } conn, err := grpc.NewClient(target, grpc.WithTransportCredentials(transportCreds), - grpc.WithStatsHandler(&statsHandler{logger: opts.Logger}), + grpc.WithStatsHandler(&statsHandler{loggerProvider: opts.LoggerProvider}), grpc.WithDefaultCallOptions(grpc.WaitForReady(true)), + grpc.WithUserAgent(opts.UserAgent), ) if err != nil { return nil, fmt.Errorf("dial to %q: %w", target, err) @@ -207,7 +210,7 @@ func (c *GRPCClient) do(ctx context.Context, methodName string, call func(ctx co } type statsHandler struct { - logger log.FieldLogger + loggerProvider func(ctx context.Context) log.FieldLogger } func (sh *statsHandler) TagRPC(ctx context.Context, info *stats.RPCTagInfo) context.Context { @@ -224,8 +227,8 @@ func (sh *statsHandler) TagConn(ctx context.Context, info *stats.ConnTagInfo) co func (sh *statsHandler) HandleConn(ctx context.Context, s stats.ConnStats) { switch s.(type) { case *stats.ConnBegin: - sh.logger.Infof("grpc connection established") + idputil.GetLoggerFromProvider(ctx, sh.loggerProvider).Infof("grpc connection established") case *stats.ConnEnd: - sh.logger.Infof("grpc connection closed") + idputil.GetLoggerFromProvider(ctx, sh.loggerProvider).Infof("grpc connection closed") } } diff --git a/idptoken/introspector.go b/idptoken/introspector.go index 20693e3..c64c23c 100644 --- a/idptoken/introspector.go +++ b/idptoken/introspector.go @@ -103,8 +103,8 @@ type IntrospectorOpts struct { // If it's set, then only access policies in scope that match at least one of the filtering policies will be returned. ScopeFilter IntrospectionScopeFilter - // Logger is a logger for logging errors and debug information. - Logger log.FieldLogger + // LoggerProvider is a function that provides a logger for the Introspector. + LoggerProvider func(ctx context.Context) log.FieldLogger // TrustedIssuerNotFoundFallback is a function called // when given issuer from JWT is not found in the list of trusted ones. @@ -155,7 +155,7 @@ type Introspector struct { scopeFilter IntrospectionScopeFilter scopeFilterFormURLEncoded string - logger log.FieldLogger + loggerProvider func(ctx context.Context) log.FieldLogger trustedIssuerStore *idputil.TrustedIssuerStore trustedIssuerNotFoundFallback TrustedIssNotFoundFallback @@ -205,9 +205,8 @@ func NewIntrospector(tokenProvider IntrospectionTokenProvider) (*Introspector, e // NewIntrospectorWithOpts creates a new Introspector with the given token provider and options. // See IntrospectorOpts for more details. func NewIntrospectorWithOpts(accessTokenProvider IntrospectionTokenProvider, opts IntrospectorOpts) (*Introspector, error) { - opts.Logger = idputil.PrepareLogger(opts.Logger) if opts.HTTPClient == nil { - opts.HTTPClient = idputil.MakeDefaultHTTPClient(idputil.DefaultHTTPRequestTimeout, opts.Logger) + opts.HTTPClient = idputil.MakeDefaultHTTPClient(idputil.DefaultHTTPRequestTimeout, opts.LoggerProvider) } values := url.Values{} @@ -256,7 +255,7 @@ func NewIntrospectorWithOpts(accessTokenProvider IntrospectionTokenProvider, opt accessTokenProvider: accessTokenProvider, accessTokenScope: opts.AccessTokenScope, jwtParser: jwtgo.NewParser(), - logger: opts.Logger, + loggerProvider: opts.LoggerProvider, GRPCClient: opts.GRPCClient, HTTPClient: opts.HTTPClient, httpEndpoint: opts.HTTPEndpoint, @@ -462,7 +461,8 @@ func (i *Introspector) makeIntrospectFuncHTTP(introspectionEndpointURL string) i } defer func() { if closeBodyErr := resp.Body.Close(); closeBodyErr != nil { - i.logger.Error(fmt.Sprintf("closing response body error for POST %s", introspectionEndpointURL), + idputil.GetLoggerFromProvider(ctx, i.loggerProvider).Error( + fmt.Sprintf("closing response body error for POST %s", introspectionEndpointURL), log.Error(closeBodyErr)) } }() @@ -502,9 +502,10 @@ func (i *Introspector) makeIntrospectFuncGRPC() introspectFunc { } func (i *Introspector) getWellKnownIntrospectionEndpointURL(ctx context.Context, issuerURL string) (string, error) { + logger := idputil.GetLoggerFromProvider(ctx, i.loggerProvider) openIDCfgURL := strings.TrimSuffix(issuerURL, "/") + wellKnownPath openIDCfg, err := idputil.GetOpenIDConfiguration( - ctx, i.HTTPClient, openIDCfgURL, nil, i.logger, i.promMetrics) + ctx, i.HTTPClient, openIDCfgURL, nil, logger, i.promMetrics) if err != nil { return "", fmt.Errorf("get OpenID configuration: %w", err) } diff --git a/idptoken/introspector_test.go b/idptoken/introspector_test.go index 4047d81..65868b3 100644 --- a/idptoken/introspector_test.go +++ b/idptoken/introspector_test.go @@ -12,7 +12,6 @@ import ( gotesting "testing" "time" - "github.com/acronis/go-appkit/log" jwtgo "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" "github.com/stretchr/testify/require" @@ -48,7 +47,7 @@ func TestIntrospector_IntrospectToken(t *gotesting.T) { require.NoError(t, err) defer func() { require.NoError(t, grpcClient.Close()) }() - jwtParser := jwt.NewParser(jwks.NewClient(), log.NewDisabledLogger()) + jwtParser := jwt.NewParser(jwks.NewClient()) require.NoError(t, jwtParser.AddTrustedIssuerURL(httpIDPSrv.URL())) httpServerIntrospector.JWTParser = jwtParser grpcServerIntrospector.JWTParser = jwtParser @@ -353,8 +352,7 @@ func TestCachingIntrospector_IntrospectTokenWithCache(t *gotesting.T) { require.NoError(t, idpSrv.StartAndWaitForReady(time.Second)) defer func() { _ = idpSrv.Shutdown(context.Background()) }() - logger := log.NewDisabledLogger() - jwtParser := jwt.NewParser(jwks.NewClient(), logger) + jwtParser := jwt.NewParser(jwks.NewClient()) require.NoError(t, jwtParser.AddTrustedIssuerURL(idpSrv.URL())) serverIntrospector.JWTParser = jwtParser diff --git a/idptoken/provider.go b/idptoken/provider.go index cbbc995..020d302 100644 --- a/idptoken/provider.go +++ b/idptoken/provider.go @@ -185,7 +185,8 @@ func NewMultiSourceProviderWithOpts(sources []Source, opts ProviderOpts) *MultiS p.cache = NewInMemoryTokenCache() } if p.httpClient == nil { - p.httpClient = idputil.MakeDefaultHTTPClient(idputil.DefaultHTTPRequestTimeout, p.logger) + p.httpClient = idputil.MakeDefaultHTTPClient(idputil.DefaultHTTPRequestTimeout, + func(_ context.Context) log.FieldLogger { return p.logger }) } for _, source := range sources { diff --git a/internal/idputil/idp_util.go b/internal/idputil/idp_util.go index ecfc1bf..e1dd772 100644 --- a/internal/idputil/idp_util.go +++ b/internal/idputil/idp_util.go @@ -7,7 +7,7 @@ Released under MIT license. package idputil import ( - "fmt" + "context" "net/http" "time" @@ -28,20 +28,34 @@ const ( DefaultHTTPRequestMaxRetryAttempts = 3 ) -func MakeDefaultHTTPClient(reqTimeout time.Duration, logger log.FieldLogger) *http.Client { +var DefaultLogger = log.NewDisabledLogger() + +func MakeDefaultHTTPClient(reqTimeout time.Duration, loggerProvider func(ctx context.Context) log.FieldLogger) *http.Client { if reqTimeout == 0 { reqTimeout = DefaultHTTPRequestTimeout } var tr http.RoundTripper = http.DefaultTransport.(*http.Transport).Clone() - tr, _ = httpclient.NewRetryableRoundTripperWithOpts(tr, httpclient.RetryableRoundTripperOpts{ - MaxRetryAttempts: DefaultHTTPRequestMaxRetryAttempts, Logger: logger}) // error is always nil - tr = httpclient.NewUserAgentRoundTripper(tr, libinfo.LibName+"/"+libinfo.GetLibVersion()) + retryableOpts := httpclient.RetryableRoundTripperOpts{ + MaxRetryAttempts: DefaultHTTPRequestMaxRetryAttempts, + LoggerProvider: loggerProvider, + } + tr, _ = httpclient.NewRetryableRoundTripperWithOpts(tr, retryableOpts) // error is always nil + tr = httpclient.NewUserAgentRoundTripper(tr, libinfo.UserAgent()) return &http.Client{Timeout: reqTimeout, Transport: tr} } func PrepareLogger(logger log.FieldLogger) log.FieldLogger { if logger == nil { - return log.NewDisabledLogger() + return DefaultLogger + } + return log.NewPrefixedLogger(logger, libinfo.LogPrefix()) +} + +func GetLoggerFromProvider(ctx context.Context, provider func(ctx context.Context) log.FieldLogger) log.FieldLogger { + if provider != nil { + if logger := provider(ctx); logger != nil { + return log.NewPrefixedLogger(logger, libinfo.LogPrefix()) + } } - return log.NewPrefixedLogger(logger, fmt.Sprintf("[%s/%s] ", libinfo.LibName, libinfo.GetLibVersion())) + return DefaultLogger } diff --git a/internal/libinfo/lib_info.go b/internal/libinfo/lib_info.go index ea8cd2e..74c65e8 100644 --- a/internal/libinfo/lib_info.go +++ b/internal/libinfo/lib_info.go @@ -5,3 +5,11 @@ Released under MIT license. */ package libinfo + +func UserAgent() string { + return LibName + "/" + GetLibVersion() +} + +func LogPrefix() string { + return "[" + LibName + "/" + GetLibVersion() + "] " +} diff --git a/jwks/client.go b/jwks/client.go index 69e3e0f..06a5506 100644 --- a/jwks/client.go +++ b/jwks/client.go @@ -34,8 +34,8 @@ type ClientOpts struct { // HTTPClient is an HTTP client for making requests. HTTPClient *http.Client - // Logger is a logger for the client. - Logger log.FieldLogger + // LoggerProvider is a function that provides a logger for the Client. + LoggerProvider func(ctx context.Context) log.FieldLogger // PrometheusLibInstanceLabel is a label for Prometheus metrics. // It allows distinguishing metrics from different instances of the same library. @@ -47,9 +47,9 @@ type ClientOpts struct { // NOTE: CachingClient should be used in a typical service // to avoid making HTTP requests on each JWT verification. type Client struct { - httpClient *http.Client - logger log.FieldLogger - promMetrics *metrics.PrometheusMetrics + httpClient *http.Client + loggerProvider func(ctx context.Context) log.FieldLogger + promMetrics *metrics.PrometheusMetrics } // NewClient returns a new Client. @@ -60,37 +60,38 @@ func NewClient() *Client { // NewClientWithOpts returns a new Client with options. func NewClientWithOpts(opts ClientOpts) *Client { promMetrics := metrics.GetPrometheusMetrics(opts.PrometheusLibInstanceLabel, "jwks_client") - opts.Logger = idputil.PrepareLogger(opts.Logger) if opts.HTTPClient == nil { - opts.HTTPClient = idputil.MakeDefaultHTTPClient(idputil.DefaultHTTPRequestTimeout, opts.Logger) + opts.HTTPClient = idputil.MakeDefaultHTTPClient(idputil.DefaultHTTPRequestTimeout, opts.LoggerProvider) } - return &Client{httpClient: opts.HTTPClient, logger: opts.Logger, promMetrics: promMetrics} + return &Client{httpClient: opts.HTTPClient, loggerProvider: opts.LoggerProvider, promMetrics: promMetrics} } func (c *Client) getRSAPubKeysForIssuer(ctx context.Context, issuerURL string) (map[string]interface{}, error) { + logger := idputil.GetLoggerFromProvider(ctx, c.loggerProvider) + openIDConfigURL := strings.TrimPrefix(issuerURL, "/") + OpenIDConfigurationPath openIDConfig, err := idputil.GetOpenIDConfiguration( - ctx, c.httpClient, openIDConfigURL, nil, c.logger, c.promMetrics) + ctx, c.httpClient, openIDConfigURL, nil, logger, c.promMetrics) if err != nil { return nil, &GetOpenIDConfigurationError{Inner: err, URL: openIDConfigURL} } - jwksRespData, err := c.getJWKS(ctx, openIDConfig.JWKSURI) + jwksRespData, err := c.getJWKS(ctx, openIDConfig.JWKSURI, logger) if err != nil { return nil, &GetJWKSError{Inner: err, URL: openIDConfig.JWKSURI, OpenIDConfigurationURL: openIDConfigURL} } - c.logger.Info(fmt.Sprintf("%d keys fetched (jwks_url: %s)", len(jwksRespData.Keys), openIDConfig.JWKSURI)) + logger.Info(fmt.Sprintf("%d keys fetched (jwks_url: %s)", len(jwksRespData.Keys), openIDConfig.JWKSURI)) pubKeys := make(map[string]interface{}, len(jwksRespData.Keys)) for _, jwk := range jwksRespData.Keys { var pubKey crypto.PublicKey if pubKey, err = jwk.DecodePublicKey(); err != nil { - c.logger.Error(fmt.Sprintf("decoding JWK (kid: %s, jwks_url: %s) to public key error", + logger.Error(fmt.Sprintf("decoding JWK (kid: %s, jwks_url: %s) to public key error", jwk.Kid, openIDConfig.JWKSURI), log.Error(err)) continue } rsaPubKey, ok := pubKey.(*rsa.PublicKey) if !ok { - c.logger.Error(fmt.Sprintf("converting JWK (kid: %s, jwks_url: %s) to RSA public key error", + logger.Error(fmt.Sprintf("converting JWK (kid: %s, jwks_url: %s) to RSA public key error", jwk.Kid, openIDConfig.JWKSURI), log.Error(err)) continue } @@ -112,7 +113,7 @@ func (c *Client) GetRSAPublicKey(ctx context.Context, issuerURL, keyID string) ( return pubKey, nil } -func (c *Client) getJWKS(ctx context.Context, jwksURL string) (jwksData, error) { +func (c *Client) getJWKS(ctx context.Context, jwksURL string, logger log.FieldLogger) (jwksData, error) { req, err := http.NewRequest(http.MethodGet, jwksURL, http.NoBody) if err != nil { return jwksData{}, fmt.Errorf("new request: %w", err) @@ -126,7 +127,7 @@ func (c *Client) getJWKS(ctx context.Context, jwksURL string) (jwksData, error) } defer func() { if closeBodyErr := resp.Body.Close(); closeBodyErr != nil { - c.logger.Error(fmt.Sprintf("closing response body error for GET %s", jwksURL), log.Error(closeBodyErr)) + logger.Error(fmt.Sprintf("closing response body error for GET %s", jwksURL), log.Error(closeBodyErr)) } }() diff --git a/jwt/caching_parser.go b/jwt/caching_parser.go index f0cfdc0..edc80c1 100644 --- a/jwt/caching_parser.go +++ b/jwt/caching_parser.go @@ -12,7 +12,6 @@ import ( "fmt" "unsafe" - "github.com/acronis/go-appkit/log" "github.com/acronis/go-appkit/lrucache" jwtgo "github.com/golang-jwt/jwt/v5" @@ -41,12 +40,12 @@ type CachingParser struct { ClaimsCache ClaimsCache } -func NewCachingParser(keysProvider KeysProvider, logger log.FieldLogger) (*CachingParser, error) { - return NewCachingParserWithOpts(keysProvider, logger, CachingParserOpts{}) +func NewCachingParser(keysProvider KeysProvider) (*CachingParser, error) { + return NewCachingParserWithOpts(keysProvider, CachingParserOpts{}) } func NewCachingParserWithOpts( - keysProvider KeysProvider, logger log.FieldLogger, opts CachingParserOpts, + keysProvider KeysProvider, opts CachingParserOpts, ) (*CachingParser, error) { promMetrics := metrics.GetPrometheusMetrics(opts.CachePrometheusInstanceLabel, "jwt_parser") if opts.CacheMaxEntries == 0 { @@ -57,7 +56,7 @@ func NewCachingParserWithOpts( return nil, err } return &CachingParser{ - Parser: NewParserWithOpts(keysProvider, logger, opts.ParserOpts), + Parser: NewParserWithOpts(keysProvider, opts.ParserOpts), ClaimsCache: cache, }, nil } diff --git a/jwt/caching_parser_test.go b/jwt/caching_parser_test.go index 8b50adb..e7648e6 100644 --- a/jwt/caching_parser_test.go +++ b/jwt/caching_parser_test.go @@ -13,7 +13,6 @@ import ( "testing" "time" - "github.com/acronis/go-appkit/log" jwtgo "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/require" @@ -43,7 +42,6 @@ func TestGetTokenHash(t *testing.T) { } func TestCachingParser_Parse(t *testing.T) { - logger := log.NewDisabledLogger() jwksServer := httptest.NewServer(&idptest.JWKSHandler{}) defer jwksServer.Close() @@ -53,7 +51,7 @@ func TestCachingParser_Parse(t *testing.T) { claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: testIss, ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute))}} tokenString := idptest.MustMakeTokenStringSignedWithTestKey(claims) - parser, err := jwt.NewCachingParser(jwks.NewCachingClient(), logger) + parser, err := jwt.NewCachingParser(jwks.NewCachingClient()) require.NoError(t, err) parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) @@ -78,7 +76,6 @@ func TestCachingParser_Parse(t *testing.T) { func TestCachingParser_CheckExpiration(t *testing.T) { const jwtTTL = 2 * time.Second - logger := log.NewDisabledLogger() jwksServer := httptest.NewServer(&idptest.JWKSHandler{}) defer jwksServer.Close() @@ -88,7 +85,7 @@ func TestCachingParser_CheckExpiration(t *testing.T) { claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: testIss, ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(jwtTTL))}} tokenString := idptest.MustMakeTokenStringSignedWithTestKey(claims) - parser, err := jwt.NewCachingParser(jwks.NewCachingClient(), logger) + parser, err := jwt.NewCachingParser(jwks.NewCachingClient()) require.NoError(t, err) parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) diff --git a/jwt/jwt.go b/jwt/parser.go similarity index 94% rename from jwt/jwt.go rename to jwt/parser.go index f14469f..6177980 100644 --- a/jwt/jwt.go +++ b/jwt/parser.go @@ -36,6 +36,7 @@ type ParserOpts struct { RequireAudience bool ExpectedAudience []string TrustedIssuerNotFoundFallback TrustedIssNotFoundFallback + LoggerProvider func(ctx context.Context) log.FieldLogger } type audienceMatcher func(aud string) bool @@ -55,16 +56,16 @@ type Parser struct { trustedIssuerStore *idputil.TrustedIssuerStore trustedIssuerNotFoundFallback TrustedIssNotFoundFallback - logger log.FieldLogger + loggerProvider func(ctx context.Context) log.FieldLogger } // NewParser creates new JWT parser with specified keys provider. -func NewParser(keysProvider KeysProvider, logger log.FieldLogger) *Parser { - return NewParserWithOpts(keysProvider, logger, ParserOpts{}) +func NewParser(keysProvider KeysProvider) *Parser { + return NewParserWithOpts(keysProvider, ParserOpts{}) } // NewParserWithOpts creates new JWT parser with specified keys provider and additional options. -func NewParserWithOpts(keysProvider KeysProvider, logger log.FieldLogger, opts ParserOpts) *Parser { +func NewParserWithOpts(keysProvider KeysProvider, opts ParserOpts) *Parser { var audienceMatchers []audienceMatcher for _, audPattern := range opts.ExpectedAudience { audienceMatchers = append(audienceMatchers, glob.Compile(audPattern)) @@ -81,7 +82,7 @@ func NewParserWithOpts(keysProvider KeysProvider, logger log.FieldLogger, opts P keysProvider: keysProvider, trustedIssuerStore: idputil.NewTrustedIssuerStore(), trustedIssuerNotFoundFallback: opts.TrustedIssuerNotFoundFallback, - logger: logger, + loggerProvider: opts.LoggerProvider, } } @@ -120,7 +121,8 @@ func (p *Parser) Parse(ctx context.Context, token string) (*Claims, error) { return nil, err } if err = cachingKeysProvider.InvalidateCacheIfNeeded(ctx, issuerURL); err != nil { - p.logger.Error(fmt.Sprintf("keys provider invalidating cache error for issuer %q", issuerURL), + idputil.GetLoggerFromProvider(ctx, p.loggerProvider).Error( + fmt.Sprintf("keys provider invalidating cache error for issuer %q", issuerURL), log.Error(err)) return nil, err } diff --git a/jwt/jwt_test.go b/jwt/parser_test.go similarity index 91% rename from jwt/jwt_test.go rename to jwt/parser_test.go index a539890..fdf2a59 100644 --- a/jwt/jwt_test.go +++ b/jwt/parser_test.go @@ -12,7 +12,6 @@ import ( "testing" "time" - "github.com/acronis/go-appkit/log" jwtgo "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/require" @@ -30,8 +29,6 @@ func TestJWTParser_Parse(t *testing.T) { issuerConfigServer := httptest.NewServer(&idptest.OpenIDConfigurationHandler{JWKSURL: jwksServer.URL}) defer issuerConfigServer.Close() - logger := log.NewDisabledLogger() - t.Run("ok", func(t *testing.T) { claims := &jwt.Claims{ RegisteredClaims: jwtgo.RegisteredClaims{ @@ -42,7 +39,7 @@ func TestJWTParser_Parse(t *testing.T) { TOTPTime: time.Now().Unix(), SubType: "task_manager", } - parser := jwt.NewParser(jwks.NewCachingClient(), logger) + parser := jwt.NewParser(jwks.NewCachingClient()) parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) parsedClaims, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) require.NoError(t, err) @@ -58,7 +55,7 @@ func TestJWTParser_Parse(t *testing.T) { ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute)), }, } - parser := jwt.NewParser(jwks.NewCachingClient(), logger) + parser := jwt.NewParser(jwks.NewCachingClient()) parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) parsedClaims, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) require.NoError(t, err) @@ -79,7 +76,7 @@ func TestJWTParser_Parse(t *testing.T) { "http://127.*", } for _, issURL := range issURLs { - parser := jwt.NewParser(jwks.NewCachingClient(), logger) + parser := jwt.NewParser(jwks.NewCachingClient()) require.NoError(t, parser.AddTrustedIssuerURL(issURL)) parsedClaims, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) require.NoError(t, err) @@ -97,7 +94,7 @@ func TestJWTParser_Parse(t *testing.T) { }, Scope: []jwt.AccessPolicy{{Role: "company_admin"}}, } - parser := jwt.NewParserWithOpts(jwks.NewCachingClient(), logger, jwt.ParserOpts{ + parser := jwt.NewParserWithOpts(jwks.NewCachingClient(), jwt.ParserOpts{ ExpectedAudience: []string{"*.cloud.com"}, }) require.NoError(t, parser.AddTrustedIssuerURL(issuerConfigServer.URL)) @@ -108,7 +105,7 @@ func TestJWTParser_Parse(t *testing.T) { }) t.Run("malformed jwt", func(t *testing.T) { - parser := jwt.NewParser(jwks.NewCachingClient(), logger) + parser := jwt.NewParser(jwks.NewCachingClient()) _, err := parser.Parse(context.Background(), "invalid-jwt") require.ErrorIs(t, err, jwtgo.ErrTokenMalformed) require.ErrorContains(t, err, "token contains an invalid number of segments") @@ -126,7 +123,7 @@ func TestJWTParser_Parse(t *testing.T) { tokenString, err := token.SignedString(jwtgo.UnsafeAllowNoneSignatureType) require.NoError(t, err) - parser := jwt.NewParser(jwks.NewCachingClient(), logger) + parser := jwt.NewParser(jwks.NewCachingClient()) _, err = parser.Parse(context.Background(), tokenString) require.ErrorIs(t, err, jwtgo.NoneSignatureTypeDisallowedError) }) @@ -135,7 +132,7 @@ func TestJWTParser_Parse(t *testing.T) { claims := &jwt.Claims{ RegisteredClaims: jwtgo.RegisteredClaims{Audience: []string{"https://cloud.acronis.com"}}, } - parser := jwt.NewParser(jwks.NewCachingClient(), logger) + parser := jwt.NewParser(jwks.NewCachingClient()) _, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) require.ErrorIs(t, err, jwtgo.ErrTokenUnverifiable) var issMissingErr *jwt.IssuerMissingError @@ -146,7 +143,7 @@ func TestJWTParser_Parse(t *testing.T) { t.Run("jwt has untrusted issuer", func(t *testing.T) { const issuer = "untrusted-issuer" claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: issuer}} - parser := jwt.NewParser(jwks.NewCachingClient(), logger) + parser := jwt.NewParser(jwks.NewCachingClient()) _, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) require.ErrorIs(t, err, jwtgo.ErrTokenUnverifiable) var issUntrustedErr *jwt.IssuerUntrustedError @@ -157,7 +154,7 @@ func TestJWTParser_Parse(t *testing.T) { t.Run("jwt has untrusted issuer url", func(t *testing.T) { const issuer = "https://3rd-party-idp.com" claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: issuer}} - parser := jwt.NewParser(jwks.NewCachingClient(), logger) + parser := jwt.NewParser(jwks.NewCachingClient()) require.NoError(t, parser.AddTrustedIssuerURL("https://*.acronis.com")) _, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) require.ErrorIs(t, err, jwtgo.ErrTokenUnverifiable) @@ -176,7 +173,7 @@ func TestJWTParser_Parse(t *testing.T) { }, Scope: []jwt.AccessPolicy{{Role: "company_admin"}}, } - parser := jwt.NewParserWithOpts(jwks.NewCachingClient(), logger, jwt.ParserOpts{ + parser := jwt.NewParserWithOpts(jwks.NewCachingClient(), jwt.ParserOpts{ TrustedIssuerNotFoundFallback: func(ctx context.Context, p *jwt.Parser, iss string) (issURL string, issFound bool) { callbackCallCount++ addErr := p.AddTrustedIssuerURL(iss) @@ -199,7 +196,7 @@ func TestJWTParser_Parse(t *testing.T) { t.Run("jwt exp is missing", func(t *testing.T) { claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: testIss}} - parser := jwt.NewParser(jwks.NewCachingClient(), logger) + parser := jwt.NewParser(jwks.NewCachingClient()) parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) _, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) require.ErrorIs(t, err, jwtgo.ErrTokenInvalidClaims) @@ -209,7 +206,7 @@ func TestJWTParser_Parse(t *testing.T) { t.Run("jwt expired", func(t *testing.T) { expiresAt := time.Now().Add(-time.Second) claims := &jwt.Claims{RegisteredClaims: jwtgo.RegisteredClaims{Issuer: testIss, ExpiresAt: jwtgo.NewNumericDate(expiresAt)}} - parser := jwt.NewParser(jwks.NewCachingClient(), logger) + parser := jwt.NewParser(jwks.NewCachingClient()) parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) _, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) require.ErrorIs(t, err, jwtgo.ErrTokenInvalidClaims) @@ -223,7 +220,7 @@ func TestJWTParser_Parse(t *testing.T) { ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Hour)), NotBefore: jwtgo.NewNumericDate(notBefore), }} - parser := jwt.NewParser(jwks.NewCachingClient(), logger) + parser := jwt.NewParser(jwks.NewCachingClient()) parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) _, err := parser.Parse(context.Background(), idptest.MustMakeTokenStringSignedWithTestKey(claims)) require.ErrorIs(t, err, jwtgo.ErrTokenInvalidClaims) @@ -235,7 +232,7 @@ func TestJWTParser_Parse(t *testing.T) { Issuer: testIss, ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute)), }} - parser := jwt.NewParserWithOpts(jwks.NewCachingClient(), logger, jwt.ParserOpts{ + parser := jwt.NewParserWithOpts(jwks.NewCachingClient(), jwt.ParserOpts{ RequireAudience: true, }) parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) @@ -254,7 +251,7 @@ func TestJWTParser_Parse(t *testing.T) { Issuer: testIss, ExpiresAt: jwtgo.NewNumericDate(time.Now().Add(time.Minute)), }} - parser := jwt.NewParserWithOpts(jwks.NewCachingClient(), logger, jwt.ParserOpts{ + parser := jwt.NewParserWithOpts(jwks.NewCachingClient(), jwt.ParserOpts{ ExpectedAudience: []string{"expected-audience"}, }) parser.AddTrustedIssuer(testIss, issuerConfigServer.URL) @@ -280,7 +277,7 @@ func TestJWTParser_Parse(t *testing.T) { tokenString, err := idptest.MakeTokenString(claims, "737c5114f09b5ed05276bd4b520245982f7fb29f", idptest.GetTestRSAPrivateKey()) require.NoError(t, err) jwksClient := jwks.NewCachingClientWithOpts(jwks.CachingClientOpts{CacheUpdateMinInterval: cacheUpdateMinInterval}) - parser := jwt.NewParser(jwksClient, logger) + parser := jwt.NewParser(jwksClient) parser.AddTrustedIssuer(testIss, openIDCfgServer2.URL) for i := 0; i < 2; i++ { @@ -335,8 +332,7 @@ func TestParser_getURLForIssuer(t *testing.T) { for i := range tests { tt := tests[i] t.Run(tt.Name, func(t *testing.T) { - logger := log.NewDisabledLogger() - parser := jwt.NewParser(jwks.NewCachingClient(), logger) + parser := jwt.NewParser(jwks.NewCachingClient()) require.NoError(t, parser.AddTrustedIssuerURL(tt.IssURLPattern)) for _, issURL := range tt.TrustedIssURLs { u, ok := parser.GetURLForIssuer(issURL) diff --git a/middleware.go b/middleware.go index 7964d23..6a8a1c2 100644 --- a/middleware.go +++ b/middleware.go @@ -17,6 +17,7 @@ import ( "github.com/acronis/go-appkit/restapi" "github.com/acronis/go-authkit/idptoken" + "github.com/acronis/go-authkit/internal/idputil" "github.com/acronis/go-authkit/jwt" ) @@ -68,11 +69,13 @@ type jwtAuthHandler struct { jwtParser JWTParser verifyAccess func(r *http.Request, claims *jwt.Claims) bool tokenIntrospector TokenIntrospector + loggerProvider func(ctx context.Context) log.FieldLogger } type jwtAuthMiddlewareOpts struct { verifyAccess func(r *http.Request, claims *jwt.Claims) bool tokenIntrospector TokenIntrospector + loggerProvider func(ctx context.Context) log.FieldLogger } // JWTAuthMiddlewareOption is an option for JWTAuthMiddleware. @@ -92,6 +95,13 @@ func WithJWTAuthMiddlewareTokenIntrospector(tokenIntrospector TokenIntrospector) } } +// WithJWTAuthMiddlewareLoggerProvider is an option to set a logger provider for JWTAuthMiddleware. +func WithJWTAuthMiddlewareLoggerProvider(loggerProvider func(ctx context.Context) log.FieldLogger) JWTAuthMiddlewareOption { + return func(options *jwtAuthMiddlewareOpts) { + options.loggerProvider = loggerProvider + } +} + // JWTAuthMiddleware is a middleware that does authentication // by Access Token from the "Authorization" HTTP header of incoming request. // errorDomain is used for error responses. It is usually the name of the service that uses the middleware, @@ -101,23 +111,29 @@ func WithJWTAuthMiddlewareTokenIntrospector(tokenIntrospector TokenIntrospector) // // {"error": {"domain": "MyService", "code": "bearerTokenMissing", "message": "Authorization bearer token is missing."}} func JWTAuthMiddleware(errorDomain string, jwtParser JWTParser, opts ...JWTAuthMiddlewareOption) func(next http.Handler) http.Handler { - var options jwtAuthMiddlewareOpts + options := jwtAuthMiddlewareOpts{loggerProvider: middleware.GetLoggerFromContext} for _, opt := range opts { opt(&options) } return func(next http.Handler) http.Handler { - return &jwtAuthHandler{next, errorDomain, jwtParser, options.verifyAccess, options.tokenIntrospector} + return &jwtAuthHandler{ + next: next, + errorDomain: errorDomain, + jwtParser: jwtParser, + verifyAccess: options.verifyAccess, + tokenIntrospector: options.tokenIntrospector, + loggerProvider: options.loggerProvider, + } } } func (h *jwtAuthHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { reqCtx := r.Context() - logger := middleware.GetLoggerFromContext(reqCtx) bearerToken := GetBearerTokenFromRequest(r) if bearerToken == "" { apiErr := restapi.NewError(h.errorDomain, ErrCodeBearerTokenMissing, ErrMessageBearerTokenMissing) - restapi.RespondError(rw, http.StatusUnauthorized, apiErr, logger) + restapi.RespondError(rw, http.StatusUnauthorized, apiErr, h.logger(reqCtx)) return } @@ -127,37 +143,40 @@ func (h *jwtAuthHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { switch { case errors.Is(err, idptoken.ErrTokenIntrospectionNotNeeded): // Do nothing. Access Token already contains all necessary information for authN/authZ. + h.logger(reqCtx).AtLevel(log.LevelDebug, func(logFunc log.LogFunc) { + logFunc("token's introspection is not needed") + }) case errors.Is(err, idptoken.ErrTokenNotIntrospectable): // Token is not introspectable by some reason. // In this case, we will parse it as JWT and use it for authZ. - if logger != nil { - logger.Warn("token is not introspectable, it will be used for authentication and authorization as is", - log.Error(err)) - } + h.logger(reqCtx).Warn("token is not introspectable, it will be used for authentication and authorization as is", + log.Error(err)) default: - if logger != nil { - logger.Error("token introspection failed", log.Error(err)) - } + logger := h.logger(reqCtx) + logger.Error("token's introspection failed", log.Error(err)) apiErr := restapi.NewError(h.errorDomain, ErrCodeAuthenticationFailed, ErrMessageAuthenticationFailed) restapi.RespondError(rw, http.StatusUnauthorized, apiErr, logger) return } } else { if !introspectionResult.Active { + h.logger(reqCtx).Warn("token was successfully introspected, but it is not active") apiErr := restapi.NewError(h.errorDomain, ErrCodeAuthenticationFailed, ErrMessageAuthenticationFailed) - restapi.RespondError(rw, http.StatusUnauthorized, apiErr, logger) + restapi.RespondError(rw, http.StatusUnauthorized, apiErr, h.logger(reqCtx)) return } jwtClaims = &introspectionResult.Claims + h.logger(reqCtx).AtLevel(log.LevelDebug, func(logFunc log.LogFunc) { + logFunc("token was successfully introspected") + }) } } if jwtClaims == nil { var err error if jwtClaims, err = h.jwtParser.Parse(reqCtx, bearerToken); err != nil { - if logger != nil { - logger.Error("authentication failed", log.Error(err)) - } + logger := h.logger(reqCtx) + logger.Error("authentication failed", log.Error(err)) apiErr := restapi.NewError(h.errorDomain, ErrCodeAuthenticationFailed, ErrMessageAuthenticationFailed) restapi.RespondError(rw, http.StatusUnauthorized, apiErr, logger) return @@ -167,7 +186,7 @@ func (h *jwtAuthHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { if h.verifyAccess != nil { if !h.verifyAccess(r, jwtClaims) { apiErr := restapi.NewError(h.errorDomain, ErrCodeAuthorizationFailed, ErrMessageAuthorizationFailed) - restapi.RespondError(rw, http.StatusForbidden, apiErr, logger) + restapi.RespondError(rw, http.StatusForbidden, apiErr, h.logger(reqCtx)) return } } @@ -177,6 +196,10 @@ func (h *jwtAuthHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { h.next.ServeHTTP(rw, r.WithContext(reqCtx)) } +func (h *jwtAuthHandler) logger(ctx context.Context) log.FieldLogger { + return idputil.GetLoggerFromProvider(ctx, h.loggerProvider) +} + // GetBearerTokenFromRequest extracts jwt token from request headers. func GetBearerTokenFromRequest(r *http.Request) string { authHeader := strings.TrimSpace(r.Header.Get(HeaderAuthorization))