diff --git a/.gitignore b/.gitignore
index afd8e42b..6c193a0b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -43,3 +43,8 @@ gatewayd-files/
cmd/gatewayd-plugin-cache-linux-amd64-*
tempo-data
+# Raft files
+raft/node*/
+
+# Accidental installation of plugin
+plugins/gatewayd-plugin-cache
diff --git a/.golangci.yaml b/.golangci.yaml
index 60a34508..b3c826c8 100644
--- a/.golangci.yaml
+++ b/.golangci.yaml
@@ -4,6 +4,7 @@ linters:
enable-all: true
disable:
- cyclop
+ - dupl
- wsl
- godox
- gochecknoglobals
diff --git a/api/api_helpers_test.go b/api/api_helpers_test.go
index 7e5bba9b..381cb55c 100644
--- a/api/api_helpers_test.go
+++ b/api/api_helpers_test.go
@@ -12,7 +12,7 @@ import (
)
// getAPIConfig returns a new API configuration with all the necessary components.
-func getAPIConfig() *API {
+func getAPIConfig(httpAddr, grpcAddr string) *API {
logger := zerolog.New(nil)
defaultPool := pool.NewPool(context.Background(), config.DefaultPoolSize)
pluginReg := plugin.NewRegistry(
@@ -60,8 +60,8 @@ func getAPIConfig() *API {
return &API{
Options: &Options{
GRPCNetwork: "tcp",
- GRPCAddress: "localhost:19090",
- HTTPAddress: "localhost:18080",
+ GRPCAddress: grpcAddr,
+ HTTPAddress: httpAddr,
Logger: logger,
Servers: servers,
},
diff --git a/api/api_test.go b/api/api_test.go
index a1f0287e..378030da 100644
--- a/api/api_test.go
+++ b/api/api_test.go
@@ -30,9 +30,9 @@ func TestGetVersion(t *testing.T) {
func TestGetGlobalConfig(t *testing.T) {
// Load config from the default config file.
- conf := config.NewConfig(context.TODO(),
+ conf := config.NewConfig(context.Background(),
config.Config{GlobalConfigFile: "../gatewayd.yaml", PluginConfigFile: "../gatewayd_plugins.yaml"})
- gerr := conf.InitConfig(context.TODO())
+ gerr := conf.InitConfig(context.Background())
require.Nil(t, gerr)
assert.NotEmpty(t, conf.Global)
@@ -55,9 +55,9 @@ func TestGetGlobalConfig(t *testing.T) {
func TestGetGlobalConfigWithGroupName(t *testing.T) {
// Load config from the default config file.
- conf := config.NewConfig(context.TODO(),
+ conf := config.NewConfig(context.Background(),
config.Config{GlobalConfigFile: "../gatewayd.yaml", PluginConfigFile: "../gatewayd_plugins.yaml"})
- gerr := conf.InitConfig(context.TODO())
+ gerr := conf.InitConfig(context.Background())
require.Nil(t, gerr)
assert.NotEmpty(t, conf.Global)
@@ -88,9 +88,9 @@ func TestGetGlobalConfigWithGroupName(t *testing.T) {
func TestGetGlobalConfigWithNonExistingGroupName(t *testing.T) {
// Load config from the default config file.
- conf := config.NewConfig(context.TODO(),
+ conf := config.NewConfig(context.Background(),
config.Config{GlobalConfigFile: "../gatewayd.yaml", PluginConfigFile: "../gatewayd_plugins.yaml"})
- gerr := conf.InitConfig(context.TODO())
+ gerr := conf.InitConfig(context.Background())
require.Nil(t, gerr)
assert.NotEmpty(t, conf.Global)
@@ -106,9 +106,9 @@ func TestGetGlobalConfigWithNonExistingGroupName(t *testing.T) {
func TestGetPluginConfig(t *testing.T) {
// Load config from the default config file.
- conf := config.NewConfig(context.TODO(),
+ conf := config.NewConfig(context.Background(),
config.Config{GlobalConfigFile: "../gatewayd.yaml", PluginConfigFile: "../gatewayd_plugins.yaml"})
- gerr := conf.InitConfig(context.TODO())
+ gerr := conf.InitConfig(context.Background())
require.Nil(t, gerr)
assert.NotEmpty(t, conf.Global)
@@ -135,7 +135,7 @@ func TestGetPlugins(t *testing.T) {
Logger: zerolog.Logger{},
})
pluginRegistry := plugin.NewRegistry(
- context.TODO(),
+ context.Background(),
plugin.Registry{
ActRegistry: actRegistry,
Compatibility: config.Loose,
@@ -190,7 +190,7 @@ func TestGetPluginsWithEmptyPluginRegistry(t *testing.T) {
Logger: zerolog.Logger{},
})
pluginRegistry := plugin.NewRegistry(
- context.TODO(),
+ context.Background(),
plugin.Registry{
ActRegistry: actRegistry,
Compatibility: config.Loose,
@@ -212,7 +212,7 @@ func TestGetPluginsWithEmptyPluginRegistry(t *testing.T) {
func TestPools(t *testing.T) {
api := API{
Pools: map[string]map[string]*pool.Pool{
- config.Default: {config.DefaultConfigurationBlock: pool.NewPool(context.TODO(), config.EmptyPoolCapacity)},
+ config.Default: {config.DefaultConfigurationBlock: pool.NewPool(context.Background(), config.EmptyPoolCapacity)},
},
ctx: context.Background(),
}
@@ -247,13 +247,13 @@ func TestGetProxies(t *testing.T) {
Network: config.DefaultNetwork,
Address: postgresAddress,
}
- client := network.NewClient(context.TODO(), clientConfig, zerolog.Logger{}, nil)
+ client := network.NewClient(context.Background(), clientConfig, zerolog.Logger{}, nil)
require.NotNil(t, client)
- newPool := pool.NewPool(context.TODO(), 1)
+ newPool := pool.NewPool(context.Background(), 1)
assert.Nil(t, newPool.Put(client.ID, client))
proxy := network.NewProxy(
- context.TODO(),
+ context.Background(),
network.Proxy{
AvailableConnections: newPool,
HealthCheckPeriod: config.DefaultHealthCheckPeriod,
@@ -299,13 +299,13 @@ func TestGetServers(t *testing.T) {
Network: config.DefaultNetwork,
Address: postgresAddress,
}
- client := network.NewClient(context.TODO(), clientConfig, zerolog.Logger{}, nil)
- newPool := pool.NewPool(context.TODO(), 1)
+ client := network.NewClient(context.Background(), clientConfig, zerolog.Logger{}, nil)
+ newPool := pool.NewPool(context.Background(), 1)
require.NotNil(t, newPool)
assert.Nil(t, newPool.Put(client.ID, client))
proxy := network.NewProxy(
- context.TODO(),
+ context.Background(),
network.Proxy{
AvailableConnections: newPool,
HealthCheckPeriod: config.DefaultHealthCheckPeriod,
@@ -330,7 +330,7 @@ func TestGetServers(t *testing.T) {
})
pluginRegistry := plugin.NewRegistry(
- context.TODO(),
+ context.Background(),
plugin.Registry{
ActRegistry: actRegistry,
Compatibility: config.Loose,
@@ -340,7 +340,7 @@ func TestGetServers(t *testing.T) {
)
server := network.NewServer(
- context.TODO(),
+ context.Background(),
network.Server{
Network: config.DefaultNetwork,
Address: postgresAddress,
diff --git a/api/grpc_server.go b/api/grpc_server.go
index a12744db..9a2eb824 100644
--- a/api/grpc_server.go
+++ b/api/grpc_server.go
@@ -2,6 +2,7 @@ package api
import (
"context"
+ "errors"
"net"
v1 "github.com/gatewayd-io/gatewayd/api/v1"
@@ -41,18 +42,23 @@ func NewGRPCServer(ctx context.Context, server GRPCServer) *GRPCServer {
// Start starts the gRPC server.
func (s *GRPCServer) Start() {
- s.start(s.API, s.grpcServer, s.listener)
+ if err := s.grpcServer.Serve(s.listener); err != nil && !errors.Is(err, net.ErrClosed) {
+ s.API.Options.Logger.Err(err).Msg("failed to start gRPC API")
+ }
}
// Shutdown shuts down the gRPC server.
-func (s *GRPCServer) Shutdown(_ context.Context) {
- s.shutdown(s.grpcServer)
+func (s *GRPCServer) Shutdown(context.Context) {
+ if err := s.listener.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
+ s.API.Options.Logger.Err(err).Msg("failed to close listener")
+ }
+ s.grpcServer.GracefulStop()
}
// createGRPCAPI creates a new gRPC API server and listener.
func createGRPCAPI(api *API, healthchecker *HealthChecker) (*grpc.Server, net.Listener) {
listener, err := net.Listen(api.Options.GRPCNetwork, api.Options.GRPCAddress)
- if err != nil {
+ if err != nil && !errors.Is(err, net.ErrClosed) {
api.Options.Logger.Err(err).Msg("failed to start gRPC API")
return nil, nil
}
@@ -64,15 +70,3 @@ func createGRPCAPI(api *API, healthchecker *HealthChecker) (*grpc.Server, net.Li
return grpcServer, listener
}
-
-// start starts the gRPC API.
-func (s *GRPCServer) start(api *API, grpcServer *grpc.Server, listener net.Listener) {
- if err := grpcServer.Serve(listener); err != nil {
- api.Options.Logger.Err(err).Msg("failed to start gRPC API")
- }
-}
-
-// shutdown shuts down the gRPC API.
-func (s *GRPCServer) shutdown(grpcServer *grpc.Server) {
- grpcServer.GracefulStop()
-}
diff --git a/api/grpc_server_test.go b/api/grpc_server_test.go
index 16f2c449..f7c0c95e 100644
--- a/api/grpc_server_test.go
+++ b/api/grpc_server_test.go
@@ -14,7 +14,10 @@ import (
// Test_GRPC_Server tests the gRPC server.
func Test_GRPC_Server(t *testing.T) {
- api := getAPIConfig()
+ api := getAPIConfig(
+ "localhost:18081",
+ "localhost:19091",
+ )
healthchecker := &HealthChecker{Servers: api.Servers}
grpcServer := NewGRPCServer(
context.Background(), GRPCServer{API: api, HealthChecker: healthchecker})
@@ -25,7 +28,7 @@ func Test_GRPC_Server(t *testing.T) {
}(grpcServer)
grpcClient, err := grpc.NewClient(
- "localhost:19090", grpc.WithTransportCredentials(insecure.NewCredentials()))
+ "localhost:19091", grpc.WithTransportCredentials(insecure.NewCredentials()))
assert.Nil(t, err)
defer grpcClient.Close()
@@ -36,5 +39,5 @@ func Test_GRPC_Server(t *testing.T) {
assert.Equal(t, config.Version, resp.GetVersion())
assert.Equal(t, config.VersionInfo(), resp.GetVersionInfo())
- grpcServer.Shutdown(context.Background())
+ grpcServer.Shutdown(nil) //nolint:staticcheck
}
diff --git a/api/healthcheck_test.go b/api/healthcheck_test.go
index 74b0988f..d304b14f 100644
--- a/api/healthcheck_test.go
+++ b/api/healthcheck_test.go
@@ -24,13 +24,13 @@ func Test_Healthchecker(t *testing.T) {
Network: config.DefaultNetwork,
Address: postgresAddress,
}
- client := network.NewClient(context.TODO(), clientConfig, zerolog.Logger{}, nil)
- newPool := pool.NewPool(context.TODO(), 1)
+ client := network.NewClient(context.Background(), clientConfig, zerolog.Logger{}, nil)
+ newPool := pool.NewPool(context.Background(), 1)
require.NotNil(t, newPool)
assert.Nil(t, newPool.Put(client.ID, client))
proxy := network.NewProxy(
- context.TODO(),
+ context.Background(),
network.Proxy{
AvailableConnections: newPool,
HealthCheckPeriod: config.DefaultHealthCheckPeriod,
@@ -55,7 +55,7 @@ func Test_Healthchecker(t *testing.T) {
})
pluginRegistry := plugin.NewRegistry(
- context.TODO(),
+ context.Background(),
plugin.Registry{
ActRegistry: actRegistry,
Compatibility: config.Loose,
@@ -75,10 +75,10 @@ func Test_Healthchecker(t *testing.T) {
}()
server := network.NewServer(
- context.TODO(),
+ context.Background(),
network.Server{
Network: config.DefaultNetwork,
- Address: postgresAddress,
+ Address: "127.0.0.1:15432",
TickInterval: config.DefaultTickInterval,
Options: network.Option{
EnableTicker: false,
@@ -99,13 +99,14 @@ func Test_Healthchecker(t *testing.T) {
raftNode: raftHelper.Node,
}
assert.NotNil(t, healthchecker)
- hcr, err := healthchecker.Check(context.TODO(), &grpc_health_v1.HealthCheckRequest{})
+ hcr, err := healthchecker.Check(context.Background(), &grpc_health_v1.HealthCheckRequest{})
assert.NoError(t, err)
assert.NotNil(t, hcr)
assert.Equal(t, grpc_health_v1.HealthCheckResponse_NOT_SERVING, hcr.GetStatus())
go func(t *testing.T, server *network.Server) {
t.Helper()
+
if err := server.Run(); err != nil {
t.Errorf("server.Run() error = %v", err)
}
@@ -113,7 +114,7 @@ func Test_Healthchecker(t *testing.T) {
time.Sleep(1 * time.Second)
// Test for SERVING status
- hcr, err = healthchecker.Check(context.TODO(), &grpc_health_v1.HealthCheckRequest{})
+ hcr, err = healthchecker.Check(context.Background(), &grpc_health_v1.HealthCheckRequest{})
assert.NoError(t, err)
assert.NotNil(t, hcr)
assert.Equal(t, grpc_health_v1.HealthCheckResponse_SERVING, hcr.GetStatus())
@@ -121,4 +122,14 @@ func Test_Healthchecker(t *testing.T) {
err = healthchecker.Watch(&grpc_health_v1.HealthCheckRequest{}, nil)
assert.Error(t, err)
assert.Equal(t, "rpc error: code = Unimplemented desc = not implemented", err.Error())
+
+ server.Shutdown()
+ pluginRegistry.Shutdown()
+
+ // Wait for the server to stop.
+ <-time.After(100 * time.Millisecond)
+
+ // check server status and connections
+ assert.False(t, server.IsRunning())
+ assert.Zero(t, server.CountConnections())
}
diff --git a/api/http_server.go b/api/http_server.go
index 596214b3..c516e00d 100644
--- a/api/http_server.go
+++ b/api/http_server.go
@@ -40,12 +40,17 @@ func NewHTTPServer(options *Options) *HTTPServer {
// Start starts the HTTP server.
func (s *HTTPServer) Start() {
- s.start(s.options, s.httpServer)
+ // Start HTTP server (and proxy calls to gRPC server endpoint)
+ if err := s.httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
+ s.options.Logger.Err(err).Msg("failed to start HTTP API")
+ }
}
// Shutdown shuts down the HTTP server.
func (s *HTTPServer) Shutdown(ctx context.Context) {
- s.shutdown(ctx, s.httpServer, s.logger)
+ if err := s.httpServer.Shutdown(ctx); err != nil {
+ s.logger.Err(err).Msg("failed to shutdown HTTP API")
+ }
}
// CreateHTTPAPI creates a new HTTP API.
@@ -113,18 +118,3 @@ func createHTTPAPI(options *Options) *http.Server {
return server
}
-
-// start starts the HTTP API.
-func (s *HTTPServer) start(options *Options, server *http.Server) {
- // Start HTTP server (and proxy calls to gRPC server endpoint)
- if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
- options.Logger.Err(err).Msg("failed to start HTTP API")
- }
-}
-
-// shutdown shuts down the HTTP API.
-func (s *HTTPServer) shutdown(ctx context.Context, server *http.Server, logger zerolog.Logger) {
- if err := server.Shutdown(ctx); err != nil {
- logger.Err(err).Msg("failed to shutdown HTTP API")
- }
-}
diff --git a/api/http_server_test.go b/api/http_server_test.go
index 77a6e36c..bc6092fc 100644
--- a/api/http_server_test.go
+++ b/api/http_server_test.go
@@ -15,29 +15,27 @@ import (
// Test_HTTP_Server tests the HTTP to gRPC gateway.
func Test_HTTP_Server(t *testing.T) {
- api := getAPIConfig()
+ api := getAPIConfig(
+ "localhost:18082",
+ "localhost:19092",
+ )
healthchecker := &HealthChecker{Servers: api.Servers}
grpcServer := NewGRPCServer(
context.Background(), GRPCServer{API: api, HealthChecker: healthchecker})
- assert.NotNil(t, grpcServer)
- httpServer := NewHTTPServer(api.Options)
- assert.NotNil(t, httpServer)
+ go grpcServer.Start()
+ require.NotNil(t, grpcServer)
- go func(grpcServer *GRPCServer) {
- grpcServer.Start()
- }(grpcServer)
-
- go func(httpServer *HTTPServer) {
- httpServer.Start()
- }(httpServer)
+ httpServer := NewHTTPServer(api.Options)
+ go httpServer.Start()
+ require.NotNil(t, httpServer)
- time.Sleep(1 * time.Second) // Wait for the servers to start.
+ time.Sleep(time.Second)
// Check version via the gRPC server.
req, err := http.NewRequestWithContext(
context.Background(),
http.MethodGet,
- "http://localhost:18080/v1/GatewayDPluginService/Version",
+ "http://localhost:18082/v1/GatewayDPluginService/Version",
nil,
)
require.NoError(t, err)
@@ -56,7 +54,7 @@ func Test_HTTP_Server(t *testing.T) {
req, err = http.NewRequestWithContext(
context.Background(),
http.MethodGet,
- "http://localhost:18080/healthz",
+ "http://localhost:18082/healthz",
nil,
)
require.NoError(t, err)
@@ -73,7 +71,7 @@ func Test_HTTP_Server(t *testing.T) {
req, err = http.NewRequestWithContext(
context.Background(),
http.MethodGet,
- "http://localhost:18080/version",
+ "http://localhost:18082/version",
nil,
)
require.NoError(t, err)
@@ -87,6 +85,6 @@ func Test_HTTP_Server(t *testing.T) {
assert.Equal(t, len(config.Version), len(respBodyBytes))
assert.Equal(t, config.Version, string(respBodyBytes))
- grpcServer.Shutdown(context.Background())
+ grpcServer.Shutdown(nil) //nolint:staticcheck
httpServer.Shutdown(context.Background())
}
diff --git a/cmd/config_init.go b/cmd/config_init.go
index 57c33dd9..e2064ac2 100644
--- a/cmd/config_init.go
+++ b/cmd/config_init.go
@@ -6,13 +6,15 @@ import (
"github.com/spf13/cobra"
)
-var force bool
-
// configInitCmd represents the plugin init command.
var configInitCmd = &cobra.Command{
Use: "init",
Short: "Create or overwrite the GatewayD global config",
Run: func(cmd *cobra.Command, _ []string) {
+ force, _ := cmd.Flags().GetBool("force")
+ enableSentry, _ := cmd.Flags().GetBool("sentry")
+ globalConfigFile, _ := cmd.Flags().GetString("config")
+
// Enable Sentry.
if enableSentry {
// Initialize Sentry.
@@ -39,12 +41,10 @@ var configInitCmd = &cobra.Command{
func init() {
configCmd.AddCommand(configInitCmd)
- configInitCmd.Flags().BoolVarP(
- &force, "force", "f", false, "Force overwrite of existing config file")
- configInitCmd.Flags().StringVarP(
- &globalConfigFile, // Already exists in run.go
+ configInitCmd.Flags().BoolP(
+ "force", "f", false, "Force overwrite of existing config file")
+ configInitCmd.Flags().StringP(
"config", "c", config.GetDefaultConfigFilePath(config.GlobalConfigFilename),
"Global config file")
- configInitCmd.Flags().BoolVar(
- &enableSentry, "sentry", true, "Enable Sentry") // Already exists in run.go
+ configInitCmd.Flags().Bool("sentry", true, "Enable Sentry")
}
diff --git a/cmd/config_lint.go b/cmd/config_lint.go
index fa52e314..fb7e0de5 100644
--- a/cmd/config_lint.go
+++ b/cmd/config_lint.go
@@ -14,6 +14,9 @@ var configLintCmd = &cobra.Command{
Use: "lint",
Short: "Lint the GatewayD global config",
Run: func(cmd *cobra.Command, _ []string) {
+ enableSentry, _ := cmd.Flags().GetBool("sentry")
+ globalConfigFile, _ := cmd.Flags().GetString("config")
+
// Enable Sentry.
if enableSentry {
// Initialize Sentry.
@@ -44,10 +47,8 @@ var configLintCmd = &cobra.Command{
func init() {
configCmd.AddCommand(configLintCmd)
- configLintCmd.Flags().StringVarP(
- &globalConfigFile, // Already exists in run.go
+ configLintCmd.Flags().StringP(
"config", "c", config.GetDefaultConfigFilePath(config.GlobalConfigFilename),
"Global config file")
- configLintCmd.Flags().BoolVar(
- &enableSentry, "sentry", true, "Enable Sentry") // Already exists in run.go
+ configLintCmd.Flags().Bool("sentry", true, "Enable Sentry")
}
diff --git a/cmd/configs.go b/cmd/configs.go
index 5baf9185..7152ca2f 100644
--- a/cmd/configs.go
+++ b/cmd/configs.go
@@ -27,7 +27,7 @@ func generateConfig(
GlobalKoanf: koanf.New("."),
PluginKoanf: koanf.New("."),
}
- if err := conf.LoadDefaults(context.TODO()); err != nil {
+ if err := conf.LoadDefaults(context.Background()); err != nil {
logger.Fatal(err)
}
@@ -73,28 +73,28 @@ func lintConfig(fileType configFileType, configFile string) *gerr.GatewayDError
var conf *config.Config
switch fileType {
case Global:
- conf = config.NewConfig(context.TODO(), config.Config{GlobalConfigFile: configFile})
- if err := conf.LoadDefaults(context.TODO()); err != nil {
+ conf = config.NewConfig(context.Background(), config.Config{GlobalConfigFile: configFile})
+ if err := conf.LoadDefaults(context.Background()); err != nil {
return err
}
- if err := conf.LoadGlobalConfigFile(context.TODO()); err != nil {
+ if err := conf.LoadGlobalConfigFile(context.Background()); err != nil {
return err
}
- if err := conf.ConvertKeysToLowercase(context.TODO()); err != nil {
+ if err := conf.ConvertKeysToLowercase(context.Background()); err != nil {
return err
}
- if err := conf.UnmarshalGlobalConfig(context.TODO()); err != nil {
+ if err := conf.UnmarshalGlobalConfig(context.Background()); err != nil {
return err
}
case Plugins:
- conf = config.NewConfig(context.TODO(), config.Config{PluginConfigFile: configFile})
- if err := conf.LoadDefaults(context.TODO()); err != nil {
+ conf = config.NewConfig(context.Background(), config.Config{PluginConfigFile: configFile})
+ if err := conf.LoadDefaults(context.Background()); err != nil {
return err
}
- if err := conf.LoadPluginConfigFile(context.TODO()); err != nil {
+ if err := conf.LoadPluginConfigFile(context.Background()); err != nil {
return err
}
- if err := conf.UnmarshalPluginConfig(context.TODO()); err != nil {
+ if err := conf.UnmarshalPluginConfig(context.Background()); err != nil {
return err
}
default:
diff --git a/cmd/gatewayd_app.go b/cmd/gatewayd_app.go
new file mode 100644
index 00000000..fdd4c49d
--- /dev/null
+++ b/cmd/gatewayd_app.go
@@ -0,0 +1,1143 @@
+package cmd
+
+import (
+ "context"
+ "crypto/tls"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "os"
+ "os/signal"
+ "runtime"
+ "strconv"
+ "sync/atomic"
+ "time"
+
+ "github.com/NYTimes/gziphandler"
+ sdkAct "github.com/gatewayd-io/gatewayd-plugin-sdk/act"
+ sdkPlugin "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin"
+ v1 "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin/v1"
+ "github.com/gatewayd-io/gatewayd/act"
+ "github.com/gatewayd-io/gatewayd/api"
+ "github.com/gatewayd-io/gatewayd/config"
+ gerr "github.com/gatewayd-io/gatewayd/errors"
+ "github.com/gatewayd-io/gatewayd/logging"
+ "github.com/gatewayd-io/gatewayd/metrics"
+ "github.com/gatewayd-io/gatewayd/network"
+ "github.com/gatewayd-io/gatewayd/plugin"
+ "github.com/gatewayd-io/gatewayd/pool"
+ "github.com/gatewayd-io/gatewayd/raft"
+ usage "github.com/gatewayd-io/gatewayd/usagereport/v1"
+ "github.com/getsentry/sentry-go"
+ "github.com/go-co-op/gocron"
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/prometheus/client_golang/prometheus/promhttp"
+ "github.com/redis/go-redis/v9"
+ "github.com/rs/zerolog"
+ "github.com/spf13/cobra"
+ "go.opentelemetry.io/otel"
+ "go.opentelemetry.io/otel/attribute"
+ "go.opentelemetry.io/otel/trace"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/credentials"
+)
+
+var _ io.Writer = &cobraCmdWriter{}
+
+type cobraCmdWriter struct {
+ *cobra.Command
+}
+
+func (c *cobraCmdWriter) Write(p []byte) (int, error) {
+ c.Print(string(p))
+ return len(p), nil
+}
+
+var UsageReportURL = "localhost:59091"
+
+const (
+ DefaultMetricsServerProbeTimeout = 5 * time.Second
+)
+
+type GatewayDApp struct {
+ EnableTracing bool
+ EnableSentry bool
+ EnableLinting bool
+ EnableUsageReport bool
+ DevMode bool
+ CollectorURL string
+ PluginConfigFile string
+ GlobalConfigFile string
+
+ conf *config.Config
+ pluginRegistry *plugin.Registry
+ actRegistry *act.Registry
+ metricsServer *http.Server
+ metricsMerger *metrics.Merger
+ httpServer *api.HTTPServer
+ grpcServer *api.GRPCServer
+
+ loggers map[string]zerolog.Logger
+ pools map[string]map[string]*pool.Pool
+ clients map[string]map[string]*config.Client
+ proxies map[string]map[string]*network.Proxy
+ servers map[string]*network.Server
+ healthCheckScheduler *gocron.Scheduler
+ stopChan chan struct{}
+ ranStopGracefully *atomic.Bool
+}
+
+// NewGatewayDApp creates a new GatewayDApp instance.
+func NewGatewayDApp(cmd *cobra.Command) *GatewayDApp {
+ app := GatewayDApp{
+ loggers: make(map[string]zerolog.Logger),
+ pools: make(map[string]map[string]*pool.Pool),
+ clients: make(map[string]map[string]*config.Client),
+ proxies: make(map[string]map[string]*network.Proxy),
+ servers: make(map[string]*network.Server),
+ healthCheckScheduler: gocron.NewScheduler(time.UTC),
+ stopChan: make(chan struct{}),
+ ranStopGracefully: &atomic.Bool{},
+ }
+ app.EnableTracing, _ = cmd.Flags().GetBool("enable-tracing")
+ app.EnableSentry, _ = cmd.Flags().GetBool("enable-sentry")
+ app.EnableUsageReport, _ = cmd.Flags().GetBool("enable-usage-report")
+ app.EnableLinting, _ = cmd.Flags().GetBool("enable-linting")
+ app.DevMode, _ = cmd.Flags().GetBool("dev")
+ app.CollectorURL, _ = cmd.Flags().GetString("collector-url")
+ app.GlobalConfigFile, _ = cmd.Flags().GetString("config")
+ app.PluginConfigFile, _ = cmd.Flags().GetString("plugin-config")
+ return &app
+}
+
+// loadConfig loads global and plugin configuration.
+func (app *GatewayDApp) loadConfig(runCtx context.Context) error {
+ app.conf = config.NewConfig(runCtx,
+ config.Config{
+ GlobalConfigFile: app.GlobalConfigFile,
+ PluginConfigFile: app.PluginConfigFile,
+ },
+ )
+ if err := app.conf.InitConfig(runCtx); err != nil {
+ return err
+ }
+ return nil
+}
+
+// createLoggers creates loggers from the config.
+func (app *GatewayDApp) createLoggers(
+ runCtx context.Context, cmd *cobra.Command,
+) zerolog.Logger {
+ // Use cobra command cmd instead of os.Stdout for the console output.
+ cmdLogger := &cobraCmdWriter{cmd}
+
+ // Create a logger for each tenant.
+ for name, cfg := range app.conf.Global.Loggers {
+ app.loggers[name] = logging.NewLogger(runCtx, logging.LoggerConfig{
+ Output: cfg.GetOutput(),
+ ConsoleOut: cmdLogger,
+ Level: config.If(
+ config.Exists(config.LogLevels, cfg.Level),
+ config.LogLevels[cfg.Level],
+ config.LogLevels[config.DefaultLogLevel],
+ ),
+ TimeFormat: config.If(
+ config.Exists(config.TimeFormats, cfg.TimeFormat),
+ config.TimeFormats[cfg.TimeFormat],
+ config.TimeFormats[config.DefaultTimeFormat],
+ ),
+ ConsoleTimeFormat: config.If(
+ config.Exists(
+ config.ConsoleTimeFormats, cfg.ConsoleTimeFormat),
+ config.ConsoleTimeFormats[cfg.ConsoleTimeFormat],
+ config.ConsoleTimeFormats[config.DefaultConsoleTimeFormat],
+ ),
+ NoColor: cfg.NoColor,
+ FileName: cfg.FileName,
+ MaxSize: cfg.MaxSize,
+ MaxBackups: cfg.MaxBackups,
+ MaxAge: cfg.MaxAge,
+ Compress: cfg.Compress,
+ LocalTime: cfg.LocalTime,
+ SyslogPriority: cfg.GetSyslogPriority(),
+ RSyslogNetwork: cfg.RSyslogNetwork,
+ RSyslogAddress: cfg.RSyslogAddress,
+ Name: name,
+ })
+ }
+ return app.loggers[config.Default]
+}
+
+// createActRegistry creates a new act registry given
+// the built-in signals, policies, and actions.
+func (app *GatewayDApp) createActRegistry(logger zerolog.Logger) error {
+ // Create a new act registry given the built-in signals, policies, and actions.
+ var publisher *act.Publisher
+ if app.conf.Plugin.ActionRedis.Enabled {
+ rdb := redis.NewClient(&redis.Options{
+ Addr: app.conf.Plugin.ActionRedis.Address,
+ })
+ var err error
+ publisher, err = act.NewPublisher(act.Publisher{
+ Logger: logger,
+ RedisDB: rdb,
+ ChannelName: app.conf.Plugin.ActionRedis.Channel,
+ })
+ if err != nil {
+ logger.Error().Err(err).Msg("Failed to create publisher for act registry")
+ return err //nolint:wrapcheck
+ }
+ logger.Info().Msg("Created Redis publisher for Act registry")
+ }
+
+ app.actRegistry = act.NewActRegistry(
+ act.Registry{
+ Signals: act.BuiltinSignals(),
+ Policies: act.BuiltinPolicies(),
+ Actions: act.BuiltinActions(),
+ DefaultPolicyName: app.conf.Plugin.DefaultPolicy,
+ PolicyTimeout: app.conf.Plugin.PolicyTimeout,
+ DefaultActionTimeout: app.conf.Plugin.ActionTimeout,
+ TaskPublisher: publisher,
+ Logger: logger,
+ },
+ )
+
+ return nil
+}
+
+// loadPolicies loads policies from the configuration file and
+// adds them to the registry.
+func (app *GatewayDApp) loadPolicies(logger zerolog.Logger) error {
+ // Load policies from the configuration file and add them to the registry.
+ for _, plc := range app.conf.Plugin.Policies {
+ policy, err := sdkAct.NewPolicy(
+ plc.Name, plc.Policy, plc.Metadata,
+ )
+ if err != nil || policy == nil {
+ logger.Error().Err(err).Str("name", plc.Name).Msg("Failed to create policy")
+ return err //nolint:wrapcheck
+ }
+ app.actRegistry.Add(policy)
+ }
+
+ return nil
+}
+
+// createPluginRegistry creates a new plugin registry.
+func (app *GatewayDApp) createPluginRegistry(runCtx context.Context, logger zerolog.Logger) {
+ // Create a new plugin registry.
+ // The plugins are loaded and hooks registered before the configuration is loaded.
+ app.pluginRegistry = plugin.NewRegistry(
+ runCtx,
+ plugin.Registry{
+ ActRegistry: app.actRegistry,
+ Compatibility: config.If(
+ config.Exists(
+ config.CompatibilityPolicies, app.conf.Plugin.CompatibilityPolicy,
+ ),
+ config.CompatibilityPolicies[app.conf.Plugin.CompatibilityPolicy],
+ config.DefaultCompatibilityPolicy),
+ Logger: logger,
+ DevMode: app.DevMode,
+ },
+ )
+}
+
+// startMetricsMerger starts the metrics merger if enabled.
+func (app *GatewayDApp) startMetricsMerger(runCtx context.Context, logger zerolog.Logger) {
+ // Start the metrics merger if enabled.
+ if app.conf.Plugin.EnableMetricsMerger {
+ app.metricsMerger = metrics.NewMerger(runCtx, metrics.Merger{
+ MetricsMergerPeriod: app.conf.Plugin.MetricsMergerPeriod,
+ Logger: logger,
+ })
+ app.pluginRegistry.ForEach(func(_ sdkPlugin.Identifier, plugin *plugin.Plugin) {
+ if metricsEnabled, err := strconv.ParseBool(plugin.Config["metricsEnabled"]); err == nil && metricsEnabled {
+ app.metricsMerger.Add(plugin.ID.Name, plugin.Config["metricsUnixDomainSocket"])
+ logger.Debug().Str("plugin", plugin.ID.Name).Msg(
+ "Added plugin to metrics merger")
+ }
+ })
+ app.metricsMerger.Start() //nolint:contextcheck
+ }
+}
+
+// startHealthCheckScheduler starts the health check scheduler if enabled.
+func (app *GatewayDApp) startHealthCheckScheduler(
+ runCtx, ctx context.Context, span trace.Span, logger zerolog.Logger,
+) {
+ // Ping the plugins to check if they are alive, and remove them if they are not.
+ startDelay := time.Now().Add(app.conf.Plugin.HealthCheckPeriod)
+ if _, err := app.healthCheckScheduler.Every(
+ app.conf.Plugin.HealthCheckPeriod).SingletonMode().StartAt(startDelay).Do(
+ func() {
+ _, span := otel.Tracer(config.TracerName).Start(ctx, "Run plugin health check")
+ defer span.End()
+
+ var plugins []string
+ app.pluginRegistry.ForEach(
+ func(pluginId sdkPlugin.Identifier, plugin *plugin.Plugin) {
+ if err := plugin.Ping(); err != nil {
+ span.RecordError(err)
+ logger.Error().Err(err).Msg("Failed to ping plugin")
+ if app.conf.Plugin.EnableMetricsMerger && app.metricsMerger != nil {
+ app.metricsMerger.Remove(pluginId.Name)
+ }
+ app.pluginRegistry.Remove(pluginId)
+
+ if !app.conf.Plugin.ReloadOnCrash {
+ return // Do not reload the plugins.
+ }
+
+ // Reload the plugins and register their hooks upon crash.
+ logger.Info().Str("name", pluginId.Name).Msg("Reloading crashed plugin")
+ pluginConfig := app.conf.Plugin.GetPlugins(pluginId.Name)
+ if pluginConfig != nil {
+ app.pluginRegistry.LoadPlugins(runCtx, pluginConfig, app.conf.Plugin.StartTimeout)
+ }
+ } else {
+ logger.Trace().Str("name", pluginId.Name).Msg("Successfully pinged plugin")
+ plugins = append(plugins, pluginId.Name)
+ }
+ })
+ span.SetAttributes(attribute.StringSlice("plugins", plugins))
+ }); err != nil {
+ logger.Error().Err(err).Msg("Failed to start plugin health check scheduler")
+ span.RecordError(err)
+ }
+
+ // Start the health check scheduler only if there are plugins.
+ if app.pluginRegistry.Size() > 0 {
+ logger.Info().Str(
+ "healthCheckPeriod", app.conf.Plugin.HealthCheckPeriod.String(),
+ ).Msg("Starting plugin health check scheduler")
+ app.healthCheckScheduler.StartAsync()
+ }
+}
+
+// onConfigLoaded runs the OnConfigLoaded hook and
+// merges the global config with the one from the plugins.
+func (app *GatewayDApp) onConfigLoaded(
+ runCtx context.Context, span trace.Span, logger zerolog.Logger,
+) error {
+ // Set the plugin timeout context.
+ pluginTimeoutCtx, cancel := context.WithTimeout(
+ context.Background(), app.conf.Plugin.Timeout)
+ defer cancel()
+
+ // The config will be passed to the plugins that register to the "OnConfigLoaded" plugin.
+ // The plugins can modify the config and return it.
+ updatedGlobalConfig, err := app.pluginRegistry.Run( //nolint:contextcheck
+ pluginTimeoutCtx, app.conf.GlobalKoanf.All(), v1.HookName_HOOK_NAME_ON_CONFIG_LOADED)
+ if err != nil {
+ logger.Error().Err(err).Msg("Failed to run OnConfigLoaded hooks")
+ span.RecordError(err)
+ }
+ if updatedGlobalConfig != nil {
+ updatedGlobalConfig = app.pluginRegistry.ActRegistry.RunAll(updatedGlobalConfig) //nolint:contextcheck
+ }
+
+ // If the config was modified by the plugins, merge it with the one loaded from the file.
+ // Only global configuration is merged, which means that plugins cannot modify the plugin
+ // configurations.
+ if updatedGlobalConfig != nil {
+ // Merge the config with the one loaded from the file (in memory).
+ // The changes won't be persisted to disk.
+ if err := app.conf.MergeGlobalConfig(runCtx, updatedGlobalConfig); err != nil {
+ logger.Error().Err(err).Msg("Failed to merge global config")
+ span.RecordError(err)
+ return err
+ }
+ }
+
+ return nil
+}
+
+// startMetricsServer starts the metrics server if enabled.
+func (app *GatewayDApp) startMetricsServer(
+ runCtx context.Context, logger zerolog.Logger,
+) error {
+ // Start the metrics server if enabled.
+ // TODO: Start multiple metrics servers. For now, only one default is supported.
+ // I should first find a use case for those multiple metrics servers.
+ _, span := otel.Tracer(config.TracerName).Start(runCtx, "Start metrics server")
+ defer span.End()
+
+ metricsConfig := app.conf.Global.Metrics[config.Default]
+
+ // TODO: refactor this to a separate function.
+ if !metricsConfig.Enabled {
+ logger.Info().Msg("Metrics server is disabled")
+ return nil
+ }
+
+ scheme := "http://"
+ if metricsConfig.KeyFile != "" && metricsConfig.CertFile != "" {
+ scheme = "https://"
+ }
+
+ fqdn, err := url.Parse(scheme + metricsConfig.Address)
+ if err != nil {
+ logger.Error().Err(err).Msg("Failed to parse metrics address")
+ span.RecordError(err)
+ return err //nolint:wrapcheck
+ }
+
+ address, err := url.JoinPath(fqdn.String(), metricsConfig.Path)
+ if err != nil {
+ logger.Error().Err(err).Msg("Failed to parse metrics path")
+ span.RecordError(err)
+ return err //nolint:wrapcheck
+ }
+
+ // Merge the metrics from the plugins with the ones from GatewayD.
+ mergedMetricsHandler := func(next http.Handler) http.Handler {
+ handler := func(responseWriter http.ResponseWriter, request *http.Request) {
+ if _, err := responseWriter.Write(app.metricsMerger.OutputMetrics); err != nil {
+ logger.Error().Err(err).Msg("Failed to write metrics")
+ span.RecordError(err)
+ sentry.CaptureException(err)
+ }
+ // The WriteHeader method intentionally does nothing, to prevent a bug
+ // in the merging metrics that causes the headers to be written twice,
+ // which results in an error: "http: superfluous response.WriteHeader call".
+ next.ServeHTTP(
+ &metrics.HeaderBypassResponseWriter{
+ ResponseWriter: responseWriter,
+ },
+ request)
+ }
+ return http.HandlerFunc(handler)
+ }
+
+ handler := func() http.Handler {
+ return promhttp.InstrumentMetricHandler(
+ prometheus.DefaultRegisterer,
+ promhttp.HandlerFor(prometheus.DefaultGatherer, promhttp.HandlerOpts{
+ DisableCompression: true,
+ }),
+ )
+ }()
+
+ mux := http.NewServeMux()
+ mux.HandleFunc("/", func(responseWriter http.ResponseWriter, _ *http.Request) {
+ // Serve a static page with a link to the metrics endpoint.
+ if _, err := responseWriter.Write([]byte(fmt.Sprintf(
+ `
GatewayD Prometheus Metrics ServerMetrics`,
+ address,
+ ))); err != nil {
+ logger.Error().Err(err).Msg("Failed to write metrics")
+ span.RecordError(err)
+ sentry.CaptureException(err)
+ }
+ })
+
+ if app.conf.Plugin.EnableMetricsMerger && app.metricsMerger != nil {
+ handler = mergedMetricsHandler(handler)
+ }
+
+ readHeaderTimeout := config.If(
+ metricsConfig.ReadHeaderTimeout > 0,
+ metricsConfig.ReadHeaderTimeout,
+ config.DefaultReadHeaderTimeout,
+ )
+
+ // Check if the metrics server is already running before registering the handler
+ ctx, cancel := context.WithTimeout(context.Background(), DefaultMetricsServerProbeTimeout)
+ defer cancel()
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, address, nil) //nolint:contextcheck
+ if err != nil {
+ logger.Error().Err(err).Msg("Failed to create request to check metrics server")
+ span.RecordError(err)
+ }
+
+ if resp, err := http.DefaultClient.Do(req); err != nil {
+ // The timeout handler limits the nested handlers from running for too long.
+ mux.Handle(
+ metricsConfig.Path,
+ http.TimeoutHandler(
+ gziphandler.GzipHandler(handler),
+ readHeaderTimeout,
+ "The request timed out while fetching the metrics",
+ ),
+ )
+ } else {
+ if resp != nil && resp.Body != nil {
+ defer resp.Body.Close()
+ }
+ logger.Warn().Msg("Metrics server is already running, consider changing the port")
+ span.RecordError(err)
+ }
+
+ // Create a new metrics server.
+ timeout := config.If(
+ metricsConfig.Timeout > 0,
+ metricsConfig.Timeout,
+ config.DefaultMetricsServerTimeout,
+ )
+ app.metricsServer = &http.Server{
+ Addr: metricsConfig.Address,
+ Handler: mux,
+ ReadHeaderTimeout: readHeaderTimeout,
+ ReadTimeout: timeout,
+ WriteTimeout: timeout,
+ IdleTimeout: timeout,
+ }
+
+ logger.Info().Fields(map[string]any{
+ "address": address,
+ "timeout": timeout.String(),
+ "readHeaderTimeout": readHeaderTimeout.String(),
+ }).Msg("Metrics are exposed")
+
+ if metricsConfig.CertFile != "" && metricsConfig.KeyFile != "" {
+ // Set up TLS.
+ app.metricsServer.TLSConfig = &tls.Config{
+ MinVersion: tls.VersionTLS13,
+ CurvePreferences: []tls.CurveID{
+ tls.CurveP521,
+ tls.CurveP384,
+ tls.CurveP256,
+ },
+ CipherSuites: []uint16{
+ tls.TLS_AES_128_GCM_SHA256,
+ tls.TLS_AES_256_GCM_SHA384,
+ tls.TLS_CHACHA20_POLY1305_SHA256,
+ },
+ }
+ app.metricsServer.TLSNextProto = make(
+ map[string]func(*http.Server, *tls.Conn, http.Handler))
+ logger.Debug().Msg("Metrics server is running with TLS")
+
+ // Start the metrics server with TLS.
+ if err = app.metricsServer.ListenAndServeTLS(
+ metricsConfig.CertFile, metricsConfig.KeyFile); !errors.Is(err, http.ErrServerClosed) {
+ logger.Error().Err(err).Msg("Failed to start metrics server")
+ span.RecordError(err)
+ }
+ } else {
+ // Start the metrics server without TLS.
+ if err = app.metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
+ logger.Error().Err(err).Msg("Failed to start metrics server")
+ span.RecordError(err)
+ }
+ }
+
+ return nil
+}
+
+// onNewLogger runs the OnNewLogger hook.
+func (app *GatewayDApp) onNewLogger(
+ span trace.Span, logger zerolog.Logger,
+) {
+ // This is a notification hook, so we don't care about the result.
+ pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), app.conf.Plugin.Timeout)
+ defer cancel()
+
+ if data, ok := app.conf.GlobalKoanf.Get("loggers").(map[string]any); ok {
+ result, err := app.pluginRegistry.Run(
+ pluginTimeoutCtx, data, v1.HookName_HOOK_NAME_ON_NEW_LOGGER)
+ if err != nil {
+ logger.Error().Err(err).Msg("Failed to run OnNewLogger hooks")
+ span.RecordError(err)
+ }
+ if result != nil {
+ _ = app.pluginRegistry.ActRegistry.RunAll(result)
+ }
+ } else {
+ logger.Error().Msg("Failed to get loggers from config")
+ }
+}
+
+// createPoolAndClients creates pools of connections and clients.
+func (app *GatewayDApp) createPoolAndClients(
+ runCtx context.Context, span trace.Span,
+) error {
+ // Create and initialize pools of connections.
+ for configGroupName, configGroup := range app.conf.Global.Pools {
+ for configBlockName, cfg := range configGroup {
+ logger := app.loggers[configGroupName]
+ // Check if the pool size is greater than zero.
+ currentPoolSize := config.If(
+ cfg.Size > 0,
+ // Check if the pool size is greater than the minimum pool size.
+ config.If(
+ cfg.Size > config.MinimumPoolSize,
+ cfg.Size,
+ config.MinimumPoolSize,
+ ),
+ config.DefaultPoolSize,
+ )
+
+ if _, ok := app.pools[configGroupName]; !ok {
+ app.pools[configGroupName] = make(map[string]*pool.Pool)
+ }
+ app.pools[configGroupName][configBlockName] = pool.NewPool(runCtx, currentPoolSize)
+
+ span.AddEvent("Create pool", trace.WithAttributes(
+ attribute.String("name", configBlockName),
+ attribute.Int("size", currentPoolSize),
+ ))
+
+ if _, ok := app.clients[configGroupName]; !ok {
+ app.clients[configGroupName] = make(map[string]*config.Client)
+ }
+
+ // Get client config from the config file.
+ if clientConfig, ok := app.conf.Global.Clients[configGroupName][configBlockName]; !ok {
+ // This ensures that the default client config is used if the pool name is not
+ // found in the clients section.
+ app.clients[configGroupName][configBlockName] = app.conf.Global.Clients[config.Default][config.DefaultConfigurationBlock] //nolint:lll
+ } else {
+ // Merge the default client config with the one from the pool.
+ app.clients[configGroupName][configBlockName] = clientConfig
+ }
+
+ // Fill the missing and zero values with the default ones.
+ app.clients[configGroupName][configBlockName].TCPKeepAlivePeriod = config.If(
+ app.clients[configGroupName][configBlockName].TCPKeepAlivePeriod > 0,
+ app.clients[configGroupName][configBlockName].TCPKeepAlivePeriod,
+ config.DefaultTCPKeepAlivePeriod,
+ )
+ app.clients[configGroupName][configBlockName].ReceiveDeadline = config.If(
+ app.clients[configGroupName][configBlockName].ReceiveDeadline > 0,
+ app.clients[configGroupName][configBlockName].ReceiveDeadline,
+ config.DefaultReceiveDeadline,
+ )
+ app.clients[configGroupName][configBlockName].ReceiveTimeout = config.If(
+ app.clients[configGroupName][configBlockName].ReceiveTimeout > 0,
+ app.clients[configGroupName][configBlockName].ReceiveTimeout,
+ config.DefaultReceiveTimeout,
+ )
+ app.clients[configGroupName][configBlockName].SendDeadline = config.If(
+ app.clients[configGroupName][configBlockName].SendDeadline > 0,
+ app.clients[configGroupName][configBlockName].SendDeadline,
+ config.DefaultSendDeadline,
+ )
+ app.clients[configGroupName][configBlockName].ReceiveChunkSize = config.If(
+ app.clients[configGroupName][configBlockName].ReceiveChunkSize > 0,
+ app.clients[configGroupName][configBlockName].ReceiveChunkSize,
+ config.DefaultChunkSize,
+ )
+ app.clients[configGroupName][configBlockName].DialTimeout = config.If(
+ app.clients[configGroupName][configBlockName].DialTimeout > 0,
+ app.clients[configGroupName][configBlockName].DialTimeout,
+ config.DefaultDialTimeout,
+ )
+
+ // Add clients to the pool.
+ for range currentPoolSize {
+ clientConfig := app.clients[configGroupName][configBlockName]
+ clientConfig.GroupName = configGroupName
+ clientConfig.BlockName = configBlockName
+ client := network.NewClient(
+ runCtx, clientConfig, logger,
+ network.NewRetry(
+ network.Retry{
+ Retries: clientConfig.Retries,
+ Backoff: config.If(
+ clientConfig.Backoff > 0,
+ clientConfig.Backoff,
+ config.DefaultBackoff,
+ ),
+ BackoffMultiplier: clientConfig.BackoffMultiplier,
+ DisableBackoffCaps: clientConfig.DisableBackoffCaps,
+ Logger: app.loggers[configBlockName],
+ },
+ ),
+ )
+
+ if client == nil {
+ return errors.New("failed to create client, please check the configuration")
+ }
+
+ eventOptions := trace.WithAttributes(
+ attribute.String("name", configBlockName),
+ attribute.String("group", configGroupName),
+ attribute.String("network", client.Network),
+ attribute.String("address", client.Address),
+ attribute.Int("receiveChunkSize", client.ReceiveChunkSize),
+ attribute.String("receiveDeadline", client.ReceiveDeadline.String()),
+ attribute.String("receiveTimeout", client.ReceiveTimeout.String()),
+ attribute.String("sendDeadline", client.SendDeadline.String()),
+ attribute.String("dialTimeout", client.DialTimeout.String()),
+ attribute.Bool("tcpKeepAlive", client.TCPKeepAlive),
+ attribute.String("tcpKeepAlivePeriod", client.TCPKeepAlivePeriod.String()),
+ attribute.String("localAddress", client.LocalAddr()),
+ attribute.String("remoteAddress", client.RemoteAddr()),
+ attribute.Int("retries", clientConfig.Retries),
+ attribute.String("backoff", client.Retry().Backoff.String()),
+ attribute.Float64("backoffMultiplier", clientConfig.BackoffMultiplier),
+ attribute.Bool("disableBackoffCaps", clientConfig.DisableBackoffCaps),
+ )
+ if client.ID != "" {
+ eventOptions = trace.WithAttributes(
+ attribute.String("id", client.ID),
+ )
+ }
+
+ span.AddEvent("Create client", eventOptions)
+
+ pluginTimeoutCtx, cancel := context.WithTimeout(
+ context.Background(), app.conf.Plugin.Timeout)
+ defer cancel()
+
+ clientCfg := map[string]any{
+ "id": client.ID,
+ "name": configBlockName,
+ "group": configGroupName,
+ "network": client.Network,
+ "address": client.Address,
+ "receiveChunkSize": client.ReceiveChunkSize,
+ "receiveDeadline": client.ReceiveDeadline.String(),
+ "receiveTimeout": client.ReceiveTimeout.String(),
+ "sendDeadline": client.SendDeadline.String(),
+ "dialTimeout": client.DialTimeout.String(),
+ "tcpKeepAlive": client.TCPKeepAlive,
+ "tcpKeepAlivePeriod": client.TCPKeepAlivePeriod.String(),
+ "localAddress": client.LocalAddr(),
+ "remoteAddress": client.RemoteAddr(),
+ "retries": clientConfig.Retries,
+ "backoff": client.Retry().Backoff.String(),
+ "backoffMultiplier": clientConfig.BackoffMultiplier,
+ "disableBackoffCaps": clientConfig.DisableBackoffCaps,
+ }
+ result, err := app.pluginRegistry.Run( //nolint:contextcheck
+ pluginTimeoutCtx, clientCfg, v1.HookName_HOOK_NAME_ON_NEW_CLIENT)
+ if err != nil {
+ logger.Error().Err(err).Msg("Failed to run OnNewClient hooks")
+ span.RecordError(err)
+ }
+ if result != nil {
+ _ = app.pluginRegistry.ActRegistry.RunAll(result) //nolint:contextcheck
+ }
+
+ err = app.pools[configGroupName][configBlockName].Put(client.ID, client)
+ if err != nil {
+ logger.Error().Err(err).Msg("Failed to add client to the pool")
+ span.RecordError(err)
+ }
+ }
+
+ // Verify that the pool is properly populated.
+ logger.Info().Fields(map[string]any{
+ "name": configBlockName,
+ "count": strconv.Itoa(app.pools[configGroupName][configBlockName].Size()),
+ }).Msg("There are clients available in the pool")
+
+ if app.pools[configGroupName][configBlockName].Size() != currentPoolSize {
+ logger.Error().Msg(
+ "The pool size is incorrect, either because " +
+ "the clients cannot connect due to no network connectivity " +
+ "or the server is not running. exiting...")
+ app.pluginRegistry.Shutdown()
+ return errors.New("failed to initialize pool, please check the configuration")
+ }
+
+ // Run the OnNewPool hook.
+ pluginTimeoutCtx, cancel := context.WithTimeout(
+ context.Background(), app.conf.Plugin.Timeout)
+ defer cancel()
+
+ result, err := app.pluginRegistry.Run( //nolint:contextcheck
+ pluginTimeoutCtx,
+ map[string]any{"name": configBlockName, "size": currentPoolSize},
+ v1.HookName_HOOK_NAME_ON_NEW_POOL)
+ if err != nil {
+ logger.Error().Err(err).Msg("Failed to run OnNewPool hooks")
+ span.RecordError(err)
+ }
+ if result != nil {
+ _ = app.pluginRegistry.ActRegistry.RunAll(result) //nolint:contextcheck
+ }
+ }
+ }
+
+ return nil
+}
+
+// createProxies creates proxies.
+func (app *GatewayDApp) createProxies(runCtx context.Context, span trace.Span) {
+ // Create and initialize prefork proxies with each pool of clients.
+ for configGroupName, configGroup := range app.conf.Global.Proxies {
+ for configBlockName, cfg := range configGroup {
+ logger := app.loggers[configGroupName]
+ clientConfig := app.clients[configGroupName][configBlockName]
+
+ // Fill the missing and zero value with the default one.
+ cfg.HealthCheckPeriod = config.If(
+ cfg.HealthCheckPeriod > 0,
+ cfg.HealthCheckPeriod,
+ config.DefaultHealthCheckPeriod,
+ )
+
+ if _, ok := app.proxies[configGroupName]; !ok {
+ app.proxies[configGroupName] = make(map[string]*network.Proxy)
+ }
+
+ app.proxies[configGroupName][configBlockName] = network.NewProxy(
+ runCtx,
+ network.Proxy{
+ GroupName: configGroupName,
+ BlockName: configBlockName,
+ AvailableConnections: app.pools[configGroupName][configBlockName],
+ PluginRegistry: app.pluginRegistry,
+ HealthCheckPeriod: cfg.HealthCheckPeriod,
+ ClientConfig: clientConfig,
+ Logger: logger,
+ PluginTimeout: app.conf.Plugin.Timeout,
+ },
+ )
+
+ span.AddEvent("Create proxy", trace.WithAttributes(
+ attribute.String("name", configBlockName),
+ attribute.String("healthCheckPeriod", cfg.HealthCheckPeriod.String()),
+ ))
+
+ pluginTimeoutCtx, cancel := context.WithTimeout(
+ context.Background(), app.conf.Plugin.Timeout)
+ defer cancel()
+
+ if data, ok := app.conf.GlobalKoanf.Get("proxies").(map[string]any); ok {
+ result, err := app.pluginRegistry.Run( //nolint:contextcheck
+ pluginTimeoutCtx, data, v1.HookName_HOOK_NAME_ON_NEW_PROXY)
+ if err != nil {
+ logger.Error().Err(err).Msg("Failed to run OnNewProxy hooks")
+ span.RecordError(err)
+ }
+ if result != nil {
+ _ = app.pluginRegistry.ActRegistry.RunAll(result) //nolint:contextcheck
+ }
+ } else {
+ logger.Error().Msg("Failed to get proxy from config")
+ }
+ }
+ }
+}
+
+// createServers creates servers.
+func (app *GatewayDApp) createServers(
+ runCtx context.Context, span trace.Span, raftNode *raft.Node,
+) {
+ // Create and initialize servers.
+ for name, cfg := range app.conf.Global.Servers {
+ logger := app.loggers[name]
+
+ var serverProxies []network.IProxy
+ for _, proxy := range app.proxies[name] {
+ serverProxies = append(serverProxies, proxy)
+ }
+
+ app.servers[name] = network.NewServer(
+ runCtx,
+ network.Server{
+ GroupName: name,
+ Network: cfg.Network,
+ Address: cfg.Address,
+ TickInterval: config.If(
+ cfg.TickInterval > 0,
+ cfg.TickInterval,
+ config.DefaultTickInterval,
+ ),
+ Options: network.Option{
+ // Can be used to send keepalive messages to the client.
+ EnableTicker: cfg.EnableTicker,
+ },
+ Proxies: serverProxies,
+ Logger: logger,
+ PluginRegistry: app.pluginRegistry,
+ PluginTimeout: app.conf.Plugin.Timeout,
+ EnableTLS: cfg.EnableTLS,
+ CertFile: cfg.CertFile,
+ KeyFile: cfg.KeyFile,
+ HandshakeTimeout: cfg.HandshakeTimeout,
+ LoadbalancerStrategyName: cfg.LoadBalancer.Strategy,
+ LoadbalancerRules: cfg.LoadBalancer.LoadBalancingRules,
+ LoadbalancerConsistentHash: cfg.LoadBalancer.ConsistentHash,
+ RaftNode: raftNode,
+ },
+ )
+
+ span.AddEvent("Create server", trace.WithAttributes(
+ attribute.String("name", name),
+ attribute.String("network", cfg.Network),
+ attribute.String("address", cfg.Address),
+ attribute.String("tickInterval", cfg.TickInterval.String()),
+ attribute.String("pluginTimeout", app.conf.Plugin.Timeout.String()),
+ attribute.Bool("enableTLS", cfg.EnableTLS),
+ attribute.String("certFile", cfg.CertFile),
+ attribute.String("keyFile", cfg.KeyFile),
+ attribute.String("handshakeTimeout", cfg.HandshakeTimeout.String()),
+ ))
+
+ pluginTimeoutCtx, cancel := context.WithTimeout(
+ context.Background(), app.conf.Plugin.Timeout)
+ defer cancel()
+
+ if data, ok := app.conf.GlobalKoanf.Get("servers").(map[string]any); ok {
+ result, err := app.pluginRegistry.Run( //nolint:contextcheck
+ pluginTimeoutCtx, data, v1.HookName_HOOK_NAME_ON_NEW_SERVER)
+ if err != nil {
+ logger.Error().Err(err).Msg("Failed to run OnNewServer hooks")
+ span.RecordError(err)
+ }
+ if result != nil {
+ _ = app.pluginRegistry.ActRegistry.RunAll(result) //nolint:contextcheck
+ }
+ } else {
+ logger.Error().Msg("Failed to get the servers configuration")
+ }
+ }
+}
+
+// startAPIServers starts the API servers.
+func (app *GatewayDApp) startAPIServers(
+ runCtx context.Context, logger zerolog.Logger, raftNode *raft.Node,
+) {
+ // Start the HTTP and gRPC APIs.
+ if !app.conf.Global.API.Enabled {
+ logger.Info().Msg("API is not enabled, skipping")
+ return
+ }
+
+ apiOptions := api.Options{
+ Logger: logger,
+ GRPCNetwork: app.conf.Global.API.GRPCNetwork,
+ GRPCAddress: app.conf.Global.API.GRPCAddress,
+ HTTPAddress: app.conf.Global.API.HTTPAddress,
+ Servers: app.servers,
+ RaftNode: raftNode,
+ }
+
+ apiObj := &api.API{
+ Options: &apiOptions,
+ Config: app.conf,
+ PluginRegistry: app.pluginRegistry,
+ Pools: app.pools,
+ Proxies: app.proxies,
+ Servers: app.servers,
+ }
+ app.grpcServer = api.NewGRPCServer(
+ runCtx,
+ api.GRPCServer{
+ API: apiObj,
+ HealthChecker: &api.HealthChecker{Servers: app.servers},
+ },
+ )
+ if app.grpcServer != nil {
+ go app.grpcServer.Start()
+ logger.Info().Str("address", apiOptions.HTTPAddress).Msg("Started the HTTP API")
+
+ app.httpServer = api.NewHTTPServer(&apiOptions) //nolint:contextcheck
+ go app.httpServer.Start()
+
+ logger.Info().Fields(
+ map[string]any{
+ "network": apiOptions.GRPCNetwork,
+ "address": apiOptions.GRPCAddress,
+ },
+ ).Msg("Started the gRPC Server")
+ }
+}
+
+// reportUsage reports usage statistics.
+func (app *GatewayDApp) reportUsage(logger zerolog.Logger) {
+ if !app.EnableUsageReport {
+ logger.Info().Msg("Usage reporting is not enabled, skipping")
+ return
+ }
+
+ // Report usage statistics.
+ go func() {
+ conn, err := grpc.NewClient(
+ UsageReportURL,
+ grpc.WithTransportCredentials(
+ credentials.NewTLS(
+ &tls.Config{
+ MinVersion: tls.VersionTLS12,
+ },
+ ),
+ ),
+ )
+ if err != nil {
+ logger.Trace().Err(err).Msg(
+ "Failed to dial to the gRPC server for usage reporting")
+ }
+ defer func(conn *grpc.ClientConn) {
+ err := conn.Close()
+ if err != nil {
+ logger.Trace().Err(err).Msg("Failed to close the connection to the usage report service")
+ }
+ }(conn)
+
+ client := usage.NewUsageReportServiceClient(conn)
+ report := usage.UsageReportRequest{
+ Version: config.Version,
+ RuntimeVersion: runtime.Version(),
+ Goos: runtime.GOOS,
+ Goarch: runtime.GOARCH,
+ Service: "gatewayd",
+ DevMode: app.DevMode,
+ Plugins: []*usage.Plugin{},
+ }
+ app.pluginRegistry.ForEach(
+ func(identifier sdkPlugin.Identifier, _ *plugin.Plugin) {
+ report.Plugins = append(report.GetPlugins(), &usage.Plugin{
+ Name: identifier.Name,
+ Version: identifier.Version,
+ Checksum: identifier.Checksum,
+ })
+ },
+ )
+ _, err = client.Report(context.Background(), &report)
+ if err != nil {
+ logger.Trace().Err(err).Msg("Failed to report usage statistics")
+ }
+ }()
+}
+
+// startServers starts the servers.
+func (app *GatewayDApp) startServers(
+ runCtx context.Context, span trace.Span,
+) {
+ // Start the server.
+ for name, server := range app.servers {
+ logger := app.loggers[name]
+ go func(
+ span trace.Span,
+ server *network.Server,
+ logger zerolog.Logger,
+ ) {
+ span.AddEvent("Start server")
+ if err := server.Run(); err != nil { //nolint:contextcheck
+ logger.Error().Err(err).Msg("Failed to start server")
+ span.RecordError(err)
+ app.stopGracefully(runCtx, nil)
+ os.Exit(gerr.FailedToStartServer)
+ }
+ }(span, server, logger)
+ }
+}
+
+// stopGracefully stops the server gracefully.
+func (app *GatewayDApp) stopGracefully(runCtx context.Context, sig os.Signal) {
+ // Only allow one call to this function.
+ if !app.ranStopGracefully.CompareAndSwap(false, true) {
+ return
+ }
+
+ _, span := otel.Tracer(config.TracerName).Start(runCtx, "Shutdown server")
+ currentSignal := "unknown"
+ if sig != nil {
+ currentSignal = sig.String()
+ }
+
+ logger := app.loggers[config.Default]
+
+ logger.Info().Msg("Notifying the plugins that the server is shutting down")
+ if app.pluginRegistry != nil {
+ pluginTimeoutCtx, cancel := context.WithTimeout(
+ context.Background(), app.conf.Plugin.Timeout)
+ defer cancel()
+
+ //nolint:contextcheck
+ result, err := app.pluginRegistry.Run(
+ pluginTimeoutCtx,
+ map[string]any{"signal": currentSignal},
+ v1.HookName_HOOK_NAME_ON_SIGNAL,
+ )
+ if err != nil {
+ logger.Error().Err(err).Msg("Failed to run OnSignal hooks")
+ span.RecordError(err)
+ }
+ if result != nil {
+ _ = app.pluginRegistry.ActRegistry.RunAll(result) //nolint:contextcheck
+ }
+ }
+
+ logger.Info().Msg("GatewayD is shutting down")
+ span.AddEvent("GatewayD is shutting down", trace.WithAttributes(
+ attribute.String("signal", currentSignal),
+ ))
+ if app.healthCheckScheduler != nil {
+ app.healthCheckScheduler.Stop()
+ app.healthCheckScheduler.Clear()
+ logger.Info().Msg("Stopped health check scheduler")
+ span.AddEvent("Stopped health check scheduler")
+ }
+ if app.metricsMerger != nil {
+ app.metricsMerger.Stop()
+ logger.Info().Msg("Stopped metrics merger")
+ span.AddEvent("Stopped metrics merger")
+ }
+ if app.metricsServer != nil {
+ //nolint:contextcheck
+ if err := app.metricsServer.Shutdown(context.Background()); err != nil {
+ logger.Error().Err(err).Msg("Failed to stop metrics server")
+ span.RecordError(err)
+ } else {
+ logger.Info().Msg("Stopped metrics server")
+ span.AddEvent("Stopped metrics server")
+ }
+ }
+
+ properlyInitialized := false
+ for name, server := range app.servers {
+ if server.IsRunning() {
+ logger.Info().Str("name", name).Msg("Stopping server")
+ server.Shutdown()
+ span.AddEvent("Stopped server")
+ properlyInitialized = true
+ }
+ }
+ logger.Info().Msg("Stopped all servers")
+ span.AddEvent("Stopped all servers")
+
+ if app.pluginRegistry != nil {
+ app.pluginRegistry.Shutdown()
+ logger.Info().Msg("Stopped plugin registry")
+ span.AddEvent("Stopped plugin registry")
+ }
+
+ if app.httpServer != nil {
+ app.httpServer.Shutdown(runCtx)
+ logger.Info().Msg("Stopped HTTP Server")
+ span.AddEvent("Stopped HTTP Server")
+ }
+
+ if app.grpcServer != nil {
+ app.grpcServer.Shutdown(nil) //nolint:staticcheck,contextcheck
+ logger.Info().Msg("Stopped gRPC Server")
+ span.AddEvent("Stopped gRPC Server")
+ }
+
+ logger.Info().Msg("GatewayD is shutdown")
+ span.AddEvent("GatewayD is shutdown")
+ span.End()
+
+ // If the server was properly initialized, the stop channel would have been
+ // listened on by the run command. So, we must close the stop channel to stop
+ // the app. Otherwise, the app is shutting down abruptly, so we don't need to
+ // close the stop channel.
+ if properlyInitialized {
+ app.stopChan <- struct{}{}
+ close(app.stopChan)
+ }
+}
+
+// handleSignals handles the signals and stops the server gracefully.
+func (app *GatewayDApp) handleSignals(runCtx context.Context, signals []os.Signal) {
+ signalsCh := make(chan os.Signal, 1)
+ signal.Notify(signalsCh, signals...)
+
+ go func() {
+ for sig := range signalsCh {
+ app.stopGracefully(runCtx, sig)
+ os.Exit(0)
+ }
+ }()
+}
diff --git a/cmd/plugin_init.go b/cmd/plugin_init.go
index 54b829c5..ef841045 100644
--- a/cmd/plugin_init.go
+++ b/cmd/plugin_init.go
@@ -11,6 +11,10 @@ var pluginInitCmd = &cobra.Command{
Use: "init",
Short: "Create or overwrite the GatewayD plugins config",
Run: func(cmd *cobra.Command, _ []string) {
+ force, _ := cmd.Flags().GetBool("force")
+ enableSentry, _ := cmd.Flags().GetBool("sentry")
+ pluginConfigFile, _ := cmd.Flags().GetString("plugin-config")
+
// Enable Sentry.
if enableSentry {
// Initialize Sentry.
@@ -37,12 +41,10 @@ var pluginInitCmd = &cobra.Command{
func init() {
pluginCmd.AddCommand(pluginInitCmd)
- pluginInitCmd.Flags().BoolVarP(
- &force, "force", "f", false, "Force overwrite of existing config file")
- pluginInitCmd.Flags().StringVarP(
- &pluginConfigFile, // Already exists in run.go
+ pluginInitCmd.Flags().BoolP(
+ "force", "f", false, "Force overwrite of existing config file")
+ pluginInitCmd.Flags().StringP(
"plugin-config", "p", config.GetDefaultConfigFilePath(config.PluginsConfigFilename),
"Plugin config file")
- pluginInitCmd.Flags().BoolVar(
- &enableSentry, "sentry", true, "Enable Sentry") // Already exists in run.go
+ pluginInitCmd.Flags().Bool("sentry", true, "Enable Sentry")
}
diff --git a/cmd/plugin_init_test.go b/cmd/plugin_init_test.go
index 2d9abb32..7b3217bd 100644
--- a/cmd/plugin_init_test.go
+++ b/cmd/plugin_init_test.go
@@ -18,7 +18,8 @@ func Test_pluginInitCmd(t *testing.T) {
fmt.Sprintf("Config file '%s' was created successfully.", pluginTestConfigFile),
output,
"plugin init command should have returned the correct output")
- assert.FileExists(t, pluginTestConfigFile, "plugin init command should have created a config file")
+ assert.FileExists(
+ t, pluginTestConfigFile, "plugin init command should have created a config file")
// Clean up.
err = os.Remove(pluginTestConfigFile)
diff --git a/cmd/plugin_install.go b/cmd/plugin_install.go
index e7015558..e6770601 100644
--- a/cmd/plugin_install.go
+++ b/cmd/plugin_install.go
@@ -58,24 +58,24 @@ const (
Plugins configFileType = "plugins"
)
-var (
- pluginOutputDir string
- pullOnly bool
- cleanup bool
- update bool
- backupConfig bool
- noPrompt bool
- pluginName string
- overwriteConfig bool
- skipPathSlipVerification bool
-)
-
// pluginInstallCmd represents the plugin install command.
var pluginInstallCmd = &cobra.Command{
Use: "install",
Short: "Install a plugin from a local archive or a GitHub repository",
Example: " gatewayd plugin install ", //nolint:lll
Run: func(cmd *cobra.Command, args []string) {
+ enableSentry, _ := cmd.Flags().GetBool("sentry")
+ pluginConfigFile, _ := cmd.Flags().GetString("plugin-config")
+ pluginOutputDir, _ := cmd.Flags().GetString("output-dir")
+ pullOnly, _ := cmd.Flags().GetBool("pull-only")
+ cleanup, _ := cmd.Flags().GetBool("cleanup")
+ noPrompt, _ := cmd.Flags().GetBool("no-prompt")
+ update, _ := cmd.Flags().GetBool("update")
+ backupConfig, _ := cmd.Flags().GetBool("backup")
+ overwriteConfig, _ := cmd.Flags().GetBool("overwrite-config")
+ pluginName, _ := cmd.Flags().GetString("name")
+ skipPathSlipVerification, _ := cmd.Flags().GetBool("skip-path-slip-verification")
+
// Enable Sentry.
if enableSentry {
// Initialize Sentry.
@@ -99,7 +99,20 @@ var pluginInstallCmd = &cobra.Command{
case LocationArgs:
// Install the plugin from the CLI argument.
cmd.Println("Installing plugin from CLI argument")
- installPlugin(cmd, args[0])
+ installPlugin(
+ cmd,
+ args[0],
+ pluginOutputDir,
+ pullOnly,
+ cleanup,
+ noPrompt,
+ update,
+ backupConfig,
+ overwriteConfig,
+ skipPathSlipVerification,
+ pluginConfigFile,
+ pluginName,
+ )
case LocationConfig:
// Read the gatewayd_plugins.yaml file.
pluginsConfig, err := os.ReadFile(pluginConfigFile)
@@ -197,7 +210,20 @@ var pluginInstallCmd = &cobra.Command{
// Install all the plugins from the plugins configuration file.
cmd.Println("Installing plugins from plugins configuration file")
for _, pluginURL := range pluginURLs {
- installPlugin(cmd, pluginURL)
+ installPlugin(
+ cmd,
+ pluginURL,
+ pluginOutputDir,
+ pullOnly,
+ cleanup,
+ noPrompt,
+ update,
+ backupConfig,
+ overwriteConfig,
+ skipPathSlipVerification,
+ pluginConfigFile,
+ pluginName,
+ )
}
default:
cmd.Println("Invalid plugin URL or file path")
@@ -208,35 +234,36 @@ var pluginInstallCmd = &cobra.Command{
func init() {
pluginCmd.AddCommand(pluginInstallCmd)
- pluginInstallCmd.Flags().StringVarP(
- &pluginConfigFile, // Already exists in run.go
+ pluginInstallCmd.Flags().StringP(
"plugin-config", "p", config.GetDefaultConfigFilePath(config.PluginsConfigFilename),
"Plugin config file")
- pluginInstallCmd.Flags().StringVarP(
- &pluginOutputDir, "output-dir", "o", "./plugins", "Output directory for the plugin")
- pluginInstallCmd.Flags().BoolVar(
- &pullOnly, "pull-only", false, "Only pull the plugin, don't install it")
- pluginInstallCmd.Flags().BoolVar(
- &cleanup, "cleanup", true,
+ pluginInstallCmd.Flags().StringP(
+ "output-dir", "o", "./plugins", "Output directory for the plugin")
+ pluginInstallCmd.Flags().Bool(
+ "pull-only", false, "Only pull the plugin, don't install it")
+ pluginInstallCmd.Flags().Bool(
+ "cleanup", true,
"Delete downloaded and extracted files after installing the plugin (except the plugin binary)")
- pluginInstallCmd.Flags().BoolVar(
- &noPrompt, "no-prompt", true, "Do not prompt for user input")
- pluginInstallCmd.Flags().BoolVar(
- &update, "update", false, "Update the plugin if it already exists")
- pluginInstallCmd.Flags().BoolVar(
- &backupConfig, "backup", false, "Backup the plugins configuration file before installing the plugin")
- pluginInstallCmd.Flags().StringVarP(
- &pluginName, "name", "n", "", "Name of the plugin (only for installing from archive files)")
- pluginInstallCmd.Flags().BoolVar(
- &overwriteConfig, "overwrite-config", true, "Overwrite the existing plugins configuration file (overrides --update, only used for installing from the plugins configuration file)") //nolint:lll
- pluginInstallCmd.Flags().BoolVar(
- &skipPathSlipVerification, "skip-path-slip-verification", false, "Skip path slip verification when extracting the plugin archive from a TRUSTED source") //nolint:lll
- pluginInstallCmd.Flags().BoolVar(
- &enableSentry, "sentry", true, "Enable Sentry") // Already exists in run.go
+ pluginInstallCmd.Flags().Bool(
+ "no-prompt", true, "Do not prompt for user input")
+ pluginInstallCmd.Flags().Bool(
+ "update", false, "Update the plugin if it already exists")
+ pluginInstallCmd.Flags().Bool(
+ "backup", false, "Backup the plugins configuration file before installing the plugin")
+ pluginInstallCmd.Flags().StringP(
+ "name", "n", "", "Name of the plugin (only for installing from archive files)")
+ pluginInstallCmd.Flags().Bool(
+ "overwrite-config", true, "Overwrite the existing plugins configuration file (overrides --update, only used for installing from the plugins configuration file)") //nolint:lll
+ pluginInstallCmd.Flags().Bool(
+ "skip-path-slip-verification", false, "Skip path slip verification when extracting the plugin archive from a TRUSTED source") //nolint:lll
+ pluginInstallCmd.Flags().Bool(
+ "sentry", true, "Enable Sentry")
}
// extractZip extracts the files from a zip archive.
-func extractZip(filename, dest string) ([]string, *gerr.GatewayDError) {
+func extractZip(
+ filename, dest string, skipPathSlipVerification bool,
+) ([]string, *gerr.GatewayDError) {
// Open and extract the zip file.
zipRc, err := zip.OpenReader(filename)
if err != nil {
@@ -317,7 +344,9 @@ func extractZip(filename, dest string) ([]string, *gerr.GatewayDError) {
}
// extractTarGz extracts the files from a tar.gz archive.
-func extractTarGz(filename, dest string) ([]string, *gerr.GatewayDError) {
+func extractTarGz(
+ filename, dest string, skipPathSlipVerification bool,
+) ([]string, *gerr.GatewayDError) {
// Open and extract the tar.gz file.
gzipStream, err := os.Open(filename)
if err != nil {
@@ -530,7 +559,20 @@ func getFileExtension() Extension {
}
// installPlugin installs a plugin from a given URL.
-func installPlugin(cmd *cobra.Command, pluginURL string) {
+func installPlugin(
+ cmd *cobra.Command,
+ pluginURL string,
+ pluginOutputDir string,
+ pullOnly bool,
+ cleanup bool,
+ noPrompt bool,
+ update bool,
+ backupConfig bool,
+ overwriteConfig bool,
+ skipPathSlipVerification bool,
+ pluginConfigFile string,
+ pluginName string,
+) {
var (
// This is a list of files that will be deleted after the plugin is installed.
toBeDeleted []string
@@ -725,8 +767,10 @@ func installPlugin(cmd *cobra.Command, pluginURL string) {
return
}
case SourceUnknown:
+ fallthrough
default:
cmd.Println("Invalid URL or file path")
+ return
}
// NOTE: The rest of the code is executed regardless of the source,
@@ -806,9 +850,9 @@ func installPlugin(cmd *cobra.Command, pluginURL string) {
var gErr *gerr.GatewayDError
switch archiveExt {
case ExtensionZip:
- filenames, gErr = extractZip(pluginFilename, pluginOutputDir)
+ filenames, gErr = extractZip(pluginFilename, pluginOutputDir, skipPathSlipVerification)
case ExtensionTarGz:
- filenames, gErr = extractTarGz(pluginFilename, pluginOutputDir)
+ filenames, gErr = extractTarGz(pluginFilename, pluginOutputDir, skipPathSlipVerification)
default:
cmd.Println("Invalid archive extension")
return
diff --git a/cmd/plugin_install_test.go b/cmd/plugin_install_test.go
index e3364c49..cb6e7871 100644
--- a/cmd/plugin_install_test.go
+++ b/cmd/plugin_install_test.go
@@ -3,6 +3,7 @@ package cmd
import (
"fmt"
"os"
+ "path/filepath"
"runtime"
"testing"
@@ -10,19 +11,22 @@ import (
"github.com/stretchr/testify/require"
)
-func Test_pluginInstallCmd(t *testing.T) {
- pluginTestConfigFile := "./test_plugins_pluginInstallCmd.yaml"
+func Test_pluginInstallCmdWithFile(t *testing.T) {
+ pluginTestConfigFile := "./test_plugins_pluginInstallCmdWithFile.yaml"
// Create a test plugin config file.
- output, err := executeCommandC(rootCmd, "plugin", "init", "-p", pluginTestConfigFile)
+ output, err := executeCommandC(
+ rootCmd, "plugin", "init", "-p", pluginTestConfigFile, "--force")
require.NoError(t, err, "plugin init should not return an error")
assert.Equal(t,
fmt.Sprintf("Config file '%s' was created successfully.", pluginTestConfigFile),
output,
"plugin init command should have returned the correct output")
- assert.FileExists(t, pluginTestConfigFile, "plugin init command should have created a config file")
+ assert.FileExists(
+ t, pluginTestConfigFile, "plugin init command should have created a config file")
// Pull the plugin archive and install it.
+ var pluginArchivePath string
pluginArchivePath, err = mustPullPlugin()
require.NoError(t, err, "mustPullPlugin should not return an error")
assert.FileExists(t, pluginArchivePath, "mustPullPlugin should have downloaded the plugin archive")
@@ -30,9 +34,12 @@ func Test_pluginInstallCmd(t *testing.T) {
// Test plugin install command.
output, err = executeCommandC(
rootCmd, "plugin", "install", "-p", pluginTestConfigFile,
- "--update", "--backup", "--name", "gatewayd-plugin-cache", pluginArchivePath)
+ "--backup", "--name=gatewayd-plugin-cache", pluginArchivePath)
require.NoError(t, err, "plugin install should not return an error")
- assert.Equal(t, output, "Installing plugin from CLI argument\nBackup completed successfully\nPlugin binary extracted to plugins/gatewayd-plugin-cache\nPlugin installed successfully\n") //nolint:lll
+ assert.Contains(t, output, "Installing plugin from CLI argument")
+ assert.Contains(t, output, "Backup completed successfully")
+ assert.Contains(t, output, "Plugin binary extracted to plugins/gatewayd-plugin-cache")
+ assert.Contains(t, output, "Plugin installed successfully")
// See if the plugin was actually installed.
output, err = executeCommandC(rootCmd, "plugin", "list", "-p", pluginTestConfigFile)
@@ -40,7 +47,7 @@ func Test_pluginInstallCmd(t *testing.T) {
assert.Contains(t, output, "Name: gatewayd-plugin-cache")
// Clean up.
- assert.FileExists(t, "plugins/gatewayd-plugin-cache")
+ assert.FileExists(t, "./plugins/gatewayd-plugin-cache")
assert.FileExists(t, pluginTestConfigFile+BackupFileExt)
assert.NoFileExists(t, "gatewayd-plugin-cache-linux-amd64-v0.2.4.tar.gz")
assert.NoFileExists(t, "checksums.txt")
@@ -57,28 +64,26 @@ func Test_pluginInstallCmd(t *testing.T) {
func Test_pluginInstallCmdAutomatedNoOverwrite(t *testing.T) {
pluginTestConfigFile := "./testdata/gatewayd_plugins.yaml"
- // Reset the global variable.
- pullOnly = false
-
- // Test plugin install command.
+ // Test plugin install command with overwrite disabled
output, err := executeCommandC(
- rootCmd, "plugin", "install",
- "-p", pluginTestConfigFile, "--update", "--backup", "--overwrite-config=false")
+ rootCmd, "plugin", "install", "-p", pluginTestConfigFile,
+ "--update", "--backup", "--overwrite-config=false", "--pull-only=false")
require.NoError(t, err, "plugin install should not return an error")
- assert.Contains(t, output, fmt.Sprintf("/gatewayd-plugin-cache-%s-%s-", runtime.GOOS, runtime.GOARCH))
- assert.Contains(t, output, "/checksums.txt")
+
+ // Verify expected output for no-overwrite case
+ assert.Contains(t, output, "Installing plugins from plugins configuration file")
+ assert.Contains(
+ t,
+ output,
+ fmt.Sprintf("gatewayd-plugin-cache-%s-%s-", runtime.GOOS, runtime.GOARCH))
+ assert.Contains(t, output, "checksums.txt")
assert.Contains(t, output, "Download completed successfully")
assert.Contains(t, output, "Checksum verification passed")
assert.Contains(t, output, "Plugin binary extracted to plugins/gatewayd-plugin-cache")
assert.Contains(t, output, "Plugin installed successfully")
- // See if the plugin was actually installed.
- output, err = executeCommandC(rootCmd, "plugin", "list", "-p", pluginTestConfigFile)
- require.NoError(t, err, "plugin list should not return an error")
- assert.Contains(t, output, "Name: gatewayd-plugin-cache")
-
// Clean up.
- assert.FileExists(t, "plugins/gatewayd-plugin-cache")
+ assert.FileExists(t, "./plugins/gatewayd-plugin-cache")
assert.FileExists(t, pluginTestConfigFile+BackupFileExt)
assert.NoFileExists(t, "plugins/LICENSE")
assert.NoFileExists(t, "plugins/README.md")
@@ -88,3 +93,38 @@ func Test_pluginInstallCmdAutomatedNoOverwrite(t *testing.T) {
require.NoError(t, os.RemoveAll("plugins/"))
require.NoError(t, os.Remove(pluginTestConfigFile+BackupFileExt))
}
+
+func Test_pluginInstallCmdPullOnly(t *testing.T) {
+ pwd, err := os.Getwd()
+ require.NoError(t, err, "os.Getwd should not return an error")
+
+ pluginTestConfigFile := "./testdata/gatewayd_plugins.yaml"
+
+ // Test plugin install command in pull-only mode
+ output, err := executeCommandC(
+ rootCmd, "plugin", "install", "-p", pluginTestConfigFile,
+ "--update", "--backup", "--pull-only", "--overwrite-config=false")
+ require.NoError(t, err, "plugin install should not return an error")
+
+ // Verify pull-only behavior
+ assert.Contains(t, output, "Installing plugins from plugins configuration file")
+ assert.Contains(t, output, fmt.Sprintf("gatewayd-plugin-cache-%s-%s-", runtime.GOOS, runtime.GOARCH))
+ assert.Contains(t, output, "Download completed successfully")
+
+ // Should not contain installation messages in pull-only mode
+ assert.NotContains(t, output, "Plugin binary extracted")
+ assert.NotContains(t, output, "Plugin installed successfully")
+
+ // Cleanup downloaded files
+ pattern := filepath.Join(pwd, fmt.Sprintf("gatewayd-plugin-cache-%s-%s-*.tar.gz", runtime.GOOS, runtime.GOARCH))
+ matches, err := filepath.Glob(pattern)
+ require.NoError(t, err, "failed to glob plugin archives")
+ for _, match := range matches {
+ require.NoError(t, os.Remove(match))
+ }
+
+ checksumFile := filepath.Join(pwd, "checksums.txt")
+ if _, err := os.Stat(checksumFile); err == nil {
+ require.NoError(t, os.Remove(checksumFile))
+ }
+}
diff --git a/cmd/plugin_lint.go b/cmd/plugin_lint.go
index 13508b3a..6e317068 100644
--- a/cmd/plugin_lint.go
+++ b/cmd/plugin_lint.go
@@ -14,6 +14,9 @@ var pluginLintCmd = &cobra.Command{
Use: "lint",
Short: "Lint the GatewayD plugins config",
Run: func(cmd *cobra.Command, _ []string) {
+ enableSentry, _ := cmd.Flags().GetBool("sentry")
+ pluginConfigFile, _ := cmd.Flags().GetString("plugin-config")
+
// Enable Sentry.
if enableSentry {
// Initialize Sentry.
@@ -44,10 +47,9 @@ var pluginLintCmd = &cobra.Command{
func init() {
pluginCmd.AddCommand(pluginLintCmd)
- pluginLintCmd.Flags().StringVarP(
- &pluginConfigFile, // Already exists in run.go
+ pluginLintCmd.Flags().StringP(
"plugin-config", "p", config.GetDefaultConfigFilePath(config.PluginsConfigFilename),
"Plugin config file")
- pluginLintCmd.Flags().BoolVar(
- &enableSentry, "sentry", true, "Enable Sentry") // Already exists in run.go
+ pluginLintCmd.Flags().BoolP(
+ "sentry", "s", true, "Enable Sentry")
}
diff --git a/cmd/plugin_list.go b/cmd/plugin_list.go
index 000d9616..e64add2d 100644
--- a/cmd/plugin_list.go
+++ b/cmd/plugin_list.go
@@ -9,13 +9,15 @@ import (
"github.com/spf13/cobra"
)
-var onlyEnabled bool
-
// pluginListCmd represents the plugin list command.
var pluginListCmd = &cobra.Command{
Use: "list",
Short: "List the GatewayD plugins",
Run: func(cmd *cobra.Command, _ []string) {
+ pluginConfigFile, _ := cmd.Flags().GetString("plugin-config")
+ onlyEnabled, _ := cmd.Flags().GetBool("only-enabled")
+ enableSentry, _ := cmd.Flags().GetBool("sentry")
+
// Enable Sentry.
if enableSentry {
// Initialize Sentry.
@@ -42,30 +44,27 @@ var pluginListCmd = &cobra.Command{
func init() {
pluginCmd.AddCommand(pluginListCmd)
- pluginListCmd.Flags().StringVarP(
- &pluginConfigFile, // Already exists in run.go
+ pluginListCmd.Flags().StringP(
"plugin-config", "p", config.GetDefaultConfigFilePath(config.PluginsConfigFilename),
"Plugin config file")
- pluginListCmd.Flags().BoolVarP(
- &onlyEnabled,
+ pluginListCmd.Flags().BoolP(
"only-enabled", "e",
false, "Only list enabled plugins")
- pluginListCmd.Flags().BoolVar(
- &enableSentry, "sentry", true, "Enable Sentry") // Already exists in run.go
+ pluginListCmd.Flags().BoolP("sentry", "s", true, "Enable Sentry")
}
func listPlugins(cmd *cobra.Command, pluginConfigFile string, onlyEnabled bool) {
// Load the plugin config file.
- conf := config.NewConfig(context.TODO(), config.Config{PluginConfigFile: pluginConfigFile})
- if err := conf.LoadDefaults(context.TODO()); err != nil {
+ conf := config.NewConfig(context.Background(), config.Config{PluginConfigFile: pluginConfigFile})
+ if err := conf.LoadDefaults(context.Background()); err != nil {
cmd.PrintErr(err)
return
}
- if err := conf.LoadPluginConfigFile(context.TODO()); err != nil {
+ if err := conf.LoadPluginConfigFile(context.Background()); err != nil {
cmd.PrintErr(err)
return
}
- if err := conf.UnmarshalPluginConfig(context.TODO()); err != nil {
+ if err := conf.UnmarshalPluginConfig(context.Background()); err != nil {
cmd.PrintErr(err)
return
}
diff --git a/cmd/plugin_list_test.go b/cmd/plugin_list_test.go
index c96342ba..2492858d 100644
--- a/cmd/plugin_list_test.go
+++ b/cmd/plugin_list_test.go
@@ -11,6 +11,7 @@ import (
func Test_pluginListCmd(t *testing.T) {
pluginTestConfigFile := "./test_plugins_pluginListCmd.yaml"
+
// Test plugin list command.
output, err := executeCommandC(rootCmd, "plugin", "init", "-p", pluginTestConfigFile)
require.NoError(t, err, "plugin init command should not have returned an error")
@@ -36,6 +37,7 @@ func Test_pluginListCmdWithPlugins(t *testing.T) {
// Test plugin list command.
// Read the plugin config file from the root directory.
pluginTestConfigFile := "../gatewayd_plugins.yaml"
+
output, err := executeCommandC(rootCmd, "plugin", "list", "-p", pluginTestConfigFile)
require.NoError(t, err, "plugin list command should not have returned an error")
assert.Contains(t,
diff --git a/cmd/plugin_scaffold.go b/cmd/plugin_scaffold.go
index 4a4da902..4ddad37a 100644
--- a/cmd/plugin_scaffold.go
+++ b/cmd/plugin_scaffold.go
@@ -5,16 +5,14 @@ import (
"github.com/spf13/cobra"
)
-var (
- pluginScaffoldInputFile string
- pluginScaffoldOutputDir string
-)
-
// pluginScaffoldCmd represents the scaffold command.
var pluginScaffoldCmd = &cobra.Command{
Use: "scaffold",
Short: "Scaffold a plugin and store the files into a directory",
Run: func(cmd *cobra.Command, _ []string) {
+ pluginScaffoldInputFile := cmd.Flag("input-file").Value.String()
+ pluginScaffoldOutputDir := cmd.Flag("output-dir").Value.String()
+
createdFiles, err := plugin.Scaffold(pluginScaffoldInputFile, pluginScaffoldOutputDir)
if err != nil {
cmd.Println("Scaffold failed: ", err)
@@ -31,12 +29,10 @@ var pluginScaffoldCmd = &cobra.Command{
func init() {
pluginCmd.AddCommand(pluginScaffoldCmd)
- pluginScaffoldCmd.Flags().StringVarP(
- &pluginScaffoldInputFile,
+ pluginScaffoldCmd.Flags().StringP(
"input-file", "i", "input.yaml",
"Plugin scaffold input file")
- pluginScaffoldCmd.Flags().StringVarP(
- &pluginScaffoldOutputDir,
+ pluginScaffoldCmd.Flags().StringP(
"output-dir", "o", "./plugins",
"Output directory for the scaffold")
}
diff --git a/cmd/plugin_scaffold_test.go b/cmd/plugin_scaffold_test.go
index 7eb6b246..a31b1ef3 100644
--- a/cmd/plugin_scaffold_test.go
+++ b/cmd/plugin_scaffold_test.go
@@ -9,7 +9,6 @@ import (
"time"
"github.com/codingsince1985/checksum"
- "github.com/gatewayd-io/gatewayd/config"
"github.com/gatewayd-io/gatewayd/plugin"
"github.com/gatewayd-io/gatewayd/testhelpers"
"github.com/spf13/cast"
@@ -19,6 +18,12 @@ import (
)
func Test_pluginScaffoldCmd(t *testing.T) {
+ previous := EnableTestMode()
+ defer func() {
+ testMode = previous
+ testApp = nil
+ }()
+
// Start the test containers.
ctx := context.Background()
postgresHostIP1, postgresMappedPort1 := testhelpers.SetupPostgreSQLTestContainer(ctx, t)
@@ -87,11 +92,9 @@ func Test_pluginScaffoldCmd(t *testing.T) {
pluginTestConfigFile := filepath.Join(pluginDir, "gatewayd_plugins.yaml")
- stopChan = make(chan struct{})
-
var waitGroup sync.WaitGroup
- waitGroup.Add(1)
+ waitGroup.Add(2)
go func(waitGroup *sync.WaitGroup) {
// Test run command.
output, err := executeCommandC(
@@ -107,23 +110,9 @@ func Test_pluginScaffoldCmd(t *testing.T) {
waitGroup.Done()
}(&waitGroup)
- waitGroup.Add(1)
go func(waitGroup *sync.WaitGroup) {
time.Sleep(waitBeforeStop)
-
- StopGracefully(
- context.Background(),
- nil,
- nil,
- metricsServer,
- nil,
- loggers[config.Default],
- servers,
- stopChan,
- nil,
- nil,
- )
-
+ testApp.stopGracefully(context.Background(), os.Interrupt)
waitGroup.Done()
}(&waitGroup)
diff --git a/cmd/run.go b/cmd/run.go
index 51e61241..a9798415 100644
--- a/cmd/run.go
+++ b/cmd/run.go
@@ -2,181 +2,31 @@ package cmd
import (
"context"
- "crypto/tls"
- "errors"
- "fmt"
- "io"
"log"
- "net/http"
- "net/url"
"os"
- "os/signal"
- "runtime"
- "strconv"
"syscall"
- "time"
- "github.com/NYTimes/gziphandler"
- sdkAct "github.com/gatewayd-io/gatewayd-plugin-sdk/act"
- sdkPlugin "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin"
- v1 "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin/v1"
- "github.com/gatewayd-io/gatewayd/act"
- "github.com/gatewayd-io/gatewayd/api"
"github.com/gatewayd-io/gatewayd/config"
gerr "github.com/gatewayd-io/gatewayd/errors"
- "github.com/gatewayd-io/gatewayd/logging"
- "github.com/gatewayd-io/gatewayd/metrics"
- "github.com/gatewayd-io/gatewayd/network"
- "github.com/gatewayd-io/gatewayd/plugin"
- "github.com/gatewayd-io/gatewayd/pool"
"github.com/gatewayd-io/gatewayd/raft"
"github.com/gatewayd-io/gatewayd/tracing"
- usage "github.com/gatewayd-io/gatewayd/usagereport/v1"
"github.com/getsentry/sentry-go"
- "github.com/go-co-op/gocron"
- "github.com/prometheus/client_golang/prometheus"
- "github.com/prometheus/client_golang/prometheus/promhttp"
- "github.com/redis/go-redis/v9"
- "github.com/rs/zerolog"
"github.com/spf13/cobra"
"go.opentelemetry.io/otel"
- "go.opentelemetry.io/otel/attribute"
- "go.opentelemetry.io/otel/trace"
"golang.org/x/exp/maps"
- "google.golang.org/grpc"
- "google.golang.org/grpc/credentials"
)
-var _ io.Writer = &cobraCmdWriter{}
-
-type cobraCmdWriter struct {
- *cobra.Command
-}
-
-func (c *cobraCmdWriter) Write(p []byte) (int, error) {
- c.Print(string(p))
- return len(p), nil
-}
-
-// TODO: Get rid of the global variables.
-// https://github.com/gatewayd-io/gatewayd/issues/324
var (
- enableTracing bool
- enableLinting bool
- collectorURL string
- enableSentry bool
- devMode bool
- enableUsageReport bool
- pluginConfigFile string
- globalConfigFile string
- conf *config.Config
- pluginRegistry *plugin.Registry
- actRegistry *act.Registry
- metricsServer *http.Server
-
- UsageReportURL = "localhost:59091"
-
- loggers = make(map[string]zerolog.Logger)
- pools = make(map[string]map[string]*pool.Pool)
- clients = make(map[string]map[string]*config.Client)
- proxies = make(map[string]map[string]*network.Proxy)
- servers = make(map[string]*network.Server)
- healthCheckScheduler = gocron.NewScheduler(time.UTC)
-
- stopChan = make(chan struct{})
+ testMode bool
+ testApp *GatewayDApp
)
-func StopGracefully(
- runCtx context.Context,
- sig os.Signal,
- metricsMerger *metrics.Merger,
- metricsServer *http.Server,
- pluginRegistry *plugin.Registry,
- logger zerolog.Logger,
- servers map[string]*network.Server,
- stopChan chan struct{},
- httpServer *api.HTTPServer,
- grpcServer *api.GRPCServer,
-) {
- _, span := otel.Tracer(config.TracerName).Start(runCtx, "Shutdown server")
- currentSignal := "unknown"
- if sig != nil {
- currentSignal = sig.String()
- }
-
- logger.Info().Msg("Notifying the plugins that the server is shutting down")
- if pluginRegistry != nil {
- pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), conf.Plugin.Timeout)
- defer cancel()
-
- //nolint:contextcheck
- result, err := pluginRegistry.Run(
- pluginTimeoutCtx,
- map[string]any{"signal": currentSignal},
- v1.HookName_HOOK_NAME_ON_SIGNAL,
- )
- if err != nil {
- logger.Error().Err(err).Msg("Failed to run OnSignal hooks")
- span.RecordError(err)
- }
- if result != nil {
- _ = pluginRegistry.ActRegistry.RunAll(result) //nolint:contextcheck
- }
- }
-
- logger.Info().Msg("GatewayD is shutting down")
- span.AddEvent("GatewayD is shutting down", trace.WithAttributes(
- attribute.String("signal", currentSignal),
- ))
- if healthCheckScheduler != nil {
- healthCheckScheduler.Stop()
- healthCheckScheduler.Clear()
- logger.Info().Msg("Stopped health check scheduler")
- span.AddEvent("Stopped health check scheduler")
- }
- if metricsMerger != nil {
- metricsMerger.Stop()
- logger.Info().Msg("Stopped metrics merger")
- span.AddEvent("Stopped metrics merger")
- }
- if metricsServer != nil {
- //nolint:contextcheck
- if err := metricsServer.Shutdown(context.Background()); err != nil {
- logger.Error().Err(err).Msg("Failed to stop metrics server")
- span.RecordError(err)
- } else {
- logger.Info().Msg("Stopped metrics server")
- span.AddEvent("Stopped metrics server")
- }
- }
- for name, server := range servers {
- logger.Info().Str("name", name).Msg("Stopping server")
- server.Shutdown()
- span.AddEvent("Stopped server")
- }
- logger.Info().Msg("Stopped all servers")
- if pluginRegistry != nil {
- pluginRegistry.Shutdown()
- logger.Info().Msg("Stopped plugin registry")
- span.AddEvent("Stopped plugin registry")
- }
- span.End()
-
- if httpServer != nil {
- httpServer.Shutdown(runCtx)
- logger.Info().Msg("Stopped HTTP Server")
- span.AddEvent("Stopped HTTP Server")
- }
-
- if grpcServer != nil {
- grpcServer.Shutdown(runCtx)
- logger.Info().Msg("Stopped gRPC Server")
- span.AddEvent("Stopped gRPC Server")
- }
-
- // Close the stop channel to notify the other goroutines to stop.
- stopChan <- struct{}{}
- close(stopChan)
+// EnableTestMode enables test mode and returns the previous value.
+// This should only be used in tests.
+func EnableTestMode() bool {
+ previous := testMode
+ testMode = true
+ return previous
}
// runCmd represents the run command.
@@ -184,23 +34,51 @@ var runCmd = &cobra.Command{
Use: "run",
Short: "Run a GatewayD instance",
Run: func(cmd *cobra.Command, _ []string) {
+ app := NewGatewayDApp(cmd)
+
+ // If test mode is enabled, we need to access the app instance from the test,
+ // so we can stop the server gracefully.
+ if testMode {
+ testApp = app
+ }
+
+ runCtx, span := otel.Tracer(config.TracerName).Start(context.Background(), "GatewayD")
+ span.End()
+
+ // Handle signals from the user.
+ app.handleSignals(
+ runCtx,
+ []os.Signal{
+ os.Interrupt,
+ os.Kill,
+ syscall.SIGTERM,
+ syscall.SIGABRT,
+ syscall.SIGQUIT,
+ syscall.SIGHUP,
+ syscall.SIGINT,
+ },
+ )
+
+ // Stop the server gracefully when the program terminates cleanly.
+ defer app.stopGracefully(runCtx, nil)
+
// Enable tracing with OpenTelemetry.
- if enableTracing {
+ if app.EnableTracing {
// TODO: Make this configurable.
- shutdown := tracing.OTLPTracer(true, collectorURL, config.TracerName)
+ shutdown := tracing.OTLPTracer(true, app.CollectorURL, config.TracerName)
defer func() {
if err := shutdown(context.Background()); err != nil {
cmd.Println(err)
+ app.stopGracefully(runCtx, nil)
os.Exit(gerr.FailedToStartTracer)
}
}()
}
- runCtx, span := otel.Tracer(config.TracerName).Start(context.Background(), "GatewayD")
span.End()
// Enable Sentry.
- if enableSentry {
+ if app.EnableSentry {
_, span := otel.Tracer(config.TracerName).Start(runCtx, "Sentry")
defer span.End()
@@ -223,968 +101,162 @@ var runCmd = &cobra.Command{
}
// Lint the configuration files before loading them.
- if enableLinting {
+ if app.EnableLinting {
_, span := otel.Tracer(config.TracerName).Start(runCtx, "Lint configuration files")
defer span.End()
// Lint the global configuration file and fail if it's not valid.
- if err := lintConfig(Global, globalConfigFile); err != nil {
+ if err := lintConfig(Global, app.GlobalConfigFile); err != nil {
log.Fatal(err)
}
// Lint the plugin configuration file and fail if it's not valid.
- if err := lintConfig(Plugins, pluginConfigFile); err != nil {
+ if err := lintConfig(Plugins, app.PluginConfigFile); err != nil {
log.Fatal(err)
}
}
- // Load global and plugin configuration.
- conf = config.NewConfig(runCtx, config.Config{GlobalConfigFile: globalConfigFile, PluginConfigFile: pluginConfigFile})
- if err := conf.InitConfig(runCtx); err != nil {
+ // Load the configuration files.
+ if err := app.loadConfig(runCtx); err != nil {
log.Fatal(err)
}
// Create and initialize loggers from the config.
- // Use cobra command cmd instead of os.Stdout for the console output.
- cmdLogger := &cobraCmdWriter{cmd}
- for name, cfg := range conf.Global.Loggers {
- loggers[name] = logging.NewLogger(runCtx, logging.LoggerConfig{
- Output: cfg.GetOutput(),
- ConsoleOut: cmdLogger,
- Level: config.If(
- config.Exists(config.LogLevels, cfg.Level),
- config.LogLevels[cfg.Level],
- config.LogLevels[config.DefaultLogLevel],
- ),
- TimeFormat: config.If(
- config.Exists(config.TimeFormats, cfg.TimeFormat),
- config.TimeFormats[cfg.TimeFormat],
- config.TimeFormats[config.DefaultTimeFormat],
- ),
- ConsoleTimeFormat: config.If(
- config.Exists(
- config.ConsoleTimeFormats, cfg.ConsoleTimeFormat),
- config.ConsoleTimeFormats[cfg.ConsoleTimeFormat],
- config.ConsoleTimeFormats[config.DefaultConsoleTimeFormat],
- ),
- NoColor: cfg.NoColor,
- FileName: cfg.FileName,
- MaxSize: cfg.MaxSize,
- MaxBackups: cfg.MaxBackups,
- MaxAge: cfg.MaxAge,
- Compress: cfg.Compress,
- LocalTime: cfg.LocalTime,
- SyslogPriority: cfg.GetSyslogPriority(),
- RSyslogNetwork: cfg.RSyslogNetwork,
- RSyslogAddress: cfg.RSyslogAddress,
- Name: name,
- })
- }
-
- // Set the default logger.
- logger := loggers[config.Default]
+ // And then set the default logger.
+ logger := app.createLoggers(runCtx, cmd)
- if devMode {
+ if app.DevMode {
logger.Warn().Msg(
"Running GatewayD in development mode (not recommended for production)")
}
- // Create a new act registry given the built-in signals, policies, and actions.
- var publisher *act.Publisher
- if conf.Plugin.ActionRedis.Enabled {
- rdb := redis.NewClient(&redis.Options{
- Addr: conf.Plugin.ActionRedis.Address,
- })
- var err error
- publisher, err = act.NewPublisher(act.Publisher{
- Logger: logger,
- RedisDB: rdb,
- ChannelName: conf.Plugin.ActionRedis.Channel,
- })
- if err != nil {
- logger.Error().Err(err).Msg("Failed to create publisher for act registry")
- os.Exit(gerr.FailedToCreateActRegistry)
- }
- logger.Info().Msg("Created Redis publisher for Act registry")
- }
-
- actRegistry = act.NewActRegistry(
- act.Registry{
- Signals: act.BuiltinSignals(),
- Policies: act.BuiltinPolicies(),
- Actions: act.BuiltinActions(),
- DefaultPolicyName: conf.Plugin.DefaultPolicy,
- PolicyTimeout: conf.Plugin.PolicyTimeout,
- DefaultActionTimeout: conf.Plugin.ActionTimeout,
- TaskPublisher: publisher,
- Logger: logger,
- })
-
- if actRegistry == nil {
- logger.Error().Msg("Failed to create act registry")
+ // Create the Act registry.
+ if err := app.createActRegistry(logger); err != nil {
+ logger.Error().Err(err).Msg("Failed to create act registry")
+ app.stopGracefully(runCtx, nil)
os.Exit(gerr.FailedToCreateActRegistry)
}
// Load policies from the configuration file and add them to the registry.
- for _, plc := range conf.Plugin.Policies {
- if policy, err := sdkAct.NewPolicy(
- plc.Name, plc.Policy, plc.Metadata,
- ); err != nil || policy == nil {
- logger.Error().Err(err).Str("name", plc.Name).Msg("Failed to create policy")
- } else {
- actRegistry.Add(policy)
- }
+ if err := app.loadPolicies(logger); err != nil {
+ logger.Error().Err(err).Msg("Failed to load policies")
+ app.stopGracefully(runCtx, nil)
+ os.Exit(gerr.FailedToLoadPolicies)
}
logger.Info().Fields(map[string]any{
- "policies": maps.Keys(actRegistry.Policies),
+ "policies": maps.Keys(app.actRegistry.Policies),
}).Msg("Policies are loaded")
- // Create a new plugin registry.
- // The plugins are loaded and hooks registered before the configuration is loaded.
- pluginRegistry = plugin.NewRegistry(
- runCtx,
- plugin.Registry{
- ActRegistry: actRegistry,
- Compatibility: config.If(
- config.Exists(
- config.CompatibilityPolicies, conf.Plugin.CompatibilityPolicy,
- ),
- config.CompatibilityPolicies[conf.Plugin.CompatibilityPolicy],
- config.DefaultCompatibilityPolicy),
- Logger: logger,
- DevMode: devMode,
- },
- )
+ // Create the plugin registry.
+ app.createPluginRegistry(runCtx, logger)
// Load plugins and register their hooks.
- pluginRegistry.LoadPlugins(runCtx, conf.Plugin.Plugins, conf.Plugin.StartTimeout)
+ app.pluginRegistry.LoadPlugins(
+ runCtx,
+ app.conf.Plugin.Plugins,
+ app.conf.Plugin.StartTimeout,
+ )
// Start the metrics merger if enabled.
- var metricsMerger *metrics.Merger
- if conf.Plugin.EnableMetricsMerger {
- metricsMerger = metrics.NewMerger(runCtx, metrics.Merger{
- MetricsMergerPeriod: conf.Plugin.MetricsMergerPeriod,
- Logger: logger,
- })
- pluginRegistry.ForEach(func(_ sdkPlugin.Identifier, plugin *plugin.Plugin) {
- if metricsEnabled, err := strconv.ParseBool(plugin.Config["metricsEnabled"]); err == nil && metricsEnabled {
- metricsMerger.Add(plugin.ID.Name, plugin.Config["metricsUnixDomainSocket"])
- logger.Debug().Str("plugin", plugin.ID.Name).Msg(
- "Added plugin to metrics merger")
- }
- })
- metricsMerger.Start()
- }
+ app.startMetricsMerger(runCtx, logger)
// TODO: Move this to the plugin registry.
ctx, span := otel.Tracer(config.TracerName).Start(runCtx, "Plugin health check")
- // Ping the plugins to check if they are alive, and remove them if they are not.
- startDelay := time.Now().Add(conf.Plugin.HealthCheckPeriod)
- if _, err := healthCheckScheduler.Every(
- conf.Plugin.HealthCheckPeriod).SingletonMode().StartAt(startDelay).Do(func() {
- _, span := otel.Tracer(config.TracerName).Start(ctx, "Run plugin health check")
- defer span.End()
-
- var plugins []string
- pluginRegistry.ForEach(func(pluginId sdkPlugin.Identifier, plugin *plugin.Plugin) {
- if err := plugin.Ping(); err != nil {
- span.RecordError(err)
- logger.Error().Err(err).Msg("Failed to ping plugin")
- if conf.Plugin.EnableMetricsMerger && metricsMerger != nil {
- metricsMerger.Remove(pluginId.Name)
- }
- pluginRegistry.Remove(pluginId)
-
- if !conf.Plugin.ReloadOnCrash {
- return // Do not reload the plugins.
- }
-
- // Reload the plugins and register their hooks upon crash.
- logger.Info().Str("name", pluginId.Name).Msg("Reloading crashed plugin")
- pluginConfig := conf.Plugin.GetPlugins(pluginId.Name)
- if pluginConfig != nil {
- pluginRegistry.LoadPlugins(runCtx, pluginConfig, conf.Plugin.StartTimeout)
- }
- } else {
- logger.Trace().Str("name", pluginId.Name).Msg("Successfully pinged plugin")
- plugins = append(plugins, pluginId.Name)
- }
- })
- span.SetAttributes(attribute.StringSlice("plugins", plugins))
- }); err != nil {
- logger.Error().Err(err).Msg("Failed to start plugin health check scheduler")
- span.RecordError(err)
- }
- if pluginRegistry.Size() > 0 {
- logger.Info().Str(
- "healthCheckPeriod", conf.Plugin.HealthCheckPeriod.String(),
- ).Msg("Starting plugin health check scheduler")
- healthCheckScheduler.StartAsync()
- }
+ // Start the health check scheduler only if there are plugins.
+ app.startHealthCheckScheduler(runCtx, ctx, span, logger)
span.End()
- // Set the plugin timeout context.
- pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), conf.Plugin.Timeout)
- defer cancel()
-
- // The config will be passed to the plugins that register to the "OnConfigLoaded" plugin.
- // The plugins can modify the config and return it.
- updatedGlobalConfig, err := pluginRegistry.Run(
- pluginTimeoutCtx, conf.GlobalKoanf.All(), v1.HookName_HOOK_NAME_ON_CONFIG_LOADED)
- if err != nil {
- logger.Error().Err(err).Msg("Failed to run OnConfigLoaded hooks")
- span.RecordError(err)
- }
- if updatedGlobalConfig != nil {
- updatedGlobalConfig = pluginRegistry.ActRegistry.RunAll(updatedGlobalConfig)
- }
-
- // If the config was modified by the plugins, merge it with the one loaded from the file.
- // Only global configuration is merged, which means that plugins cannot modify the plugin
- // configurations.
- if updatedGlobalConfig != nil {
- // Merge the config with the one loaded from the file (in memory).
- // The changes won't be persisted to disk.
- if err := conf.MergeGlobalConfig(runCtx, updatedGlobalConfig); err != nil {
- log.Fatal(err)
- }
+ // Merge the global config with the one from the plugins.
+ if err := app.onConfigLoaded(runCtx, span, logger); err != nil {
+ app.stopGracefully(runCtx, nil)
+ os.Exit(gerr.FailedToMergeGlobalConfig)
}
// Start the metrics server if enabled.
- // TODO: Start multiple metrics servers. For now, only one default is supported.
- // I should first find a use case for those multiple metrics servers.
- go func(metricsConfig *config.Metrics, logger zerolog.Logger) {
- _, span := otel.Tracer(config.TracerName).Start(runCtx, "Start metrics server")
- defer span.End()
-
- // TODO: refactor this to a separate function.
- if !metricsConfig.Enabled {
- logger.Info().Msg("Metrics server is disabled")
- return
- }
-
- scheme := "http://"
- if metricsConfig.KeyFile != "" && metricsConfig.CertFile != "" {
- scheme = "https://"
- }
-
- fqdn, err := url.Parse(scheme + metricsConfig.Address)
- if err != nil {
- logger.Error().Err(err).Msg("Failed to parse metrics address")
+ go func(app *GatewayDApp) {
+ if err := app.startMetricsServer(runCtx, logger); err != nil {
+ logger.Error().Err(err).Msg("Failed to start metrics server")
span.RecordError(err)
- return
}
+ }(app)
- address, err := url.JoinPath(fqdn.String(), metricsConfig.Path)
- if err != nil {
- logger.Error().Err(err).Msg("Failed to parse metrics path")
- span.RecordError(err)
- return
- }
-
- // Merge the metrics from the plugins with the ones from GatewayD.
- mergedMetricsHandler := func(next http.Handler) http.Handler {
- handler := func(responseWriter http.ResponseWriter, request *http.Request) {
- if _, err := responseWriter.Write(metricsMerger.OutputMetrics); err != nil {
- logger.Error().Err(err).Msg("Failed to write metrics")
- span.RecordError(err)
- sentry.CaptureException(err)
- }
- // The WriteHeader method intentionally does nothing, to prevent a bug
- // in the merging metrics that causes the headers to be written twice,
- // which results in an error: "http: superfluous response.WriteHeader call".
- next.ServeHTTP(
- &metrics.HeaderBypassResponseWriter{
- ResponseWriter: responseWriter,
- },
- request)
- }
- return http.HandlerFunc(handler)
- }
-
- handler := func() http.Handler {
- return promhttp.InstrumentMetricHandler(
- prometheus.DefaultRegisterer,
- promhttp.HandlerFor(prometheus.DefaultGatherer, promhttp.HandlerOpts{
- DisableCompression: true,
- }),
- )
- }()
-
- mux := http.NewServeMux()
- mux.HandleFunc("/", func(responseWriter http.ResponseWriter, _ *http.Request) {
- // Serve a static page with a link to the metrics endpoint.
- if _, err := responseWriter.Write([]byte(fmt.Sprintf(
- `GatewayD Prometheus Metrics ServerMetrics`,
- address,
- ))); err != nil {
- logger.Error().Err(err).Msg("Failed to write metrics")
- span.RecordError(err)
- sentry.CaptureException(err)
- }
- })
-
- if conf.Plugin.EnableMetricsMerger && metricsMerger != nil {
- handler = mergedMetricsHandler(handler)
- }
-
- readHeaderTimeout := config.If(
- metricsConfig.ReadHeaderTimeout > 0,
- metricsConfig.ReadHeaderTimeout,
- config.DefaultReadHeaderTimeout,
- )
-
- // Check if the metrics server is already running before registering the handler.
- if _, err = http.Get(address); err != nil { //nolint:gosec
- // The timeout handler limits the nested handlers from running for too long.
- mux.Handle(
- metricsConfig.Path,
- http.TimeoutHandler(
- gziphandler.GzipHandler(handler),
- readHeaderTimeout,
- "The request timed out while fetching the metrics",
- ),
- )
- } else {
- logger.Warn().Msg("Metrics server is already running, consider changing the port")
- span.RecordError(err)
- }
-
- // Create a new metrics server.
- timeout := config.If(
- metricsConfig.Timeout > 0,
- metricsConfig.Timeout,
- config.DefaultMetricsServerTimeout,
- )
- metricsServer = &http.Server{
- Addr: metricsConfig.Address,
- Handler: mux,
- ReadHeaderTimeout: readHeaderTimeout,
- ReadTimeout: timeout,
- WriteTimeout: timeout,
- IdleTimeout: timeout,
- }
-
- logger.Info().Fields(map[string]any{
- "address": address,
- "timeout": timeout.String(),
- "readHeaderTimeout": readHeaderTimeout.String(),
- }).Msg("Metrics are exposed")
-
- if metricsConfig.CertFile != "" && metricsConfig.KeyFile != "" {
- // Set up TLS.
- metricsServer.TLSConfig = &tls.Config{
- MinVersion: tls.VersionTLS13,
- CurvePreferences: []tls.CurveID{
- tls.CurveP521,
- tls.CurveP384,
- tls.CurveP256,
- },
- CipherSuites: []uint16{
- tls.TLS_AES_128_GCM_SHA256,
- tls.TLS_AES_256_GCM_SHA384,
- tls.TLS_CHACHA20_POLY1305_SHA256,
- },
- }
- metricsServer.TLSNextProto = make(
- map[string]func(*http.Server, *tls.Conn, http.Handler))
- logger.Debug().Msg("Metrics server is running with TLS")
-
- // Start the metrics server with TLS.
- if err = metricsServer.ListenAndServeTLS(
- metricsConfig.CertFile, metricsConfig.KeyFile); !errors.Is(err, http.ErrServerClosed) {
- logger.Error().Err(err).Msg("Failed to start metrics server")
- span.RecordError(err)
- }
- } else {
- // Start the metrics server without TLS.
- if err = metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
- logger.Error().Err(err).Msg("Failed to start metrics server")
- span.RecordError(err)
- }
- }
- }(conf.Global.Metrics[config.Default], logger)
-
- // This is a notification hook, so we don't care about the result.
- pluginTimeoutCtx, cancel = context.WithTimeout(context.Background(), conf.Plugin.Timeout)
- defer cancel()
-
- if data, ok := conf.GlobalKoanf.Get("loggers").(map[string]any); ok {
- result, err := pluginRegistry.Run(
- pluginTimeoutCtx, data, v1.HookName_HOOK_NAME_ON_NEW_LOGGER)
- if err != nil {
- logger.Error().Err(err).Msg("Failed to run OnNewLogger hooks")
- span.RecordError(err)
- }
- if result != nil {
- _ = pluginRegistry.ActRegistry.RunAll(result)
- }
- } else {
- logger.Error().Msg("Failed to get loggers from config")
- }
-
- // Declare httpServer and grpcServer here as it is used in the StopGracefully function ahead of their definition.
- var httpServer *api.HTTPServer
- var grpcServer *api.GRPCServer
+ // Run the OnNewLogger hook.
+ app.onNewLogger(span, logger)
_, span = otel.Tracer(config.TracerName).Start(runCtx, "Create pools and clients")
- // Create and initialize pools of connections.
- for configGroupName, configGroup := range conf.Global.Pools {
- for configBlockName, cfg := range configGroup {
- logger := loggers[configGroupName]
- // Check if the pool size is greater than zero.
- currentPoolSize := config.If(
- cfg.Size > 0,
- // Check if the pool size is greater than the minimum pool size.
- config.If(
- cfg.Size > config.MinimumPoolSize,
- cfg.Size,
- config.MinimumPoolSize,
- ),
- config.DefaultPoolSize,
- )
-
- if _, ok := pools[configGroupName]; !ok {
- pools[configGroupName] = make(map[string]*pool.Pool)
- }
- pools[configGroupName][configBlockName] = pool.NewPool(runCtx, currentPoolSize)
-
- span.AddEvent("Create pool", trace.WithAttributes(
- attribute.String("name", configBlockName),
- attribute.Int("size", currentPoolSize),
- ))
-
- if _, ok := clients[configGroupName]; !ok {
- clients[configGroupName] = make(map[string]*config.Client)
- }
-
- // Get client config from the config file.
- if clientConfig, ok := conf.Global.Clients[configGroupName][configBlockName]; !ok {
- // This ensures that the default client config is used if the pool name is not
- // found in the clients section.
- clients[configGroupName][configBlockName] = conf.Global.Clients[config.Default][config.DefaultConfigurationBlock]
- } else {
- // Merge the default client config with the one from the pool.
- clients[configGroupName][configBlockName] = clientConfig
- }
-
- // Fill the missing and zero values with the default ones.
- clients[configGroupName][configBlockName].TCPKeepAlivePeriod = config.If(
- clients[configGroupName][configBlockName].TCPKeepAlivePeriod > 0,
- clients[configGroupName][configBlockName].TCPKeepAlivePeriod,
- config.DefaultTCPKeepAlivePeriod,
- )
- clients[configGroupName][configBlockName].ReceiveDeadline = config.If(
- clients[configGroupName][configBlockName].ReceiveDeadline > 0,
- clients[configGroupName][configBlockName].ReceiveDeadline,
- config.DefaultReceiveDeadline,
- )
- clients[configGroupName][configBlockName].ReceiveTimeout = config.If(
- clients[configGroupName][configBlockName].ReceiveTimeout > 0,
- clients[configGroupName][configBlockName].ReceiveTimeout,
- config.DefaultReceiveTimeout,
- )
- clients[configGroupName][configBlockName].SendDeadline = config.If(
- clients[configGroupName][configBlockName].SendDeadline > 0,
- clients[configGroupName][configBlockName].SendDeadline,
- config.DefaultSendDeadline,
- )
- clients[configGroupName][configBlockName].ReceiveChunkSize = config.If(
- clients[configGroupName][configBlockName].ReceiveChunkSize > 0,
- clients[configGroupName][configBlockName].ReceiveChunkSize,
- config.DefaultChunkSize,
- )
- clients[configGroupName][configBlockName].DialTimeout = config.If(
- clients[configGroupName][configBlockName].DialTimeout > 0,
- clients[configGroupName][configBlockName].DialTimeout,
- config.DefaultDialTimeout,
- )
-
- // Add clients to the pool.
- for range currentPoolSize {
- clientConfig := clients[configGroupName][configBlockName]
- clientConfig.GroupName = configGroupName
- clientConfig.BlockName = configBlockName
- client := network.NewClient(
- runCtx, clientConfig, logger,
- network.NewRetry(
- network.Retry{
- Retries: clientConfig.Retries,
- Backoff: config.If(
- clientConfig.Backoff > 0,
- clientConfig.Backoff,
- config.DefaultBackoff,
- ),
- BackoffMultiplier: clientConfig.BackoffMultiplier,
- DisableBackoffCaps: clientConfig.DisableBackoffCaps,
- Logger: loggers[configBlockName],
- },
- ),
- )
-
- if client != nil {
- eventOptions := trace.WithAttributes(
- attribute.String("name", configBlockName),
- attribute.String("group", configGroupName),
- attribute.String("network", client.Network),
- attribute.String("address", client.Address),
- attribute.Int("receiveChunkSize", client.ReceiveChunkSize),
- attribute.String("receiveDeadline", client.ReceiveDeadline.String()),
- attribute.String("receiveTimeout", client.ReceiveTimeout.String()),
- attribute.String("sendDeadline", client.SendDeadline.String()),
- attribute.String("dialTimeout", client.DialTimeout.String()),
- attribute.Bool("tcpKeepAlive", client.TCPKeepAlive),
- attribute.String("tcpKeepAlivePeriod", client.TCPKeepAlivePeriod.String()),
- attribute.String("localAddress", client.LocalAddr()),
- attribute.String("remoteAddress", client.RemoteAddr()),
- attribute.Int("retries", clientConfig.Retries),
- attribute.String("backoff", client.Retry().Backoff.String()),
- attribute.Float64("backoffMultiplier", clientConfig.BackoffMultiplier),
- attribute.Bool("disableBackoffCaps", clientConfig.DisableBackoffCaps),
- )
- if client.ID != "" {
- eventOptions = trace.WithAttributes(
- attribute.String("id", client.ID),
- )
- }
-
- span.AddEvent("Create client", eventOptions)
-
- pluginTimeoutCtx, cancel = context.WithTimeout(
- context.Background(), conf.Plugin.Timeout)
- defer cancel()
-
- clientCfg := map[string]any{
- "id": client.ID,
- "name": configBlockName,
- "group": configGroupName,
- "network": client.Network,
- "address": client.Address,
- "receiveChunkSize": client.ReceiveChunkSize,
- "receiveDeadline": client.ReceiveDeadline.String(),
- "receiveTimeout": client.ReceiveTimeout.String(),
- "sendDeadline": client.SendDeadline.String(),
- "dialTimeout": client.DialTimeout.String(),
- "tcpKeepAlive": client.TCPKeepAlive,
- "tcpKeepAlivePeriod": client.TCPKeepAlivePeriod.String(),
- "localAddress": client.LocalAddr(),
- "remoteAddress": client.RemoteAddr(),
- "retries": clientConfig.Retries,
- "backoff": client.Retry().Backoff.String(),
- "backoffMultiplier": clientConfig.BackoffMultiplier,
- "disableBackoffCaps": clientConfig.DisableBackoffCaps,
- }
- result, err := pluginRegistry.Run(
- pluginTimeoutCtx, clientCfg, v1.HookName_HOOK_NAME_ON_NEW_CLIENT)
- if err != nil {
- logger.Error().Err(err).Msg("Failed to run OnNewClient hooks")
- span.RecordError(err)
- }
- if result != nil {
- _ = pluginRegistry.ActRegistry.RunAll(result)
- }
-
- err = pools[configGroupName][configBlockName].Put(client.ID, client)
- if err != nil {
- logger.Error().Err(err).Msg("Failed to add client to the pool")
- span.RecordError(err)
- }
- } else {
- logger.Error().Msg("Failed to create client, please check the configuration")
- go func() {
- // Wait for the stop signal to exit gracefully.
- // This prevents the program from waiting indefinitely
- // after the StopGracefully function is called.
- <-stopChan
- os.Exit(gerr.FailedToCreateClient)
- }()
- StopGracefully(
- runCtx,
- nil,
- metricsMerger,
- metricsServer,
- pluginRegistry,
- logger,
- servers,
- stopChan,
- httpServer,
- grpcServer,
- )
- }
- }
-
- // Verify that the pool is properly populated.
- logger.Info().Fields(map[string]any{
- "name": configBlockName,
- "count": strconv.Itoa(pools[configGroupName][configBlockName].Size()),
- }).Msg("There are clients available in the pool")
-
- if pools[configGroupName][configBlockName].Size() != currentPoolSize {
- logger.Error().Msg(
- "The pool size is incorrect, either because " +
- "the clients cannot connect due to no network connectivity " +
- "or the server is not running. exiting...")
- pluginRegistry.Shutdown()
- os.Exit(gerr.FailedToInitializePool)
- }
-
- pluginTimeoutCtx, cancel = context.WithTimeout(
- context.Background(), conf.Plugin.Timeout)
- defer cancel()
- result, err := pluginRegistry.Run(
- pluginTimeoutCtx,
- map[string]any{"name": configBlockName, "size": currentPoolSize},
- v1.HookName_HOOK_NAME_ON_NEW_POOL)
- if err != nil {
- logger.Error().Err(err).Msg("Failed to run OnNewPool hooks")
- span.RecordError(err)
- }
- if result != nil {
- _ = pluginRegistry.ActRegistry.RunAll(result)
- }
- }
+ // Create pools and clients.
+ if err := app.createPoolAndClients(runCtx, span); err != nil {
+ logger.Error().Err(err).Msg("Failed to create pools and clients")
+ span.RecordError(err)
+ app.stopGracefully(runCtx, nil)
+ os.Exit(gerr.FailedToCreatePoolAndClients)
}
span.End()
_, span = otel.Tracer(config.TracerName).Start(runCtx, "Create proxies")
- // Create and initialize prefork proxies with each pool of clients.
- for configGroupName, configGroup := range conf.Global.Proxies {
- for configBlockName, cfg := range configGroup {
- logger := loggers[configGroupName]
- clientConfig := clients[configGroupName][configBlockName]
-
- // Fill the missing and zero value with the default one.
- cfg.HealthCheckPeriod = config.If(
- cfg.HealthCheckPeriod > 0,
- cfg.HealthCheckPeriod,
- config.DefaultHealthCheckPeriod,
- )
-
- if _, ok := proxies[configGroupName]; !ok {
- proxies[configGroupName] = make(map[string]*network.Proxy)
- }
-
- proxies[configGroupName][configBlockName] = network.NewProxy(
- runCtx,
- network.Proxy{
- GroupName: configGroupName,
- BlockName: configBlockName,
- AvailableConnections: pools[configGroupName][configBlockName],
- PluginRegistry: pluginRegistry,
- HealthCheckPeriod: cfg.HealthCheckPeriod,
- ClientConfig: clientConfig,
- Logger: logger,
- PluginTimeout: conf.Plugin.Timeout,
- },
- )
-
- span.AddEvent("Create proxy", trace.WithAttributes(
- attribute.String("name", configBlockName),
- attribute.String("healthCheckPeriod", cfg.HealthCheckPeriod.String()),
- ))
-
- pluginTimeoutCtx, cancel = context.WithTimeout(
- context.Background(), conf.Plugin.Timeout)
- defer cancel()
-
- if data, ok := conf.GlobalKoanf.Get("proxies").(map[string]any); ok {
- result, err := pluginRegistry.Run(
- pluginTimeoutCtx, data, v1.HookName_HOOK_NAME_ON_NEW_PROXY)
- if err != nil {
- logger.Error().Err(err).Msg("Failed to run OnNewProxy hooks")
- span.RecordError(err)
- }
- if result != nil {
- _ = pluginRegistry.ActRegistry.RunAll(result)
- }
- } else {
- logger.Error().Msg("Failed to get proxy from config")
- }
- }
- }
+ // Create proxies.
+ app.createProxies(runCtx, span)
span.End()
_, span = otel.Tracer(config.TracerName).Start(runCtx, "Create Raft Node")
defer span.End()
- raftNode, originalErr := raft.NewRaftNode(logger, conf.Global.Raft)
+ // Create the Raft node.
+ raftNode, originalErr := raft.NewRaftNode(logger, app.conf.Global.Raft)
if originalErr != nil {
logger.Error().Err(originalErr).Msg("Failed to start raft node")
span.RecordError(originalErr)
- pluginRegistry.Shutdown()
+ app.stopGracefully(runCtx, nil)
os.Exit(gerr.FailedToStartRaftNode)
}
_, span = otel.Tracer(config.TracerName).Start(runCtx, "Create servers")
- // Create and initialize servers.
- for name, cfg := range conf.Global.Servers {
- logger := loggers[name]
-
- var serverProxies []network.IProxy
- for _, proxy := range proxies[name] {
- serverProxies = append(serverProxies, proxy)
- }
-
- servers[name] = network.NewServer(
- runCtx,
- network.Server{
- GroupName: name,
- Network: cfg.Network,
- Address: cfg.Address,
- TickInterval: config.If(
- cfg.TickInterval > 0,
- cfg.TickInterval,
- config.DefaultTickInterval,
- ),
- Options: network.Option{
- // Can be used to send keepalive messages to the client.
- EnableTicker: cfg.EnableTicker,
- },
- Proxies: serverProxies,
- Logger: logger,
- PluginRegistry: pluginRegistry,
- PluginTimeout: conf.Plugin.Timeout,
- EnableTLS: cfg.EnableTLS,
- CertFile: cfg.CertFile,
- KeyFile: cfg.KeyFile,
- HandshakeTimeout: cfg.HandshakeTimeout,
- LoadbalancerStrategyName: cfg.LoadBalancer.Strategy,
- LoadbalancerRules: cfg.LoadBalancer.LoadBalancingRules,
- LoadbalancerConsistentHash: cfg.LoadBalancer.ConsistentHash,
- RaftNode: raftNode,
- },
- )
- span.AddEvent("Create server", trace.WithAttributes(
- attribute.String("name", name),
- attribute.String("network", cfg.Network),
- attribute.String("address", cfg.Address),
- attribute.String("tickInterval", cfg.TickInterval.String()),
- attribute.String("pluginTimeout", conf.Plugin.Timeout.String()),
- attribute.Bool("enableTLS", cfg.EnableTLS),
- attribute.String("certFile", cfg.CertFile),
- attribute.String("keyFile", cfg.KeyFile),
- attribute.String("handshakeTimeout", cfg.HandshakeTimeout.String()),
- ))
-
- pluginTimeoutCtx, cancel = context.WithTimeout(
- context.Background(), conf.Plugin.Timeout)
- defer cancel()
-
- if data, ok := conf.GlobalKoanf.Get("servers").(map[string]any); ok {
- result, err := pluginRegistry.Run(
- pluginTimeoutCtx, data, v1.HookName_HOOK_NAME_ON_NEW_SERVER)
- if err != nil {
- logger.Error().Err(err).Msg("Failed to run OnNewServer hooks")
- span.RecordError(err)
- }
- if result != nil {
- _ = pluginRegistry.ActRegistry.RunAll(result)
- }
- } else {
- logger.Error().Msg("Failed to get the servers configuration")
- }
- }
+ // Create servers.
+ app.createServers(runCtx, span, raftNode)
span.End()
- // Start the HTTP and gRPC APIs.
- if conf.Global.API.Enabled {
- apiOptions := api.Options{
- Logger: logger,
- GRPCNetwork: conf.Global.API.GRPCNetwork,
- GRPCAddress: conf.Global.API.GRPCAddress,
- HTTPAddress: conf.Global.API.HTTPAddress,
- Servers: servers,
- RaftNode: raftNode,
- }
-
- apiObj := &api.API{
- Options: &apiOptions,
- Config: conf,
- PluginRegistry: pluginRegistry,
- Pools: pools,
- Proxies: proxies,
- Servers: servers,
- }
- grpcServer = api.NewGRPCServer(
- runCtx,
- api.GRPCServer{
- API: apiObj,
- HealthChecker: &api.HealthChecker{Servers: servers},
- },
- )
- if grpcServer != nil {
- go grpcServer.Start()
- logger.Info().Str("address", apiOptions.HTTPAddress).Msg("Started the HTTP API")
+ // Start the API servers.
+ app.startAPIServers(runCtx, logger, raftNode)
- httpServer = api.NewHTTPServer(&apiOptions)
- go httpServer.Start()
-
- logger.Info().Fields(
- map[string]any{
- "network": apiOptions.GRPCNetwork,
- "address": apiOptions.GRPCAddress,
- },
- ).Msg("Started the gRPC Server")
- }
- }
-
- // Report usage statistics.
- if enableUsageReport {
- go func() {
- conn, err := grpc.NewClient(
- UsageReportURL,
- grpc.WithTransportCredentials(
- credentials.NewTLS(
- &tls.Config{
- MinVersion: tls.VersionTLS12,
- },
- ),
- ),
- )
- if err != nil {
- logger.Trace().Err(err).Msg(
- "Failed to dial to the gRPC server for usage reporting")
- }
- defer func(conn *grpc.ClientConn) {
- err := conn.Close()
- if err != nil {
- logger.Trace().Err(err).Msg("Failed to close the connection to the usage report service")
- }
- }(conn)
-
- client := usage.NewUsageReportServiceClient(conn)
- report := usage.UsageReportRequest{
- Version: config.Version,
- RuntimeVersion: runtime.Version(),
- Goos: runtime.GOOS,
- Goarch: runtime.GOARCH,
- Service: "gatewayd",
- DevMode: devMode,
- Plugins: []*usage.Plugin{},
- }
- pluginRegistry.ForEach(
- func(identifier sdkPlugin.Identifier, _ *plugin.Plugin) {
- report.Plugins = append(report.GetPlugins(), &usage.Plugin{
- Name: identifier.Name,
- Version: identifier.Version,
- Checksum: identifier.Checksum,
- })
- },
- )
- _, err = client.Report(context.Background(), &report)
- if err != nil {
- logger.Trace().Err(err).Msg("Failed to report usage statistics")
- }
- }()
- }
-
- // Shutdown the server gracefully.
- signals := []os.Signal{
- os.Interrupt,
- os.Kill,
- syscall.SIGTERM,
- syscall.SIGABRT,
- syscall.SIGQUIT,
- syscall.SIGHUP,
- syscall.SIGINT,
- }
- signalsCh := make(chan os.Signal, 1)
- signal.Notify(signalsCh, signals...)
- go func(pluginRegistry *plugin.Registry,
- logger zerolog.Logger,
- servers map[string]*network.Server,
- metricsMerger *metrics.Merger,
- metricsServer *http.Server,
- stopChan chan struct{},
- httpServer *api.HTTPServer,
- grpcServer *api.GRPCServer,
- ) {
- for sig := range signalsCh {
- for _, s := range signals {
- if sig != s {
- StopGracefully(
- runCtx,
- sig,
- metricsMerger,
- metricsServer,
- pluginRegistry,
- logger,
- servers,
- stopChan,
- httpServer,
- grpcServer,
- )
- os.Exit(0)
- }
- }
- }
- }(pluginRegistry, logger, servers, metricsMerger, metricsServer, stopChan, httpServer, grpcServer)
+ // Report usage.
+ app.reportUsage(logger)
_, span = otel.Tracer(config.TracerName).Start(runCtx, "Start servers")
- // Start the server.
- for name, server := range servers {
- logger := loggers[name]
- go func(
- span trace.Span,
- server *network.Server,
- logger zerolog.Logger,
- healthCheckScheduler *gocron.Scheduler,
- metricsMerger *metrics.Merger,
- pluginRegistry *plugin.Registry,
- ) {
- span.AddEvent("Start server")
- if err := server.Run(); err != nil {
- logger.Error().Err(err).Msg("Failed to start server")
- span.RecordError(err)
- healthCheckScheduler.Clear()
- if metricsMerger != nil {
- metricsMerger.Stop()
- }
- server.Shutdown()
- pluginRegistry.Shutdown()
- os.Exit(gerr.FailedToStartServer)
- }
- }(span, server, logger, healthCheckScheduler, metricsMerger, pluginRegistry)
- }
+ // Start the servers.
+ app.startServers(runCtx, span)
+
span.End()
// Wait for the server to shut down.
- <-stopChan
+ <-app.stopChan
},
}
func init() {
rootCmd.AddCommand(runCmd)
- runCmd.Flags().StringVarP(
- &globalConfigFile,
+ runCmd.Flags().StringP(
"config", "c", config.GetDefaultConfigFilePath(config.GlobalConfigFilename),
"Global config file")
- runCmd.Flags().StringVarP(
- &pluginConfigFile,
+ runCmd.Flags().StringP(
"plugin-config", "p", config.GetDefaultConfigFilePath(config.PluginsConfigFilename),
"Plugin config file")
- runCmd.Flags().BoolVar(
- &devMode, "dev", false, "Enable development mode for plugin development")
- runCmd.Flags().BoolVar(
- &enableTracing, "tracing", false, "Enable tracing with OpenTelemetry via gRPC")
- runCmd.Flags().StringVar(
- &collectorURL, "collector-url", "localhost:4317",
- "Collector URL of OpenTelemetry gRPC endpoint")
- runCmd.Flags().BoolVar(
- &enableSentry, "sentry", true, "Enable Sentry")
- runCmd.Flags().BoolVar(
- &enableUsageReport, "usage-report", true, "Enable usage report")
- runCmd.Flags().BoolVar(
- &enableLinting, "lint", true, "Enable linting of configuration files")
+ runCmd.Flags().Bool("dev", false, "Enable development mode for plugin development")
+ runCmd.Flags().Bool("tracing", false, "Enable tracing with OpenTelemetry via gRPC")
+ runCmd.Flags().String(
+ "collector-url", "localhost:4317", "Collector URL of OpenTelemetry gRPC endpoint")
+ runCmd.Flags().Bool("sentry", true, "Enable Sentry")
+ runCmd.Flags().Bool("usage-report", true, "Enable usage report")
+ runCmd.Flags().Bool("lint", true, "Enable linting of configuration files")
+ runCmd.Flags().Bool("metrics-merger", true, "Enable metrics merger")
}
diff --git a/cmd/run_test.go b/cmd/run_test.go
index 37fd54c5..61f0b269 100644
--- a/cmd/run_test.go
+++ b/cmd/run_test.go
@@ -7,7 +7,6 @@ import (
"testing"
"time"
- "github.com/gatewayd-io/gatewayd/config"
"github.com/gatewayd-io/gatewayd/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -19,6 +18,12 @@ var (
)
func Test_runCmd(t *testing.T) {
+ previous := EnableTestMode()
+ defer func() {
+ testMode = previous
+ testApp = nil
+ }()
+
postgresHostIP, postgresMappedPort := testhelpers.SetupPostgreSQLTestContainer(context.Background(), t)
postgresAddress := postgresHostIP + ":" + postgresMappedPort.Port()
t.Setenv("GATEWAYD_CLIENTS_DEFAULT_WRITES_ADDRESS", postgresAddress)
@@ -41,14 +46,13 @@ func Test_runCmd(t *testing.T) {
// Check that the config file was created.
assert.FileExists(t, globalTestConfigFile, "configInitCmd should create a config file")
- stopChan = make(chan struct{})
-
var waitGroup sync.WaitGroup
waitGroup.Add(1)
go func(waitGroup *sync.WaitGroup) {
// Test run command.
- output, err := executeCommandC(rootCmd, "run", "-c", globalTestConfigFile, "-p", pluginTestConfigFile)
+ output, err := executeCommandC(
+ rootCmd, "run", "-c", globalTestConfigFile, "-p", pluginTestConfigFile)
require.NoError(t, err, "run command should not have returned an error")
// Print the output for debugging purposes.
@@ -60,23 +64,11 @@ func Test_runCmd(t *testing.T) {
waitGroup.Done()
}(&waitGroup)
+ // Stop the server after a delay
waitGroup.Add(1)
go func(waitGroup *sync.WaitGroup) {
time.Sleep(waitBeforeStop)
-
- StopGracefully(
- context.Background(),
- nil,
- nil,
- metricsServer,
- nil,
- loggers[config.Default],
- servers,
- stopChan,
- nil,
- nil,
- )
-
+ testApp.stopGracefully(context.Background(), os.Interrupt)
waitGroup.Done()
}(&waitGroup)
@@ -88,6 +80,12 @@ func Test_runCmd(t *testing.T) {
// Test_runCmdWithTLS tests the run command with TLS enabled on the server.
func Test_runCmdWithTLS(t *testing.T) {
+ previous := EnableTestMode()
+ defer func() {
+ testMode = previous
+ testApp = nil
+ }()
+
postgresHostIP, postgresMappedPort := testhelpers.SetupPostgreSQLTestContainer(context.Background(), t)
postgresAddress := postgresHostIP + ":" + postgresMappedPort.Port()
t.Setenv("GATEWAYD_CLIENTS_DEFAULT_WRITES_ADDRESS", postgresAddress)
@@ -104,8 +102,6 @@ func Test_runCmdWithTLS(t *testing.T) {
require.NoError(t, err, "plugin init command should not have returned an error")
assert.FileExists(t, pluginTestConfigFile, "plugin init command should have created a config file")
- stopChan = make(chan struct{})
-
var waitGroup sync.WaitGroup
// TODO: Test client certificate authentication.
@@ -126,23 +122,11 @@ func Test_runCmdWithTLS(t *testing.T) {
waitGroup.Done()
}(&waitGroup)
+ // Stop the server after a delay
waitGroup.Add(1)
go func(waitGroup *sync.WaitGroup) {
time.Sleep(waitBeforeStop)
-
- StopGracefully(
- context.Background(),
- nil,
- nil,
- metricsServer,
- nil,
- loggers[config.Default],
- servers,
- stopChan,
- nil,
- nil,
- )
-
+ testApp.stopGracefully(context.Background(), os.Interrupt)
waitGroup.Done()
}(&waitGroup)
@@ -153,6 +137,12 @@ func Test_runCmdWithTLS(t *testing.T) {
// Test_runCmdWithMultiTenancy tests the run command with multi-tenancy enabled.
func Test_runCmdWithMultiTenancy(t *testing.T) {
+ previous := EnableTestMode()
+ defer func() {
+ testMode = previous
+ testApp = nil
+ }()
+
postgresHostIP, postgresMappedPort := testhelpers.SetupPostgreSQLTestContainer(context.Background(), t)
postgresAddress := postgresHostIP + ":" + postgresMappedPort.Port()
t.Setenv("GATEWAYD_CLIENTS_DEFAULT_WRITES_ADDRESS", postgresAddress)
@@ -172,8 +162,6 @@ func Test_runCmdWithMultiTenancy(t *testing.T) {
require.NoError(t, err, "plugin init command should not have returned an error")
assert.FileExists(t, pluginTestConfigFile, "plugin init command should have created a config file")
- stopChan = make(chan struct{})
-
var waitGroup sync.WaitGroup
waitGroup.Add(1)
@@ -196,23 +184,11 @@ func Test_runCmdWithMultiTenancy(t *testing.T) {
waitGroup.Done()
}(&waitGroup)
+ // Stop the server after a delay
waitGroup.Add(1)
go func(waitGroup *sync.WaitGroup) {
time.Sleep(waitBeforeStop)
-
- StopGracefully(
- context.Background(),
- nil,
- nil,
- metricsServer,
- nil,
- loggers[config.Default],
- servers,
- stopChan,
- nil,
- nil,
- )
-
+ testApp.stopGracefully(context.Background(), os.Interrupt)
waitGroup.Done()
}(&waitGroup)
@@ -222,6 +198,12 @@ func Test_runCmdWithMultiTenancy(t *testing.T) {
}
func Test_runCmdWithCachePlugin(t *testing.T) {
+ previous := EnableTestMode()
+ defer func() {
+ testMode = previous
+ testApp = nil
+ }()
+
postgresHostIP, postgresMappedPort := testhelpers.SetupPostgreSQLTestContainer(context.Background(), t)
postgresAddress := postgresHostIP + ":" + postgresMappedPort.Port()
t.Setenv("GATEWAYD_CLIENTS_DEFAULT_WRITES_ADDRESS", postgresAddress)
@@ -233,9 +215,6 @@ func Test_runCmdWithCachePlugin(t *testing.T) {
globalTestConfigFile := "./test_global_runCmdWithCachePlugin.yaml"
pluginTestConfigFile := "./test_plugins_runCmdWithCachePlugin.yaml"
- // TODO: Remove this once these global variables are removed from cmd/run.go.
- // https://github.com/gatewayd-io/gatewayd/issues/324
- stopChan = make(chan struct{})
// Create a test plugins config file.
_, err := executeCommandC(rootCmd, "plugin", "init", "--force", "-p", pluginTestConfigFile)
@@ -258,15 +237,16 @@ func Test_runCmdWithCachePlugin(t *testing.T) {
rootCmd, "plugin", "install", "-p", pluginTestConfigFile, "--update", "--backup",
"--overwrite-config=true", "--name", "gatewayd-plugin-cache", pluginArchivePath)
require.NoError(t, err, "plugin install should not return an error")
- assert.Equal(t, output, "Installing plugin from CLI argument\nBackup completed successfully\nPlugin binary extracted to plugins/gatewayd-plugin-cache\nPlugin installed successfully\n") //nolint:lll
+ assert.Contains(t, output, "Installing plugin from CLI argument")
+ assert.Contains(t, output, "Backup completed successfully")
+ assert.Contains(t, output, "Plugin binary extracted to plugins/gatewayd-plugin-cache")
+ assert.Contains(t, output, "Plugin installed successfully")
// See if the plugin was actually installed.
output, err = executeCommandC(rootCmd, "plugin", "list", "-p", pluginTestConfigFile)
require.NoError(t, err, "plugin list should not return an error")
assert.Contains(t, output, "Name: gatewayd-plugin-cache")
- stopChan = make(chan struct{})
-
var waitGroup sync.WaitGroup
waitGroup.Add(1)
@@ -284,23 +264,11 @@ func Test_runCmdWithCachePlugin(t *testing.T) {
waitGroup.Done()
}(&waitGroup)
+ // Stop the server after a delay
waitGroup.Add(1)
go func(waitGroup *sync.WaitGroup) {
- time.Sleep(waitBeforeStop * 2)
-
- StopGracefully(
- context.Background(),
- nil,
- nil,
- metricsServer,
- nil,
- loggers[config.Default],
- servers,
- stopChan,
- nil,
- nil,
- )
-
+ time.Sleep(waitBeforeStop)
+ testApp.stopGracefully(context.Background(), os.Interrupt)
waitGroup.Done()
}(&waitGroup)
diff --git a/config/getters_test.go b/config/getters_test.go
index 3f71eb37..a9b5e3c7 100644
--- a/config/getters_test.go
+++ b/config/getters_test.go
@@ -35,9 +35,9 @@ func TestGetDefaultConfigFilePath(t *testing.T) {
// TestFilter tests the Filter function.
func TestFilter(t *testing.T) {
// Load config from the default config file.
- conf := NewConfig(context.TODO(),
+ conf := NewConfig(context.Background(),
Config{GlobalConfigFile: "../gatewayd.yaml", PluginConfigFile: "../gatewayd_plugins.yaml"})
- err := conf.InitConfig(context.TODO())
+ err := conf.InitConfig(context.Background())
require.Nil(t, err)
assert.NotEmpty(t, conf.Global)
@@ -55,9 +55,9 @@ func TestFilter(t *testing.T) {
// TestFilterWithMissingGroupName tests the Filter function with a missing group name.
func TestFilterWithMissingGroupName(t *testing.T) {
// Load config from the default config file.
- conf := NewConfig(context.TODO(),
+ conf := NewConfig(context.Background(),
Config{GlobalConfigFile: "../gatewayd.yaml", PluginConfigFile: "../gatewayd_plugins.yaml"})
- err := conf.InitConfig(context.TODO())
+ err := conf.InitConfig(context.Background())
require.Nil(t, err)
assert.NotEmpty(t, conf.Global)
diff --git a/errors/errors.go b/errors/errors.go
index 61bb0e32..9cce47db 100644
--- a/errors/errors.go
+++ b/errors/errors.go
@@ -212,12 +212,3 @@ var (
// Unwrapped errors.
ErrLoggerRequired = errors.New("terminate action requires a logger parameter")
)
-
-const (
- FailedToCreateClient = 1
- FailedToInitializePool = 2
- FailedToStartServer = 3
- FailedToStartTracer = 4
- FailedToCreateActRegistry = 5
- FailedToStartRaftNode = 6
-)
diff --git a/errors/exit_codes.go b/errors/exit_codes.go
new file mode 100644
index 00000000..14ed89f3
--- /dev/null
+++ b/errors/exit_codes.go
@@ -0,0 +1,11 @@
+package errors
+
+const (
+ FailedToStartServer = 1
+ FailedToStartTracer = 2
+ FailedToCreateActRegistry = 3
+ FailedToLoadPolicies = 4
+ FailedToStartRaftNode = 5
+ FailedToMergeGlobalConfig = 6
+ FailedToCreatePoolAndClients = 7
+)
diff --git a/network/server_test.go b/network/server_test.go
index 0ea89696..fcb65de4 100644
--- a/network/server_test.go
+++ b/network/server_test.go
@@ -109,11 +109,10 @@ func TestRunServer(t *testing.T) {
},
)
assert.NotNil(t, server)
- assert.Zero(t, server.connections)
assert.Zero(t, server.CountConnections())
assert.Empty(t, server.host)
assert.Empty(t, server.port)
- assert.False(t, server.running.Load())
+ assert.False(t, server.IsRunning())
var waitGroup sync.WaitGroup
waitGroup.Add(2)
@@ -215,8 +214,8 @@ func TestRunServer(t *testing.T) {
<-time.After(100 * time.Millisecond)
// check server status and connections
- assert.False(t, server.running.Load())
- assert.Zero(t, server.connections)
+ assert.False(t, server.IsRunning())
+ assert.Zero(t, server.CountConnections())
// Read the log file and check if the log file contains the expected log messages.
require.FileExists(t, "server_test.log")