diff --git a/examples/simplehttp/config.go b/examples/simplehttp/config.go index 5e9d920..2c51088 100644 --- a/examples/simplehttp/config.go +++ b/examples/simplehttp/config.go @@ -1,6 +1,10 @@ package main -import "github.com/XDoubleU/essentia/pkg/config" +import ( + "log/slog" + + "github.com/XDoubleU/essentia/pkg/config" +) type Config struct { Env string @@ -9,13 +13,15 @@ type Config struct { AllowedOrigins []string } -func NewConfig() Config { +func NewConfig(logger *slog.Logger) Config { + c := config.New(logger) + var cfg Config - cfg.Env = config.EnvStr("ENV", config.ProdEnv) - cfg.Port = config.EnvInt("PORT", 8000) - cfg.DBDsn = config.EnvStr("DB_DSN", "postgres://postgres@localhost/postgres") - cfg.AllowedOrigins = config.EnvStrArray( + cfg.Env = c.EnvStr("ENV", config.ProdEnv) + cfg.Port = c.EnvInt("PORT", 8000) + cfg.DBDsn = c.EnvStr("DB_DSN", "postgres://postgres@localhost/postgres") + cfg.AllowedOrigins = c.EnvStrArray( "ALLOWED_ORIGINS", []string{"http://localhost"}, ) diff --git a/examples/simplehttp/main.go b/examples/simplehttp/main.go index 45851fa..b0382ed 100644 --- a/examples/simplehttp/main.go +++ b/examples/simplehttp/main.go @@ -30,11 +30,12 @@ func NewApp(logger *slog.Logger, config Config, db postgres.DB) application { } func main() { - cfg := NewConfig() + cfg := NewConfig(slog.New(slog.NewTextHandler(os.Stdout, nil))) logger := slog.New( sentrytools.NewLogHandler(cfg.Env, slog.NewTextHandler(os.Stdout, nil)), ) + db, err := postgres.Connect( logger, cfg.DBDsn, diff --git a/examples/simplehttp/main_test.go b/examples/simplehttp/main_test.go index 320a64c..16dce60 100644 --- a/examples/simplehttp/main_test.go +++ b/examples/simplehttp/main_test.go @@ -15,11 +15,11 @@ import ( ) func TestHealth(t *testing.T) { - cfg := NewConfig() - cfg.Env = config.TestEnv - logger := logging.NewNopLogger() + cfg := NewConfig(logger) + cfg.Env = config.TestEnv + db, err := postgres.Connect( logger, cfg.DBDsn, diff --git a/examples/simplewebsocket/config.go b/examples/simplewebsocket/config.go index fa3441e..e62e2f4 100644 --- a/examples/simplewebsocket/config.go +++ b/examples/simplewebsocket/config.go @@ -1,6 +1,10 @@ package main -import "github.com/XDoubleU/essentia/pkg/config" +import ( + "log/slog" + + "github.com/XDoubleU/essentia/pkg/config" +) type Config struct { Env string @@ -8,12 +12,14 @@ type Config struct { AllowedOrigins []string } -func NewConfig() Config { +func NewConfig(logger *slog.Logger) Config { + c := config.New(logger) + var cfg Config - cfg.Env = config.EnvStr("ENV", config.ProdEnv) - cfg.Port = config.EnvInt("PORT", 8000) - cfg.AllowedOrigins = config.EnvStrArray( + cfg.Env = c.EnvStr("ENV", config.ProdEnv) + cfg.Port = c.EnvInt("PORT", 8000) + cfg.AllowedOrigins = c.EnvStrArray( "ALLOWED_ORIGINS", []string{"http://localhost"}, ) diff --git a/examples/simplewebsocket/main.go b/examples/simplewebsocket/main.go index a2fcf1a..7a1fcad 100644 --- a/examples/simplewebsocket/main.go +++ b/examples/simplewebsocket/main.go @@ -25,7 +25,7 @@ func NewApp(logger *slog.Logger, config Config) application { } func main() { - cfg := NewConfig() + cfg := NewConfig(slog.New(slog.NewTextHandler(os.Stdout, nil))) logger := slog.New( sentrytools.NewLogHandler(cfg.Env, slog.NewTextHandler(os.Stdout, nil)), diff --git a/examples/simplewebsocket/main_test.go b/examples/simplewebsocket/main_test.go index 3daa8cd..7167674 100644 --- a/examples/simplewebsocket/main_test.go +++ b/examples/simplewebsocket/main_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/XDoubleU/essentia/pkg/config" + "github.com/XDoubleU/essentia/pkg/logging" sentrytools "github.com/XDoubleU/essentia/pkg/sentry" "github.com/XDoubleU/essentia/pkg/test" "github.com/stretchr/testify/assert" @@ -13,7 +14,7 @@ import ( ) func TestWebSocket(t *testing.T) { - cfg := NewConfig() + cfg := NewConfig(logging.NewNopLogger()) cfg.Env = config.TestEnv logger := slog.New( diff --git a/examples/simplewebsocket/websockethandler.go b/examples/simplewebsocket/websockethandler.go index ad7e478..2c24526 100644 --- a/examples/simplewebsocket/websockethandler.go +++ b/examples/simplewebsocket/websockethandler.go @@ -16,8 +16,9 @@ type ResponseMessageDto struct { Message string `json:"message"` } -func (msg SubscribeMessageDto) Validate() *validate.Validator { - return validate.New() +func (msg SubscribeMessageDto) Validate() (bool, map[string]string) { + v := validate.New() + return v.Valid(), v.Errors() } func (msg SubscribeMessageDto) Topic() string { @@ -34,6 +35,7 @@ func (app *application) websocketRoutes(mux *http.ServeMux) { func (app *application) getWebSocketHandler() http.HandlerFunc { wsHandler := wstools.CreateWebSocketHandler[SubscribeMessageDto]( + app.logger, 1, 10, ) diff --git a/internal/shared/any_to_string.go b/internal/shared/any_to_string.go index 31c9342..72db431 100644 --- a/internal/shared/any_to_string.go +++ b/internal/shared/any_to_string.go @@ -3,7 +3,6 @@ package shared import ( "errors" "fmt" - "strconv" ) func arrayToString[T any](array []T) (string, error) { @@ -27,16 +26,24 @@ func AnyToString(value any) (string, error) { switch value := value.(type) { case string: return value, nil - case int: - return strconv.Itoa(value), nil - case int64: - return strconv.FormatInt(value, 10), nil + case bool: + return fmt.Sprintf("%t", value), nil + case int, int64: + return fmt.Sprintf("%d", value), nil + case float32, float64: + return fmt.Sprintf("%.2f", value), nil case []string: return arrayToString(value) + case []bool: + return arrayToString(value) case []int: return arrayToString(value) case []int64: return arrayToString(value) + case []float32: + return arrayToString(value) + case []float64: + return arrayToString(value) default: return "", errors.New("undefined type") } diff --git a/internal/wsinternal/worker.go b/internal/wsinternal/worker.go index 25e02b3..7a2e913 100644 --- a/internal/wsinternal/worker.go +++ b/internal/wsinternal/worker.go @@ -2,6 +2,7 @@ package wsinternal import ( "context" + "log/slog" "math" "sync" ) @@ -50,7 +51,7 @@ func (worker *Worker) EnqueueEvent(event any) { } // Start makes [Worker] start doing work. -func (worker *Worker) Start(_ context.Context) error { +func (worker *Worker) Start(_ context.Context, _ *slog.Logger) error { // already active if worker.Active() { return nil diff --git a/internal/wsinternal/worker_pool.go b/internal/wsinternal/worker_pool.go index fd28c5a..5a6dca5 100644 --- a/internal/wsinternal/worker_pool.go +++ b/internal/wsinternal/worker_pool.go @@ -3,6 +3,7 @@ package wsinternal import ( "context" "fmt" + "log/slog" "sync" "github.com/XDoubleU/essentia/pkg/sentry" @@ -19,14 +20,20 @@ const stopEvent = "stop" // WorkerPool is used to divide [Subscriber]s between [Worker]s. // This prevents one [Worker] of being very busy. type WorkerPool struct { + logger *slog.Logger subscribers []Subscriber subscribersMu *sync.RWMutex workers []Worker } // NewWorkerPool creates a new [WorkerPool]. -func NewWorkerPool(maxWorkers int, channelBufferSize int) *WorkerPool { +func NewWorkerPool( + logger *slog.Logger, + maxWorkers int, + channelBufferSize int, +) *WorkerPool { pool := &WorkerPool{ + logger: logger, subscribers: []Subscriber{}, subscribersMu: &sync.RWMutex{}, workers: make([]Worker, maxWorkers), @@ -80,8 +87,9 @@ func (pool *WorkerPool) RemoveSubscriber(sub Subscriber) { // Start starts [Worker]s of a [WorkerPool] if they weren't active yet. func (pool *WorkerPool) Start() { for i := range pool.workers { - go sentry.GoRoutineErrorHandler( + go sentry.GoRoutineWrapper( context.Background(), + pool.logger, fmt.Sprintf("Worker %d", i), pool.workers[i].Start, ) diff --git a/internal/wsinternal/worker_pool_test.go b/internal/wsinternal/worker_pool_test.go index e4be892..dcae89a 100644 --- a/internal/wsinternal/worker_pool_test.go +++ b/internal/wsinternal/worker_pool_test.go @@ -6,6 +6,7 @@ import ( "time" "github.com/XDoubleU/essentia/internal/wsinternal" + "github.com/XDoubleU/essentia/pkg/logging" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -52,7 +53,9 @@ func (sub *TestSubscriber) Output() string { const sleep = 100 * time.Millisecond func TestBasic(t *testing.T) { - wp := wsinternal.NewWorkerPool(1, 10) + logger := logging.NewNopLogger() + + wp := wsinternal.NewWorkerPool(logger, 1, 10) tSub := NewTestSubscriber() wp.AddSubscriber(tSub) @@ -81,7 +84,9 @@ func TestBasic(t *testing.T) { } func TestMoreWorkersThanSubs(t *testing.T) { - wp := wsinternal.NewWorkerPool(2, 10) + logger := logging.NewNopLogger() + + wp := wsinternal.NewWorkerPool(logger, 2, 10) tSub := NewTestSubscriber() wp.AddSubscriber(tSub) @@ -100,7 +105,9 @@ func TestMoreWorkersThanSubs(t *testing.T) { } func TestAddRemoveSubscriberWhileWorkersActive(t *testing.T) { - wp := wsinternal.NewWorkerPool(2, 10) + logger := logging.NewNopLogger() + + wp := wsinternal.NewWorkerPool(logger, 2, 10) tSub := NewTestSubscriber() wp.AddSubscriber(tSub) @@ -146,7 +153,9 @@ func work(t *testing.T, wp *wsinternal.WorkerPool, nr int) { } func TestToggleWork(t *testing.T) { - wp := wsinternal.NewWorkerPool(1, 10) + logger := logging.NewNopLogger() + + wp := wsinternal.NewWorkerPool(logger, 1, 10) work(t, wp, 1) work(t, wp, 2) diff --git a/pkg/communication/ws/errors_test.go b/pkg/communication/ws/errors_test.go index a43adcd..2f68bbd 100644 --- a/pkg/communication/ws/errors_test.go +++ b/pkg/communication/ws/errors_test.go @@ -9,6 +9,7 @@ import ( wstools "github.com/XDoubleU/essentia/pkg/communication/ws" "github.com/XDoubleU/essentia/pkg/config" errortools "github.com/XDoubleU/essentia/pkg/errors" + "github.com/XDoubleU/essentia/pkg/logging" sentrytools "github.com/XDoubleU/essentia/pkg/sentry" "github.com/XDoubleU/essentia/pkg/test" "github.com/stretchr/testify/assert" @@ -35,7 +36,10 @@ func testErrorStatusCode(t *testing.T, handler http.HandlerFunc) int { func setupWS(t *testing.T, allowedOrigin string) http.Handler { t.Helper() + logger := logging.NewNopLogger() + wsHandler := wstools.CreateWebSocketHandler[TestSubscribeMsg]( + logger, 1, 10, ) diff --git a/pkg/communication/ws/topic.go b/pkg/communication/ws/topic.go index 7312b32..0dc384d 100644 --- a/pkg/communication/ws/topic.go +++ b/pkg/communication/ws/topic.go @@ -2,6 +2,7 @@ package ws import ( "context" + "log/slog" "strings" "github.com/XDoubleU/essentia/internal/wsinternal" @@ -23,6 +24,7 @@ type Topic struct { // NewTopic creates a new [Topic]. func NewTopic( + logger *slog.Logger, name string, allowedOrigins []string, maxWorkers int, @@ -39,6 +41,7 @@ func NewTopic( Name: name, allowedOrigins: allowedOrigins, pool: wsinternal.NewWorkerPool( + logger, maxWorkers, channelBufferSize, ), diff --git a/pkg/communication/ws/websocket.go b/pkg/communication/ws/websocket.go index 82837d2..547e062 100644 --- a/pkg/communication/ws/websocket.go +++ b/pkg/communication/ws/websocket.go @@ -2,6 +2,7 @@ package ws import ( "fmt" + "log/slog" "net/http" "net/url" "path/filepath" @@ -22,6 +23,7 @@ type SubscribeMessageDto interface { // A WebSocketHandler handles incoming requests to a // websocket and makes sure subscriptions are made to the right topics. type WebSocketHandler[T SubscribeMessageDto] struct { + logger *slog.Logger maxTopicWorkers int topicChannelBufferSize int topicMap map[string]*Topic @@ -29,10 +31,12 @@ type WebSocketHandler[T SubscribeMessageDto] struct { // CreateWebSocketHandler creates a new [WebSocketHandler]. func CreateWebSocketHandler[T SubscribeMessageDto]( + logger *slog.Logger, maxTopicWorkers int, topicChannelBufferSize int, ) WebSocketHandler[T] { return WebSocketHandler[T]{ + logger: logger, maxTopicWorkers: maxTopicWorkers, topicChannelBufferSize: topicChannelBufferSize, topicMap: make(map[string]*Topic), @@ -53,6 +57,7 @@ func (h *WebSocketHandler[T]) AddTopic( } topic := NewTopic( + h.logger, topicName, allowedOrigins, h.maxTopicWorkers, @@ -119,8 +124,8 @@ func (h WebSocketHandler[T]) Handler() http.HandlerFunc { return } - if v := msg.Validate(); !v.Valid() { - FailedValidationResponse(r.Context(), conn, v.Errors) + if valid, errors := msg.Validate(); !valid { + FailedValidationResponse(r.Context(), conn, errors) return } diff --git a/pkg/communication/ws/websocket_test.go b/pkg/communication/ws/websocket_test.go index 48486bb..cd49579 100644 --- a/pkg/communication/ws/websocket_test.go +++ b/pkg/communication/ws/websocket_test.go @@ -7,6 +7,7 @@ import ( wstools "github.com/XDoubleU/essentia/pkg/communication/ws" errortools "github.com/XDoubleU/essentia/pkg/errors" + "github.com/XDoubleU/essentia/pkg/logging" "github.com/XDoubleU/essentia/pkg/test" "github.com/XDoubleU/essentia/pkg/validate" "github.com/stretchr/testify/assert" @@ -21,12 +22,12 @@ type TestSubscribeMsg struct { TopicName string `json:"topicName"` } -func (s TestSubscribeMsg) Validate() *validate.Validator { +func (s TestSubscribeMsg) Validate() (bool, map[string]string) { v := validate.New() - validate.Check(v, s.TopicName, validate.IsNotEmpty, "topicName") + validate.Check(v, "topicName", s.TopicName, validate.IsNotEmpty) - return v + return v.Valid(), v.Errors() } func (s TestSubscribeMsg) Topic() string { @@ -36,7 +37,10 @@ func (s TestSubscribeMsg) Topic() string { func setup(t *testing.T) http.Handler { t.Helper() + logger := logging.NewNopLogger() + ws := wstools.CreateWebSocketHandler[TestSubscribeMsg]( + logger, 1, 10, ) @@ -81,7 +85,9 @@ func TestWebSocketUnknownTopic(t *testing.T) { } func TestWebSocketExistingHandler(t *testing.T) { - ws := wstools.CreateWebSocketHandler[TestSubscribeMsg](1, 10) + logger := logging.NewNopLogger() + + ws := wstools.CreateWebSocketHandler[TestSubscribeMsg](logger, 1, 10) topic, err := ws.AddTopic( "exists", []string{"http://localhost"}, @@ -100,7 +106,9 @@ func TestWebSocketExistingHandler(t *testing.T) { } func TestWebsocketBasic(t *testing.T) { - ws := wstools.CreateWebSocketHandler[TestSubscribeMsg](1, 10) + logger := logging.NewNopLogger() + + ws := wstools.CreateWebSocketHandler[TestSubscribeMsg](logger, 1, 10) topic, err := ws.AddTopic( "exists", []string{"http://localhost"}, @@ -139,7 +147,9 @@ func TestWebsocketBasic(t *testing.T) { } func TestWebSocketUpdateExistingTopic(t *testing.T) { - ws := wstools.CreateWebSocketHandler[TestSubscribeMsg](1, 10) + logger := logging.NewNopLogger() + + ws := wstools.CreateWebSocketHandler[TestSubscribeMsg](logger, 1, 10) topic, err := ws.AddTopic( "exists", []string{"http://localhost"}, @@ -156,7 +166,9 @@ func TestWebSocketUpdateExistingTopic(t *testing.T) { } func TestWebSocketUpdateNonExistingTopic(t *testing.T) { - ws := wstools.CreateWebSocketHandler[TestSubscribeMsg](1, 10) + logger := logging.NewNopLogger() + + ws := wstools.CreateWebSocketHandler[TestSubscribeMsg](logger, 1, 10) topic, err := ws.UpdateTopicName(&wstools.Topic{ Name: "unknown", }, "exists") @@ -165,7 +177,9 @@ func TestWebSocketUpdateNonExistingTopic(t *testing.T) { } func TestWebSocketRemoveNonExistingTopic(t *testing.T) { - ws := wstools.CreateWebSocketHandler[TestSubscribeMsg](1, 10) + logger := logging.NewNopLogger() + + ws := wstools.CreateWebSocketHandler[TestSubscribeMsg](logger, 1, 10) err := ws.RemoveTopic(&wstools.Topic{ Name: "unknown", }) diff --git a/pkg/config/main.go b/pkg/config/main.go index b0057da..6b7aaa0 100644 --- a/pkg/config/main.go +++ b/pkg/config/main.go @@ -4,14 +4,19 @@ package config import ( "fmt" + "log/slog" "os" "strconv" "strings" + "github.com/XDoubleU/essentia/internal/shared" "github.com/joho/godotenv" ) -var dotEnvLoaded = false //nolint:gochecknoglobals //need this for tracking state +// Parser parses the config provided through environment variables. +type Parser struct { + logger *slog.Logger +} const ( // ProdEnv can be used as value when reading out the type of environment. @@ -24,73 +29,108 @@ const ( const errorMessage = "can't convert env var '%s' with value '%s' to %s" -// EnvStr extracts a string environment variable. -func EnvStr(key string, defaultValue string) string { - if !dotEnvLoaded { - _ = godotenv.Load() - dotEnvLoaded = true +// New returns a new Parser and loads environment variables that +// could be provided using a .env file (particularly useful during development). +func New(logger *slog.Logger) Parser { + _ = godotenv.Load() + + return Parser{ + logger: logger, } +} +func (c Parser) baseEnv(key string) string { value, exists := os.LookupEnv(key) if !exists { - return defaultValue + return "" } return value } -// EnvStrArray extracts a string -// array environment variable. The values should be separated by ','. -func EnvStrArray(key string, defaultValue []string) []string { - strVal := EnvStr(key, "") - if len(strVal) == 0 { - return defaultValue +func (c Parser) logValue(valType string, key string, value any) { + strVal, err := shared.AnyToString(value) + if err != nil { + panic(err) } - return strings.Split(strVal, ",") + c.logger.Info( + fmt.Sprintf("loaded env var '%s'='%s' with type '%s'", key, strVal, valType), + ) } -// EnvInt extracts an integer environment variable. -func EnvInt(key string, defaultValue int) int { - strVal := EnvStr(key, "") - if len(strVal) == 0 { - return defaultValue +// EnvStr extracts a string environment variable. +func (c Parser) EnvStr(key string, defaultValue string) string { + value := c.baseEnv(key) + if len(value) == 0 { + value = defaultValue } - intVal, err := strconv.Atoi(strVal) - if err != nil { - panic(fmt.Sprintf(errorMessage, key, strVal, "int")) + c.logValue("string", key, value) + return value +} + +// EnvStrArray extracts a string +// array environment variable. The values should be separated by ','. +func (c Parser) EnvStrArray(key string, defaultValue []string) []string { + value := defaultValue + + strVal := c.baseEnv(key) + if len(strVal) != 0 { + value = strings.Split(strVal, ",") } - return intVal + c.logValue("string array", key, value) + return value } -// EnvFloat extracts a float environment variable. -func EnvFloat(key string, defaultValue float64) float64 { - strVal := EnvStr(key, "") - if len(strVal) == 0 { - return defaultValue +// EnvInt extracts an integer environment variable. +func (c Parser) EnvInt(key string, defaultValue int) int { + value := defaultValue + + strVal := c.baseEnv(key) + if len(strVal) != 0 { + intVal, err := strconv.Atoi(strVal) + if err != nil { + panic(fmt.Sprintf(errorMessage, key, strVal, "int")) + } + value = intVal } - floatVal, err := strconv.ParseFloat(strVal, 64) - if err != nil { - panic(fmt.Sprintf(errorMessage, key, strVal, "float64")) + c.logValue("int", key, value) + return value +} + +// EnvFloat extracts a float environment variable. +func (c Parser) EnvFloat(key string, defaultValue float64) float64 { + value := defaultValue + + strVal := c.baseEnv(key) + if len(strVal) != 0 { + floatVal, err := strconv.ParseFloat(strVal, 64) + if err != nil { + panic(fmt.Sprintf(errorMessage, key, strVal, "float64")) + } + value = floatVal } - return floatVal + c.logValue("float64", key, value) + return value } // EnvBool extracts a boolean environment variable. -func EnvBool(key string, defaultValue bool) bool { - strVal := EnvStr(key, "") - if len(strVal) == 0 { - return defaultValue +func (c Parser) EnvBool(key string, defaultValue bool) bool { + value := defaultValue + + strVal := c.baseEnv(key) + if len(strVal) != 0 { + boolVal, err := strconv.ParseBool(strVal) + if err != nil { + panic(fmt.Sprintf(errorMessage, key, strVal, "bool")) + } + value = boolVal } - boolVal, err := strconv.ParseBool(strVal) - if err != nil { - panic(fmt.Sprintf(errorMessage, key, strVal, "bool")) - } - - return boolVal + c.logValue("bool", key, value) + return value } diff --git a/pkg/config/main_test.go b/pkg/config/main_test.go index 2ef2776..a894ea7 100644 --- a/pkg/config/main_test.go +++ b/pkg/config/main_test.go @@ -1,10 +1,13 @@ package config_test import ( + "bytes" + "log/slog" "strconv" "testing" "github.com/XDoubleU/essentia/pkg/config" + "github.com/XDoubleU/essentia/pkg/logging" "github.com/stretchr/testify/assert" ) @@ -15,11 +18,25 @@ func TestEnvStr(t *testing.T) { t.Setenv(existingKey, expected) - exists := config.EnvStr(existingKey, def) - notExists := config.EnvStr(nonExistingKey, def) + var buf bytes.Buffer + c := config.New( + slog.New( + logging.NewBufLogHandler( + &buf, + //nolint:exhaustruct //other fields are optional + &slog.HandlerOptions{Level: slog.LevelDebug}, + ), + ), + ) + + exists := c.EnvStr(existingKey, def) + notExists := c.EnvStr(nonExistingKey, def) assert.Equal(t, exists, expected) assert.Equal(t, notExists, def) + + assert.Contains(t, buf.String(), "loaded env var 'key'='string' with type 'string'") + assert.Contains(t, buf.String(), "loaded env var 'non_key'='' with type 'string'") } func TestEnvStrArray(t *testing.T) { @@ -28,11 +45,33 @@ func TestEnvStrArray(t *testing.T) { t.Setenv(existingKey, rawExpected) - exists := config.EnvStrArray(existingKey, def) - notExists := config.EnvStrArray(nonExistingKey, def) + var buf bytes.Buffer + c := config.New( + slog.New( + logging.NewBufLogHandler( + &buf, + //nolint:exhaustruct //other fields are optional + &slog.HandlerOptions{Level: slog.LevelDebug}, + ), + ), + ) + + exists := c.EnvStrArray(existingKey, def) + notExists := c.EnvStrArray(nonExistingKey, def) assert.Equal(t, exists, expected) assert.Equal(t, notExists, def) + + assert.Contains( + t, + buf.String(), + "loaded env var 'key'='string1,string2' with type 'string array'", + ) + assert.Contains( + t, + buf.String(), + "loaded env var 'non_key'='' with type 'string array'", + ) } func TestEnvInt(t *testing.T) { @@ -40,11 +79,25 @@ func TestEnvInt(t *testing.T) { t.Setenv(existingKey, strconv.Itoa(expected)) - exists := config.EnvInt(existingKey, def) - notExists := config.EnvInt(nonExistingKey, def) + var buf bytes.Buffer + c := config.New( + slog.New( + logging.NewBufLogHandler( + &buf, + //nolint:exhaustruct //other fields are optional + &slog.HandlerOptions{Level: slog.LevelDebug}, + ), + ), + ) + + exists := c.EnvInt(existingKey, def) + notExists := c.EnvInt(nonExistingKey, def) assert.Equal(t, exists, expected) assert.Equal(t, notExists, def) + + assert.Contains(t, buf.String(), "loaded env var 'key'='14' with type 'int'") + assert.Contains(t, buf.String(), "loaded env var 'non_key'='0' with type 'int'") } func TestEnvIntWrong(t *testing.T) { @@ -52,10 +105,21 @@ func TestEnvIntWrong(t *testing.T) { t.Setenv(existingKey, expected) + var buf bytes.Buffer + c := config.New( + slog.New( + logging.NewBufLogHandler( + &buf, + //nolint:exhaustruct //other fields are optional + &slog.HandlerOptions{Level: slog.LevelDebug}, + ), + ), + ) + assert.PanicsWithValue( t, "can't convert env var 'key' with value 'string' to int", - func() { config.EnvInt(existingKey, def) }, + func() { c.EnvInt(existingKey, def) }, ) } @@ -64,11 +128,29 @@ func TestEnvFloat(t *testing.T) { t.Setenv(existingKey, strconv.FormatFloat(expected, 'f', -1, 64)) - exists := config.EnvFloat(existingKey, def) - notExists := config.EnvFloat(nonExistingKey, def) + var buf bytes.Buffer + c := config.New( + slog.New( + logging.NewBufLogHandler( + &buf, + //nolint:exhaustruct //other fields are optional + &slog.HandlerOptions{Level: slog.LevelDebug}, + ), + ), + ) + + exists := c.EnvFloat(existingKey, def) + notExists := c.EnvFloat(nonExistingKey, def) assert.Equal(t, exists, expected) assert.Equal(t, notExists, def) + + assert.Contains(t, buf.String(), "loaded env var 'key'='14.00' with type 'float64'") + assert.Contains( + t, + buf.String(), + "loaded env var 'non_key'='0.00' with type 'float64'", + ) } func TestEnvFloatWrong(t *testing.T) { @@ -76,10 +158,21 @@ func TestEnvFloatWrong(t *testing.T) { t.Setenv(existingKey, expected) + var buf bytes.Buffer + c := config.New( + slog.New( + logging.NewBufLogHandler( + &buf, + //nolint:exhaustruct //other fields are optional + &slog.HandlerOptions{Level: slog.LevelDebug}, + ), + ), + ) + assert.PanicsWithValue( t, "can't convert env var 'key' with value 'string' to float64", - func() { config.EnvFloat(existingKey, def) }, + func() { c.EnvFloat(existingKey, def) }, ) } @@ -88,11 +181,29 @@ func TestEnvBool(t *testing.T) { t.Setenv(existingKey, strconv.FormatBool(expected)) - exists := config.EnvBool(existingKey, def) - notExists := config.EnvBool(nonExistingKey, def) + var buf bytes.Buffer + c := config.New( + slog.New( + logging.NewBufLogHandler( + &buf, + //nolint:exhaustruct //other fields are optional + &slog.HandlerOptions{Level: slog.LevelDebug}, + ), + ), + ) + + exists := c.EnvBool(existingKey, def) + notExists := c.EnvBool(nonExistingKey, def) assert.Equal(t, exists, expected) assert.Equal(t, notExists, def) + + assert.Contains(t, buf.String(), "loaded env var 'key'='true' with type 'bool'") + assert.Contains( + t, + buf.String(), + "loaded env var 'non_key'='false' with type 'bool'", + ) } func TestEnvBoolWrong(t *testing.T) { @@ -100,9 +211,20 @@ func TestEnvBoolWrong(t *testing.T) { t.Setenv(existingKey, expected) + var buf bytes.Buffer + c := config.New( + slog.New( + logging.NewBufLogHandler( + &buf, + //nolint:exhaustruct //other fields are optional + &slog.HandlerOptions{Level: slog.LevelDebug}, + ), + ), + ) + assert.PanicsWithValue( t, "can't convert env var 'key' with value 'string' to bool", - func() { config.EnvBool(existingKey, def) }, + func() { c.EnvBool(existingKey, def) }, ) } diff --git a/pkg/database/postgres/interface.go b/pkg/database/postgres/interface.go index 17c07f5..a767c87 100644 --- a/pkg/database/postgres/interface.go +++ b/pkg/database/postgres/interface.go @@ -21,6 +21,8 @@ type DB interface { optionsAndArgs ...any, ) (pgx.Rows, error) QueryRow(ctx context.Context, sql string, optionsAndArgs ...any) pgx.Row + SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults Begin(ctx context.Context) (pgx.Tx, error) + BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error) Ping(ctx context.Context) error } diff --git a/pkg/database/postgres/pgx_sync_tx.go b/pkg/database/postgres/pgx_sync_tx.go index 2ad8570..3b181c6 100644 --- a/pkg/database/postgres/pgx_sync_tx.go +++ b/pkg/database/postgres/pgx_sync_tx.go @@ -6,7 +6,6 @@ import ( "github.com/XDoubleU/essentia/pkg/database" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" - "github.com/jackc/pgx/v5/pgtype" ) // PgxSyncTx uses [database.SyncTx] to make sure @@ -17,19 +16,14 @@ type PgxSyncTx struct { // PgxSyncRow is a concurrent wrapper for [pgx.Row]. type PgxSyncRow struct { - rows pgx.Rows - err error + syncTx *database.SyncTx[pgx.Tx] + row pgx.Row } // PgxSyncRows is a concurrent wrapper for [pgx.Rows]. type PgxSyncRows struct { - values [][]any - rawValues [][][]byte - err error - fieldDescriptions []pgconn.FieldDescription - commandTag pgconn.CommandTag - conn *pgx.Conn - i int + syncTx *database.SyncTx[pgx.Tx] + rows pgx.Rows } // CreatePgxSyncTx returns a [pgx.Tx] which works concurrently. @@ -62,128 +56,92 @@ func (tx *PgxSyncTx) Query( sql string, args ...any, ) (pgx.Rows, error) { - return database.WrapInSyncTx( + tx.syncTx.Mutex.Lock() + + rows, err := tx.syncTx.Tx.Query(ctx, sql, args...) + if err != nil { + return nil, err + } + + return &PgxSyncRows{ + syncTx: tx.syncTx, + rows: rows, + }, nil +} + +// SendBatch is used to wrap [pgx.Tx.QueryRow] in a [database.SyncTx]. +func (tx *PgxSyncTx) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults { + return database.WrapInSyncTxNoError( ctx, tx.syncTx, - func(ctx context.Context) (*PgxSyncRows, error) { - rows, err := tx.syncTx.Tx.Query(ctx, sql, args...) - if err != nil { - return nil, err - } - defer rows.Close() - - var results [][]any - var rawResults [][][]byte - for rows.Next() { - var values []any - values, err = rows.Values() - if err != nil { - break - } - - temp := rows.RawValues() - rawValues := make([][]byte, len(temp)) - copy(rawValues, temp) - - results = append(results, values) - rawResults = append(rawResults, rawValues) - } - - if err == nil { - err = rows.Err() - } - - return &PgxSyncRows{ - values: results, - rawValues: rawResults, - err: err, - fieldDescriptions: rows.FieldDescriptions(), - commandTag: rows.CommandTag(), - conn: rows.Conn(), - i: -1, - }, nil + func(ctx context.Context) pgx.BatchResults { + return tx.syncTx.Tx.SendBatch(ctx, b) }, ) } -// Close doesn't do anything for [PgxSyncRows] as these are closed in [Query]. +// Close closes the opened [pgx.Rows]. func (rows *PgxSyncRows) Close() { + rows.syncTx.Unlock() + rows.rows.Close() } // CommandTag fetches the [pgconn.CommandTag]. func (rows *PgxSyncRows) CommandTag() pgconn.CommandTag { - return rows.commandTag + return rows.rows.CommandTag() } // Conn fetches the [pgx.Conn]. func (rows *PgxSyncRows) Conn() *pgx.Conn { - return rows.conn + return rows.rows.Conn() } // Err fetches any errors. func (rows *PgxSyncRows) Err() error { - return rows.err + return rows.rows.Err() } // FieldDescriptions fetches [pgconn.FieldDescription]s. func (rows *PgxSyncRows) FieldDescriptions() []pgconn.FieldDescription { - return rows.fieldDescriptions + return rows.rows.FieldDescriptions() } // Next continues to the next row of [PgxSyncRows] if there is one. func (rows *PgxSyncRows) Next() bool { - rows.i++ - return rows.i < len(rows.values) + return rows.rows.Next() } // RawValues fetches the raw values of the current row. func (rows *PgxSyncRows) RawValues() [][]byte { - return rows.rawValues[rows.i] + return rows.rows.RawValues() } // Scan scans the data of the current row into dest. func (rows *PgxSyncRows) Scan(dest ...any) error { - if err := rows.Err(); err != nil { - return err - } - - return pgx.ScanRow( - pgtype.NewMap(), - rows.FieldDescriptions(), - rows.RawValues(), - dest...) + return rows.rows.Scan(dest...) } // Values fetches the values of the current row. func (rows *PgxSyncRows) Values() ([]any, error) { - return rows.values[rows.i], nil + return rows.rows.Values() } // QueryRow is used to wrap [pgx.Tx.QueryRow] in a [database.SyncTx]. func (tx *PgxSyncTx) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row { - rows, err := tx.Query(ctx, sql, args...) + tx.syncTx.Mutex.Lock() + + row := tx.syncTx.Tx.QueryRow(ctx, sql, args...) return &PgxSyncRow{ - rows: rows, - err: err, + syncTx: tx.syncTx, + row: row, } } // Scan scans the data of [PgxSyncRow] into dest. func (row *PgxSyncRow) Scan(dest ...any) error { - if row.err != nil { - return row.err - } - - if err := row.rows.Err(); err != nil { - return err - } - - if !row.rows.Next() { - return pgx.ErrNoRows - } - - return row.rows.Scan(dest...) + defer row.syncTx.Unlock() + return row.row.Scan(dest...) } // Ping is used to wrap [pgx.Tx.Conn.Ping] in a [database.SyncTx]. @@ -208,6 +166,31 @@ func (tx *PgxSyncTx) Begin(ctx context.Context) (pgx.Tx, error) { ) } +// BeginTx is used to wrap [pgx.Tx.BeginTx] in a [database.SyncTx]. +func (tx *PgxSyncTx) BeginTx( + ctx context.Context, + txOptions pgx.TxOptions, +) (pgx.Tx, error) { + return database.WrapInSyncTx( + ctx, + tx.syncTx, + func(ctx context.Context) (pgx.Tx, error) { + return tx.syncTx.Tx.Conn().BeginTx(ctx, txOptions) + }, + ) +} + +// Commit is used to wrap [pgx.Tx.Commit] in a [database.SyncTx]. +func (tx *PgxSyncTx) Commit(ctx context.Context) error { + return database.WrapInSyncTxNoError( + ctx, + tx.syncTx, + func(ctx context.Context) error { + return tx.syncTx.Tx.Commit(ctx) + }, + ) +} + // Rollback is used to wrap [pgx.Tx.Rollback] in a [database.SyncTx]. func (tx *PgxSyncTx) Rollback(ctx context.Context) error { return database.WrapInSyncTxNoError( diff --git a/pkg/database/postgres/pgx_sync_tx_test.go b/pkg/database/postgres/pgx_sync_tx_test.go index f397dec..3ebae26 100644 --- a/pkg/database/postgres/pgx_sync_tx_test.go +++ b/pkg/database/postgres/pgx_sync_tx_test.go @@ -2,6 +2,7 @@ package postgres_test import ( "context" + "sync" "testing" "time" @@ -32,6 +33,55 @@ func setup(t *testing.T) *postgres.PgxSyncTx { return postgres.CreatePgxSyncTx(context.Background(), db) } +func TestParallel(t *testing.T) { + tx := setup(t) + defer func() { err := tx.Rollback(context.Background()); assert.Nil(t, err) }() + + db := postgres.NewSpanDB(tx) + + _, err := db.Exec( + context.Background(), + "CREATE TABLE kv (key VARCHAR(255), value VARCHAR(255));", + ) + require.Nil(t, err) + + _, err = db.Exec( + context.Background(), + "INSERT INTO kv (key, value) VALUES ('key1', 'value1');", + ) + require.Nil(t, err) + + mu1 := sync.Mutex{} + mu2 := sync.Mutex{} + + mu1.Lock() + mu2.Lock() + + go queryRow(t, db, &mu1) + go queryRow(t, db, &mu2) + + mu1.Lock() + mu2.Lock() + + assert.True(t, true) +} + +func queryRow(t *testing.T, db postgres.DB, mu *sync.Mutex) { + for i := 0; i < 100; i++ { + var key string + var value string + err := db.QueryRow( + context.Background(), + "SELECT key, value FROM kv WHERE key = 'key1';", + ).Scan(&key, &value) + assert.Nil(t, err) + + time.Sleep(10 * time.Millisecond) + } + + mu.Unlock() +} + func TestPing(t *testing.T) { tx := setup(t) defer func() { err := tx.Rollback(context.Background()); assert.Nil(t, err) }() @@ -74,6 +124,7 @@ func TestQuery(t *testing.T) { rows, err := db.Query(context.Background(), "SELECT key, value FROM kv;") require.Nil(t, err) + defer rows.Close() results := make([][]string, 2) results[0] = make([]string, 2) diff --git a/pkg/database/postgres/spandb.go b/pkg/database/postgres/spandb.go index 4f56419..81ddf8a 100644 --- a/pkg/database/postgres/spandb.go +++ b/pkg/database/postgres/spandb.go @@ -2,6 +2,7 @@ package postgres import ( "context" + "fmt" "github.com/XDoubleU/essentia/pkg/database" "github.com/jackc/pgx/v5" @@ -54,12 +55,37 @@ func (db *SpanDB) QueryRow( optionsAndArgs...) } +// SendBatch is used to wrap SendBatch in a [sentry.Span]. +func (db *SpanDB) SendBatch( + ctx context.Context, + b *pgx.Batch, +) pgx.BatchResults { + sql := "" + for i, query := range b.QueuedQueries { + sql += fmt.Sprintf("query %d: %s\n", i, query.SQL) + } + + span := database.StartSpan(ctx, db.dbName, sql) + defer span.Finish() + + return db.DB.SendBatch(ctx, b) +} + // Begin doesn't wrap Begin in a [sentry.Span] as // this makes little sense for starting a transaction. func (db *SpanDB) Begin(ctx context.Context) (pgx.Tx, error) { return db.DB.Begin(ctx) } +// BeginTx doesn't wrap BeginTx in a [sentry.Span] as +// this makes little sense for starting a transaction. +func (db *SpanDB) BeginTx( + ctx context.Context, + txOptions pgx.TxOptions, +) (pgx.Tx, error) { + return db.DB.BeginTx(ctx, txOptions) +} + // Ping doesn't wrap Ping in a [sentry.Span] as // this makes little sense for pinging the db. func (db *SpanDB) Ping(ctx context.Context) error { diff --git a/pkg/database/span.go b/pkg/database/span.go index eafec35..8a4288d 100644 --- a/pkg/database/span.go +++ b/pkg/database/span.go @@ -6,7 +6,8 @@ import ( "github.com/getsentry/sentry-go" ) -func startSpan(ctx context.Context, dbName string, sql string) *sentry.Span { +// StartSpan is used to start a [sentry.Span]. +func StartSpan(ctx context.Context, dbName string, sql string) *sentry.Span { span := sentry.StartSpan(ctx, "db.query", sentry.WithDescription(sql)) span.SetData("db.system", dbName) @@ -20,7 +21,7 @@ func WrapWithSpan[T any]( dbName string, queryFunc func(ctx context.Context, sql string, args ...any) (T, error), sql string, args ...any) (T, error) { - span := startSpan(ctx, dbName, sql) + span := StartSpan(ctx, dbName, sql) defer span.Finish() return queryFunc(ctx, sql, args...) @@ -34,7 +35,7 @@ func WrapWithSpanNoError[T any]( dbName string, queryFunc func(ctx context.Context, sql string, args ...any) T, sql string, args ...any) T { - span := startSpan(ctx, dbName, sql) + span := StartSpan(ctx, dbName, sql) defer span.Finish() return queryFunc(ctx, sql, args...) diff --git a/pkg/logging/main.go b/pkg/logging/main.go index fe47c06..ff653e1 100644 --- a/pkg/logging/main.go +++ b/pkg/logging/main.go @@ -2,6 +2,7 @@ package logging import ( + "bytes" "io" "log/slog" ) @@ -16,3 +17,9 @@ func ErrAttr(err error) slog.Attr { func NewNopLogger() *slog.Logger { return slog.New(slog.NewTextHandler(io.Discard, nil)) } + +// NewBufLogHandler provides a [slog.TextHandler] +// which logs to the provided [bytes.Buffer]. +func NewBufLogHandler(buf *bytes.Buffer, opts *slog.HandlerOptions) *slog.TextHandler { + return slog.NewTextHandler(buf, opts) +} diff --git a/pkg/parse/main_test.go b/pkg/parse/main_test.go index 9f475f8..2d14223 100644 --- a/pkg/parse/main_test.go +++ b/pkg/parse/main_test.go @@ -123,7 +123,7 @@ func TestArrayQueryParamFailedParseFunc(t *testing.T) { req, "queryParam", []int{1, 2}, - parse.IntFunc(false, true), + parse.Int(false, true), ) assert.Equal(t, []int{}, result) diff --git a/pkg/parse/parser_funcs.go b/pkg/parse/parser_funcs.go index 42f7408..f6c3538 100644 --- a/pkg/parse/parser_funcs.go +++ b/pkg/parse/parser_funcs.go @@ -12,6 +12,12 @@ import ( // ParserFunc is the expected format used for parsing data using any parsing function. type ParserFunc[T any] func(paramType string, paramName string, value string) (T, error) +// String is used to parse a parameter as string value. +// As all parameters are string by default this returns the original value. +func String(_ string, _ string, value string) (string, error) { + return value, nil +} + // UUID is used to parse a parameter as UUID value. // Technically this only validates if a string is a UUID. func UUID(paramType string, paramName string, value string) (string, error) { @@ -28,15 +34,15 @@ func UUID(paramType string, paramName string, value string) (string, error) { return uuidVal.String(), nil } -// IntFunc parses a parameter as [int]. -func IntFunc(isPositive bool, isZero bool) ParserFunc[int] { +// Int parses a parameter as [int]. +func Int(isPositive bool, isZero bool) ParserFunc[int] { return func(paramType string, paramName string, value string) (int, error) { return parseInt[int](isPositive, isZero, paramType, paramName, value, 0) } } -// Int64Func parses a parameter as [int64]. -func Int64Func(isPositive bool, isZero bool) ParserFunc[int64] { +// Int64 parses a parameter as [int64]. +func Int64(isPositive bool, isZero bool) ParserFunc[int64] { return func(paramType string, paramName string, value string) (int64, error) { //nolint:mnd // no magic number return parseInt[int64](isPositive, isZero, paramType, paramName, value, 64) @@ -83,9 +89,9 @@ func parseInt[T shared.IntType]( return T(result), nil } -// DateFunc parses a parameter as a date. +// Date parses a parameter as a date. // The parameter should match the required date layout. -func DateFunc(layout string) ParserFunc[time.Time] { +func Date(layout string) ParserFunc[time.Time] { return func(paramType string, paramName string, value string) (time.Time, error) { result, err := time.Parse(layout, value) if err != nil { diff --git a/pkg/parse/parser_funcs_test.go b/pkg/parse/parser_funcs_test.go index 1bcd18c..add3ecb 100644 --- a/pkg/parse/parser_funcs_test.go +++ b/pkg/parse/parser_funcs_test.go @@ -12,6 +12,16 @@ import ( "github.com/stretchr/testify/assert" ) +func TestURLParamString(t *testing.T) { + req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) + req.SetPathValue("pathValue", "value") + + result, err := parse.URLParam(req, "pathValue", parse.String) + + assert.Equal(t, "value", result) + assert.Equal(t, nil, err) +} + func TestURLParamUUIDOK(t *testing.T) { val, _ := uuid.NewV7() @@ -44,7 +54,7 @@ func TestURLParamInt64OK(t *testing.T) { req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) req.SetPathValue("pathValue", fmt.Sprintf("%d", 1)) - result, err := parse.URLParam(req, "pathValue", parse.Int64Func(false, true)) + result, err := parse.URLParam(req, "pathValue", parse.Int64(false, true)) assert.Equal(t, int64(1), result) assert.Equal(t, nil, err) @@ -54,7 +64,7 @@ func TestURLParamIntOK(t *testing.T) { req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) req.SetPathValue("pathValue", fmt.Sprintf("%d", 1)) - result, err := parse.URLParam(req, "pathValue", parse.IntFunc(false, true)) + result, err := parse.URLParam(req, "pathValue", parse.Int(false, true)) assert.Equal(t, 1, result) assert.Equal(t, nil, err) @@ -64,7 +74,7 @@ func TestURLParamIntNOK(t *testing.T) { req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) req.SetPathValue("pathValue", "notint") - result, err := parse.URLParam(req, "pathValue", parse.IntFunc(false, true)) + result, err := parse.URLParam(req, "pathValue", parse.Int(false, true)) assert.Equal(t, 0, result) assert.Equal( @@ -80,7 +90,7 @@ func TestURLParamIntLTZero(t *testing.T) { req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) req.SetPathValue("pathValue", fmt.Sprintf("%d", -1)) - result, err := parse.URLParam(req, "pathValue", parse.IntFunc(true, true)) + result, err := parse.URLParam(req, "pathValue", parse.Int(true, true)) assert.Equal(t, 0, result) assert.Equal( @@ -96,7 +106,7 @@ func TestURLParamIntZero(t *testing.T) { req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) req.SetPathValue("pathValue", fmt.Sprintf("%d", 0)) - result, err := parse.URLParam(req, "pathValue", parse.IntFunc(true, false)) + result, err := parse.URLParam(req, "pathValue", parse.Int(true, false)) assert.Equal(t, 0, result) assert.Equal( @@ -112,7 +122,7 @@ func TestURLParamDateOK(t *testing.T) { req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) req.SetPathValue("pathValue", datetime) - result, err := parse.URLParam(req, "pathValue", parse.DateFunc("2006-01-02")) + result, err := parse.URLParam(req, "pathValue", parse.Date("2006-01-02")) expected, _ := time.Parse("2006-01-02", datetime) assert.Equal(t, expected, result) @@ -125,7 +135,7 @@ func TestURLParamDateNOK(t *testing.T) { req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) req.SetPathValue("pathValue", datetime) - result, err := parse.URLParam(req, "pathValue", parse.DateFunc("2006-01-02")) + result, err := parse.URLParam(req, "pathValue", parse.Date("2006-01-02")) expected, _ := time.Parse("2006-01-02", datetime) assert.Equal(t, expected, result) diff --git a/pkg/sentry/goroutine.go b/pkg/sentry/goroutine.go index 141506b..2612604 100644 --- a/pkg/sentry/goroutine.go +++ b/pkg/sentry/goroutine.go @@ -3,17 +3,19 @@ package sentry import ( "context" "fmt" + "log/slog" "net/http" "github.com/getsentry/sentry-go" ) -// GoRoutineErrorHandler makes sure a -// go routine and its errors are captured by Sentry. -func GoRoutineErrorHandler( +// GoRoutineWrapper wraps a go routine with +// Sentry logic for error and performance tracking. +func GoRoutineWrapper( ctx context.Context, + logger *slog.Logger, name string, - f func(ctx context.Context) error, + f func(ctx context.Context, logger *slog.Logger) error, ) { name = fmt.Sprintf("GO ROUTINE %s", name) @@ -28,7 +30,7 @@ func GoRoutineErrorHandler( transaction.Status = sentry.HTTPtoSpanStatus(http.StatusOK) defer transaction.Finish() - err := f(transaction.Context()) + err := f(transaction.Context(), logger) if err != nil { transaction.Status = sentry.HTTPtoSpanStatus(http.StatusInternalServerError) diff --git a/pkg/sentry/goroutine_test.go b/pkg/sentry/goroutine_test.go index b713bc1..df33af6 100644 --- a/pkg/sentry/goroutine_test.go +++ b/pkg/sentry/goroutine_test.go @@ -4,9 +4,11 @@ import ( "context" "errors" "fmt" + "log/slog" "sync" "testing" + "github.com/XDoubleU/essentia/pkg/logging" sentrytools "github.com/XDoubleU/essentia/pkg/sentry" "github.com/getsentry/sentry-go" "github.com/stretchr/testify/assert" @@ -15,9 +17,13 @@ import ( func TestSentryErrorHandler(t *testing.T) { name := "test" - testFunc := func(ctx context.Context) error { + logger := logging.NewNopLogger() + + testFunc := func(ctx context.Context, logger *slog.Logger) error { transaction := sentry.TransactionFromContext(ctx) + logger.Debug("started execution") + assert.Equal(t, fmt.Sprintf("GO ROUTINE %s", name), transaction.Name) assert.Equal(t, "go.routine", transaction.Op) @@ -28,8 +34,9 @@ func TestSentryErrorHandler(t *testing.T) { wg.Add(1) go func() { - sentrytools.GoRoutineErrorHandler( + sentrytools.GoRoutineWrapper( context.Background(), + logger, name, testFunc, ) diff --git a/pkg/sentry/loghandler.go b/pkg/sentry/loghandler.go index 940e94a..1b799b5 100644 --- a/pkg/sentry/loghandler.go +++ b/pkg/sentry/loghandler.go @@ -14,8 +14,12 @@ import ( type LogHandler struct { level slog.Level handler slog.Handler - attrs []slog.Attr - groups []string + goas []groupOrAttrs +} + +type groupOrAttrs struct { + group string // group name if non-empty + attrs []slog.Attr // attrs if non-empty } // NewLogHandler returns a new [SentryLogHandler]. @@ -28,9 +32,8 @@ func NewLogHandler(env string, handler slog.Handler) slog.Handler { return &LogHandler{ handler: handler, - attrs: []slog.Attr{}, - groups: []string{}, level: level, + goas: []groupOrAttrs{}, } } @@ -42,34 +45,65 @@ func (l *LogHandler) Enabled(_ context.Context, level slog.Level) bool { // WithAttrs adds [[]slog.Attr] to a [SentryLogHandler]. func (l *LogHandler) WithAttrs(attrs []slog.Attr) slog.Handler { - //nolint:exhaustruct //other fields are optional - return &LogHandler{ - attrs: append(l.attrs, attrs...), - groups: l.groups, + if len(attrs) == 0 { + return l } + + l.handler = l.handler.WithAttrs(attrs) + return l.withGroupOrAttrs(groupOrAttrs{group: "", attrs: attrs}) } // WithGroup adds a group to a [SentryLogHandler]. func (l *LogHandler) WithGroup(name string) slog.Handler { - //nolint:exhaustruct //other fields are optional - return &LogHandler{ - attrs: l.attrs, - groups: append(l.groups, name), + if name == "" { + return l } + + l.handler = l.handler.WithGroup(name) + return l.withGroupOrAttrs(groupOrAttrs{group: name, attrs: []slog.Attr{}}) +} + +func (l *LogHandler) withGroupOrAttrs(goa groupOrAttrs) slog.Handler { + l2 := *l + l2.goas = make([]groupOrAttrs, len(l.goas)+1) + copy(l2.goas, l.goas) + l2.goas[len(l2.goas)-1] = goa + return &l2 } // Handle handles a [slog.Record] by a [SentryLogHandler]. func (l *LogHandler) Handle(ctx context.Context, record slog.Record) error { if record.Level == slog.LevelError { - sendErrorToSentry(ctx, recordToError(record)) + l.sendErrorToSentry(ctx, recordToError(record)) } return l.handler.Handle(ctx, record) } -func sendErrorToSentry(ctx context.Context, err error) { +func (l *LogHandler) sendErrorToSentry(ctx context.Context, err error) { if hub := sentry.GetHubFromContext(ctx); hub != nil { hub.WithScope(func(scope *sentry.Scope) { + prefix := "" + + for _, goa := range l.goas { + temporaryPrefix := prefix + if goa.group != "" { + temporaryPrefix = fmt.Sprintf("%s.", goa.group) + } + + if len(goa.attrs) == 0 { + prefix = temporaryPrefix + continue + } + + for _, attr := range goa.attrs { + scope.SetTag( + fmt.Sprintf("%s%s", temporaryPrefix, attr.Key), + attr.Value.String(), + ) + } + } + scope.SetLevel(sentry.LevelError) hub.CaptureException(err) }) @@ -77,12 +111,5 @@ func sendErrorToSentry(ctx context.Context, err error) { } func recordToError(record slog.Record) error { - err := record.Message - - record.Attrs(func(a slog.Attr) bool { - err += fmt.Sprintf(" %s=%s", a.Key, a.Value) - return true - }) - - return errors.New(err) + return errors.New(record.Message) } diff --git a/pkg/sentry/loghandler_test.go b/pkg/sentry/loghandler_test.go index a7ff652..5b3bee3 100644 --- a/pkg/sentry/loghandler_test.go +++ b/pkg/sentry/loghandler_test.go @@ -19,7 +19,10 @@ func TestLogHandlerDev(t *testing.T) { sentry.NewLogHandler( config.DevEnv, //nolint:exhaustruct //other fields are optional - slog.NewTextHandler(&buf, &slog.HandlerOptions{Level: slog.LevelDebug}), + logging.NewBufLogHandler( + &buf, + &slog.HandlerOptions{Level: slog.LevelDebug}, + ), ), ) @@ -27,3 +30,25 @@ func TestLogHandlerDev(t *testing.T) { assert.Contains(t, buf.String(), "level=ERROR msg=test error=testerror") } + +func TestLogHandlerWith(t *testing.T) { + var buf bytes.Buffer + + logger := slog.New( + sentry.NewLogHandler( + config.DevEnv, + //nolint:exhaustruct //other fields are optional + logging.NewBufLogHandler( + &buf, + &slog.HandlerOptions{Level: slog.LevelDebug}, + ), + ), + ) + + logger = logger.With(slog.String("source", "test")) + + logger.Error("test", logging.ErrAttr(errors.New("testerror"))) + + test := buf.String() + assert.Contains(t, test, "level=ERROR msg=test source=test error=testerror") +} diff --git a/pkg/test/api_helpers_test.go b/pkg/test/api_helpers_test.go index c38cd23..edec29e 100644 --- a/pkg/test/api_helpers_test.go +++ b/pkg/test/api_helpers_test.go @@ -15,7 +15,7 @@ func paginatedEndpointHandler(w http.ResponseWriter, r *http.Request) { pageSize := 2 data := []string{"1", "2", "3"} - page, err := parse.RequiredQueryParam(r, "page", parse.IntFunc(true, false)) + page, err := parse.RequiredQueryParam(r, "page", parse.Int(true, false)) if err != nil { httptools.BadRequestResponse(w, r, errors.NewBadRequestError(err)) return diff --git a/pkg/test/websocket_test.go b/pkg/test/websocket_test.go index 7c7b685..f57d791 100644 --- a/pkg/test/websocket_test.go +++ b/pkg/test/websocket_test.go @@ -7,6 +7,7 @@ import ( "testing" wstools "github.com/XDoubleU/essentia/pkg/communication/ws" + "github.com/XDoubleU/essentia/pkg/logging" "github.com/XDoubleU/essentia/pkg/test" "github.com/XDoubleU/essentia/pkg/validate" "github.com/stretchr/testify/assert" @@ -22,8 +23,9 @@ type TestSubscribeMsg struct { TopicName string `json:"topicName"` } -func (s TestSubscribeMsg) Validate() *validate.Validator { - return validate.New() +func (s TestSubscribeMsg) Validate() (bool, map[string]string) { + v := validate.New() + return v.Valid(), v.Errors() } func (s TestSubscribeMsg) Topic() string { @@ -33,7 +35,9 @@ func (s TestSubscribeMsg) Topic() string { func setup(t *testing.T) (http.Handler, *wstools.Topic) { t.Helper() + logger := logging.NewNopLogger() ws := wstools.CreateWebSocketHandler[TestSubscribeMsg]( + logger, 1, 10, ) diff --git a/pkg/time/main.go b/pkg/time/main.go index 534167f..5b649e6 100644 --- a/pkg/time/main.go +++ b/pkg/time/main.go @@ -29,20 +29,20 @@ func EndOfDay(dateTime time.Time) time.Time { return output } -// NowTimeZoneIndependent returns the time in the +// LocationIndependentTime returns the provided time in the // provided time zone but forces the time zone to UTC. -func NowTimeZoneIndependent(locationTimeZone string) time.Time { +func LocationIndependentTime(t time.Time, locationTimeZone string) time.Time { timeZone, _ := time.LoadLocation(locationTimeZone) utcTimeZone, _ := time.LoadLocation("UTC") - now := time.Now().In(timeZone) + t = t.In(timeZone) return time.Date( - now.Year(), - now.Month(), - now.Day(), - now.Hour(), - now.Minute(), - now.Second(), - now.Nanosecond(), + t.Year(), + t.Month(), + t.Day(), + t.Hour(), + t.Minute(), + t.Second(), + t.Nanosecond(), utcTimeZone, ) } diff --git a/pkg/time/main_test.go b/pkg/time/main_test.go index 2d85e07..d0d944c 100644 --- a/pkg/time/main_test.go +++ b/pkg/time/main_test.go @@ -34,11 +34,11 @@ func TestEndOfDay(t *testing.T) { assert.Equal(t, endOfDay, timetools.EndOfDay(now)) } -func TestNowTimeZoneIndependent(t *testing.T) { +func TestLocationIndependentTime(t *testing.T) { now := time.Now() utcTimeZone, _ := time.LoadLocation("UTC") - result := timetools.NowTimeZoneIndependent(now.Location().String()) + result := timetools.LocationIndependentTime(now, now.Location().String()) expected := time.Date( now.Year(), @@ -47,7 +47,7 @@ func TestNowTimeZoneIndependent(t *testing.T) { now.Hour(), now.Minute(), now.Second(), - result.Nanosecond(), + now.Nanosecond(), utcTimeZone, ) diff --git a/pkg/validate/main.go b/pkg/validate/main.go index a737933..965f3f5 100644 --- a/pkg/validate/main.go +++ b/pkg/validate/main.go @@ -4,35 +4,55 @@ package validate // ValidatedType is implemented by any struct with a Validate method. type ValidatedType interface { - Validate() *Validator + Validate() (bool, map[string]string) } // Validator is used to validate contents // of structs using [Check]. type Validator struct { - Errors map[string]string + errors map[string]string } // New creates a new [Validator]. func New() *Validator { - return &Validator{Errors: make(map[string]string)} + return &Validator{errors: make(map[string]string)} } // Valid checks if a [Validator] has any errors. func (v *Validator) Valid() bool { - return len(v.Errors) == 0 + return len(v.errors) == 0 +} + +// Errors returns the [Validator] errors. +func (v *Validator) Errors() map[string]string { + return v.errors } func (v *Validator) addError(key, message string) { - if _, exists := v.Errors[key]; !exists { - v.Errors[key] = message + if _, exists := v.errors[key]; !exists { + v.errors[key] = message } } // Check checks if value passes the validatorFunc. // The provided key is used for creating the errors map of the [Validator]. -func Check[T any](v *Validator, value T, validatorFunc ValidatorFunc[T], key string) { +func Check[T any](v *Validator, key string, value T, validatorFunc ValidatorFunc[T]) { if result, message := validatorFunc(value); !result { v.addError(key, message) } } + +// CheckOptional checks if value passes the validatorFunc, when the value is provided. +// The provided key is used for creating the errors map of the [Validator]. +func CheckOptional[T any]( + v *Validator, + key string, + value *T, + validatorFunc ValidatorFunc[T], +) { + if value == nil { + return + } + + Check(v, key, *value, validatorFunc) +} diff --git a/pkg/validate/validator_funcs.go b/pkg/validate/validator_funcs.go index 8d7c719..33a233e 100644 --- a/pkg/validate/validator_funcs.go +++ b/pkg/validate/validator_funcs.go @@ -2,6 +2,7 @@ package validate import ( "fmt" + "slices" "time" "github.com/XDoubleU/essentia/internal/shared" @@ -15,15 +16,15 @@ func IsNotEmpty(value string) (bool, string) { return value != "", "must be provided" } -// IsGreaterThanFunc checks if the provided value2 > value1. -func IsGreaterThanFunc[T shared.IntType](value1 T) ValidatorFunc[T] { +// IsGreaterThan checks if the provided value2 > value1. +func IsGreaterThan[T shared.IntType](value1 T) ValidatorFunc[T] { return func(value2 T) (bool, string) { return value2 > value1, fmt.Sprintf("must be greater than %d", value1) } } -// IsGreaterThanOrEqualFunc checks if the provided value2 >= value1. -func IsGreaterThanOrEqualFunc[T shared.IntType](value1 T) ValidatorFunc[T] { +// IsGreaterThanOrEqual checks if the provided value2 >= value1. +func IsGreaterThanOrEqual[T shared.IntType](value1 T) ValidatorFunc[T] { return func(value2 T) (bool, string) { return value2 >= value1, fmt.Sprintf( @@ -33,15 +34,15 @@ func IsGreaterThanOrEqualFunc[T shared.IntType](value1 T) ValidatorFunc[T] { } } -// IsLesserThanFunc checks if the provided value2 < value1. -func IsLesserThanFunc[T shared.IntType](value1 T) ValidatorFunc[T] { +// IsLesserThan checks if the provided value2 < value1. +func IsLesserThan[T shared.IntType](value1 T) ValidatorFunc[T] { return func(value2 T) (bool, string) { return value2 < value1, fmt.Sprintf("must be lesser than %d", value1) } } -// IsLesserThanOrEqualFunc checks if the provided value2 <= value1. -func IsLesserThanOrEqualFunc[T shared.IntType](value1 T) ValidatorFunc[T] { +// IsLesserThanOrEqual checks if the provided value2 <= value1. +func IsLesserThanOrEqual[T shared.IntType](value1 T) ValidatorFunc[T] { return func(value2 T) (bool, string) { return value2 <= value1, fmt.Sprintf( @@ -56,3 +57,10 @@ func IsValidTimeZone(value string) (bool, string) { _, err := time.LoadLocation(value) return err == nil, "must be a valid IANA value" } + +// IsInSlice checks if the provided value is part of the provided slice. +func IsInSlice[T comparable](slice []T) ValidatorFunc[T] { + return func(value T) (bool, string) { + return slices.Contains(slice, value), "must be a valid value" + } +} diff --git a/pkg/validate/validator_test.go b/pkg/validate/validator_test.go index 0e8341d..fdcfa21 100644 --- a/pkg/validate/validator_test.go +++ b/pkg/validate/validator_test.go @@ -1,3 +1,4 @@ +//nolint:exhaustruct //on purpose package validate_test import ( @@ -8,161 +9,187 @@ import ( ) type TestStruct struct { - strVal string - intVal int - int64Val int64 - tzVal string + StrVal string `json:"strVal"` + IntVal int `json:"intVal"` + Int64Val int64 `json:"int64Val"` + TzVal string `json:"tzVal"` + OptStrVal *string `json:"optStrVal"` } -func (ts TestStruct) Validate() *validate.Validator { +func (ts *TestStruct) Validate() (bool, map[string]string) { v := validate.New() - validate.Check(v, ts.strVal, validate.IsNotEmpty, "strVal") + validate.Check(v, "strVal", ts.StrVal, validate.IsNotEmpty) - validate.Check(v, ts.intVal, validate.IsGreaterThanFunc(-1), "intVal") - validate.Check(v, ts.intVal, validate.IsGreaterThanOrEqualFunc(1), "intVal") - validate.Check(v, ts.intVal, validate.IsLesserThanFunc(4), "intVal") - validate.Check(v, ts.intVal, validate.IsLesserThanOrEqualFunc(2), "intVal") + validate.Check(v, "intVal", ts.IntVal, validate.IsGreaterThan(-1)) + validate.Check(v, "intVal", ts.IntVal, validate.IsGreaterThanOrEqual(1)) + validate.Check(v, "intVal", ts.IntVal, validate.IsLesserThan(4)) + validate.Check(v, "intVal", ts.IntVal, validate.IsLesserThanOrEqual(2)) - validate.Check(v, ts.int64Val, validate.IsGreaterThanFunc(int64(-1)), "int64Val") + validate.Check(v, "int64Val", ts.Int64Val, validate.IsGreaterThan(int64(-1))) validate.Check( v, - ts.int64Val, - validate.IsGreaterThanOrEqualFunc(int64(1)), "int64Val", + ts.Int64Val, + validate.IsGreaterThanOrEqual(int64(1)), ) - validate.Check(v, ts.int64Val, validate.IsLesserThanFunc(int64(4)), "int64Val") + validate.Check(v, "int64Val", ts.Int64Val, validate.IsLesserThan(int64(4))) validate.Check( v, - ts.int64Val, - validate.IsLesserThanOrEqualFunc(int64(2)), "int64Val", + ts.Int64Val, + validate.IsLesserThanOrEqual(int64(2)), ) - validate.Check(v, ts.tzVal, validate.IsValidTimeZone, "tzVal") + validate.Check(v, "tzVal", ts.TzVal, validate.IsValidTimeZone) - return v + validate.CheckOptional( + v, + "optStrVal", + ts.OptStrVal, + validate.IsInSlice([]string{"allowed"}), + ) + + return v.Valid(), v.Errors() } func TestAllOk(t *testing.T) { + val := "allowed" + ts := TestStruct{ - strVal: "hello", - intVal: 1, - int64Val: 1, - tzVal: "Europe/Brussels", + StrVal: "hello", + IntVal: 1, + Int64Val: 1, + TzVal: "Europe/Brussels", + OptStrVal: &val, } - assert.True(t, ts.Validate().Valid()) + valid, errors := ts.Validate() + assert.True(t, valid) + assert.Equal(t, 0, len(errors)) } func TestIsEmpty(t *testing.T) { ts := TestStruct{ - strVal: "", - intVal: 1, - int64Val: 1, - tzVal: "Europe/Brussels", + StrVal: "", + IntVal: 1, + Int64Val: 1, + TzVal: "Europe/Brussels", } - errors := map[string]string{ + expectedErrors := map[string]string{ "strVal": "must be provided", } - v := ts.Validate() - - assert.False(t, v.Valid()) - assert.Equal(t, errors, v.Errors) + valid, errors := ts.Validate() + assert.False(t, valid) + assert.Equal(t, expectedErrors, errors) } func TestIsNotGT(t *testing.T) { ts := TestStruct{ - strVal: "hello", - intVal: -1, - int64Val: -1, - tzVal: "Europe/Brussels", + StrVal: "hello", + IntVal: -1, + Int64Val: -1, + TzVal: "Europe/Brussels", } - errors := map[string]string{ + expectedErrors := map[string]string{ "intVal": "must be greater than -1", "int64Val": "must be greater than -1", } - v := ts.Validate() - - assert.False(t, v.Valid()) - assert.Equal(t, errors, v.Errors) + valid, errors := ts.Validate() + assert.False(t, valid) + assert.Equal(t, expectedErrors, errors) } func TestIsNotGTE(t *testing.T) { ts := TestStruct{ - strVal: "hello", - intVal: 0, - int64Val: 0, - tzVal: "Europe/Brussels", + StrVal: "hello", + IntVal: 0, + Int64Val: 0, + TzVal: "Europe/Brussels", } - errors := map[string]string{ + expectedErrors := map[string]string{ "intVal": "must be greater than or equal to 1", "int64Val": "must be greater than or equal to 1", } - v := ts.Validate() - - assert.False(t, v.Valid()) - assert.Equal(t, errors, v.Errors) + valid, errors := ts.Validate() + assert.False(t, valid) + assert.Equal(t, expectedErrors, errors) } func TestIsNotLT(t *testing.T) { ts := TestStruct{ - strVal: "hello", - intVal: 4, - int64Val: 4, - tzVal: "Europe/Brussels", + StrVal: "hello", + IntVal: 4, + Int64Val: 4, + TzVal: "Europe/Brussels", } - errors := map[string]string{ + expectedErrors := map[string]string{ "intVal": "must be lesser than 4", "int64Val": "must be lesser than 4", } - v := ts.Validate() - - assert.False(t, v.Valid()) - assert.Equal(t, errors, v.Errors) + valid, errors := ts.Validate() + assert.False(t, valid) + assert.Equal(t, expectedErrors, errors) } func TestIsNotLTE(t *testing.T) { ts := TestStruct{ - strVal: "hello", - intVal: 3, - int64Val: 3, - tzVal: "Europe/Brussels", + StrVal: "hello", + IntVal: 3, + Int64Val: 3, + TzVal: "Europe/Brussels", } - errors := map[string]string{ + expectedErrors := map[string]string{ "intVal": "must be lesser than or equal to 2", "int64Val": "must be lesser than or equal to 2", } - v := ts.Validate() - - assert.False(t, v.Valid()) - assert.Equal(t, errors, v.Errors) + valid, errors := ts.Validate() + assert.False(t, valid) + assert.Equal(t, expectedErrors, errors) } func TestIsNotValidTz(t *testing.T) { ts := TestStruct{ - strVal: "hello", - intVal: 1, - int64Val: 1, - tzVal: "AAAAAAAAAAAAAAAAAAAAAA", + StrVal: "hello", + IntVal: 1, + Int64Val: 1, + TzVal: "AAAAAAAAAAAAAAAAAAAAAA", } - errors := map[string]string{ + expectedErrors := map[string]string{ "tzVal": "must be a valid IANA value", } - v := ts.Validate() + valid, errors := ts.Validate() + assert.False(t, valid) + assert.Equal(t, expectedErrors, errors) +} + +func TestIsNotInSlice(t *testing.T) { + val := "notallowed" + ts := TestStruct{ + StrVal: "hello", + IntVal: 1, + Int64Val: 1, + TzVal: "Europe/Brussels", + OptStrVal: &val, + } + + expectedErrors := map[string]string{ + "optStrVal": "must be a valid value", + } - assert.False(t, v.Valid()) - assert.Equal(t, errors, v.Errors) + valid, errors := ts.Validate() + assert.False(t, valid) + assert.Equal(t, expectedErrors, errors) }