From dcfc48b8bb42a8c29df70374628b3c52f94424a7 Mon Sep 17 00:00:00 2001 From: Mostafa Moradian Date: Thu, 26 Dec 2024 14:44:42 +0100 Subject: [PATCH] Refactor commands (#644) * Refactor global variables into a global struct * Update tests * Use param instead of global variable * Update tests * Fix formatting * Add a hack to fix the tests * Ignore raft files * Fix tests * Remove global variables * Use a global app variable for test only and assign it only in tests * Add stopGracefully as a method to the GatewayDInstance struct * Move signal handler to the top to avoid half-states on exit * Add a new function for creating an instance of GatewayDInstance using parsed flags * Remove duplicate internal functions * Update tests * Use exported functions instead of internal variables * Remove global variables * Update tests * Fix linter error * Ignore dupl linter * Fix tests with the actual log output * Clean up after test * Remove global variables * Skip path slip verification * Fix tests * Remove backup * Another try to fix the test * Fix missing/unknown behavior * Add a pull-only test * Declare variable before assignment * Fix plugin install behavior * Fix plugin install with no overwrite * Revert changes * Rename variable * Use filepath to join paths * Remove unnecessary file * Ignore plugins file if exists * Remove duplicate code * Reset the pull-only flag * Check if the server is properly closed before erroring out * Refactor run command into a separate file * Fix bug in handling early exit * Move left-over functions * Add comments * Fix linter errors * Fix missing log message and span * Handle errors when stopping the listener for gRPC server * Graceful shutdown of gRPC server * Revert changes to path * Use local variable * Replace all context.TODO with context.Background * Split exit codes into a separate file * Remove unused constant and renumber others * Use exported function instead of internal variables * Ignore linter errors * Rename variable and comment to match the behavior --- .gitignore | 6 +- .golangci.yaml | 1 + api/api_helpers_test.go | 6 +- api/api_test.go | 38 +- api/grpc_server.go | 26 +- api/grpc_server_test.go | 9 +- api/healthcheck_test.go | 27 +- api/http_server.go | 24 +- api/http_server_test.go | 45 +- cmd/config_init.go | 16 +- cmd/config_lint.go | 9 +- cmd/configs.go | 20 +- cmd/gatewayd_app.go | 1143 ++++++++++++++++++++++++++++++++++ cmd/plugin_init.go | 14 +- cmd/plugin_init_test.go | 3 +- cmd/plugin_install.go | 126 ++-- cmd/plugin_install_test.go | 82 ++- cmd/plugin_lint.go | 10 +- cmd/plugin_list.go | 23 +- cmd/plugin_list_test.go | 2 + cmd/plugin_scaffold.go | 14 +- cmd/plugin_scaffold_test.go | 27 +- cmd/run.go | 1144 ++++------------------------------- cmd/run_test.go | 110 ++-- config/getters_test.go | 8 +- errors/errors.go | 9 - errors/exit_codes.go | 11 + network/server_test.go | 7 +- 28 files changed, 1602 insertions(+), 1358 deletions(-) create mode 100644 cmd/gatewayd_app.go create mode 100644 errors/exit_codes.go diff --git a/.gitignore b/.gitignore index afd2692a..6c193a0b 100644 --- a/.gitignore +++ b/.gitignore @@ -43,4 +43,8 @@ gatewayd-files/ cmd/gatewayd-plugin-cache-linux-amd64-* tempo-data -raft/node* +# Raft files +raft/node*/ + +# Accidental installation of plugin +plugins/gatewayd-plugin-cache diff --git a/.golangci.yaml b/.golangci.yaml index 60a34508..b3c826c8 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -4,6 +4,7 @@ linters: enable-all: true disable: - cyclop + - dupl - wsl - godox - gochecknoglobals diff --git a/api/api_helpers_test.go b/api/api_helpers_test.go index 7e5bba9b..381cb55c 100644 --- a/api/api_helpers_test.go +++ b/api/api_helpers_test.go @@ -12,7 +12,7 @@ import ( ) // getAPIConfig returns a new API configuration with all the necessary components. -func getAPIConfig() *API { +func getAPIConfig(httpAddr, grpcAddr string) *API { logger := zerolog.New(nil) defaultPool := pool.NewPool(context.Background(), config.DefaultPoolSize) pluginReg := plugin.NewRegistry( @@ -60,8 +60,8 @@ func getAPIConfig() *API { return &API{ Options: &Options{ GRPCNetwork: "tcp", - GRPCAddress: "localhost:19090", - HTTPAddress: "localhost:18080", + GRPCAddress: grpcAddr, + HTTPAddress: httpAddr, Logger: logger, Servers: servers, }, diff --git a/api/api_test.go b/api/api_test.go index 0804a9c3..e948712f 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -36,9 +36,9 @@ func TestGetVersion(t *testing.T) { func TestGetGlobalConfig(t *testing.T) { // Load config from the default config file. - conf := config.NewConfig(context.TODO(), + conf := config.NewConfig(context.Background(), config.Config{GlobalConfigFile: "../gatewayd.yaml", PluginConfigFile: "../gatewayd_plugins.yaml"}) - gerr := conf.InitConfig(context.TODO()) + gerr := conf.InitConfig(context.Background()) require.Nil(t, gerr) assert.NotEmpty(t, conf.Global) @@ -61,9 +61,9 @@ func TestGetGlobalConfig(t *testing.T) { func TestGetGlobalConfigWithGroupName(t *testing.T) { // Load config from the default config file. - conf := config.NewConfig(context.TODO(), + conf := config.NewConfig(context.Background(), config.Config{GlobalConfigFile: "../gatewayd.yaml", PluginConfigFile: "../gatewayd_plugins.yaml"}) - gerr := conf.InitConfig(context.TODO()) + gerr := conf.InitConfig(context.Background()) require.Nil(t, gerr) assert.NotEmpty(t, conf.Global) @@ -94,9 +94,9 @@ func TestGetGlobalConfigWithGroupName(t *testing.T) { func TestGetGlobalConfigWithNonExistingGroupName(t *testing.T) { // Load config from the default config file. - conf := config.NewConfig(context.TODO(), + conf := config.NewConfig(context.Background(), config.Config{GlobalConfigFile: "../gatewayd.yaml", PluginConfigFile: "../gatewayd_plugins.yaml"}) - gerr := conf.InitConfig(context.TODO()) + gerr := conf.InitConfig(context.Background()) require.Nil(t, gerr) assert.NotEmpty(t, conf.Global) @@ -112,9 +112,9 @@ func TestGetGlobalConfigWithNonExistingGroupName(t *testing.T) { func TestGetPluginConfig(t *testing.T) { // Load config from the default config file. - conf := config.NewConfig(context.TODO(), + conf := config.NewConfig(context.Background(), config.Config{GlobalConfigFile: "../gatewayd.yaml", PluginConfigFile: "../gatewayd_plugins.yaml"}) - gerr := conf.InitConfig(context.TODO()) + gerr := conf.InitConfig(context.Background()) require.Nil(t, gerr) assert.NotEmpty(t, conf.Global) @@ -141,7 +141,7 @@ func TestGetPlugins(t *testing.T) { Logger: zerolog.Logger{}, }) pluginRegistry := plugin.NewRegistry( - context.TODO(), + context.Background(), plugin.Registry{ ActRegistry: actRegistry, Compatibility: config.Loose, @@ -196,7 +196,7 @@ func TestGetPluginsWithEmptyPluginRegistry(t *testing.T) { Logger: zerolog.Logger{}, }) pluginRegistry := plugin.NewRegistry( - context.TODO(), + context.Background(), plugin.Registry{ ActRegistry: actRegistry, Compatibility: config.Loose, @@ -218,7 +218,7 @@ func TestGetPluginsWithEmptyPluginRegistry(t *testing.T) { func TestPools(t *testing.T) { api := API{ Pools: map[string]map[string]*pool.Pool{ - config.Default: {config.DefaultConfigurationBlock: pool.NewPool(context.TODO(), config.EmptyPoolCapacity)}, + config.Default: {config.DefaultConfigurationBlock: pool.NewPool(context.Background(), config.EmptyPoolCapacity)}, }, ctx: context.Background(), } @@ -253,13 +253,13 @@ func TestGetProxies(t *testing.T) { Network: config.DefaultNetwork, Address: postgresAddress, } - client := network.NewClient(context.TODO(), clientConfig, zerolog.Logger{}, nil) + client := network.NewClient(context.Background(), clientConfig, zerolog.Logger{}, nil) require.NotNil(t, client) - newPool := pool.NewPool(context.TODO(), 1) + newPool := pool.NewPool(context.Background(), 1) assert.Nil(t, newPool.Put(client.ID, client)) proxy := network.NewProxy( - context.TODO(), + context.Background(), network.Proxy{ AvailableConnections: newPool, HealthCheckPeriod: config.DefaultHealthCheckPeriod, @@ -305,13 +305,13 @@ func TestGetServers(t *testing.T) { Network: config.DefaultNetwork, Address: postgresAddress, } - client := network.NewClient(context.TODO(), clientConfig, zerolog.Logger{}, nil) - newPool := pool.NewPool(context.TODO(), 1) + client := network.NewClient(context.Background(), clientConfig, zerolog.Logger{}, nil) + newPool := pool.NewPool(context.Background(), 1) require.NotNil(t, newPool) assert.Nil(t, newPool.Put(client.ID, client)) proxy := network.NewProxy( - context.TODO(), + context.Background(), network.Proxy{ AvailableConnections: newPool, HealthCheckPeriod: config.DefaultHealthCheckPeriod, @@ -336,7 +336,7 @@ func TestGetServers(t *testing.T) { }) pluginRegistry := plugin.NewRegistry( - context.TODO(), + context.Background(), plugin.Registry{ ActRegistry: actRegistry, Compatibility: config.Loose, @@ -346,7 +346,7 @@ func TestGetServers(t *testing.T) { ) server := network.NewServer( - context.TODO(), + context.Background(), network.Server{ Network: config.DefaultNetwork, Address: postgresAddress, diff --git a/api/grpc_server.go b/api/grpc_server.go index a12744db..9a2eb824 100644 --- a/api/grpc_server.go +++ b/api/grpc_server.go @@ -2,6 +2,7 @@ package api import ( "context" + "errors" "net" v1 "github.com/gatewayd-io/gatewayd/api/v1" @@ -41,18 +42,23 @@ func NewGRPCServer(ctx context.Context, server GRPCServer) *GRPCServer { // Start starts the gRPC server. func (s *GRPCServer) Start() { - s.start(s.API, s.grpcServer, s.listener) + if err := s.grpcServer.Serve(s.listener); err != nil && !errors.Is(err, net.ErrClosed) { + s.API.Options.Logger.Err(err).Msg("failed to start gRPC API") + } } // Shutdown shuts down the gRPC server. -func (s *GRPCServer) Shutdown(_ context.Context) { - s.shutdown(s.grpcServer) +func (s *GRPCServer) Shutdown(context.Context) { + if err := s.listener.Close(); err != nil && !errors.Is(err, net.ErrClosed) { + s.API.Options.Logger.Err(err).Msg("failed to close listener") + } + s.grpcServer.GracefulStop() } // createGRPCAPI creates a new gRPC API server and listener. func createGRPCAPI(api *API, healthchecker *HealthChecker) (*grpc.Server, net.Listener) { listener, err := net.Listen(api.Options.GRPCNetwork, api.Options.GRPCAddress) - if err != nil { + if err != nil && !errors.Is(err, net.ErrClosed) { api.Options.Logger.Err(err).Msg("failed to start gRPC API") return nil, nil } @@ -64,15 +70,3 @@ func createGRPCAPI(api *API, healthchecker *HealthChecker) (*grpc.Server, net.Li return grpcServer, listener } - -// start starts the gRPC API. -func (s *GRPCServer) start(api *API, grpcServer *grpc.Server, listener net.Listener) { - if err := grpcServer.Serve(listener); err != nil { - api.Options.Logger.Err(err).Msg("failed to start gRPC API") - } -} - -// shutdown shuts down the gRPC API. -func (s *GRPCServer) shutdown(grpcServer *grpc.Server) { - grpcServer.GracefulStop() -} diff --git a/api/grpc_server_test.go b/api/grpc_server_test.go index 16f2c449..f7c0c95e 100644 --- a/api/grpc_server_test.go +++ b/api/grpc_server_test.go @@ -14,7 +14,10 @@ import ( // Test_GRPC_Server tests the gRPC server. func Test_GRPC_Server(t *testing.T) { - api := getAPIConfig() + api := getAPIConfig( + "localhost:18081", + "localhost:19091", + ) healthchecker := &HealthChecker{Servers: api.Servers} grpcServer := NewGRPCServer( context.Background(), GRPCServer{API: api, HealthChecker: healthchecker}) @@ -25,7 +28,7 @@ func Test_GRPC_Server(t *testing.T) { }(grpcServer) grpcClient, err := grpc.NewClient( - "localhost:19090", grpc.WithTransportCredentials(insecure.NewCredentials())) + "localhost:19091", grpc.WithTransportCredentials(insecure.NewCredentials())) assert.Nil(t, err) defer grpcClient.Close() @@ -36,5 +39,5 @@ func Test_GRPC_Server(t *testing.T) { assert.Equal(t, config.Version, resp.GetVersion()) assert.Equal(t, config.VersionInfo(), resp.GetVersionInfo()) - grpcServer.Shutdown(context.Background()) + grpcServer.Shutdown(nil) //nolint:staticcheck } diff --git a/api/healthcheck_test.go b/api/healthcheck_test.go index 74b0988f..d304b14f 100644 --- a/api/healthcheck_test.go +++ b/api/healthcheck_test.go @@ -24,13 +24,13 @@ func Test_Healthchecker(t *testing.T) { Network: config.DefaultNetwork, Address: postgresAddress, } - client := network.NewClient(context.TODO(), clientConfig, zerolog.Logger{}, nil) - newPool := pool.NewPool(context.TODO(), 1) + client := network.NewClient(context.Background(), clientConfig, zerolog.Logger{}, nil) + newPool := pool.NewPool(context.Background(), 1) require.NotNil(t, newPool) assert.Nil(t, newPool.Put(client.ID, client)) proxy := network.NewProxy( - context.TODO(), + context.Background(), network.Proxy{ AvailableConnections: newPool, HealthCheckPeriod: config.DefaultHealthCheckPeriod, @@ -55,7 +55,7 @@ func Test_Healthchecker(t *testing.T) { }) pluginRegistry := plugin.NewRegistry( - context.TODO(), + context.Background(), plugin.Registry{ ActRegistry: actRegistry, Compatibility: config.Loose, @@ -75,10 +75,10 @@ func Test_Healthchecker(t *testing.T) { }() server := network.NewServer( - context.TODO(), + context.Background(), network.Server{ Network: config.DefaultNetwork, - Address: postgresAddress, + Address: "127.0.0.1:15432", TickInterval: config.DefaultTickInterval, Options: network.Option{ EnableTicker: false, @@ -99,13 +99,14 @@ func Test_Healthchecker(t *testing.T) { raftNode: raftHelper.Node, } assert.NotNil(t, healthchecker) - hcr, err := healthchecker.Check(context.TODO(), &grpc_health_v1.HealthCheckRequest{}) + hcr, err := healthchecker.Check(context.Background(), &grpc_health_v1.HealthCheckRequest{}) assert.NoError(t, err) assert.NotNil(t, hcr) assert.Equal(t, grpc_health_v1.HealthCheckResponse_NOT_SERVING, hcr.GetStatus()) go func(t *testing.T, server *network.Server) { t.Helper() + if err := server.Run(); err != nil { t.Errorf("server.Run() error = %v", err) } @@ -113,7 +114,7 @@ func Test_Healthchecker(t *testing.T) { time.Sleep(1 * time.Second) // Test for SERVING status - hcr, err = healthchecker.Check(context.TODO(), &grpc_health_v1.HealthCheckRequest{}) + hcr, err = healthchecker.Check(context.Background(), &grpc_health_v1.HealthCheckRequest{}) assert.NoError(t, err) assert.NotNil(t, hcr) assert.Equal(t, grpc_health_v1.HealthCheckResponse_SERVING, hcr.GetStatus()) @@ -121,4 +122,14 @@ func Test_Healthchecker(t *testing.T) { err = healthchecker.Watch(&grpc_health_v1.HealthCheckRequest{}, nil) assert.Error(t, err) assert.Equal(t, "rpc error: code = Unimplemented desc = not implemented", err.Error()) + + server.Shutdown() + pluginRegistry.Shutdown() + + // Wait for the server to stop. + <-time.After(100 * time.Millisecond) + + // check server status and connections + assert.False(t, server.IsRunning()) + assert.Zero(t, server.CountConnections()) } diff --git a/api/http_server.go b/api/http_server.go index 596214b3..c516e00d 100644 --- a/api/http_server.go +++ b/api/http_server.go @@ -40,12 +40,17 @@ func NewHTTPServer(options *Options) *HTTPServer { // Start starts the HTTP server. func (s *HTTPServer) Start() { - s.start(s.options, s.httpServer) + // Start HTTP server (and proxy calls to gRPC server endpoint) + if err := s.httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + s.options.Logger.Err(err).Msg("failed to start HTTP API") + } } // Shutdown shuts down the HTTP server. func (s *HTTPServer) Shutdown(ctx context.Context) { - s.shutdown(ctx, s.httpServer, s.logger) + if err := s.httpServer.Shutdown(ctx); err != nil { + s.logger.Err(err).Msg("failed to shutdown HTTP API") + } } // CreateHTTPAPI creates a new HTTP API. @@ -113,18 +118,3 @@ func createHTTPAPI(options *Options) *http.Server { return server } - -// start starts the HTTP API. -func (s *HTTPServer) start(options *Options, server *http.Server) { - // Start HTTP server (and proxy calls to gRPC server endpoint) - if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { - options.Logger.Err(err).Msg("failed to start HTTP API") - } -} - -// shutdown shuts down the HTTP API. -func (s *HTTPServer) shutdown(ctx context.Context, server *http.Server, logger zerolog.Logger) { - if err := server.Shutdown(ctx); err != nil { - logger.Err(err).Msg("failed to shutdown HTTP API") - } -} diff --git a/api/http_server_test.go b/api/http_server_test.go index a28a2c93..bc6092fc 100644 --- a/api/http_server_test.go +++ b/api/http_server_test.go @@ -3,7 +3,6 @@ package api import ( "context" "encoding/json" - "fmt" "io" "net/http" "testing" @@ -16,45 +15,27 @@ import ( // Test_HTTP_Server tests the HTTP to gRPC gateway. func Test_HTTP_Server(t *testing.T) { - api := getAPIConfig() + api := getAPIConfig( + "localhost:18082", + "localhost:19092", + ) healthchecker := &HealthChecker{Servers: api.Servers} grpcServer := NewGRPCServer( context.Background(), GRPCServer{API: api, HealthChecker: healthchecker}) - require.NotNil(t, grpcServer, "gRPC server should not be nil") + go grpcServer.Start() + require.NotNil(t, grpcServer) httpServer := NewHTTPServer(api.Options) - require.NotNil(t, httpServer, "HTTP server should not be nil") - - // Start gRPC server with error handling - errChan := make(chan error, 1) - go func(grpcServer *GRPCServer) { - errChan <- func() error { - defer func() { - if r := recover(); r != nil { - errChan <- fmt.Errorf("gRPC server panicked: %v", r) - } - }() - grpcServer.Start() - return nil - }() - }(grpcServer) - - go func(httpServer *HTTPServer) { - httpServer.Start() - }(httpServer) + go httpServer.Start() + require.NotNil(t, httpServer) - // Wait for potential startup errors - select { - case err := <-errChan: - require.NoError(t, err, "gRPC server failed to start") - case <-time.After(1 * time.Second): // Wait for the servers to start - } + time.Sleep(time.Second) // Check version via the gRPC server. req, err := http.NewRequestWithContext( context.Background(), http.MethodGet, - "http://localhost:18080/v1/GatewayDPluginService/Version", + "http://localhost:18082/v1/GatewayDPluginService/Version", nil, ) require.NoError(t, err) @@ -73,7 +54,7 @@ func Test_HTTP_Server(t *testing.T) { req, err = http.NewRequestWithContext( context.Background(), http.MethodGet, - "http://localhost:18080/healthz", + "http://localhost:18082/healthz", nil, ) require.NoError(t, err) @@ -90,7 +71,7 @@ func Test_HTTP_Server(t *testing.T) { req, err = http.NewRequestWithContext( context.Background(), http.MethodGet, - "http://localhost:18080/version", + "http://localhost:18082/version", nil, ) require.NoError(t, err) @@ -104,6 +85,6 @@ func Test_HTTP_Server(t *testing.T) { assert.Equal(t, len(config.Version), len(respBodyBytes)) assert.Equal(t, config.Version, string(respBodyBytes)) - grpcServer.Shutdown(context.Background()) + grpcServer.Shutdown(nil) //nolint:staticcheck httpServer.Shutdown(context.Background()) } diff --git a/cmd/config_init.go b/cmd/config_init.go index 57c33dd9..e2064ac2 100644 --- a/cmd/config_init.go +++ b/cmd/config_init.go @@ -6,13 +6,15 @@ import ( "github.com/spf13/cobra" ) -var force bool - // configInitCmd represents the plugin init command. var configInitCmd = &cobra.Command{ Use: "init", Short: "Create or overwrite the GatewayD global config", Run: func(cmd *cobra.Command, _ []string) { + force, _ := cmd.Flags().GetBool("force") + enableSentry, _ := cmd.Flags().GetBool("sentry") + globalConfigFile, _ := cmd.Flags().GetString("config") + // Enable Sentry. if enableSentry { // Initialize Sentry. @@ -39,12 +41,10 @@ var configInitCmd = &cobra.Command{ func init() { configCmd.AddCommand(configInitCmd) - configInitCmd.Flags().BoolVarP( - &force, "force", "f", false, "Force overwrite of existing config file") - configInitCmd.Flags().StringVarP( - &globalConfigFile, // Already exists in run.go + configInitCmd.Flags().BoolP( + "force", "f", false, "Force overwrite of existing config file") + configInitCmd.Flags().StringP( "config", "c", config.GetDefaultConfigFilePath(config.GlobalConfigFilename), "Global config file") - configInitCmd.Flags().BoolVar( - &enableSentry, "sentry", true, "Enable Sentry") // Already exists in run.go + configInitCmd.Flags().Bool("sentry", true, "Enable Sentry") } diff --git a/cmd/config_lint.go b/cmd/config_lint.go index fa52e314..fb7e0de5 100644 --- a/cmd/config_lint.go +++ b/cmd/config_lint.go @@ -14,6 +14,9 @@ var configLintCmd = &cobra.Command{ Use: "lint", Short: "Lint the GatewayD global config", Run: func(cmd *cobra.Command, _ []string) { + enableSentry, _ := cmd.Flags().GetBool("sentry") + globalConfigFile, _ := cmd.Flags().GetString("config") + // Enable Sentry. if enableSentry { // Initialize Sentry. @@ -44,10 +47,8 @@ var configLintCmd = &cobra.Command{ func init() { configCmd.AddCommand(configLintCmd) - configLintCmd.Flags().StringVarP( - &globalConfigFile, // Already exists in run.go + configLintCmd.Flags().StringP( "config", "c", config.GetDefaultConfigFilePath(config.GlobalConfigFilename), "Global config file") - configLintCmd.Flags().BoolVar( - &enableSentry, "sentry", true, "Enable Sentry") // Already exists in run.go + configLintCmd.Flags().Bool("sentry", true, "Enable Sentry") } diff --git a/cmd/configs.go b/cmd/configs.go index 5baf9185..7152ca2f 100644 --- a/cmd/configs.go +++ b/cmd/configs.go @@ -27,7 +27,7 @@ func generateConfig( GlobalKoanf: koanf.New("."), PluginKoanf: koanf.New("."), } - if err := conf.LoadDefaults(context.TODO()); err != nil { + if err := conf.LoadDefaults(context.Background()); err != nil { logger.Fatal(err) } @@ -73,28 +73,28 @@ func lintConfig(fileType configFileType, configFile string) *gerr.GatewayDError var conf *config.Config switch fileType { case Global: - conf = config.NewConfig(context.TODO(), config.Config{GlobalConfigFile: configFile}) - if err := conf.LoadDefaults(context.TODO()); err != nil { + conf = config.NewConfig(context.Background(), config.Config{GlobalConfigFile: configFile}) + if err := conf.LoadDefaults(context.Background()); err != nil { return err } - if err := conf.LoadGlobalConfigFile(context.TODO()); err != nil { + if err := conf.LoadGlobalConfigFile(context.Background()); err != nil { return err } - if err := conf.ConvertKeysToLowercase(context.TODO()); err != nil { + if err := conf.ConvertKeysToLowercase(context.Background()); err != nil { return err } - if err := conf.UnmarshalGlobalConfig(context.TODO()); err != nil { + if err := conf.UnmarshalGlobalConfig(context.Background()); err != nil { return err } case Plugins: - conf = config.NewConfig(context.TODO(), config.Config{PluginConfigFile: configFile}) - if err := conf.LoadDefaults(context.TODO()); err != nil { + conf = config.NewConfig(context.Background(), config.Config{PluginConfigFile: configFile}) + if err := conf.LoadDefaults(context.Background()); err != nil { return err } - if err := conf.LoadPluginConfigFile(context.TODO()); err != nil { + if err := conf.LoadPluginConfigFile(context.Background()); err != nil { return err } - if err := conf.UnmarshalPluginConfig(context.TODO()); err != nil { + if err := conf.UnmarshalPluginConfig(context.Background()); err != nil { return err } default: diff --git a/cmd/gatewayd_app.go b/cmd/gatewayd_app.go new file mode 100644 index 00000000..fdd4c49d --- /dev/null +++ b/cmd/gatewayd_app.go @@ -0,0 +1,1143 @@ +package cmd + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "os" + "os/signal" + "runtime" + "strconv" + "sync/atomic" + "time" + + "github.com/NYTimes/gziphandler" + sdkAct "github.com/gatewayd-io/gatewayd-plugin-sdk/act" + sdkPlugin "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin" + v1 "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin/v1" + "github.com/gatewayd-io/gatewayd/act" + "github.com/gatewayd-io/gatewayd/api" + "github.com/gatewayd-io/gatewayd/config" + gerr "github.com/gatewayd-io/gatewayd/errors" + "github.com/gatewayd-io/gatewayd/logging" + "github.com/gatewayd-io/gatewayd/metrics" + "github.com/gatewayd-io/gatewayd/network" + "github.com/gatewayd-io/gatewayd/plugin" + "github.com/gatewayd-io/gatewayd/pool" + "github.com/gatewayd-io/gatewayd/raft" + usage "github.com/gatewayd-io/gatewayd/usagereport/v1" + "github.com/getsentry/sentry-go" + "github.com/go-co-op/gocron" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promhttp" + "github.com/redis/go-redis/v9" + "github.com/rs/zerolog" + "github.com/spf13/cobra" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" +) + +var _ io.Writer = &cobraCmdWriter{} + +type cobraCmdWriter struct { + *cobra.Command +} + +func (c *cobraCmdWriter) Write(p []byte) (int, error) { + c.Print(string(p)) + return len(p), nil +} + +var UsageReportURL = "localhost:59091" + +const ( + DefaultMetricsServerProbeTimeout = 5 * time.Second +) + +type GatewayDApp struct { + EnableTracing bool + EnableSentry bool + EnableLinting bool + EnableUsageReport bool + DevMode bool + CollectorURL string + PluginConfigFile string + GlobalConfigFile string + + conf *config.Config + pluginRegistry *plugin.Registry + actRegistry *act.Registry + metricsServer *http.Server + metricsMerger *metrics.Merger + httpServer *api.HTTPServer + grpcServer *api.GRPCServer + + loggers map[string]zerolog.Logger + pools map[string]map[string]*pool.Pool + clients map[string]map[string]*config.Client + proxies map[string]map[string]*network.Proxy + servers map[string]*network.Server + healthCheckScheduler *gocron.Scheduler + stopChan chan struct{} + ranStopGracefully *atomic.Bool +} + +// NewGatewayDApp creates a new GatewayDApp instance. +func NewGatewayDApp(cmd *cobra.Command) *GatewayDApp { + app := GatewayDApp{ + loggers: make(map[string]zerolog.Logger), + pools: make(map[string]map[string]*pool.Pool), + clients: make(map[string]map[string]*config.Client), + proxies: make(map[string]map[string]*network.Proxy), + servers: make(map[string]*network.Server), + healthCheckScheduler: gocron.NewScheduler(time.UTC), + stopChan: make(chan struct{}), + ranStopGracefully: &atomic.Bool{}, + } + app.EnableTracing, _ = cmd.Flags().GetBool("enable-tracing") + app.EnableSentry, _ = cmd.Flags().GetBool("enable-sentry") + app.EnableUsageReport, _ = cmd.Flags().GetBool("enable-usage-report") + app.EnableLinting, _ = cmd.Flags().GetBool("enable-linting") + app.DevMode, _ = cmd.Flags().GetBool("dev") + app.CollectorURL, _ = cmd.Flags().GetString("collector-url") + app.GlobalConfigFile, _ = cmd.Flags().GetString("config") + app.PluginConfigFile, _ = cmd.Flags().GetString("plugin-config") + return &app +} + +// loadConfig loads global and plugin configuration. +func (app *GatewayDApp) loadConfig(runCtx context.Context) error { + app.conf = config.NewConfig(runCtx, + config.Config{ + GlobalConfigFile: app.GlobalConfigFile, + PluginConfigFile: app.PluginConfigFile, + }, + ) + if err := app.conf.InitConfig(runCtx); err != nil { + return err + } + return nil +} + +// createLoggers creates loggers from the config. +func (app *GatewayDApp) createLoggers( + runCtx context.Context, cmd *cobra.Command, +) zerolog.Logger { + // Use cobra command cmd instead of os.Stdout for the console output. + cmdLogger := &cobraCmdWriter{cmd} + + // Create a logger for each tenant. + for name, cfg := range app.conf.Global.Loggers { + app.loggers[name] = logging.NewLogger(runCtx, logging.LoggerConfig{ + Output: cfg.GetOutput(), + ConsoleOut: cmdLogger, + Level: config.If( + config.Exists(config.LogLevels, cfg.Level), + config.LogLevels[cfg.Level], + config.LogLevels[config.DefaultLogLevel], + ), + TimeFormat: config.If( + config.Exists(config.TimeFormats, cfg.TimeFormat), + config.TimeFormats[cfg.TimeFormat], + config.TimeFormats[config.DefaultTimeFormat], + ), + ConsoleTimeFormat: config.If( + config.Exists( + config.ConsoleTimeFormats, cfg.ConsoleTimeFormat), + config.ConsoleTimeFormats[cfg.ConsoleTimeFormat], + config.ConsoleTimeFormats[config.DefaultConsoleTimeFormat], + ), + NoColor: cfg.NoColor, + FileName: cfg.FileName, + MaxSize: cfg.MaxSize, + MaxBackups: cfg.MaxBackups, + MaxAge: cfg.MaxAge, + Compress: cfg.Compress, + LocalTime: cfg.LocalTime, + SyslogPriority: cfg.GetSyslogPriority(), + RSyslogNetwork: cfg.RSyslogNetwork, + RSyslogAddress: cfg.RSyslogAddress, + Name: name, + }) + } + return app.loggers[config.Default] +} + +// createActRegistry creates a new act registry given +// the built-in signals, policies, and actions. +func (app *GatewayDApp) createActRegistry(logger zerolog.Logger) error { + // Create a new act registry given the built-in signals, policies, and actions. + var publisher *act.Publisher + if app.conf.Plugin.ActionRedis.Enabled { + rdb := redis.NewClient(&redis.Options{ + Addr: app.conf.Plugin.ActionRedis.Address, + }) + var err error + publisher, err = act.NewPublisher(act.Publisher{ + Logger: logger, + RedisDB: rdb, + ChannelName: app.conf.Plugin.ActionRedis.Channel, + }) + if err != nil { + logger.Error().Err(err).Msg("Failed to create publisher for act registry") + return err //nolint:wrapcheck + } + logger.Info().Msg("Created Redis publisher for Act registry") + } + + app.actRegistry = act.NewActRegistry( + act.Registry{ + Signals: act.BuiltinSignals(), + Policies: act.BuiltinPolicies(), + Actions: act.BuiltinActions(), + DefaultPolicyName: app.conf.Plugin.DefaultPolicy, + PolicyTimeout: app.conf.Plugin.PolicyTimeout, + DefaultActionTimeout: app.conf.Plugin.ActionTimeout, + TaskPublisher: publisher, + Logger: logger, + }, + ) + + return nil +} + +// loadPolicies loads policies from the configuration file and +// adds them to the registry. +func (app *GatewayDApp) loadPolicies(logger zerolog.Logger) error { + // Load policies from the configuration file and add them to the registry. + for _, plc := range app.conf.Plugin.Policies { + policy, err := sdkAct.NewPolicy( + plc.Name, plc.Policy, plc.Metadata, + ) + if err != nil || policy == nil { + logger.Error().Err(err).Str("name", plc.Name).Msg("Failed to create policy") + return err //nolint:wrapcheck + } + app.actRegistry.Add(policy) + } + + return nil +} + +// createPluginRegistry creates a new plugin registry. +func (app *GatewayDApp) createPluginRegistry(runCtx context.Context, logger zerolog.Logger) { + // Create a new plugin registry. + // The plugins are loaded and hooks registered before the configuration is loaded. + app.pluginRegistry = plugin.NewRegistry( + runCtx, + plugin.Registry{ + ActRegistry: app.actRegistry, + Compatibility: config.If( + config.Exists( + config.CompatibilityPolicies, app.conf.Plugin.CompatibilityPolicy, + ), + config.CompatibilityPolicies[app.conf.Plugin.CompatibilityPolicy], + config.DefaultCompatibilityPolicy), + Logger: logger, + DevMode: app.DevMode, + }, + ) +} + +// startMetricsMerger starts the metrics merger if enabled. +func (app *GatewayDApp) startMetricsMerger(runCtx context.Context, logger zerolog.Logger) { + // Start the metrics merger if enabled. + if app.conf.Plugin.EnableMetricsMerger { + app.metricsMerger = metrics.NewMerger(runCtx, metrics.Merger{ + MetricsMergerPeriod: app.conf.Plugin.MetricsMergerPeriod, + Logger: logger, + }) + app.pluginRegistry.ForEach(func(_ sdkPlugin.Identifier, plugin *plugin.Plugin) { + if metricsEnabled, err := strconv.ParseBool(plugin.Config["metricsEnabled"]); err == nil && metricsEnabled { + app.metricsMerger.Add(plugin.ID.Name, plugin.Config["metricsUnixDomainSocket"]) + logger.Debug().Str("plugin", plugin.ID.Name).Msg( + "Added plugin to metrics merger") + } + }) + app.metricsMerger.Start() //nolint:contextcheck + } +} + +// startHealthCheckScheduler starts the health check scheduler if enabled. +func (app *GatewayDApp) startHealthCheckScheduler( + runCtx, ctx context.Context, span trace.Span, logger zerolog.Logger, +) { + // Ping the plugins to check if they are alive, and remove them if they are not. + startDelay := time.Now().Add(app.conf.Plugin.HealthCheckPeriod) + if _, err := app.healthCheckScheduler.Every( + app.conf.Plugin.HealthCheckPeriod).SingletonMode().StartAt(startDelay).Do( + func() { + _, span := otel.Tracer(config.TracerName).Start(ctx, "Run plugin health check") + defer span.End() + + var plugins []string + app.pluginRegistry.ForEach( + func(pluginId sdkPlugin.Identifier, plugin *plugin.Plugin) { + if err := plugin.Ping(); err != nil { + span.RecordError(err) + logger.Error().Err(err).Msg("Failed to ping plugin") + if app.conf.Plugin.EnableMetricsMerger && app.metricsMerger != nil { + app.metricsMerger.Remove(pluginId.Name) + } + app.pluginRegistry.Remove(pluginId) + + if !app.conf.Plugin.ReloadOnCrash { + return // Do not reload the plugins. + } + + // Reload the plugins and register their hooks upon crash. + logger.Info().Str("name", pluginId.Name).Msg("Reloading crashed plugin") + pluginConfig := app.conf.Plugin.GetPlugins(pluginId.Name) + if pluginConfig != nil { + app.pluginRegistry.LoadPlugins(runCtx, pluginConfig, app.conf.Plugin.StartTimeout) + } + } else { + logger.Trace().Str("name", pluginId.Name).Msg("Successfully pinged plugin") + plugins = append(plugins, pluginId.Name) + } + }) + span.SetAttributes(attribute.StringSlice("plugins", plugins)) + }); err != nil { + logger.Error().Err(err).Msg("Failed to start plugin health check scheduler") + span.RecordError(err) + } + + // Start the health check scheduler only if there are plugins. + if app.pluginRegistry.Size() > 0 { + logger.Info().Str( + "healthCheckPeriod", app.conf.Plugin.HealthCheckPeriod.String(), + ).Msg("Starting plugin health check scheduler") + app.healthCheckScheduler.StartAsync() + } +} + +// onConfigLoaded runs the OnConfigLoaded hook and +// merges the global config with the one from the plugins. +func (app *GatewayDApp) onConfigLoaded( + runCtx context.Context, span trace.Span, logger zerolog.Logger, +) error { + // Set the plugin timeout context. + pluginTimeoutCtx, cancel := context.WithTimeout( + context.Background(), app.conf.Plugin.Timeout) + defer cancel() + + // The config will be passed to the plugins that register to the "OnConfigLoaded" plugin. + // The plugins can modify the config and return it. + updatedGlobalConfig, err := app.pluginRegistry.Run( //nolint:contextcheck + pluginTimeoutCtx, app.conf.GlobalKoanf.All(), v1.HookName_HOOK_NAME_ON_CONFIG_LOADED) + if err != nil { + logger.Error().Err(err).Msg("Failed to run OnConfigLoaded hooks") + span.RecordError(err) + } + if updatedGlobalConfig != nil { + updatedGlobalConfig = app.pluginRegistry.ActRegistry.RunAll(updatedGlobalConfig) //nolint:contextcheck + } + + // If the config was modified by the plugins, merge it with the one loaded from the file. + // Only global configuration is merged, which means that plugins cannot modify the plugin + // configurations. + if updatedGlobalConfig != nil { + // Merge the config with the one loaded from the file (in memory). + // The changes won't be persisted to disk. + if err := app.conf.MergeGlobalConfig(runCtx, updatedGlobalConfig); err != nil { + logger.Error().Err(err).Msg("Failed to merge global config") + span.RecordError(err) + return err + } + } + + return nil +} + +// startMetricsServer starts the metrics server if enabled. +func (app *GatewayDApp) startMetricsServer( + runCtx context.Context, logger zerolog.Logger, +) error { + // Start the metrics server if enabled. + // TODO: Start multiple metrics servers. For now, only one default is supported. + // I should first find a use case for those multiple metrics servers. + _, span := otel.Tracer(config.TracerName).Start(runCtx, "Start metrics server") + defer span.End() + + metricsConfig := app.conf.Global.Metrics[config.Default] + + // TODO: refactor this to a separate function. + if !metricsConfig.Enabled { + logger.Info().Msg("Metrics server is disabled") + return nil + } + + scheme := "http://" + if metricsConfig.KeyFile != "" && metricsConfig.CertFile != "" { + scheme = "https://" + } + + fqdn, err := url.Parse(scheme + metricsConfig.Address) + if err != nil { + logger.Error().Err(err).Msg("Failed to parse metrics address") + span.RecordError(err) + return err //nolint:wrapcheck + } + + address, err := url.JoinPath(fqdn.String(), metricsConfig.Path) + if err != nil { + logger.Error().Err(err).Msg("Failed to parse metrics path") + span.RecordError(err) + return err //nolint:wrapcheck + } + + // Merge the metrics from the plugins with the ones from GatewayD. + mergedMetricsHandler := func(next http.Handler) http.Handler { + handler := func(responseWriter http.ResponseWriter, request *http.Request) { + if _, err := responseWriter.Write(app.metricsMerger.OutputMetrics); err != nil { + logger.Error().Err(err).Msg("Failed to write metrics") + span.RecordError(err) + sentry.CaptureException(err) + } + // The WriteHeader method intentionally does nothing, to prevent a bug + // in the merging metrics that causes the headers to be written twice, + // which results in an error: "http: superfluous response.WriteHeader call". + next.ServeHTTP( + &metrics.HeaderBypassResponseWriter{ + ResponseWriter: responseWriter, + }, + request) + } + return http.HandlerFunc(handler) + } + + handler := func() http.Handler { + return promhttp.InstrumentMetricHandler( + prometheus.DefaultRegisterer, + promhttp.HandlerFor(prometheus.DefaultGatherer, promhttp.HandlerOpts{ + DisableCompression: true, + }), + ) + }() + + mux := http.NewServeMux() + mux.HandleFunc("/", func(responseWriter http.ResponseWriter, _ *http.Request) { + // Serve a static page with a link to the metrics endpoint. + if _, err := responseWriter.Write([]byte(fmt.Sprintf( + `GatewayD Prometheus Metrics ServerMetrics`, + address, + ))); err != nil { + logger.Error().Err(err).Msg("Failed to write metrics") + span.RecordError(err) + sentry.CaptureException(err) + } + }) + + if app.conf.Plugin.EnableMetricsMerger && app.metricsMerger != nil { + handler = mergedMetricsHandler(handler) + } + + readHeaderTimeout := config.If( + metricsConfig.ReadHeaderTimeout > 0, + metricsConfig.ReadHeaderTimeout, + config.DefaultReadHeaderTimeout, + ) + + // Check if the metrics server is already running before registering the handler + ctx, cancel := context.WithTimeout(context.Background(), DefaultMetricsServerProbeTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, address, nil) //nolint:contextcheck + if err != nil { + logger.Error().Err(err).Msg("Failed to create request to check metrics server") + span.RecordError(err) + } + + if resp, err := http.DefaultClient.Do(req); err != nil { + // The timeout handler limits the nested handlers from running for too long. + mux.Handle( + metricsConfig.Path, + http.TimeoutHandler( + gziphandler.GzipHandler(handler), + readHeaderTimeout, + "The request timed out while fetching the metrics", + ), + ) + } else { + if resp != nil && resp.Body != nil { + defer resp.Body.Close() + } + logger.Warn().Msg("Metrics server is already running, consider changing the port") + span.RecordError(err) + } + + // Create a new metrics server. + timeout := config.If( + metricsConfig.Timeout > 0, + metricsConfig.Timeout, + config.DefaultMetricsServerTimeout, + ) + app.metricsServer = &http.Server{ + Addr: metricsConfig.Address, + Handler: mux, + ReadHeaderTimeout: readHeaderTimeout, + ReadTimeout: timeout, + WriteTimeout: timeout, + IdleTimeout: timeout, + } + + logger.Info().Fields(map[string]any{ + "address": address, + "timeout": timeout.String(), + "readHeaderTimeout": readHeaderTimeout.String(), + }).Msg("Metrics are exposed") + + if metricsConfig.CertFile != "" && metricsConfig.KeyFile != "" { + // Set up TLS. + app.metricsServer.TLSConfig = &tls.Config{ + MinVersion: tls.VersionTLS13, + CurvePreferences: []tls.CurveID{ + tls.CurveP521, + tls.CurveP384, + tls.CurveP256, + }, + CipherSuites: []uint16{ + tls.TLS_AES_128_GCM_SHA256, + tls.TLS_AES_256_GCM_SHA384, + tls.TLS_CHACHA20_POLY1305_SHA256, + }, + } + app.metricsServer.TLSNextProto = make( + map[string]func(*http.Server, *tls.Conn, http.Handler)) + logger.Debug().Msg("Metrics server is running with TLS") + + // Start the metrics server with TLS. + if err = app.metricsServer.ListenAndServeTLS( + metricsConfig.CertFile, metricsConfig.KeyFile); !errors.Is(err, http.ErrServerClosed) { + logger.Error().Err(err).Msg("Failed to start metrics server") + span.RecordError(err) + } + } else { + // Start the metrics server without TLS. + if err = app.metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { + logger.Error().Err(err).Msg("Failed to start metrics server") + span.RecordError(err) + } + } + + return nil +} + +// onNewLogger runs the OnNewLogger hook. +func (app *GatewayDApp) onNewLogger( + span trace.Span, logger zerolog.Logger, +) { + // This is a notification hook, so we don't care about the result. + pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), app.conf.Plugin.Timeout) + defer cancel() + + if data, ok := app.conf.GlobalKoanf.Get("loggers").(map[string]any); ok { + result, err := app.pluginRegistry.Run( + pluginTimeoutCtx, data, v1.HookName_HOOK_NAME_ON_NEW_LOGGER) + if err != nil { + logger.Error().Err(err).Msg("Failed to run OnNewLogger hooks") + span.RecordError(err) + } + if result != nil { + _ = app.pluginRegistry.ActRegistry.RunAll(result) + } + } else { + logger.Error().Msg("Failed to get loggers from config") + } +} + +// createPoolAndClients creates pools of connections and clients. +func (app *GatewayDApp) createPoolAndClients( + runCtx context.Context, span trace.Span, +) error { + // Create and initialize pools of connections. + for configGroupName, configGroup := range app.conf.Global.Pools { + for configBlockName, cfg := range configGroup { + logger := app.loggers[configGroupName] + // Check if the pool size is greater than zero. + currentPoolSize := config.If( + cfg.Size > 0, + // Check if the pool size is greater than the minimum pool size. + config.If( + cfg.Size > config.MinimumPoolSize, + cfg.Size, + config.MinimumPoolSize, + ), + config.DefaultPoolSize, + ) + + if _, ok := app.pools[configGroupName]; !ok { + app.pools[configGroupName] = make(map[string]*pool.Pool) + } + app.pools[configGroupName][configBlockName] = pool.NewPool(runCtx, currentPoolSize) + + span.AddEvent("Create pool", trace.WithAttributes( + attribute.String("name", configBlockName), + attribute.Int("size", currentPoolSize), + )) + + if _, ok := app.clients[configGroupName]; !ok { + app.clients[configGroupName] = make(map[string]*config.Client) + } + + // Get client config from the config file. + if clientConfig, ok := app.conf.Global.Clients[configGroupName][configBlockName]; !ok { + // This ensures that the default client config is used if the pool name is not + // found in the clients section. + app.clients[configGroupName][configBlockName] = app.conf.Global.Clients[config.Default][config.DefaultConfigurationBlock] //nolint:lll + } else { + // Merge the default client config with the one from the pool. + app.clients[configGroupName][configBlockName] = clientConfig + } + + // Fill the missing and zero values with the default ones. + app.clients[configGroupName][configBlockName].TCPKeepAlivePeriod = config.If( + app.clients[configGroupName][configBlockName].TCPKeepAlivePeriod > 0, + app.clients[configGroupName][configBlockName].TCPKeepAlivePeriod, + config.DefaultTCPKeepAlivePeriod, + ) + app.clients[configGroupName][configBlockName].ReceiveDeadline = config.If( + app.clients[configGroupName][configBlockName].ReceiveDeadline > 0, + app.clients[configGroupName][configBlockName].ReceiveDeadline, + config.DefaultReceiveDeadline, + ) + app.clients[configGroupName][configBlockName].ReceiveTimeout = config.If( + app.clients[configGroupName][configBlockName].ReceiveTimeout > 0, + app.clients[configGroupName][configBlockName].ReceiveTimeout, + config.DefaultReceiveTimeout, + ) + app.clients[configGroupName][configBlockName].SendDeadline = config.If( + app.clients[configGroupName][configBlockName].SendDeadline > 0, + app.clients[configGroupName][configBlockName].SendDeadline, + config.DefaultSendDeadline, + ) + app.clients[configGroupName][configBlockName].ReceiveChunkSize = config.If( + app.clients[configGroupName][configBlockName].ReceiveChunkSize > 0, + app.clients[configGroupName][configBlockName].ReceiveChunkSize, + config.DefaultChunkSize, + ) + app.clients[configGroupName][configBlockName].DialTimeout = config.If( + app.clients[configGroupName][configBlockName].DialTimeout > 0, + app.clients[configGroupName][configBlockName].DialTimeout, + config.DefaultDialTimeout, + ) + + // Add clients to the pool. + for range currentPoolSize { + clientConfig := app.clients[configGroupName][configBlockName] + clientConfig.GroupName = configGroupName + clientConfig.BlockName = configBlockName + client := network.NewClient( + runCtx, clientConfig, logger, + network.NewRetry( + network.Retry{ + Retries: clientConfig.Retries, + Backoff: config.If( + clientConfig.Backoff > 0, + clientConfig.Backoff, + config.DefaultBackoff, + ), + BackoffMultiplier: clientConfig.BackoffMultiplier, + DisableBackoffCaps: clientConfig.DisableBackoffCaps, + Logger: app.loggers[configBlockName], + }, + ), + ) + + if client == nil { + return errors.New("failed to create client, please check the configuration") + } + + eventOptions := trace.WithAttributes( + attribute.String("name", configBlockName), + attribute.String("group", configGroupName), + attribute.String("network", client.Network), + attribute.String("address", client.Address), + attribute.Int("receiveChunkSize", client.ReceiveChunkSize), + attribute.String("receiveDeadline", client.ReceiveDeadline.String()), + attribute.String("receiveTimeout", client.ReceiveTimeout.String()), + attribute.String("sendDeadline", client.SendDeadline.String()), + attribute.String("dialTimeout", client.DialTimeout.String()), + attribute.Bool("tcpKeepAlive", client.TCPKeepAlive), + attribute.String("tcpKeepAlivePeriod", client.TCPKeepAlivePeriod.String()), + attribute.String("localAddress", client.LocalAddr()), + attribute.String("remoteAddress", client.RemoteAddr()), + attribute.Int("retries", clientConfig.Retries), + attribute.String("backoff", client.Retry().Backoff.String()), + attribute.Float64("backoffMultiplier", clientConfig.BackoffMultiplier), + attribute.Bool("disableBackoffCaps", clientConfig.DisableBackoffCaps), + ) + if client.ID != "" { + eventOptions = trace.WithAttributes( + attribute.String("id", client.ID), + ) + } + + span.AddEvent("Create client", eventOptions) + + pluginTimeoutCtx, cancel := context.WithTimeout( + context.Background(), app.conf.Plugin.Timeout) + defer cancel() + + clientCfg := map[string]any{ + "id": client.ID, + "name": configBlockName, + "group": configGroupName, + "network": client.Network, + "address": client.Address, + "receiveChunkSize": client.ReceiveChunkSize, + "receiveDeadline": client.ReceiveDeadline.String(), + "receiveTimeout": client.ReceiveTimeout.String(), + "sendDeadline": client.SendDeadline.String(), + "dialTimeout": client.DialTimeout.String(), + "tcpKeepAlive": client.TCPKeepAlive, + "tcpKeepAlivePeriod": client.TCPKeepAlivePeriod.String(), + "localAddress": client.LocalAddr(), + "remoteAddress": client.RemoteAddr(), + "retries": clientConfig.Retries, + "backoff": client.Retry().Backoff.String(), + "backoffMultiplier": clientConfig.BackoffMultiplier, + "disableBackoffCaps": clientConfig.DisableBackoffCaps, + } + result, err := app.pluginRegistry.Run( //nolint:contextcheck + pluginTimeoutCtx, clientCfg, v1.HookName_HOOK_NAME_ON_NEW_CLIENT) + if err != nil { + logger.Error().Err(err).Msg("Failed to run OnNewClient hooks") + span.RecordError(err) + } + if result != nil { + _ = app.pluginRegistry.ActRegistry.RunAll(result) //nolint:contextcheck + } + + err = app.pools[configGroupName][configBlockName].Put(client.ID, client) + if err != nil { + logger.Error().Err(err).Msg("Failed to add client to the pool") + span.RecordError(err) + } + } + + // Verify that the pool is properly populated. + logger.Info().Fields(map[string]any{ + "name": configBlockName, + "count": strconv.Itoa(app.pools[configGroupName][configBlockName].Size()), + }).Msg("There are clients available in the pool") + + if app.pools[configGroupName][configBlockName].Size() != currentPoolSize { + logger.Error().Msg( + "The pool size is incorrect, either because " + + "the clients cannot connect due to no network connectivity " + + "or the server is not running. exiting...") + app.pluginRegistry.Shutdown() + return errors.New("failed to initialize pool, please check the configuration") + } + + // Run the OnNewPool hook. + pluginTimeoutCtx, cancel := context.WithTimeout( + context.Background(), app.conf.Plugin.Timeout) + defer cancel() + + result, err := app.pluginRegistry.Run( //nolint:contextcheck + pluginTimeoutCtx, + map[string]any{"name": configBlockName, "size": currentPoolSize}, + v1.HookName_HOOK_NAME_ON_NEW_POOL) + if err != nil { + logger.Error().Err(err).Msg("Failed to run OnNewPool hooks") + span.RecordError(err) + } + if result != nil { + _ = app.pluginRegistry.ActRegistry.RunAll(result) //nolint:contextcheck + } + } + } + + return nil +} + +// createProxies creates proxies. +func (app *GatewayDApp) createProxies(runCtx context.Context, span trace.Span) { + // Create and initialize prefork proxies with each pool of clients. + for configGroupName, configGroup := range app.conf.Global.Proxies { + for configBlockName, cfg := range configGroup { + logger := app.loggers[configGroupName] + clientConfig := app.clients[configGroupName][configBlockName] + + // Fill the missing and zero value with the default one. + cfg.HealthCheckPeriod = config.If( + cfg.HealthCheckPeriod > 0, + cfg.HealthCheckPeriod, + config.DefaultHealthCheckPeriod, + ) + + if _, ok := app.proxies[configGroupName]; !ok { + app.proxies[configGroupName] = make(map[string]*network.Proxy) + } + + app.proxies[configGroupName][configBlockName] = network.NewProxy( + runCtx, + network.Proxy{ + GroupName: configGroupName, + BlockName: configBlockName, + AvailableConnections: app.pools[configGroupName][configBlockName], + PluginRegistry: app.pluginRegistry, + HealthCheckPeriod: cfg.HealthCheckPeriod, + ClientConfig: clientConfig, + Logger: logger, + PluginTimeout: app.conf.Plugin.Timeout, + }, + ) + + span.AddEvent("Create proxy", trace.WithAttributes( + attribute.String("name", configBlockName), + attribute.String("healthCheckPeriod", cfg.HealthCheckPeriod.String()), + )) + + pluginTimeoutCtx, cancel := context.WithTimeout( + context.Background(), app.conf.Plugin.Timeout) + defer cancel() + + if data, ok := app.conf.GlobalKoanf.Get("proxies").(map[string]any); ok { + result, err := app.pluginRegistry.Run( //nolint:contextcheck + pluginTimeoutCtx, data, v1.HookName_HOOK_NAME_ON_NEW_PROXY) + if err != nil { + logger.Error().Err(err).Msg("Failed to run OnNewProxy hooks") + span.RecordError(err) + } + if result != nil { + _ = app.pluginRegistry.ActRegistry.RunAll(result) //nolint:contextcheck + } + } else { + logger.Error().Msg("Failed to get proxy from config") + } + } + } +} + +// createServers creates servers. +func (app *GatewayDApp) createServers( + runCtx context.Context, span trace.Span, raftNode *raft.Node, +) { + // Create and initialize servers. + for name, cfg := range app.conf.Global.Servers { + logger := app.loggers[name] + + var serverProxies []network.IProxy + for _, proxy := range app.proxies[name] { + serverProxies = append(serverProxies, proxy) + } + + app.servers[name] = network.NewServer( + runCtx, + network.Server{ + GroupName: name, + Network: cfg.Network, + Address: cfg.Address, + TickInterval: config.If( + cfg.TickInterval > 0, + cfg.TickInterval, + config.DefaultTickInterval, + ), + Options: network.Option{ + // Can be used to send keepalive messages to the client. + EnableTicker: cfg.EnableTicker, + }, + Proxies: serverProxies, + Logger: logger, + PluginRegistry: app.pluginRegistry, + PluginTimeout: app.conf.Plugin.Timeout, + EnableTLS: cfg.EnableTLS, + CertFile: cfg.CertFile, + KeyFile: cfg.KeyFile, + HandshakeTimeout: cfg.HandshakeTimeout, + LoadbalancerStrategyName: cfg.LoadBalancer.Strategy, + LoadbalancerRules: cfg.LoadBalancer.LoadBalancingRules, + LoadbalancerConsistentHash: cfg.LoadBalancer.ConsistentHash, + RaftNode: raftNode, + }, + ) + + span.AddEvent("Create server", trace.WithAttributes( + attribute.String("name", name), + attribute.String("network", cfg.Network), + attribute.String("address", cfg.Address), + attribute.String("tickInterval", cfg.TickInterval.String()), + attribute.String("pluginTimeout", app.conf.Plugin.Timeout.String()), + attribute.Bool("enableTLS", cfg.EnableTLS), + attribute.String("certFile", cfg.CertFile), + attribute.String("keyFile", cfg.KeyFile), + attribute.String("handshakeTimeout", cfg.HandshakeTimeout.String()), + )) + + pluginTimeoutCtx, cancel := context.WithTimeout( + context.Background(), app.conf.Plugin.Timeout) + defer cancel() + + if data, ok := app.conf.GlobalKoanf.Get("servers").(map[string]any); ok { + result, err := app.pluginRegistry.Run( //nolint:contextcheck + pluginTimeoutCtx, data, v1.HookName_HOOK_NAME_ON_NEW_SERVER) + if err != nil { + logger.Error().Err(err).Msg("Failed to run OnNewServer hooks") + span.RecordError(err) + } + if result != nil { + _ = app.pluginRegistry.ActRegistry.RunAll(result) //nolint:contextcheck + } + } else { + logger.Error().Msg("Failed to get the servers configuration") + } + } +} + +// startAPIServers starts the API servers. +func (app *GatewayDApp) startAPIServers( + runCtx context.Context, logger zerolog.Logger, raftNode *raft.Node, +) { + // Start the HTTP and gRPC APIs. + if !app.conf.Global.API.Enabled { + logger.Info().Msg("API is not enabled, skipping") + return + } + + apiOptions := api.Options{ + Logger: logger, + GRPCNetwork: app.conf.Global.API.GRPCNetwork, + GRPCAddress: app.conf.Global.API.GRPCAddress, + HTTPAddress: app.conf.Global.API.HTTPAddress, + Servers: app.servers, + RaftNode: raftNode, + } + + apiObj := &api.API{ + Options: &apiOptions, + Config: app.conf, + PluginRegistry: app.pluginRegistry, + Pools: app.pools, + Proxies: app.proxies, + Servers: app.servers, + } + app.grpcServer = api.NewGRPCServer( + runCtx, + api.GRPCServer{ + API: apiObj, + HealthChecker: &api.HealthChecker{Servers: app.servers}, + }, + ) + if app.grpcServer != nil { + go app.grpcServer.Start() + logger.Info().Str("address", apiOptions.HTTPAddress).Msg("Started the HTTP API") + + app.httpServer = api.NewHTTPServer(&apiOptions) //nolint:contextcheck + go app.httpServer.Start() + + logger.Info().Fields( + map[string]any{ + "network": apiOptions.GRPCNetwork, + "address": apiOptions.GRPCAddress, + }, + ).Msg("Started the gRPC Server") + } +} + +// reportUsage reports usage statistics. +func (app *GatewayDApp) reportUsage(logger zerolog.Logger) { + if !app.EnableUsageReport { + logger.Info().Msg("Usage reporting is not enabled, skipping") + return + } + + // Report usage statistics. + go func() { + conn, err := grpc.NewClient( + UsageReportURL, + grpc.WithTransportCredentials( + credentials.NewTLS( + &tls.Config{ + MinVersion: tls.VersionTLS12, + }, + ), + ), + ) + if err != nil { + logger.Trace().Err(err).Msg( + "Failed to dial to the gRPC server for usage reporting") + } + defer func(conn *grpc.ClientConn) { + err := conn.Close() + if err != nil { + logger.Trace().Err(err).Msg("Failed to close the connection to the usage report service") + } + }(conn) + + client := usage.NewUsageReportServiceClient(conn) + report := usage.UsageReportRequest{ + Version: config.Version, + RuntimeVersion: runtime.Version(), + Goos: runtime.GOOS, + Goarch: runtime.GOARCH, + Service: "gatewayd", + DevMode: app.DevMode, + Plugins: []*usage.Plugin{}, + } + app.pluginRegistry.ForEach( + func(identifier sdkPlugin.Identifier, _ *plugin.Plugin) { + report.Plugins = append(report.GetPlugins(), &usage.Plugin{ + Name: identifier.Name, + Version: identifier.Version, + Checksum: identifier.Checksum, + }) + }, + ) + _, err = client.Report(context.Background(), &report) + if err != nil { + logger.Trace().Err(err).Msg("Failed to report usage statistics") + } + }() +} + +// startServers starts the servers. +func (app *GatewayDApp) startServers( + runCtx context.Context, span trace.Span, +) { + // Start the server. + for name, server := range app.servers { + logger := app.loggers[name] + go func( + span trace.Span, + server *network.Server, + logger zerolog.Logger, + ) { + span.AddEvent("Start server") + if err := server.Run(); err != nil { //nolint:contextcheck + logger.Error().Err(err).Msg("Failed to start server") + span.RecordError(err) + app.stopGracefully(runCtx, nil) + os.Exit(gerr.FailedToStartServer) + } + }(span, server, logger) + } +} + +// stopGracefully stops the server gracefully. +func (app *GatewayDApp) stopGracefully(runCtx context.Context, sig os.Signal) { + // Only allow one call to this function. + if !app.ranStopGracefully.CompareAndSwap(false, true) { + return + } + + _, span := otel.Tracer(config.TracerName).Start(runCtx, "Shutdown server") + currentSignal := "unknown" + if sig != nil { + currentSignal = sig.String() + } + + logger := app.loggers[config.Default] + + logger.Info().Msg("Notifying the plugins that the server is shutting down") + if app.pluginRegistry != nil { + pluginTimeoutCtx, cancel := context.WithTimeout( + context.Background(), app.conf.Plugin.Timeout) + defer cancel() + + //nolint:contextcheck + result, err := app.pluginRegistry.Run( + pluginTimeoutCtx, + map[string]any{"signal": currentSignal}, + v1.HookName_HOOK_NAME_ON_SIGNAL, + ) + if err != nil { + logger.Error().Err(err).Msg("Failed to run OnSignal hooks") + span.RecordError(err) + } + if result != nil { + _ = app.pluginRegistry.ActRegistry.RunAll(result) //nolint:contextcheck + } + } + + logger.Info().Msg("GatewayD is shutting down") + span.AddEvent("GatewayD is shutting down", trace.WithAttributes( + attribute.String("signal", currentSignal), + )) + if app.healthCheckScheduler != nil { + app.healthCheckScheduler.Stop() + app.healthCheckScheduler.Clear() + logger.Info().Msg("Stopped health check scheduler") + span.AddEvent("Stopped health check scheduler") + } + if app.metricsMerger != nil { + app.metricsMerger.Stop() + logger.Info().Msg("Stopped metrics merger") + span.AddEvent("Stopped metrics merger") + } + if app.metricsServer != nil { + //nolint:contextcheck + if err := app.metricsServer.Shutdown(context.Background()); err != nil { + logger.Error().Err(err).Msg("Failed to stop metrics server") + span.RecordError(err) + } else { + logger.Info().Msg("Stopped metrics server") + span.AddEvent("Stopped metrics server") + } + } + + properlyInitialized := false + for name, server := range app.servers { + if server.IsRunning() { + logger.Info().Str("name", name).Msg("Stopping server") + server.Shutdown() + span.AddEvent("Stopped server") + properlyInitialized = true + } + } + logger.Info().Msg("Stopped all servers") + span.AddEvent("Stopped all servers") + + if app.pluginRegistry != nil { + app.pluginRegistry.Shutdown() + logger.Info().Msg("Stopped plugin registry") + span.AddEvent("Stopped plugin registry") + } + + if app.httpServer != nil { + app.httpServer.Shutdown(runCtx) + logger.Info().Msg("Stopped HTTP Server") + span.AddEvent("Stopped HTTP Server") + } + + if app.grpcServer != nil { + app.grpcServer.Shutdown(nil) //nolint:staticcheck,contextcheck + logger.Info().Msg("Stopped gRPC Server") + span.AddEvent("Stopped gRPC Server") + } + + logger.Info().Msg("GatewayD is shutdown") + span.AddEvent("GatewayD is shutdown") + span.End() + + // If the server was properly initialized, the stop channel would have been + // listened on by the run command. So, we must close the stop channel to stop + // the app. Otherwise, the app is shutting down abruptly, so we don't need to + // close the stop channel. + if properlyInitialized { + app.stopChan <- struct{}{} + close(app.stopChan) + } +} + +// handleSignals handles the signals and stops the server gracefully. +func (app *GatewayDApp) handleSignals(runCtx context.Context, signals []os.Signal) { + signalsCh := make(chan os.Signal, 1) + signal.Notify(signalsCh, signals...) + + go func() { + for sig := range signalsCh { + app.stopGracefully(runCtx, sig) + os.Exit(0) + } + }() +} diff --git a/cmd/plugin_init.go b/cmd/plugin_init.go index 54b829c5..ef841045 100644 --- a/cmd/plugin_init.go +++ b/cmd/plugin_init.go @@ -11,6 +11,10 @@ var pluginInitCmd = &cobra.Command{ Use: "init", Short: "Create or overwrite the GatewayD plugins config", Run: func(cmd *cobra.Command, _ []string) { + force, _ := cmd.Flags().GetBool("force") + enableSentry, _ := cmd.Flags().GetBool("sentry") + pluginConfigFile, _ := cmd.Flags().GetString("plugin-config") + // Enable Sentry. if enableSentry { // Initialize Sentry. @@ -37,12 +41,10 @@ var pluginInitCmd = &cobra.Command{ func init() { pluginCmd.AddCommand(pluginInitCmd) - pluginInitCmd.Flags().BoolVarP( - &force, "force", "f", false, "Force overwrite of existing config file") - pluginInitCmd.Flags().StringVarP( - &pluginConfigFile, // Already exists in run.go + pluginInitCmd.Flags().BoolP( + "force", "f", false, "Force overwrite of existing config file") + pluginInitCmd.Flags().StringP( "plugin-config", "p", config.GetDefaultConfigFilePath(config.PluginsConfigFilename), "Plugin config file") - pluginInitCmd.Flags().BoolVar( - &enableSentry, "sentry", true, "Enable Sentry") // Already exists in run.go + pluginInitCmd.Flags().Bool("sentry", true, "Enable Sentry") } diff --git a/cmd/plugin_init_test.go b/cmd/plugin_init_test.go index 2d9abb32..7b3217bd 100644 --- a/cmd/plugin_init_test.go +++ b/cmd/plugin_init_test.go @@ -18,7 +18,8 @@ func Test_pluginInitCmd(t *testing.T) { fmt.Sprintf("Config file '%s' was created successfully.", pluginTestConfigFile), output, "plugin init command should have returned the correct output") - assert.FileExists(t, pluginTestConfigFile, "plugin init command should have created a config file") + assert.FileExists( + t, pluginTestConfigFile, "plugin init command should have created a config file") // Clean up. err = os.Remove(pluginTestConfigFile) diff --git a/cmd/plugin_install.go b/cmd/plugin_install.go index e7015558..e6770601 100644 --- a/cmd/plugin_install.go +++ b/cmd/plugin_install.go @@ -58,24 +58,24 @@ const ( Plugins configFileType = "plugins" ) -var ( - pluginOutputDir string - pullOnly bool - cleanup bool - update bool - backupConfig bool - noPrompt bool - pluginName string - overwriteConfig bool - skipPathSlipVerification bool -) - // pluginInstallCmd represents the plugin install command. var pluginInstallCmd = &cobra.Command{ Use: "install", Short: "Install a plugin from a local archive or a GitHub repository", Example: " gatewayd plugin install ", //nolint:lll Run: func(cmd *cobra.Command, args []string) { + enableSentry, _ := cmd.Flags().GetBool("sentry") + pluginConfigFile, _ := cmd.Flags().GetString("plugin-config") + pluginOutputDir, _ := cmd.Flags().GetString("output-dir") + pullOnly, _ := cmd.Flags().GetBool("pull-only") + cleanup, _ := cmd.Flags().GetBool("cleanup") + noPrompt, _ := cmd.Flags().GetBool("no-prompt") + update, _ := cmd.Flags().GetBool("update") + backupConfig, _ := cmd.Flags().GetBool("backup") + overwriteConfig, _ := cmd.Flags().GetBool("overwrite-config") + pluginName, _ := cmd.Flags().GetString("name") + skipPathSlipVerification, _ := cmd.Flags().GetBool("skip-path-slip-verification") + // Enable Sentry. if enableSentry { // Initialize Sentry. @@ -99,7 +99,20 @@ var pluginInstallCmd = &cobra.Command{ case LocationArgs: // Install the plugin from the CLI argument. cmd.Println("Installing plugin from CLI argument") - installPlugin(cmd, args[0]) + installPlugin( + cmd, + args[0], + pluginOutputDir, + pullOnly, + cleanup, + noPrompt, + update, + backupConfig, + overwriteConfig, + skipPathSlipVerification, + pluginConfigFile, + pluginName, + ) case LocationConfig: // Read the gatewayd_plugins.yaml file. pluginsConfig, err := os.ReadFile(pluginConfigFile) @@ -197,7 +210,20 @@ var pluginInstallCmd = &cobra.Command{ // Install all the plugins from the plugins configuration file. cmd.Println("Installing plugins from plugins configuration file") for _, pluginURL := range pluginURLs { - installPlugin(cmd, pluginURL) + installPlugin( + cmd, + pluginURL, + pluginOutputDir, + pullOnly, + cleanup, + noPrompt, + update, + backupConfig, + overwriteConfig, + skipPathSlipVerification, + pluginConfigFile, + pluginName, + ) } default: cmd.Println("Invalid plugin URL or file path") @@ -208,35 +234,36 @@ var pluginInstallCmd = &cobra.Command{ func init() { pluginCmd.AddCommand(pluginInstallCmd) - pluginInstallCmd.Flags().StringVarP( - &pluginConfigFile, // Already exists in run.go + pluginInstallCmd.Flags().StringP( "plugin-config", "p", config.GetDefaultConfigFilePath(config.PluginsConfigFilename), "Plugin config file") - pluginInstallCmd.Flags().StringVarP( - &pluginOutputDir, "output-dir", "o", "./plugins", "Output directory for the plugin") - pluginInstallCmd.Flags().BoolVar( - &pullOnly, "pull-only", false, "Only pull the plugin, don't install it") - pluginInstallCmd.Flags().BoolVar( - &cleanup, "cleanup", true, + pluginInstallCmd.Flags().StringP( + "output-dir", "o", "./plugins", "Output directory for the plugin") + pluginInstallCmd.Flags().Bool( + "pull-only", false, "Only pull the plugin, don't install it") + pluginInstallCmd.Flags().Bool( + "cleanup", true, "Delete downloaded and extracted files after installing the plugin (except the plugin binary)") - pluginInstallCmd.Flags().BoolVar( - &noPrompt, "no-prompt", true, "Do not prompt for user input") - pluginInstallCmd.Flags().BoolVar( - &update, "update", false, "Update the plugin if it already exists") - pluginInstallCmd.Flags().BoolVar( - &backupConfig, "backup", false, "Backup the plugins configuration file before installing the plugin") - pluginInstallCmd.Flags().StringVarP( - &pluginName, "name", "n", "", "Name of the plugin (only for installing from archive files)") - pluginInstallCmd.Flags().BoolVar( - &overwriteConfig, "overwrite-config", true, "Overwrite the existing plugins configuration file (overrides --update, only used for installing from the plugins configuration file)") //nolint:lll - pluginInstallCmd.Flags().BoolVar( - &skipPathSlipVerification, "skip-path-slip-verification", false, "Skip path slip verification when extracting the plugin archive from a TRUSTED source") //nolint:lll - pluginInstallCmd.Flags().BoolVar( - &enableSentry, "sentry", true, "Enable Sentry") // Already exists in run.go + pluginInstallCmd.Flags().Bool( + "no-prompt", true, "Do not prompt for user input") + pluginInstallCmd.Flags().Bool( + "update", false, "Update the plugin if it already exists") + pluginInstallCmd.Flags().Bool( + "backup", false, "Backup the plugins configuration file before installing the plugin") + pluginInstallCmd.Flags().StringP( + "name", "n", "", "Name of the plugin (only for installing from archive files)") + pluginInstallCmd.Flags().Bool( + "overwrite-config", true, "Overwrite the existing plugins configuration file (overrides --update, only used for installing from the plugins configuration file)") //nolint:lll + pluginInstallCmd.Flags().Bool( + "skip-path-slip-verification", false, "Skip path slip verification when extracting the plugin archive from a TRUSTED source") //nolint:lll + pluginInstallCmd.Flags().Bool( + "sentry", true, "Enable Sentry") } // extractZip extracts the files from a zip archive. -func extractZip(filename, dest string) ([]string, *gerr.GatewayDError) { +func extractZip( + filename, dest string, skipPathSlipVerification bool, +) ([]string, *gerr.GatewayDError) { // Open and extract the zip file. zipRc, err := zip.OpenReader(filename) if err != nil { @@ -317,7 +344,9 @@ func extractZip(filename, dest string) ([]string, *gerr.GatewayDError) { } // extractTarGz extracts the files from a tar.gz archive. -func extractTarGz(filename, dest string) ([]string, *gerr.GatewayDError) { +func extractTarGz( + filename, dest string, skipPathSlipVerification bool, +) ([]string, *gerr.GatewayDError) { // Open and extract the tar.gz file. gzipStream, err := os.Open(filename) if err != nil { @@ -530,7 +559,20 @@ func getFileExtension() Extension { } // installPlugin installs a plugin from a given URL. -func installPlugin(cmd *cobra.Command, pluginURL string) { +func installPlugin( + cmd *cobra.Command, + pluginURL string, + pluginOutputDir string, + pullOnly bool, + cleanup bool, + noPrompt bool, + update bool, + backupConfig bool, + overwriteConfig bool, + skipPathSlipVerification bool, + pluginConfigFile string, + pluginName string, +) { var ( // This is a list of files that will be deleted after the plugin is installed. toBeDeleted []string @@ -725,8 +767,10 @@ func installPlugin(cmd *cobra.Command, pluginURL string) { return } case SourceUnknown: + fallthrough default: cmd.Println("Invalid URL or file path") + return } // NOTE: The rest of the code is executed regardless of the source, @@ -806,9 +850,9 @@ func installPlugin(cmd *cobra.Command, pluginURL string) { var gErr *gerr.GatewayDError switch archiveExt { case ExtensionZip: - filenames, gErr = extractZip(pluginFilename, pluginOutputDir) + filenames, gErr = extractZip(pluginFilename, pluginOutputDir, skipPathSlipVerification) case ExtensionTarGz: - filenames, gErr = extractTarGz(pluginFilename, pluginOutputDir) + filenames, gErr = extractTarGz(pluginFilename, pluginOutputDir, skipPathSlipVerification) default: cmd.Println("Invalid archive extension") return diff --git a/cmd/plugin_install_test.go b/cmd/plugin_install_test.go index e3364c49..cb6e7871 100644 --- a/cmd/plugin_install_test.go +++ b/cmd/plugin_install_test.go @@ -3,6 +3,7 @@ package cmd import ( "fmt" "os" + "path/filepath" "runtime" "testing" @@ -10,19 +11,22 @@ import ( "github.com/stretchr/testify/require" ) -func Test_pluginInstallCmd(t *testing.T) { - pluginTestConfigFile := "./test_plugins_pluginInstallCmd.yaml" +func Test_pluginInstallCmdWithFile(t *testing.T) { + pluginTestConfigFile := "./test_plugins_pluginInstallCmdWithFile.yaml" // Create a test plugin config file. - output, err := executeCommandC(rootCmd, "plugin", "init", "-p", pluginTestConfigFile) + output, err := executeCommandC( + rootCmd, "plugin", "init", "-p", pluginTestConfigFile, "--force") require.NoError(t, err, "plugin init should not return an error") assert.Equal(t, fmt.Sprintf("Config file '%s' was created successfully.", pluginTestConfigFile), output, "plugin init command should have returned the correct output") - assert.FileExists(t, pluginTestConfigFile, "plugin init command should have created a config file") + assert.FileExists( + t, pluginTestConfigFile, "plugin init command should have created a config file") // Pull the plugin archive and install it. + var pluginArchivePath string pluginArchivePath, err = mustPullPlugin() require.NoError(t, err, "mustPullPlugin should not return an error") assert.FileExists(t, pluginArchivePath, "mustPullPlugin should have downloaded the plugin archive") @@ -30,9 +34,12 @@ func Test_pluginInstallCmd(t *testing.T) { // Test plugin install command. output, err = executeCommandC( rootCmd, "plugin", "install", "-p", pluginTestConfigFile, - "--update", "--backup", "--name", "gatewayd-plugin-cache", pluginArchivePath) + "--backup", "--name=gatewayd-plugin-cache", pluginArchivePath) require.NoError(t, err, "plugin install should not return an error") - assert.Equal(t, output, "Installing plugin from CLI argument\nBackup completed successfully\nPlugin binary extracted to plugins/gatewayd-plugin-cache\nPlugin installed successfully\n") //nolint:lll + assert.Contains(t, output, "Installing plugin from CLI argument") + assert.Contains(t, output, "Backup completed successfully") + assert.Contains(t, output, "Plugin binary extracted to plugins/gatewayd-plugin-cache") + assert.Contains(t, output, "Plugin installed successfully") // See if the plugin was actually installed. output, err = executeCommandC(rootCmd, "plugin", "list", "-p", pluginTestConfigFile) @@ -40,7 +47,7 @@ func Test_pluginInstallCmd(t *testing.T) { assert.Contains(t, output, "Name: gatewayd-plugin-cache") // Clean up. - assert.FileExists(t, "plugins/gatewayd-plugin-cache") + assert.FileExists(t, "./plugins/gatewayd-plugin-cache") assert.FileExists(t, pluginTestConfigFile+BackupFileExt) assert.NoFileExists(t, "gatewayd-plugin-cache-linux-amd64-v0.2.4.tar.gz") assert.NoFileExists(t, "checksums.txt") @@ -57,28 +64,26 @@ func Test_pluginInstallCmd(t *testing.T) { func Test_pluginInstallCmdAutomatedNoOverwrite(t *testing.T) { pluginTestConfigFile := "./testdata/gatewayd_plugins.yaml" - // Reset the global variable. - pullOnly = false - - // Test plugin install command. + // Test plugin install command with overwrite disabled output, err := executeCommandC( - rootCmd, "plugin", "install", - "-p", pluginTestConfigFile, "--update", "--backup", "--overwrite-config=false") + rootCmd, "plugin", "install", "-p", pluginTestConfigFile, + "--update", "--backup", "--overwrite-config=false", "--pull-only=false") require.NoError(t, err, "plugin install should not return an error") - assert.Contains(t, output, fmt.Sprintf("/gatewayd-plugin-cache-%s-%s-", runtime.GOOS, runtime.GOARCH)) - assert.Contains(t, output, "/checksums.txt") + + // Verify expected output for no-overwrite case + assert.Contains(t, output, "Installing plugins from plugins configuration file") + assert.Contains( + t, + output, + fmt.Sprintf("gatewayd-plugin-cache-%s-%s-", runtime.GOOS, runtime.GOARCH)) + assert.Contains(t, output, "checksums.txt") assert.Contains(t, output, "Download completed successfully") assert.Contains(t, output, "Checksum verification passed") assert.Contains(t, output, "Plugin binary extracted to plugins/gatewayd-plugin-cache") assert.Contains(t, output, "Plugin installed successfully") - // See if the plugin was actually installed. - output, err = executeCommandC(rootCmd, "plugin", "list", "-p", pluginTestConfigFile) - require.NoError(t, err, "plugin list should not return an error") - assert.Contains(t, output, "Name: gatewayd-plugin-cache") - // Clean up. - assert.FileExists(t, "plugins/gatewayd-plugin-cache") + assert.FileExists(t, "./plugins/gatewayd-plugin-cache") assert.FileExists(t, pluginTestConfigFile+BackupFileExt) assert.NoFileExists(t, "plugins/LICENSE") assert.NoFileExists(t, "plugins/README.md") @@ -88,3 +93,38 @@ func Test_pluginInstallCmdAutomatedNoOverwrite(t *testing.T) { require.NoError(t, os.RemoveAll("plugins/")) require.NoError(t, os.Remove(pluginTestConfigFile+BackupFileExt)) } + +func Test_pluginInstallCmdPullOnly(t *testing.T) { + pwd, err := os.Getwd() + require.NoError(t, err, "os.Getwd should not return an error") + + pluginTestConfigFile := "./testdata/gatewayd_plugins.yaml" + + // Test plugin install command in pull-only mode + output, err := executeCommandC( + rootCmd, "plugin", "install", "-p", pluginTestConfigFile, + "--update", "--backup", "--pull-only", "--overwrite-config=false") + require.NoError(t, err, "plugin install should not return an error") + + // Verify pull-only behavior + assert.Contains(t, output, "Installing plugins from plugins configuration file") + assert.Contains(t, output, fmt.Sprintf("gatewayd-plugin-cache-%s-%s-", runtime.GOOS, runtime.GOARCH)) + assert.Contains(t, output, "Download completed successfully") + + // Should not contain installation messages in pull-only mode + assert.NotContains(t, output, "Plugin binary extracted") + assert.NotContains(t, output, "Plugin installed successfully") + + // Cleanup downloaded files + pattern := filepath.Join(pwd, fmt.Sprintf("gatewayd-plugin-cache-%s-%s-*.tar.gz", runtime.GOOS, runtime.GOARCH)) + matches, err := filepath.Glob(pattern) + require.NoError(t, err, "failed to glob plugin archives") + for _, match := range matches { + require.NoError(t, os.Remove(match)) + } + + checksumFile := filepath.Join(pwd, "checksums.txt") + if _, err := os.Stat(checksumFile); err == nil { + require.NoError(t, os.Remove(checksumFile)) + } +} diff --git a/cmd/plugin_lint.go b/cmd/plugin_lint.go index 13508b3a..6e317068 100644 --- a/cmd/plugin_lint.go +++ b/cmd/plugin_lint.go @@ -14,6 +14,9 @@ var pluginLintCmd = &cobra.Command{ Use: "lint", Short: "Lint the GatewayD plugins config", Run: func(cmd *cobra.Command, _ []string) { + enableSentry, _ := cmd.Flags().GetBool("sentry") + pluginConfigFile, _ := cmd.Flags().GetString("plugin-config") + // Enable Sentry. if enableSentry { // Initialize Sentry. @@ -44,10 +47,9 @@ var pluginLintCmd = &cobra.Command{ func init() { pluginCmd.AddCommand(pluginLintCmd) - pluginLintCmd.Flags().StringVarP( - &pluginConfigFile, // Already exists in run.go + pluginLintCmd.Flags().StringP( "plugin-config", "p", config.GetDefaultConfigFilePath(config.PluginsConfigFilename), "Plugin config file") - pluginLintCmd.Flags().BoolVar( - &enableSentry, "sentry", true, "Enable Sentry") // Already exists in run.go + pluginLintCmd.Flags().BoolP( + "sentry", "s", true, "Enable Sentry") } diff --git a/cmd/plugin_list.go b/cmd/plugin_list.go index 000d9616..e64add2d 100644 --- a/cmd/plugin_list.go +++ b/cmd/plugin_list.go @@ -9,13 +9,15 @@ import ( "github.com/spf13/cobra" ) -var onlyEnabled bool - // pluginListCmd represents the plugin list command. var pluginListCmd = &cobra.Command{ Use: "list", Short: "List the GatewayD plugins", Run: func(cmd *cobra.Command, _ []string) { + pluginConfigFile, _ := cmd.Flags().GetString("plugin-config") + onlyEnabled, _ := cmd.Flags().GetBool("only-enabled") + enableSentry, _ := cmd.Flags().GetBool("sentry") + // Enable Sentry. if enableSentry { // Initialize Sentry. @@ -42,30 +44,27 @@ var pluginListCmd = &cobra.Command{ func init() { pluginCmd.AddCommand(pluginListCmd) - pluginListCmd.Flags().StringVarP( - &pluginConfigFile, // Already exists in run.go + pluginListCmd.Flags().StringP( "plugin-config", "p", config.GetDefaultConfigFilePath(config.PluginsConfigFilename), "Plugin config file") - pluginListCmd.Flags().BoolVarP( - &onlyEnabled, + pluginListCmd.Flags().BoolP( "only-enabled", "e", false, "Only list enabled plugins") - pluginListCmd.Flags().BoolVar( - &enableSentry, "sentry", true, "Enable Sentry") // Already exists in run.go + pluginListCmd.Flags().BoolP("sentry", "s", true, "Enable Sentry") } func listPlugins(cmd *cobra.Command, pluginConfigFile string, onlyEnabled bool) { // Load the plugin config file. - conf := config.NewConfig(context.TODO(), config.Config{PluginConfigFile: pluginConfigFile}) - if err := conf.LoadDefaults(context.TODO()); err != nil { + conf := config.NewConfig(context.Background(), config.Config{PluginConfigFile: pluginConfigFile}) + if err := conf.LoadDefaults(context.Background()); err != nil { cmd.PrintErr(err) return } - if err := conf.LoadPluginConfigFile(context.TODO()); err != nil { + if err := conf.LoadPluginConfigFile(context.Background()); err != nil { cmd.PrintErr(err) return } - if err := conf.UnmarshalPluginConfig(context.TODO()); err != nil { + if err := conf.UnmarshalPluginConfig(context.Background()); err != nil { cmd.PrintErr(err) return } diff --git a/cmd/plugin_list_test.go b/cmd/plugin_list_test.go index c96342ba..2492858d 100644 --- a/cmd/plugin_list_test.go +++ b/cmd/plugin_list_test.go @@ -11,6 +11,7 @@ import ( func Test_pluginListCmd(t *testing.T) { pluginTestConfigFile := "./test_plugins_pluginListCmd.yaml" + // Test plugin list command. output, err := executeCommandC(rootCmd, "plugin", "init", "-p", pluginTestConfigFile) require.NoError(t, err, "plugin init command should not have returned an error") @@ -36,6 +37,7 @@ func Test_pluginListCmdWithPlugins(t *testing.T) { // Test plugin list command. // Read the plugin config file from the root directory. pluginTestConfigFile := "../gatewayd_plugins.yaml" + output, err := executeCommandC(rootCmd, "plugin", "list", "-p", pluginTestConfigFile) require.NoError(t, err, "plugin list command should not have returned an error") assert.Contains(t, diff --git a/cmd/plugin_scaffold.go b/cmd/plugin_scaffold.go index 4a4da902..4ddad37a 100644 --- a/cmd/plugin_scaffold.go +++ b/cmd/plugin_scaffold.go @@ -5,16 +5,14 @@ import ( "github.com/spf13/cobra" ) -var ( - pluginScaffoldInputFile string - pluginScaffoldOutputDir string -) - // pluginScaffoldCmd represents the scaffold command. var pluginScaffoldCmd = &cobra.Command{ Use: "scaffold", Short: "Scaffold a plugin and store the files into a directory", Run: func(cmd *cobra.Command, _ []string) { + pluginScaffoldInputFile := cmd.Flag("input-file").Value.String() + pluginScaffoldOutputDir := cmd.Flag("output-dir").Value.String() + createdFiles, err := plugin.Scaffold(pluginScaffoldInputFile, pluginScaffoldOutputDir) if err != nil { cmd.Println("Scaffold failed: ", err) @@ -31,12 +29,10 @@ var pluginScaffoldCmd = &cobra.Command{ func init() { pluginCmd.AddCommand(pluginScaffoldCmd) - pluginScaffoldCmd.Flags().StringVarP( - &pluginScaffoldInputFile, + pluginScaffoldCmd.Flags().StringP( "input-file", "i", "input.yaml", "Plugin scaffold input file") - pluginScaffoldCmd.Flags().StringVarP( - &pluginScaffoldOutputDir, + pluginScaffoldCmd.Flags().StringP( "output-dir", "o", "./plugins", "Output directory for the scaffold") } diff --git a/cmd/plugin_scaffold_test.go b/cmd/plugin_scaffold_test.go index 7eb6b246..a31b1ef3 100644 --- a/cmd/plugin_scaffold_test.go +++ b/cmd/plugin_scaffold_test.go @@ -9,7 +9,6 @@ import ( "time" "github.com/codingsince1985/checksum" - "github.com/gatewayd-io/gatewayd/config" "github.com/gatewayd-io/gatewayd/plugin" "github.com/gatewayd-io/gatewayd/testhelpers" "github.com/spf13/cast" @@ -19,6 +18,12 @@ import ( ) func Test_pluginScaffoldCmd(t *testing.T) { + previous := EnableTestMode() + defer func() { + testMode = previous + testApp = nil + }() + // Start the test containers. ctx := context.Background() postgresHostIP1, postgresMappedPort1 := testhelpers.SetupPostgreSQLTestContainer(ctx, t) @@ -87,11 +92,9 @@ func Test_pluginScaffoldCmd(t *testing.T) { pluginTestConfigFile := filepath.Join(pluginDir, "gatewayd_plugins.yaml") - stopChan = make(chan struct{}) - var waitGroup sync.WaitGroup - waitGroup.Add(1) + waitGroup.Add(2) go func(waitGroup *sync.WaitGroup) { // Test run command. output, err := executeCommandC( @@ -107,23 +110,9 @@ func Test_pluginScaffoldCmd(t *testing.T) { waitGroup.Done() }(&waitGroup) - waitGroup.Add(1) go func(waitGroup *sync.WaitGroup) { time.Sleep(waitBeforeStop) - - StopGracefully( - context.Background(), - nil, - nil, - metricsServer, - nil, - loggers[config.Default], - servers, - stopChan, - nil, - nil, - ) - + testApp.stopGracefully(context.Background(), os.Interrupt) waitGroup.Done() }(&waitGroup) diff --git a/cmd/run.go b/cmd/run.go index 51e61241..a9798415 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -2,181 +2,31 @@ package cmd import ( "context" - "crypto/tls" - "errors" - "fmt" - "io" "log" - "net/http" - "net/url" "os" - "os/signal" - "runtime" - "strconv" "syscall" - "time" - "github.com/NYTimes/gziphandler" - sdkAct "github.com/gatewayd-io/gatewayd-plugin-sdk/act" - sdkPlugin "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin" - v1 "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin/v1" - "github.com/gatewayd-io/gatewayd/act" - "github.com/gatewayd-io/gatewayd/api" "github.com/gatewayd-io/gatewayd/config" gerr "github.com/gatewayd-io/gatewayd/errors" - "github.com/gatewayd-io/gatewayd/logging" - "github.com/gatewayd-io/gatewayd/metrics" - "github.com/gatewayd-io/gatewayd/network" - "github.com/gatewayd-io/gatewayd/plugin" - "github.com/gatewayd-io/gatewayd/pool" "github.com/gatewayd-io/gatewayd/raft" "github.com/gatewayd-io/gatewayd/tracing" - usage "github.com/gatewayd-io/gatewayd/usagereport/v1" "github.com/getsentry/sentry-go" - "github.com/go-co-op/gocron" - "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promhttp" - "github.com/redis/go-redis/v9" - "github.com/rs/zerolog" "github.com/spf13/cobra" "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" "golang.org/x/exp/maps" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials" ) -var _ io.Writer = &cobraCmdWriter{} - -type cobraCmdWriter struct { - *cobra.Command -} - -func (c *cobraCmdWriter) Write(p []byte) (int, error) { - c.Print(string(p)) - return len(p), nil -} - -// TODO: Get rid of the global variables. -// https://github.com/gatewayd-io/gatewayd/issues/324 var ( - enableTracing bool - enableLinting bool - collectorURL string - enableSentry bool - devMode bool - enableUsageReport bool - pluginConfigFile string - globalConfigFile string - conf *config.Config - pluginRegistry *plugin.Registry - actRegistry *act.Registry - metricsServer *http.Server - - UsageReportURL = "localhost:59091" - - loggers = make(map[string]zerolog.Logger) - pools = make(map[string]map[string]*pool.Pool) - clients = make(map[string]map[string]*config.Client) - proxies = make(map[string]map[string]*network.Proxy) - servers = make(map[string]*network.Server) - healthCheckScheduler = gocron.NewScheduler(time.UTC) - - stopChan = make(chan struct{}) + testMode bool + testApp *GatewayDApp ) -func StopGracefully( - runCtx context.Context, - sig os.Signal, - metricsMerger *metrics.Merger, - metricsServer *http.Server, - pluginRegistry *plugin.Registry, - logger zerolog.Logger, - servers map[string]*network.Server, - stopChan chan struct{}, - httpServer *api.HTTPServer, - grpcServer *api.GRPCServer, -) { - _, span := otel.Tracer(config.TracerName).Start(runCtx, "Shutdown server") - currentSignal := "unknown" - if sig != nil { - currentSignal = sig.String() - } - - logger.Info().Msg("Notifying the plugins that the server is shutting down") - if pluginRegistry != nil { - pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), conf.Plugin.Timeout) - defer cancel() - - //nolint:contextcheck - result, err := pluginRegistry.Run( - pluginTimeoutCtx, - map[string]any{"signal": currentSignal}, - v1.HookName_HOOK_NAME_ON_SIGNAL, - ) - if err != nil { - logger.Error().Err(err).Msg("Failed to run OnSignal hooks") - span.RecordError(err) - } - if result != nil { - _ = pluginRegistry.ActRegistry.RunAll(result) //nolint:contextcheck - } - } - - logger.Info().Msg("GatewayD is shutting down") - span.AddEvent("GatewayD is shutting down", trace.WithAttributes( - attribute.String("signal", currentSignal), - )) - if healthCheckScheduler != nil { - healthCheckScheduler.Stop() - healthCheckScheduler.Clear() - logger.Info().Msg("Stopped health check scheduler") - span.AddEvent("Stopped health check scheduler") - } - if metricsMerger != nil { - metricsMerger.Stop() - logger.Info().Msg("Stopped metrics merger") - span.AddEvent("Stopped metrics merger") - } - if metricsServer != nil { - //nolint:contextcheck - if err := metricsServer.Shutdown(context.Background()); err != nil { - logger.Error().Err(err).Msg("Failed to stop metrics server") - span.RecordError(err) - } else { - logger.Info().Msg("Stopped metrics server") - span.AddEvent("Stopped metrics server") - } - } - for name, server := range servers { - logger.Info().Str("name", name).Msg("Stopping server") - server.Shutdown() - span.AddEvent("Stopped server") - } - logger.Info().Msg("Stopped all servers") - if pluginRegistry != nil { - pluginRegistry.Shutdown() - logger.Info().Msg("Stopped plugin registry") - span.AddEvent("Stopped plugin registry") - } - span.End() - - if httpServer != nil { - httpServer.Shutdown(runCtx) - logger.Info().Msg("Stopped HTTP Server") - span.AddEvent("Stopped HTTP Server") - } - - if grpcServer != nil { - grpcServer.Shutdown(runCtx) - logger.Info().Msg("Stopped gRPC Server") - span.AddEvent("Stopped gRPC Server") - } - - // Close the stop channel to notify the other goroutines to stop. - stopChan <- struct{}{} - close(stopChan) +// EnableTestMode enables test mode and returns the previous value. +// This should only be used in tests. +func EnableTestMode() bool { + previous := testMode + testMode = true + return previous } // runCmd represents the run command. @@ -184,23 +34,51 @@ var runCmd = &cobra.Command{ Use: "run", Short: "Run a GatewayD instance", Run: func(cmd *cobra.Command, _ []string) { + app := NewGatewayDApp(cmd) + + // If test mode is enabled, we need to access the app instance from the test, + // so we can stop the server gracefully. + if testMode { + testApp = app + } + + runCtx, span := otel.Tracer(config.TracerName).Start(context.Background(), "GatewayD") + span.End() + + // Handle signals from the user. + app.handleSignals( + runCtx, + []os.Signal{ + os.Interrupt, + os.Kill, + syscall.SIGTERM, + syscall.SIGABRT, + syscall.SIGQUIT, + syscall.SIGHUP, + syscall.SIGINT, + }, + ) + + // Stop the server gracefully when the program terminates cleanly. + defer app.stopGracefully(runCtx, nil) + // Enable tracing with OpenTelemetry. - if enableTracing { + if app.EnableTracing { // TODO: Make this configurable. - shutdown := tracing.OTLPTracer(true, collectorURL, config.TracerName) + shutdown := tracing.OTLPTracer(true, app.CollectorURL, config.TracerName) defer func() { if err := shutdown(context.Background()); err != nil { cmd.Println(err) + app.stopGracefully(runCtx, nil) os.Exit(gerr.FailedToStartTracer) } }() } - runCtx, span := otel.Tracer(config.TracerName).Start(context.Background(), "GatewayD") span.End() // Enable Sentry. - if enableSentry { + if app.EnableSentry { _, span := otel.Tracer(config.TracerName).Start(runCtx, "Sentry") defer span.End() @@ -223,968 +101,162 @@ var runCmd = &cobra.Command{ } // Lint the configuration files before loading them. - if enableLinting { + if app.EnableLinting { _, span := otel.Tracer(config.TracerName).Start(runCtx, "Lint configuration files") defer span.End() // Lint the global configuration file and fail if it's not valid. - if err := lintConfig(Global, globalConfigFile); err != nil { + if err := lintConfig(Global, app.GlobalConfigFile); err != nil { log.Fatal(err) } // Lint the plugin configuration file and fail if it's not valid. - if err := lintConfig(Plugins, pluginConfigFile); err != nil { + if err := lintConfig(Plugins, app.PluginConfigFile); err != nil { log.Fatal(err) } } - // Load global and plugin configuration. - conf = config.NewConfig(runCtx, config.Config{GlobalConfigFile: globalConfigFile, PluginConfigFile: pluginConfigFile}) - if err := conf.InitConfig(runCtx); err != nil { + // Load the configuration files. + if err := app.loadConfig(runCtx); err != nil { log.Fatal(err) } // Create and initialize loggers from the config. - // Use cobra command cmd instead of os.Stdout for the console output. - cmdLogger := &cobraCmdWriter{cmd} - for name, cfg := range conf.Global.Loggers { - loggers[name] = logging.NewLogger(runCtx, logging.LoggerConfig{ - Output: cfg.GetOutput(), - ConsoleOut: cmdLogger, - Level: config.If( - config.Exists(config.LogLevels, cfg.Level), - config.LogLevels[cfg.Level], - config.LogLevels[config.DefaultLogLevel], - ), - TimeFormat: config.If( - config.Exists(config.TimeFormats, cfg.TimeFormat), - config.TimeFormats[cfg.TimeFormat], - config.TimeFormats[config.DefaultTimeFormat], - ), - ConsoleTimeFormat: config.If( - config.Exists( - config.ConsoleTimeFormats, cfg.ConsoleTimeFormat), - config.ConsoleTimeFormats[cfg.ConsoleTimeFormat], - config.ConsoleTimeFormats[config.DefaultConsoleTimeFormat], - ), - NoColor: cfg.NoColor, - FileName: cfg.FileName, - MaxSize: cfg.MaxSize, - MaxBackups: cfg.MaxBackups, - MaxAge: cfg.MaxAge, - Compress: cfg.Compress, - LocalTime: cfg.LocalTime, - SyslogPriority: cfg.GetSyslogPriority(), - RSyslogNetwork: cfg.RSyslogNetwork, - RSyslogAddress: cfg.RSyslogAddress, - Name: name, - }) - } - - // Set the default logger. - logger := loggers[config.Default] + // And then set the default logger. + logger := app.createLoggers(runCtx, cmd) - if devMode { + if app.DevMode { logger.Warn().Msg( "Running GatewayD in development mode (not recommended for production)") } - // Create a new act registry given the built-in signals, policies, and actions. - var publisher *act.Publisher - if conf.Plugin.ActionRedis.Enabled { - rdb := redis.NewClient(&redis.Options{ - Addr: conf.Plugin.ActionRedis.Address, - }) - var err error - publisher, err = act.NewPublisher(act.Publisher{ - Logger: logger, - RedisDB: rdb, - ChannelName: conf.Plugin.ActionRedis.Channel, - }) - if err != nil { - logger.Error().Err(err).Msg("Failed to create publisher for act registry") - os.Exit(gerr.FailedToCreateActRegistry) - } - logger.Info().Msg("Created Redis publisher for Act registry") - } - - actRegistry = act.NewActRegistry( - act.Registry{ - Signals: act.BuiltinSignals(), - Policies: act.BuiltinPolicies(), - Actions: act.BuiltinActions(), - DefaultPolicyName: conf.Plugin.DefaultPolicy, - PolicyTimeout: conf.Plugin.PolicyTimeout, - DefaultActionTimeout: conf.Plugin.ActionTimeout, - TaskPublisher: publisher, - Logger: logger, - }) - - if actRegistry == nil { - logger.Error().Msg("Failed to create act registry") + // Create the Act registry. + if err := app.createActRegistry(logger); err != nil { + logger.Error().Err(err).Msg("Failed to create act registry") + app.stopGracefully(runCtx, nil) os.Exit(gerr.FailedToCreateActRegistry) } // Load policies from the configuration file and add them to the registry. - for _, plc := range conf.Plugin.Policies { - if policy, err := sdkAct.NewPolicy( - plc.Name, plc.Policy, plc.Metadata, - ); err != nil || policy == nil { - logger.Error().Err(err).Str("name", plc.Name).Msg("Failed to create policy") - } else { - actRegistry.Add(policy) - } + if err := app.loadPolicies(logger); err != nil { + logger.Error().Err(err).Msg("Failed to load policies") + app.stopGracefully(runCtx, nil) + os.Exit(gerr.FailedToLoadPolicies) } logger.Info().Fields(map[string]any{ - "policies": maps.Keys(actRegistry.Policies), + "policies": maps.Keys(app.actRegistry.Policies), }).Msg("Policies are loaded") - // Create a new plugin registry. - // The plugins are loaded and hooks registered before the configuration is loaded. - pluginRegistry = plugin.NewRegistry( - runCtx, - plugin.Registry{ - ActRegistry: actRegistry, - Compatibility: config.If( - config.Exists( - config.CompatibilityPolicies, conf.Plugin.CompatibilityPolicy, - ), - config.CompatibilityPolicies[conf.Plugin.CompatibilityPolicy], - config.DefaultCompatibilityPolicy), - Logger: logger, - DevMode: devMode, - }, - ) + // Create the plugin registry. + app.createPluginRegistry(runCtx, logger) // Load plugins and register their hooks. - pluginRegistry.LoadPlugins(runCtx, conf.Plugin.Plugins, conf.Plugin.StartTimeout) + app.pluginRegistry.LoadPlugins( + runCtx, + app.conf.Plugin.Plugins, + app.conf.Plugin.StartTimeout, + ) // Start the metrics merger if enabled. - var metricsMerger *metrics.Merger - if conf.Plugin.EnableMetricsMerger { - metricsMerger = metrics.NewMerger(runCtx, metrics.Merger{ - MetricsMergerPeriod: conf.Plugin.MetricsMergerPeriod, - Logger: logger, - }) - pluginRegistry.ForEach(func(_ sdkPlugin.Identifier, plugin *plugin.Plugin) { - if metricsEnabled, err := strconv.ParseBool(plugin.Config["metricsEnabled"]); err == nil && metricsEnabled { - metricsMerger.Add(plugin.ID.Name, plugin.Config["metricsUnixDomainSocket"]) - logger.Debug().Str("plugin", plugin.ID.Name).Msg( - "Added plugin to metrics merger") - } - }) - metricsMerger.Start() - } + app.startMetricsMerger(runCtx, logger) // TODO: Move this to the plugin registry. ctx, span := otel.Tracer(config.TracerName).Start(runCtx, "Plugin health check") - // Ping the plugins to check if they are alive, and remove them if they are not. - startDelay := time.Now().Add(conf.Plugin.HealthCheckPeriod) - if _, err := healthCheckScheduler.Every( - conf.Plugin.HealthCheckPeriod).SingletonMode().StartAt(startDelay).Do(func() { - _, span := otel.Tracer(config.TracerName).Start(ctx, "Run plugin health check") - defer span.End() - - var plugins []string - pluginRegistry.ForEach(func(pluginId sdkPlugin.Identifier, plugin *plugin.Plugin) { - if err := plugin.Ping(); err != nil { - span.RecordError(err) - logger.Error().Err(err).Msg("Failed to ping plugin") - if conf.Plugin.EnableMetricsMerger && metricsMerger != nil { - metricsMerger.Remove(pluginId.Name) - } - pluginRegistry.Remove(pluginId) - - if !conf.Plugin.ReloadOnCrash { - return // Do not reload the plugins. - } - - // Reload the plugins and register their hooks upon crash. - logger.Info().Str("name", pluginId.Name).Msg("Reloading crashed plugin") - pluginConfig := conf.Plugin.GetPlugins(pluginId.Name) - if pluginConfig != nil { - pluginRegistry.LoadPlugins(runCtx, pluginConfig, conf.Plugin.StartTimeout) - } - } else { - logger.Trace().Str("name", pluginId.Name).Msg("Successfully pinged plugin") - plugins = append(plugins, pluginId.Name) - } - }) - span.SetAttributes(attribute.StringSlice("plugins", plugins)) - }); err != nil { - logger.Error().Err(err).Msg("Failed to start plugin health check scheduler") - span.RecordError(err) - } - if pluginRegistry.Size() > 0 { - logger.Info().Str( - "healthCheckPeriod", conf.Plugin.HealthCheckPeriod.String(), - ).Msg("Starting plugin health check scheduler") - healthCheckScheduler.StartAsync() - } + // Start the health check scheduler only if there are plugins. + app.startHealthCheckScheduler(runCtx, ctx, span, logger) span.End() - // Set the plugin timeout context. - pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), conf.Plugin.Timeout) - defer cancel() - - // The config will be passed to the plugins that register to the "OnConfigLoaded" plugin. - // The plugins can modify the config and return it. - updatedGlobalConfig, err := pluginRegistry.Run( - pluginTimeoutCtx, conf.GlobalKoanf.All(), v1.HookName_HOOK_NAME_ON_CONFIG_LOADED) - if err != nil { - logger.Error().Err(err).Msg("Failed to run OnConfigLoaded hooks") - span.RecordError(err) - } - if updatedGlobalConfig != nil { - updatedGlobalConfig = pluginRegistry.ActRegistry.RunAll(updatedGlobalConfig) - } - - // If the config was modified by the plugins, merge it with the one loaded from the file. - // Only global configuration is merged, which means that plugins cannot modify the plugin - // configurations. - if updatedGlobalConfig != nil { - // Merge the config with the one loaded from the file (in memory). - // The changes won't be persisted to disk. - if err := conf.MergeGlobalConfig(runCtx, updatedGlobalConfig); err != nil { - log.Fatal(err) - } + // Merge the global config with the one from the plugins. + if err := app.onConfigLoaded(runCtx, span, logger); err != nil { + app.stopGracefully(runCtx, nil) + os.Exit(gerr.FailedToMergeGlobalConfig) } // Start the metrics server if enabled. - // TODO: Start multiple metrics servers. For now, only one default is supported. - // I should first find a use case for those multiple metrics servers. - go func(metricsConfig *config.Metrics, logger zerolog.Logger) { - _, span := otel.Tracer(config.TracerName).Start(runCtx, "Start metrics server") - defer span.End() - - // TODO: refactor this to a separate function. - if !metricsConfig.Enabled { - logger.Info().Msg("Metrics server is disabled") - return - } - - scheme := "http://" - if metricsConfig.KeyFile != "" && metricsConfig.CertFile != "" { - scheme = "https://" - } - - fqdn, err := url.Parse(scheme + metricsConfig.Address) - if err != nil { - logger.Error().Err(err).Msg("Failed to parse metrics address") + go func(app *GatewayDApp) { + if err := app.startMetricsServer(runCtx, logger); err != nil { + logger.Error().Err(err).Msg("Failed to start metrics server") span.RecordError(err) - return } + }(app) - address, err := url.JoinPath(fqdn.String(), metricsConfig.Path) - if err != nil { - logger.Error().Err(err).Msg("Failed to parse metrics path") - span.RecordError(err) - return - } - - // Merge the metrics from the plugins with the ones from GatewayD. - mergedMetricsHandler := func(next http.Handler) http.Handler { - handler := func(responseWriter http.ResponseWriter, request *http.Request) { - if _, err := responseWriter.Write(metricsMerger.OutputMetrics); err != nil { - logger.Error().Err(err).Msg("Failed to write metrics") - span.RecordError(err) - sentry.CaptureException(err) - } - // The WriteHeader method intentionally does nothing, to prevent a bug - // in the merging metrics that causes the headers to be written twice, - // which results in an error: "http: superfluous response.WriteHeader call". - next.ServeHTTP( - &metrics.HeaderBypassResponseWriter{ - ResponseWriter: responseWriter, - }, - request) - } - return http.HandlerFunc(handler) - } - - handler := func() http.Handler { - return promhttp.InstrumentMetricHandler( - prometheus.DefaultRegisterer, - promhttp.HandlerFor(prometheus.DefaultGatherer, promhttp.HandlerOpts{ - DisableCompression: true, - }), - ) - }() - - mux := http.NewServeMux() - mux.HandleFunc("/", func(responseWriter http.ResponseWriter, _ *http.Request) { - // Serve a static page with a link to the metrics endpoint. - if _, err := responseWriter.Write([]byte(fmt.Sprintf( - `GatewayD Prometheus Metrics ServerMetrics`, - address, - ))); err != nil { - logger.Error().Err(err).Msg("Failed to write metrics") - span.RecordError(err) - sentry.CaptureException(err) - } - }) - - if conf.Plugin.EnableMetricsMerger && metricsMerger != nil { - handler = mergedMetricsHandler(handler) - } - - readHeaderTimeout := config.If( - metricsConfig.ReadHeaderTimeout > 0, - metricsConfig.ReadHeaderTimeout, - config.DefaultReadHeaderTimeout, - ) - - // Check if the metrics server is already running before registering the handler. - if _, err = http.Get(address); err != nil { //nolint:gosec - // The timeout handler limits the nested handlers from running for too long. - mux.Handle( - metricsConfig.Path, - http.TimeoutHandler( - gziphandler.GzipHandler(handler), - readHeaderTimeout, - "The request timed out while fetching the metrics", - ), - ) - } else { - logger.Warn().Msg("Metrics server is already running, consider changing the port") - span.RecordError(err) - } - - // Create a new metrics server. - timeout := config.If( - metricsConfig.Timeout > 0, - metricsConfig.Timeout, - config.DefaultMetricsServerTimeout, - ) - metricsServer = &http.Server{ - Addr: metricsConfig.Address, - Handler: mux, - ReadHeaderTimeout: readHeaderTimeout, - ReadTimeout: timeout, - WriteTimeout: timeout, - IdleTimeout: timeout, - } - - logger.Info().Fields(map[string]any{ - "address": address, - "timeout": timeout.String(), - "readHeaderTimeout": readHeaderTimeout.String(), - }).Msg("Metrics are exposed") - - if metricsConfig.CertFile != "" && metricsConfig.KeyFile != "" { - // Set up TLS. - metricsServer.TLSConfig = &tls.Config{ - MinVersion: tls.VersionTLS13, - CurvePreferences: []tls.CurveID{ - tls.CurveP521, - tls.CurveP384, - tls.CurveP256, - }, - CipherSuites: []uint16{ - tls.TLS_AES_128_GCM_SHA256, - tls.TLS_AES_256_GCM_SHA384, - tls.TLS_CHACHA20_POLY1305_SHA256, - }, - } - metricsServer.TLSNextProto = make( - map[string]func(*http.Server, *tls.Conn, http.Handler)) - logger.Debug().Msg("Metrics server is running with TLS") - - // Start the metrics server with TLS. - if err = metricsServer.ListenAndServeTLS( - metricsConfig.CertFile, metricsConfig.KeyFile); !errors.Is(err, http.ErrServerClosed) { - logger.Error().Err(err).Msg("Failed to start metrics server") - span.RecordError(err) - } - } else { - // Start the metrics server without TLS. - if err = metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { - logger.Error().Err(err).Msg("Failed to start metrics server") - span.RecordError(err) - } - } - }(conf.Global.Metrics[config.Default], logger) - - // This is a notification hook, so we don't care about the result. - pluginTimeoutCtx, cancel = context.WithTimeout(context.Background(), conf.Plugin.Timeout) - defer cancel() - - if data, ok := conf.GlobalKoanf.Get("loggers").(map[string]any); ok { - result, err := pluginRegistry.Run( - pluginTimeoutCtx, data, v1.HookName_HOOK_NAME_ON_NEW_LOGGER) - if err != nil { - logger.Error().Err(err).Msg("Failed to run OnNewLogger hooks") - span.RecordError(err) - } - if result != nil { - _ = pluginRegistry.ActRegistry.RunAll(result) - } - } else { - logger.Error().Msg("Failed to get loggers from config") - } - - // Declare httpServer and grpcServer here as it is used in the StopGracefully function ahead of their definition. - var httpServer *api.HTTPServer - var grpcServer *api.GRPCServer + // Run the OnNewLogger hook. + app.onNewLogger(span, logger) _, span = otel.Tracer(config.TracerName).Start(runCtx, "Create pools and clients") - // Create and initialize pools of connections. - for configGroupName, configGroup := range conf.Global.Pools { - for configBlockName, cfg := range configGroup { - logger := loggers[configGroupName] - // Check if the pool size is greater than zero. - currentPoolSize := config.If( - cfg.Size > 0, - // Check if the pool size is greater than the minimum pool size. - config.If( - cfg.Size > config.MinimumPoolSize, - cfg.Size, - config.MinimumPoolSize, - ), - config.DefaultPoolSize, - ) - - if _, ok := pools[configGroupName]; !ok { - pools[configGroupName] = make(map[string]*pool.Pool) - } - pools[configGroupName][configBlockName] = pool.NewPool(runCtx, currentPoolSize) - - span.AddEvent("Create pool", trace.WithAttributes( - attribute.String("name", configBlockName), - attribute.Int("size", currentPoolSize), - )) - - if _, ok := clients[configGroupName]; !ok { - clients[configGroupName] = make(map[string]*config.Client) - } - - // Get client config from the config file. - if clientConfig, ok := conf.Global.Clients[configGroupName][configBlockName]; !ok { - // This ensures that the default client config is used if the pool name is not - // found in the clients section. - clients[configGroupName][configBlockName] = conf.Global.Clients[config.Default][config.DefaultConfigurationBlock] - } else { - // Merge the default client config with the one from the pool. - clients[configGroupName][configBlockName] = clientConfig - } - - // Fill the missing and zero values with the default ones. - clients[configGroupName][configBlockName].TCPKeepAlivePeriod = config.If( - clients[configGroupName][configBlockName].TCPKeepAlivePeriod > 0, - clients[configGroupName][configBlockName].TCPKeepAlivePeriod, - config.DefaultTCPKeepAlivePeriod, - ) - clients[configGroupName][configBlockName].ReceiveDeadline = config.If( - clients[configGroupName][configBlockName].ReceiveDeadline > 0, - clients[configGroupName][configBlockName].ReceiveDeadline, - config.DefaultReceiveDeadline, - ) - clients[configGroupName][configBlockName].ReceiveTimeout = config.If( - clients[configGroupName][configBlockName].ReceiveTimeout > 0, - clients[configGroupName][configBlockName].ReceiveTimeout, - config.DefaultReceiveTimeout, - ) - clients[configGroupName][configBlockName].SendDeadline = config.If( - clients[configGroupName][configBlockName].SendDeadline > 0, - clients[configGroupName][configBlockName].SendDeadline, - config.DefaultSendDeadline, - ) - clients[configGroupName][configBlockName].ReceiveChunkSize = config.If( - clients[configGroupName][configBlockName].ReceiveChunkSize > 0, - clients[configGroupName][configBlockName].ReceiveChunkSize, - config.DefaultChunkSize, - ) - clients[configGroupName][configBlockName].DialTimeout = config.If( - clients[configGroupName][configBlockName].DialTimeout > 0, - clients[configGroupName][configBlockName].DialTimeout, - config.DefaultDialTimeout, - ) - - // Add clients to the pool. - for range currentPoolSize { - clientConfig := clients[configGroupName][configBlockName] - clientConfig.GroupName = configGroupName - clientConfig.BlockName = configBlockName - client := network.NewClient( - runCtx, clientConfig, logger, - network.NewRetry( - network.Retry{ - Retries: clientConfig.Retries, - Backoff: config.If( - clientConfig.Backoff > 0, - clientConfig.Backoff, - config.DefaultBackoff, - ), - BackoffMultiplier: clientConfig.BackoffMultiplier, - DisableBackoffCaps: clientConfig.DisableBackoffCaps, - Logger: loggers[configBlockName], - }, - ), - ) - - if client != nil { - eventOptions := trace.WithAttributes( - attribute.String("name", configBlockName), - attribute.String("group", configGroupName), - attribute.String("network", client.Network), - attribute.String("address", client.Address), - attribute.Int("receiveChunkSize", client.ReceiveChunkSize), - attribute.String("receiveDeadline", client.ReceiveDeadline.String()), - attribute.String("receiveTimeout", client.ReceiveTimeout.String()), - attribute.String("sendDeadline", client.SendDeadline.String()), - attribute.String("dialTimeout", client.DialTimeout.String()), - attribute.Bool("tcpKeepAlive", client.TCPKeepAlive), - attribute.String("tcpKeepAlivePeriod", client.TCPKeepAlivePeriod.String()), - attribute.String("localAddress", client.LocalAddr()), - attribute.String("remoteAddress", client.RemoteAddr()), - attribute.Int("retries", clientConfig.Retries), - attribute.String("backoff", client.Retry().Backoff.String()), - attribute.Float64("backoffMultiplier", clientConfig.BackoffMultiplier), - attribute.Bool("disableBackoffCaps", clientConfig.DisableBackoffCaps), - ) - if client.ID != "" { - eventOptions = trace.WithAttributes( - attribute.String("id", client.ID), - ) - } - - span.AddEvent("Create client", eventOptions) - - pluginTimeoutCtx, cancel = context.WithTimeout( - context.Background(), conf.Plugin.Timeout) - defer cancel() - - clientCfg := map[string]any{ - "id": client.ID, - "name": configBlockName, - "group": configGroupName, - "network": client.Network, - "address": client.Address, - "receiveChunkSize": client.ReceiveChunkSize, - "receiveDeadline": client.ReceiveDeadline.String(), - "receiveTimeout": client.ReceiveTimeout.String(), - "sendDeadline": client.SendDeadline.String(), - "dialTimeout": client.DialTimeout.String(), - "tcpKeepAlive": client.TCPKeepAlive, - "tcpKeepAlivePeriod": client.TCPKeepAlivePeriod.String(), - "localAddress": client.LocalAddr(), - "remoteAddress": client.RemoteAddr(), - "retries": clientConfig.Retries, - "backoff": client.Retry().Backoff.String(), - "backoffMultiplier": clientConfig.BackoffMultiplier, - "disableBackoffCaps": clientConfig.DisableBackoffCaps, - } - result, err := pluginRegistry.Run( - pluginTimeoutCtx, clientCfg, v1.HookName_HOOK_NAME_ON_NEW_CLIENT) - if err != nil { - logger.Error().Err(err).Msg("Failed to run OnNewClient hooks") - span.RecordError(err) - } - if result != nil { - _ = pluginRegistry.ActRegistry.RunAll(result) - } - - err = pools[configGroupName][configBlockName].Put(client.ID, client) - if err != nil { - logger.Error().Err(err).Msg("Failed to add client to the pool") - span.RecordError(err) - } - } else { - logger.Error().Msg("Failed to create client, please check the configuration") - go func() { - // Wait for the stop signal to exit gracefully. - // This prevents the program from waiting indefinitely - // after the StopGracefully function is called. - <-stopChan - os.Exit(gerr.FailedToCreateClient) - }() - StopGracefully( - runCtx, - nil, - metricsMerger, - metricsServer, - pluginRegistry, - logger, - servers, - stopChan, - httpServer, - grpcServer, - ) - } - } - - // Verify that the pool is properly populated. - logger.Info().Fields(map[string]any{ - "name": configBlockName, - "count": strconv.Itoa(pools[configGroupName][configBlockName].Size()), - }).Msg("There are clients available in the pool") - - if pools[configGroupName][configBlockName].Size() != currentPoolSize { - logger.Error().Msg( - "The pool size is incorrect, either because " + - "the clients cannot connect due to no network connectivity " + - "or the server is not running. exiting...") - pluginRegistry.Shutdown() - os.Exit(gerr.FailedToInitializePool) - } - - pluginTimeoutCtx, cancel = context.WithTimeout( - context.Background(), conf.Plugin.Timeout) - defer cancel() - result, err := pluginRegistry.Run( - pluginTimeoutCtx, - map[string]any{"name": configBlockName, "size": currentPoolSize}, - v1.HookName_HOOK_NAME_ON_NEW_POOL) - if err != nil { - logger.Error().Err(err).Msg("Failed to run OnNewPool hooks") - span.RecordError(err) - } - if result != nil { - _ = pluginRegistry.ActRegistry.RunAll(result) - } - } + // Create pools and clients. + if err := app.createPoolAndClients(runCtx, span); err != nil { + logger.Error().Err(err).Msg("Failed to create pools and clients") + span.RecordError(err) + app.stopGracefully(runCtx, nil) + os.Exit(gerr.FailedToCreatePoolAndClients) } span.End() _, span = otel.Tracer(config.TracerName).Start(runCtx, "Create proxies") - // Create and initialize prefork proxies with each pool of clients. - for configGroupName, configGroup := range conf.Global.Proxies { - for configBlockName, cfg := range configGroup { - logger := loggers[configGroupName] - clientConfig := clients[configGroupName][configBlockName] - - // Fill the missing and zero value with the default one. - cfg.HealthCheckPeriod = config.If( - cfg.HealthCheckPeriod > 0, - cfg.HealthCheckPeriod, - config.DefaultHealthCheckPeriod, - ) - - if _, ok := proxies[configGroupName]; !ok { - proxies[configGroupName] = make(map[string]*network.Proxy) - } - - proxies[configGroupName][configBlockName] = network.NewProxy( - runCtx, - network.Proxy{ - GroupName: configGroupName, - BlockName: configBlockName, - AvailableConnections: pools[configGroupName][configBlockName], - PluginRegistry: pluginRegistry, - HealthCheckPeriod: cfg.HealthCheckPeriod, - ClientConfig: clientConfig, - Logger: logger, - PluginTimeout: conf.Plugin.Timeout, - }, - ) - - span.AddEvent("Create proxy", trace.WithAttributes( - attribute.String("name", configBlockName), - attribute.String("healthCheckPeriod", cfg.HealthCheckPeriod.String()), - )) - - pluginTimeoutCtx, cancel = context.WithTimeout( - context.Background(), conf.Plugin.Timeout) - defer cancel() - - if data, ok := conf.GlobalKoanf.Get("proxies").(map[string]any); ok { - result, err := pluginRegistry.Run( - pluginTimeoutCtx, data, v1.HookName_HOOK_NAME_ON_NEW_PROXY) - if err != nil { - logger.Error().Err(err).Msg("Failed to run OnNewProxy hooks") - span.RecordError(err) - } - if result != nil { - _ = pluginRegistry.ActRegistry.RunAll(result) - } - } else { - logger.Error().Msg("Failed to get proxy from config") - } - } - } + // Create proxies. + app.createProxies(runCtx, span) span.End() _, span = otel.Tracer(config.TracerName).Start(runCtx, "Create Raft Node") defer span.End() - raftNode, originalErr := raft.NewRaftNode(logger, conf.Global.Raft) + // Create the Raft node. + raftNode, originalErr := raft.NewRaftNode(logger, app.conf.Global.Raft) if originalErr != nil { logger.Error().Err(originalErr).Msg("Failed to start raft node") span.RecordError(originalErr) - pluginRegistry.Shutdown() + app.stopGracefully(runCtx, nil) os.Exit(gerr.FailedToStartRaftNode) } _, span = otel.Tracer(config.TracerName).Start(runCtx, "Create servers") - // Create and initialize servers. - for name, cfg := range conf.Global.Servers { - logger := loggers[name] - - var serverProxies []network.IProxy - for _, proxy := range proxies[name] { - serverProxies = append(serverProxies, proxy) - } - - servers[name] = network.NewServer( - runCtx, - network.Server{ - GroupName: name, - Network: cfg.Network, - Address: cfg.Address, - TickInterval: config.If( - cfg.TickInterval > 0, - cfg.TickInterval, - config.DefaultTickInterval, - ), - Options: network.Option{ - // Can be used to send keepalive messages to the client. - EnableTicker: cfg.EnableTicker, - }, - Proxies: serverProxies, - Logger: logger, - PluginRegistry: pluginRegistry, - PluginTimeout: conf.Plugin.Timeout, - EnableTLS: cfg.EnableTLS, - CertFile: cfg.CertFile, - KeyFile: cfg.KeyFile, - HandshakeTimeout: cfg.HandshakeTimeout, - LoadbalancerStrategyName: cfg.LoadBalancer.Strategy, - LoadbalancerRules: cfg.LoadBalancer.LoadBalancingRules, - LoadbalancerConsistentHash: cfg.LoadBalancer.ConsistentHash, - RaftNode: raftNode, - }, - ) - span.AddEvent("Create server", trace.WithAttributes( - attribute.String("name", name), - attribute.String("network", cfg.Network), - attribute.String("address", cfg.Address), - attribute.String("tickInterval", cfg.TickInterval.String()), - attribute.String("pluginTimeout", conf.Plugin.Timeout.String()), - attribute.Bool("enableTLS", cfg.EnableTLS), - attribute.String("certFile", cfg.CertFile), - attribute.String("keyFile", cfg.KeyFile), - attribute.String("handshakeTimeout", cfg.HandshakeTimeout.String()), - )) - - pluginTimeoutCtx, cancel = context.WithTimeout( - context.Background(), conf.Plugin.Timeout) - defer cancel() - - if data, ok := conf.GlobalKoanf.Get("servers").(map[string]any); ok { - result, err := pluginRegistry.Run( - pluginTimeoutCtx, data, v1.HookName_HOOK_NAME_ON_NEW_SERVER) - if err != nil { - logger.Error().Err(err).Msg("Failed to run OnNewServer hooks") - span.RecordError(err) - } - if result != nil { - _ = pluginRegistry.ActRegistry.RunAll(result) - } - } else { - logger.Error().Msg("Failed to get the servers configuration") - } - } + // Create servers. + app.createServers(runCtx, span, raftNode) span.End() - // Start the HTTP and gRPC APIs. - if conf.Global.API.Enabled { - apiOptions := api.Options{ - Logger: logger, - GRPCNetwork: conf.Global.API.GRPCNetwork, - GRPCAddress: conf.Global.API.GRPCAddress, - HTTPAddress: conf.Global.API.HTTPAddress, - Servers: servers, - RaftNode: raftNode, - } - - apiObj := &api.API{ - Options: &apiOptions, - Config: conf, - PluginRegistry: pluginRegistry, - Pools: pools, - Proxies: proxies, - Servers: servers, - } - grpcServer = api.NewGRPCServer( - runCtx, - api.GRPCServer{ - API: apiObj, - HealthChecker: &api.HealthChecker{Servers: servers}, - }, - ) - if grpcServer != nil { - go grpcServer.Start() - logger.Info().Str("address", apiOptions.HTTPAddress).Msg("Started the HTTP API") + // Start the API servers. + app.startAPIServers(runCtx, logger, raftNode) - httpServer = api.NewHTTPServer(&apiOptions) - go httpServer.Start() - - logger.Info().Fields( - map[string]any{ - "network": apiOptions.GRPCNetwork, - "address": apiOptions.GRPCAddress, - }, - ).Msg("Started the gRPC Server") - } - } - - // Report usage statistics. - if enableUsageReport { - go func() { - conn, err := grpc.NewClient( - UsageReportURL, - grpc.WithTransportCredentials( - credentials.NewTLS( - &tls.Config{ - MinVersion: tls.VersionTLS12, - }, - ), - ), - ) - if err != nil { - logger.Trace().Err(err).Msg( - "Failed to dial to the gRPC server for usage reporting") - } - defer func(conn *grpc.ClientConn) { - err := conn.Close() - if err != nil { - logger.Trace().Err(err).Msg("Failed to close the connection to the usage report service") - } - }(conn) - - client := usage.NewUsageReportServiceClient(conn) - report := usage.UsageReportRequest{ - Version: config.Version, - RuntimeVersion: runtime.Version(), - Goos: runtime.GOOS, - Goarch: runtime.GOARCH, - Service: "gatewayd", - DevMode: devMode, - Plugins: []*usage.Plugin{}, - } - pluginRegistry.ForEach( - func(identifier sdkPlugin.Identifier, _ *plugin.Plugin) { - report.Plugins = append(report.GetPlugins(), &usage.Plugin{ - Name: identifier.Name, - Version: identifier.Version, - Checksum: identifier.Checksum, - }) - }, - ) - _, err = client.Report(context.Background(), &report) - if err != nil { - logger.Trace().Err(err).Msg("Failed to report usage statistics") - } - }() - } - - // Shutdown the server gracefully. - signals := []os.Signal{ - os.Interrupt, - os.Kill, - syscall.SIGTERM, - syscall.SIGABRT, - syscall.SIGQUIT, - syscall.SIGHUP, - syscall.SIGINT, - } - signalsCh := make(chan os.Signal, 1) - signal.Notify(signalsCh, signals...) - go func(pluginRegistry *plugin.Registry, - logger zerolog.Logger, - servers map[string]*network.Server, - metricsMerger *metrics.Merger, - metricsServer *http.Server, - stopChan chan struct{}, - httpServer *api.HTTPServer, - grpcServer *api.GRPCServer, - ) { - for sig := range signalsCh { - for _, s := range signals { - if sig != s { - StopGracefully( - runCtx, - sig, - metricsMerger, - metricsServer, - pluginRegistry, - logger, - servers, - stopChan, - httpServer, - grpcServer, - ) - os.Exit(0) - } - } - } - }(pluginRegistry, logger, servers, metricsMerger, metricsServer, stopChan, httpServer, grpcServer) + // Report usage. + app.reportUsage(logger) _, span = otel.Tracer(config.TracerName).Start(runCtx, "Start servers") - // Start the server. - for name, server := range servers { - logger := loggers[name] - go func( - span trace.Span, - server *network.Server, - logger zerolog.Logger, - healthCheckScheduler *gocron.Scheduler, - metricsMerger *metrics.Merger, - pluginRegistry *plugin.Registry, - ) { - span.AddEvent("Start server") - if err := server.Run(); err != nil { - logger.Error().Err(err).Msg("Failed to start server") - span.RecordError(err) - healthCheckScheduler.Clear() - if metricsMerger != nil { - metricsMerger.Stop() - } - server.Shutdown() - pluginRegistry.Shutdown() - os.Exit(gerr.FailedToStartServer) - } - }(span, server, logger, healthCheckScheduler, metricsMerger, pluginRegistry) - } + // Start the servers. + app.startServers(runCtx, span) + span.End() // Wait for the server to shut down. - <-stopChan + <-app.stopChan }, } func init() { rootCmd.AddCommand(runCmd) - runCmd.Flags().StringVarP( - &globalConfigFile, + runCmd.Flags().StringP( "config", "c", config.GetDefaultConfigFilePath(config.GlobalConfigFilename), "Global config file") - runCmd.Flags().StringVarP( - &pluginConfigFile, + runCmd.Flags().StringP( "plugin-config", "p", config.GetDefaultConfigFilePath(config.PluginsConfigFilename), "Plugin config file") - runCmd.Flags().BoolVar( - &devMode, "dev", false, "Enable development mode for plugin development") - runCmd.Flags().BoolVar( - &enableTracing, "tracing", false, "Enable tracing with OpenTelemetry via gRPC") - runCmd.Flags().StringVar( - &collectorURL, "collector-url", "localhost:4317", - "Collector URL of OpenTelemetry gRPC endpoint") - runCmd.Flags().BoolVar( - &enableSentry, "sentry", true, "Enable Sentry") - runCmd.Flags().BoolVar( - &enableUsageReport, "usage-report", true, "Enable usage report") - runCmd.Flags().BoolVar( - &enableLinting, "lint", true, "Enable linting of configuration files") + runCmd.Flags().Bool("dev", false, "Enable development mode for plugin development") + runCmd.Flags().Bool("tracing", false, "Enable tracing with OpenTelemetry via gRPC") + runCmd.Flags().String( + "collector-url", "localhost:4317", "Collector URL of OpenTelemetry gRPC endpoint") + runCmd.Flags().Bool("sentry", true, "Enable Sentry") + runCmd.Flags().Bool("usage-report", true, "Enable usage report") + runCmd.Flags().Bool("lint", true, "Enable linting of configuration files") + runCmd.Flags().Bool("metrics-merger", true, "Enable metrics merger") } diff --git a/cmd/run_test.go b/cmd/run_test.go index 37fd54c5..61f0b269 100644 --- a/cmd/run_test.go +++ b/cmd/run_test.go @@ -7,7 +7,6 @@ import ( "testing" "time" - "github.com/gatewayd-io/gatewayd/config" "github.com/gatewayd-io/gatewayd/testhelpers" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -19,6 +18,12 @@ var ( ) func Test_runCmd(t *testing.T) { + previous := EnableTestMode() + defer func() { + testMode = previous + testApp = nil + }() + postgresHostIP, postgresMappedPort := testhelpers.SetupPostgreSQLTestContainer(context.Background(), t) postgresAddress := postgresHostIP + ":" + postgresMappedPort.Port() t.Setenv("GATEWAYD_CLIENTS_DEFAULT_WRITES_ADDRESS", postgresAddress) @@ -41,14 +46,13 @@ func Test_runCmd(t *testing.T) { // Check that the config file was created. assert.FileExists(t, globalTestConfigFile, "configInitCmd should create a config file") - stopChan = make(chan struct{}) - var waitGroup sync.WaitGroup waitGroup.Add(1) go func(waitGroup *sync.WaitGroup) { // Test run command. - output, err := executeCommandC(rootCmd, "run", "-c", globalTestConfigFile, "-p", pluginTestConfigFile) + output, err := executeCommandC( + rootCmd, "run", "-c", globalTestConfigFile, "-p", pluginTestConfigFile) require.NoError(t, err, "run command should not have returned an error") // Print the output for debugging purposes. @@ -60,23 +64,11 @@ func Test_runCmd(t *testing.T) { waitGroup.Done() }(&waitGroup) + // Stop the server after a delay waitGroup.Add(1) go func(waitGroup *sync.WaitGroup) { time.Sleep(waitBeforeStop) - - StopGracefully( - context.Background(), - nil, - nil, - metricsServer, - nil, - loggers[config.Default], - servers, - stopChan, - nil, - nil, - ) - + testApp.stopGracefully(context.Background(), os.Interrupt) waitGroup.Done() }(&waitGroup) @@ -88,6 +80,12 @@ func Test_runCmd(t *testing.T) { // Test_runCmdWithTLS tests the run command with TLS enabled on the server. func Test_runCmdWithTLS(t *testing.T) { + previous := EnableTestMode() + defer func() { + testMode = previous + testApp = nil + }() + postgresHostIP, postgresMappedPort := testhelpers.SetupPostgreSQLTestContainer(context.Background(), t) postgresAddress := postgresHostIP + ":" + postgresMappedPort.Port() t.Setenv("GATEWAYD_CLIENTS_DEFAULT_WRITES_ADDRESS", postgresAddress) @@ -104,8 +102,6 @@ func Test_runCmdWithTLS(t *testing.T) { require.NoError(t, err, "plugin init command should not have returned an error") assert.FileExists(t, pluginTestConfigFile, "plugin init command should have created a config file") - stopChan = make(chan struct{}) - var waitGroup sync.WaitGroup // TODO: Test client certificate authentication. @@ -126,23 +122,11 @@ func Test_runCmdWithTLS(t *testing.T) { waitGroup.Done() }(&waitGroup) + // Stop the server after a delay waitGroup.Add(1) go func(waitGroup *sync.WaitGroup) { time.Sleep(waitBeforeStop) - - StopGracefully( - context.Background(), - nil, - nil, - metricsServer, - nil, - loggers[config.Default], - servers, - stopChan, - nil, - nil, - ) - + testApp.stopGracefully(context.Background(), os.Interrupt) waitGroup.Done() }(&waitGroup) @@ -153,6 +137,12 @@ func Test_runCmdWithTLS(t *testing.T) { // Test_runCmdWithMultiTenancy tests the run command with multi-tenancy enabled. func Test_runCmdWithMultiTenancy(t *testing.T) { + previous := EnableTestMode() + defer func() { + testMode = previous + testApp = nil + }() + postgresHostIP, postgresMappedPort := testhelpers.SetupPostgreSQLTestContainer(context.Background(), t) postgresAddress := postgresHostIP + ":" + postgresMappedPort.Port() t.Setenv("GATEWAYD_CLIENTS_DEFAULT_WRITES_ADDRESS", postgresAddress) @@ -172,8 +162,6 @@ func Test_runCmdWithMultiTenancy(t *testing.T) { require.NoError(t, err, "plugin init command should not have returned an error") assert.FileExists(t, pluginTestConfigFile, "plugin init command should have created a config file") - stopChan = make(chan struct{}) - var waitGroup sync.WaitGroup waitGroup.Add(1) @@ -196,23 +184,11 @@ func Test_runCmdWithMultiTenancy(t *testing.T) { waitGroup.Done() }(&waitGroup) + // Stop the server after a delay waitGroup.Add(1) go func(waitGroup *sync.WaitGroup) { time.Sleep(waitBeforeStop) - - StopGracefully( - context.Background(), - nil, - nil, - metricsServer, - nil, - loggers[config.Default], - servers, - stopChan, - nil, - nil, - ) - + testApp.stopGracefully(context.Background(), os.Interrupt) waitGroup.Done() }(&waitGroup) @@ -222,6 +198,12 @@ func Test_runCmdWithMultiTenancy(t *testing.T) { } func Test_runCmdWithCachePlugin(t *testing.T) { + previous := EnableTestMode() + defer func() { + testMode = previous + testApp = nil + }() + postgresHostIP, postgresMappedPort := testhelpers.SetupPostgreSQLTestContainer(context.Background(), t) postgresAddress := postgresHostIP + ":" + postgresMappedPort.Port() t.Setenv("GATEWAYD_CLIENTS_DEFAULT_WRITES_ADDRESS", postgresAddress) @@ -233,9 +215,6 @@ func Test_runCmdWithCachePlugin(t *testing.T) { globalTestConfigFile := "./test_global_runCmdWithCachePlugin.yaml" pluginTestConfigFile := "./test_plugins_runCmdWithCachePlugin.yaml" - // TODO: Remove this once these global variables are removed from cmd/run.go. - // https://github.com/gatewayd-io/gatewayd/issues/324 - stopChan = make(chan struct{}) // Create a test plugins config file. _, err := executeCommandC(rootCmd, "plugin", "init", "--force", "-p", pluginTestConfigFile) @@ -258,15 +237,16 @@ func Test_runCmdWithCachePlugin(t *testing.T) { rootCmd, "plugin", "install", "-p", pluginTestConfigFile, "--update", "--backup", "--overwrite-config=true", "--name", "gatewayd-plugin-cache", pluginArchivePath) require.NoError(t, err, "plugin install should not return an error") - assert.Equal(t, output, "Installing plugin from CLI argument\nBackup completed successfully\nPlugin binary extracted to plugins/gatewayd-plugin-cache\nPlugin installed successfully\n") //nolint:lll + assert.Contains(t, output, "Installing plugin from CLI argument") + assert.Contains(t, output, "Backup completed successfully") + assert.Contains(t, output, "Plugin binary extracted to plugins/gatewayd-plugin-cache") + assert.Contains(t, output, "Plugin installed successfully") // See if the plugin was actually installed. output, err = executeCommandC(rootCmd, "plugin", "list", "-p", pluginTestConfigFile) require.NoError(t, err, "plugin list should not return an error") assert.Contains(t, output, "Name: gatewayd-plugin-cache") - stopChan = make(chan struct{}) - var waitGroup sync.WaitGroup waitGroup.Add(1) @@ -284,23 +264,11 @@ func Test_runCmdWithCachePlugin(t *testing.T) { waitGroup.Done() }(&waitGroup) + // Stop the server after a delay waitGroup.Add(1) go func(waitGroup *sync.WaitGroup) { - time.Sleep(waitBeforeStop * 2) - - StopGracefully( - context.Background(), - nil, - nil, - metricsServer, - nil, - loggers[config.Default], - servers, - stopChan, - nil, - nil, - ) - + time.Sleep(waitBeforeStop) + testApp.stopGracefully(context.Background(), os.Interrupt) waitGroup.Done() }(&waitGroup) diff --git a/config/getters_test.go b/config/getters_test.go index 3f71eb37..a9b5e3c7 100644 --- a/config/getters_test.go +++ b/config/getters_test.go @@ -35,9 +35,9 @@ func TestGetDefaultConfigFilePath(t *testing.T) { // TestFilter tests the Filter function. func TestFilter(t *testing.T) { // Load config from the default config file. - conf := NewConfig(context.TODO(), + conf := NewConfig(context.Background(), Config{GlobalConfigFile: "../gatewayd.yaml", PluginConfigFile: "../gatewayd_plugins.yaml"}) - err := conf.InitConfig(context.TODO()) + err := conf.InitConfig(context.Background()) require.Nil(t, err) assert.NotEmpty(t, conf.Global) @@ -55,9 +55,9 @@ func TestFilter(t *testing.T) { // TestFilterWithMissingGroupName tests the Filter function with a missing group name. func TestFilterWithMissingGroupName(t *testing.T) { // Load config from the default config file. - conf := NewConfig(context.TODO(), + conf := NewConfig(context.Background(), Config{GlobalConfigFile: "../gatewayd.yaml", PluginConfigFile: "../gatewayd_plugins.yaml"}) - err := conf.InitConfig(context.TODO()) + err := conf.InitConfig(context.Background()) require.Nil(t, err) assert.NotEmpty(t, conf.Global) diff --git a/errors/errors.go b/errors/errors.go index 61bb0e32..9cce47db 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -212,12 +212,3 @@ var ( // Unwrapped errors. ErrLoggerRequired = errors.New("terminate action requires a logger parameter") ) - -const ( - FailedToCreateClient = 1 - FailedToInitializePool = 2 - FailedToStartServer = 3 - FailedToStartTracer = 4 - FailedToCreateActRegistry = 5 - FailedToStartRaftNode = 6 -) diff --git a/errors/exit_codes.go b/errors/exit_codes.go new file mode 100644 index 00000000..14ed89f3 --- /dev/null +++ b/errors/exit_codes.go @@ -0,0 +1,11 @@ +package errors + +const ( + FailedToStartServer = 1 + FailedToStartTracer = 2 + FailedToCreateActRegistry = 3 + FailedToLoadPolicies = 4 + FailedToStartRaftNode = 5 + FailedToMergeGlobalConfig = 6 + FailedToCreatePoolAndClients = 7 +) diff --git a/network/server_test.go b/network/server_test.go index 0ea89696..fcb65de4 100644 --- a/network/server_test.go +++ b/network/server_test.go @@ -109,11 +109,10 @@ func TestRunServer(t *testing.T) { }, ) assert.NotNil(t, server) - assert.Zero(t, server.connections) assert.Zero(t, server.CountConnections()) assert.Empty(t, server.host) assert.Empty(t, server.port) - assert.False(t, server.running.Load()) + assert.False(t, server.IsRunning()) var waitGroup sync.WaitGroup waitGroup.Add(2) @@ -215,8 +214,8 @@ func TestRunServer(t *testing.T) { <-time.After(100 * time.Millisecond) // check server status and connections - assert.False(t, server.running.Load()) - assert.Zero(t, server.connections) + assert.False(t, server.IsRunning()) + assert.Zero(t, server.CountConnections()) // Read the log file and check if the log file contains the expected log messages. require.FileExists(t, "server_test.log")