From 3d8405a64bfcd25d7e4a106d433b353cb494e49b Mon Sep 17 00:00:00 2001 From: Fabrice Aneche Date: Tue, 17 Sep 2024 21:25:04 -0400 Subject: [PATCH] ensure allow is working --- cmd/sshjump/kubernetes.go | 84 +-------------------- cmd/sshjump/port_match.go | 87 +++++++++++++++++++++ cmd/sshjump/port_match_test.go | 133 +++++++++++++++++++++++++++++++++ cmd/sshjump/server.go | 4 +- 4 files changed, 224 insertions(+), 84 deletions(-) create mode 100644 cmd/sshjump/port_match.go create mode 100644 cmd/sshjump/port_match_test.go diff --git a/cmd/sshjump/kubernetes.go b/cmd/sshjump/kubernetes.go index 6801505..2c18ecc 100644 --- a/cmd/sshjump/kubernetes.go +++ b/cmd/sshjump/kubernetes.go @@ -7,17 +7,6 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) -type Port struct { - namespace string - pod string - container string - service string - port int32 - addr string // the real addr to connect to -} - -type Ports []Port - // KubernetesPortsForUser return a list of Kubernetes services/containers the provided user is allowed to reach. func (srv *Server) KubernetesPortsForUser(ctx context.Context, user string) (Ports, error) { var kports []Port @@ -65,80 +54,11 @@ func (srv *Server) KubernetesPortsForUser(ctx context.Context, user string) (Por } } - return srv.allowed(kports, user), nil -} - -func (ps Ports) MatchingService(name, namespace string, port int32) (string, bool) { - for _, ps := range ps { - if ps.namespace == namespace && ps.service == name && ps.port == port { - return fmt.Sprintf("%s:%d", ps.addr, ps.port), true - } - } - - return "", false -} - -func (ps Ports) MatchingPod(name, namespace string, port int32) (string, bool) { - for _, ps := range ps { - if ps.namespace == namespace && ps.pod == name && ps.port == port { - return fmt.Sprintf("%s:%d", ps.addr, ps.port), true - } - } - - return "", false -} - -// allowed filter list of ports using user permissions. -func (srv *Server) allowed(ports Ports, user string) Ports { userPerms, exists := srv.permissions[user] if !exists { // If the user doesn't exist in the permissions map, return an empty slice - return []Port{} - } - - // full access - if userPerms.AllowAll { - return ports - } - - var allowed []Port - - for _, port := range ports { - for _, userNs := range userPerms.Namespaces { - if userNs.Namespace == port.namespace { - // Namespace matches - if len(userNs.Pods) == 0 { - // full access to the namespace - allowed = append(allowed, port) - - continue - } - - // check for pods & services - // check against user perms - for _, uPod := range userNs.Pods { - for _, up := range uPod.Ports { - if addr, ok := ports.MatchingPod(uPod.Name, userNs.Namespace, up); ok { - port.addr = addr - allowed = append(allowed, port) - - continue - } - } - } - for _, uService := range userNs.Services { - for _, up := range uService.Ports { - if addr, ok := ports.MatchingService(uService.Name, userNs.Namespace, up); ok { - port.addr = addr - allowed = append(allowed, port) - - continue - } - } - } - } - } + return []Port{}, nil } - return allowed + return Allowed(kports, userPerms), nil } diff --git a/cmd/sshjump/port_match.go b/cmd/sshjump/port_match.go new file mode 100644 index 0000000..7c90626 --- /dev/null +++ b/cmd/sshjump/port_match.go @@ -0,0 +1,87 @@ +package main + +import "fmt" + +type Port struct { + namespace string + pod string + container string + service string + port int32 + addr string // the real addr to connect to +} + +type Ports []Port + +func (ps Ports) MatchingService(name, namespace string, port int32) (string, bool) { + for _, ps := range ps { + if ps.namespace == namespace && ps.service == name && ps.port == port { + return fmt.Sprintf("%s:%d", ps.addr, ps.port), true + } + } + + return "", false +} + +func (ps Ports) MatchingPod(name, namespace string, port int32) (string, bool) { + for _, ps := range ps { + if ps.namespace == namespace && ps.pod == name && ps.port == port { + return fmt.Sprintf("%s:%d", ps.addr, ps.port), true + } + } + + return "", false +} + +// Allowed filter list of ports using user permissions. +func Allowed(ports Ports, userPerms Permission) Ports { + var allowed []Port + + for _, port := range ports { + fullAccess := userPerms.AllowAll + if fullAccess { + for _, p := range ports { + p.addr = fmt.Sprintf("%s:%d", p.addr, p.port) + allowed = append(allowed, p) + } + + return allowed + } + + for _, userNs := range userPerms.Namespaces { + if userNs.Namespace == port.namespace { + // check if user got the full access to the namespace no restriction + if len(userNs.Pods) == 0 { + port.addr = fmt.Sprintf("%s:%d", port.addr, port.port) + allowed = append(allowed, port) + + continue + } + + // check for pods & services + for _, uPod := range userNs.Pods { + for _, up := range uPod.Ports { + if addr, ok := ports.MatchingPod(uPod.Name, userNs.Namespace, up); ok { + port.addr = addr + allowed = append(allowed, port) + + continue + } + } + } + for _, uService := range userNs.Services { + for _, up := range uService.Ports { + if addr, ok := ports.MatchingService(uService.Name, userNs.Namespace, up); ok || fullAccess { + port.addr = addr + allowed = append(allowed, port) + + continue + } + } + } + } + } + } + + return allowed +} diff --git a/cmd/sshjump/port_match_test.go b/cmd/sshjump/port_match_test.go new file mode 100644 index 0000000..69822d6 --- /dev/null +++ b/cmd/sshjump/port_match_test.go @@ -0,0 +1,133 @@ +package main + +import ( + "reflect" + "testing" +) + +func TestAllowed(t *testing.T) { + nginxPerm := Permission{ + Namespaces: []Namespace{{ + Namespace: "default", + Pods: []Pod{{ + Name: "nginx", + Ports: []int32{8080}, + }}, + Services: nil, + }}, + AllowAll: false, + } + + allowAllPerm := Permission{ + AllowAll: true, + } + + namespacePerm := Permission{ + Namespaces: []Namespace{{ + Namespace: "default", + }}, + AllowAll: false, + } + + tests := []struct { + name string + ports Ports + userPerms Permission + want Ports + }{ + { + "simple path", + Ports{Port{ + namespace: "default", + pod: "nginx", + port: 8080, + addr: "10.16.0.10", + }}, + nginxPerm, + []Port{{ + namespace: "default", + pod: "nginx", + port: 8080, + addr: "10.16.0.10:8080", + }}, + }, + + { + "simple path empty not matching namespace", + Ports{Port{ + namespace: "notdefault", + pod: "nginx", + port: 8080, + addr: "10.16.0.10", + }}, + nginxPerm, + nil, + }, + + { + "simple path empty not matching pod", + Ports{Port{ + namespace: "default", + pod: "notnginx", + port: 8080, + addr: "10.16.0.10", + }}, + nginxPerm, + nil, + }, + + { + "simple path empty not matching port", + Ports{Port{ + namespace: "default", + pod: "nginx", + port: 8081, + addr: "10.16.0.10", + }}, + nginxPerm, + nil, + }, + + { + "simple path allow all", + Ports{Port{ + namespace: "default", + pod: "nginx", + port: 8080, + addr: "10.16.0.10", + }}, + allowAllPerm, + []Port{{ + namespace: "default", + pod: "nginx", + port: 8080, + addr: "10.16.0.10:8080", + }}, + }, + + { + "full access to namespace", + Ports{Port{ + namespace: "default", + pod: "nginx", + port: 8080, + addr: "10.16.0.10", + }}, + namespacePerm, + []Port{{ + namespace: "default", + pod: "nginx", + port: 8080, + addr: "10.16.0.10:8080", + }}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := Allowed(tt.ports, tt.userPerms); !reflect.DeepEqual(got, tt.want) { + t.Errorf("Allowed() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/cmd/sshjump/server.go b/cmd/sshjump/server.go index 7f0c1b3..d9960d4 100644 --- a/cmd/sshjump/server.go +++ b/cmd/sshjump/server.go @@ -328,7 +328,7 @@ func (srv *Server) StartWatchConfig(ctx context.Context, path string) error { if !ok { return } - srv.logger.Error("error watching config file: %v", err.Error()) + srv.logger.Error("error watching config file", "error", err.Error()) } } }() @@ -336,7 +336,7 @@ func (srv *Server) StartWatchConfig(ctx context.Context, path string) error { // watch parent directory for atomic updates err = watcher.Add(filepath.Dir(path)) if err != nil { - return fmt.Errorf("Error adding config file to watcher: %w", err) + return fmt.Errorf("error adding config file to watcher: %w", err) } return nil