From a3731d3e6a435661008ec8df790ed4cca6439a18 Mon Sep 17 00:00:00 2001 From: Scott Leggett Date: Tue, 8 Feb 2022 16:46:16 +0800 Subject: [PATCH] feat: implement service=.. and container=... argument parsing For details see https://docs.lagoon.sh/using-lagoon-advanced/ssh/#podservice-container-definition --- internal/k8s/client.go | 170 +------------------- internal/k8s/exec.go | 155 ++++++++++++++++++ internal/k8s/finddeployment.go | 25 +++ internal/k8s/namespacedetails.go | 43 +++++ internal/k8s/validate.go | 17 ++ internal/sshserver/connectionparams.go | 51 ++++++ internal/sshserver/connectionparams_test.go | 68 ++++++++ internal/sshserver/helper_test.go | 7 + internal/sshserver/sessionhandler.go | 68 +++++++- 9 files changed, 431 insertions(+), 173 deletions(-) create mode 100644 internal/k8s/exec.go create mode 100644 internal/k8s/finddeployment.go create mode 100644 internal/k8s/namespacedetails.go create mode 100644 internal/k8s/validate.go create mode 100644 internal/sshserver/connectionparams.go create mode 100644 internal/sshserver/connectionparams_test.go create mode 100644 internal/sshserver/helper_test.go diff --git a/internal/k8s/client.go b/internal/k8s/client.go index 3cc51e9c..bc5fe91b 100644 --- a/internal/k8s/client.go +++ b/internal/k8s/client.go @@ -1,26 +1,15 @@ package k8s import ( - "context" - "fmt" - "io" - "strconv" "time" - v1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/labels" - "k8s.io/apimachinery/pkg/util/wait" "k8s.io/client-go/kubernetes" - "k8s.io/client-go/kubernetes/scheme" "k8s.io/client-go/rest" - "k8s.io/client-go/tools/remotecommand" ) const ( - timeout = 90 * time.Second - projectIDLabel = "lagoon.sh/projectId" - environmentIDLabel = "lagoon.sh/environmentId" + // timeout defines the common timeout for k8s API operations + timeout = 90 * time.Second ) // Client is a k8s client. @@ -46,158 +35,3 @@ func NewClient() (*Client, error) { clientset: clientset, }, nil } - -func intFromLabel(labels map[string]string, label string) (int, error) { - var value string - var ok bool - if value, ok = labels[label]; !ok { - return 0, fmt.Errorf("no such label") - } - return strconv.Atoi(value) -} - -// NamespaceDetails gets the details for a Lagoon namespace. -// It performs some sanity checks to validate that the namespace is actually a -// Lagoon namespace. -func (c *Client) NamespaceDetails(ctx context.Context, name string) (int, int, error) { - var pid, eid int - ctx, cancel := context.WithTimeout(ctx, timeout) - defer cancel() - ns, err := c.clientset.CoreV1().Namespaces().Get(ctx, name, metav1.GetOptions{}) - if err != nil { - return 0, 0, fmt.Errorf("couldn't get namespace: %v", err) - } - if pid, err = intFromLabel(ns.Labels, projectIDLabel); err != nil { - return 0, 0, fmt.Errorf("couldn't get project ID from label: %v", err) - } - if eid, err = intFromLabel(ns.Labels, environmentIDLabel); err != nil { - return 0, 0, fmt.Errorf("couldn't get environment ID from label: %v", err) - } - return pid, eid, nil -} - -func (c *Client) podName(ctx context.Context, deployment, - namespace string) (string, error) { - d, err := c.clientset.AppsV1().Deployments(namespace).Get(ctx, deployment, - metav1.GetOptions{}) - if err != nil { - return "", err - } - pods, err := c.clientset.CoreV1().Pods(namespace).List(ctx, metav1.ListOptions{ - LabelSelector: labels.FormatLabels(d.Spec.Selector.MatchLabels), - }) - if err != nil { - return "", err - } - if len(pods.Items) == 0 { - return "", fmt.Errorf("no pods for deployment: %s", deployment) - } - return pods.Items[0].Name, nil -} - -func (c *Client) hasRunningPod(ctx context.Context, - deployment, namespace string) wait.ConditionWithContextFunc { - return func(context.Context) (bool, error) { - d, err := c.clientset.AppsV1().Deployments(namespace).Get(ctx, deployment, - metav1.GetOptions{}) - if err != nil { - return false, err - } - pods, err := c.clientset.CoreV1().Pods(namespace).List(ctx, metav1.ListOptions{ - LabelSelector: labels.FormatLabels(d.Spec.Selector.MatchLabels), - }) - if err != nil { - return false, err - } - if len(pods.Items) == 0 { - return false, nil - } - return pods.Items[0].Status.Phase == "Running", nil - } -} - -func (c *Client) ensureScaled(ctx context.Context, deployment, namespace string) error { - // get current scale - s, err := c.clientset.AppsV1().Deployments(namespace). - GetScale(ctx, deployment, metav1.GetOptions{}) - if err != nil { - return fmt.Errorf("couldn't get deployment scale: %v", err) - } - // exit early if no change required - if s.Spec.Replicas > 0 { - return nil - } - // scale up the deployment - sc := *s - sc.Spec.Replicas = 1 - _, err = c.clientset.AppsV1().Deployments(namespace). - UpdateScale(ctx, deployment, &sc, metav1.UpdateOptions{}) - if err != nil { - return fmt.Errorf("couldn't scale deployment: %v", err) - } - // wait for a pod to start running - return wait.PollImmediateWithContext(ctx, time.Second, timeout, - c.hasRunningPod(ctx, deployment, namespace)) -} - -// getExecutor prepares the environment by ensuring pods are scaled etc. and -// returns an executor object. -func (c *Client) getExecutor(ctx context.Context, deployment, namespace string, - command []string, stdio io.ReadWriter, stderr io.Writer, tty bool) (remotecommand.Executor, error) { - // If there's a tty, then animate a spinner if this function takes too long - // to return. - // Defer context cancel() after wg.Wait() because we need the context to - // cancel first in order to shortcut spinAfter() and avoid a spinner if shell - // acquisition is fast enough. - ctx, cancel := context.WithTimeout(ctx, timeout) - if tty { - wg := spinAfter(ctx, stderr, 2*time.Second) - defer wg.Wait() - } - defer cancel() - // ensure the deployment has at least one replica - if err := c.ensureScaled(ctx, deployment, namespace); err != nil { - return nil, fmt.Errorf("couldn't scale deployment: %v", err) - } - // get the name of the first pod in the deployment - podName, err := c.podName(ctx, deployment, namespace) - if err != nil { - return nil, fmt.Errorf("couldn't get pod name: %v", err) - } - // check the command. if there isn't one, give the user a shell. - if len(command) == 0 { - command = []string{"sh"} - } - // construct the request - req := c.clientset.CoreV1().RESTClient().Post().Namespace(namespace). - Resource("pods").Name(podName).SubResource("exec") - req.VersionedParams( - &v1.PodExecOptions{ - Command: command, - Stdin: true, - Stdout: true, - Stderr: true, - TTY: tty, - }, - scheme.ParameterCodec, - ) - // construct the executor - return remotecommand.NewSPDYExecutor(c.config, "POST", req.URL()) -} - -// Exec joins the given streams to the command or, if command is empty, to a -// shell running in the given pod. -func (c *Client) Exec(ctx context.Context, deployment, namespace string, - command []string, stdio io.ReadWriter, stderr io.Writer, tty bool) error { - exec, err := c.getExecutor(ctx, deployment, namespace, command, stdio, - stderr, tty) - if err != nil { - return fmt.Errorf("couldn't get executor: %v", err) - } - // execute the command - return exec.Stream(remotecommand.StreamOptions{ - Stdin: stdio, - Stdout: stdio, - Stderr: stderr, - }) -} diff --git a/internal/k8s/exec.go b/internal/k8s/exec.go new file mode 100644 index 00000000..8021ca5d --- /dev/null +++ b/internal/k8s/exec.go @@ -0,0 +1,155 @@ +package k8s + +import ( + "context" + "fmt" + "io" + "time" + + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/labels" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/client-go/kubernetes/scheme" + "k8s.io/client-go/tools/remotecommand" +) + +// podContainer returns the first pod and first container inside that pod for +// the given namespace and deployment. +func (c *Client) podContainer(ctx context.Context, namespace, + deployment string) (string, string, error) { + d, err := c.clientset.AppsV1().Deployments(namespace).Get(ctx, deployment, + metav1.GetOptions{}) + if err != nil { + return "", "", err + } + pods, err := c.clientset.CoreV1().Pods(namespace).List(ctx, metav1.ListOptions{ + LabelSelector: labels.FormatLabels(d.Spec.Selector.MatchLabels), + }) + if err != nil { + return "", "", err + } + if len(pods.Items) == 0 { + return "", "", fmt.Errorf("no pods for deployment %s", deployment) + } + if len(pods.Items[0].Spec.Containers) == 0 { + return "", "", fmt.Errorf("no containers for pod %s in deployment %s", + pods.Items[0].Name, deployment) + } + return pods.Items[0].Name, pods.Items[0].Spec.Containers[0].Name, nil +} + +func (c *Client) hasRunningPod(ctx context.Context, + namespace, deployment string) wait.ConditionWithContextFunc { + return func(context.Context) (bool, error) { + d, err := c.clientset.AppsV1().Deployments(namespace).Get(ctx, deployment, + metav1.GetOptions{}) + if err != nil { + return false, err + } + pods, err := c.clientset.CoreV1().Pods(namespace).List(ctx, metav1.ListOptions{ + LabelSelector: labels.FormatLabels(d.Spec.Selector.MatchLabels), + }) + if err != nil { + return false, err + } + if len(pods.Items) == 0 { + return false, nil + } + return pods.Items[0].Status.Phase == "Running", nil + } +} + +func (c *Client) ensureScaled(ctx context.Context, namespace, deployment string) error { + // get current scale + s, err := c.clientset.AppsV1().Deployments(namespace). + GetScale(ctx, deployment, metav1.GetOptions{}) + if err != nil { + return fmt.Errorf("couldn't get deployment scale: %v", err) + } + // exit early if no change required + if s.Spec.Replicas > 0 { + return nil + } + // scale up the deployment + sc := *s + sc.Spec.Replicas = 1 + _, err = c.clientset.AppsV1().Deployments(namespace). + UpdateScale(ctx, deployment, &sc, metav1.UpdateOptions{}) + if err != nil { + return fmt.Errorf("couldn't scale deployment: %v", err) + } + // wait for a pod to start running + return wait.PollImmediateWithContext(ctx, time.Second, timeout, + c.hasRunningPod(ctx, namespace, deployment)) +} + +// getExecutor prepares the environment by ensuring pods are scaled etc. and +// returns an executor object. +func (c *Client) getExecutor(ctx context.Context, namespace, deployment, + container string, command []string, stderr io.Writer, + tty bool) (remotecommand.Executor, error) { + // If there's a tty, then animate a spinner if this function takes too long + // to return. + // Defer context cancel() after wg.Wait() because we need the context to + // cancel first in order to shortcut spinAfter() and avoid a spinner if shell + // acquisition is fast enough. + ctx, cancel := context.WithTimeout(ctx, timeout) + if tty { + wg := spinAfter(ctx, stderr, 2*time.Second) + defer wg.Wait() + } + defer cancel() + // ensure the deployment has at least one replica + if err := c.ensureScaled(ctx, namespace, deployment); err != nil { + return nil, fmt.Errorf("couldn't scale deployment: %v", err) + } + // get the name of the first pod and first container + firstPod, firstContainer, err := c.podContainer(ctx, namespace, deployment) + if err != nil { + return nil, fmt.Errorf("couldn't get pod name: %v", err) + } + // check if we were given a container. If not, use the first container found. + if container == "" { + container = firstContainer + } + // check the command. if there isn't one, give the user a shell. + if len(command) == 0 { + command = []string{"sh"} + } + // construct the request + req := c.clientset.CoreV1().RESTClient().Post().Namespace(namespace). + Resource("pods").Name(firstPod).SubResource("exec") + req.VersionedParams( + &v1.PodExecOptions{ + Stdin: true, + Stdout: true, + Stderr: true, + TTY: tty, + Container: container, + Command: command, + }, + scheme.ParameterCodec, + ) + // construct the executor + return remotecommand.NewSPDYExecutor(c.config, "POST", req.URL()) +} + +// Exec takes a target namespace, deployment, command, and IO streams, and +// joins the streams to the command, or if command is empty to an interactive +// shell, running in a pod inside the deployment. +func (c *Client) Exec(ctx context.Context, namespace, deployment, + container string, command []string, stdio io.ReadWriter, stderr io.Writer, + tty bool) error { + exec, err := c.getExecutor(ctx, namespace, deployment, container, command, + stderr, tty) + if err != nil { + return fmt.Errorf("couldn't get executor: %v", err) + } + // execute the command + return exec.Stream(remotecommand.StreamOptions{ + Stdin: stdio, + Stdout: stdio, + Stderr: stderr, + }) +} diff --git a/internal/k8s/finddeployment.go b/internal/k8s/finddeployment.go new file mode 100644 index 00000000..7eb3a966 --- /dev/null +++ b/internal/k8s/finddeployment.go @@ -0,0 +1,25 @@ +package k8s + +import ( + "context" + "fmt" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +// FindDeployment searches the given namespace for a deployment with a matching +// lagoon.sh/service= label, and returns the name of that deployment. +func (c *Client) FindDeployment(ctx context.Context, namespace, + service string) (string, error) { + deployments, err := c.clientset.AppsV1().Deployments(namespace). + List(ctx, metav1.ListOptions{ + LabelSelector: fmt.Sprintf("lagoon.sh/service=%s", service), + }) + if err != nil { + return "", fmt.Errorf("couldn't list deployments: %v", err) + } + if len(deployments.Items) == 0 { + return "", fmt.Errorf("couldn't find deployment for service %s", service) + } + return deployments.Items[0].Name, nil +} diff --git a/internal/k8s/namespacedetails.go b/internal/k8s/namespacedetails.go new file mode 100644 index 00000000..c82fc83c --- /dev/null +++ b/internal/k8s/namespacedetails.go @@ -0,0 +1,43 @@ +package k8s + +import ( + "context" + "fmt" + "strconv" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +const ( + projectIDLabel = "lagoon.sh/projectId" + environmentIDLabel = "lagoon.sh/environmentId" +) + +func intFromLabel(labels map[string]string, label string) (int, error) { + var value string + var ok bool + if value, ok = labels[label]; !ok { + return 0, fmt.Errorf("no such label") + } + return strconv.Atoi(value) +} + +// NamespaceDetails gets the details for a Lagoon namespace. +// It performs some sanity checks to validate that the namespace is actually a +// Lagoon namespace. +func (c *Client) NamespaceDetails(ctx context.Context, name string) (int, int, error) { + var pid, eid int + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + ns, err := c.clientset.CoreV1().Namespaces().Get(ctx, name, metav1.GetOptions{}) + if err != nil { + return 0, 0, fmt.Errorf("couldn't get namespace: %v", err) + } + if pid, err = intFromLabel(ns.Labels, projectIDLabel); err != nil { + return 0, 0, fmt.Errorf("couldn't get project ID from label: %v", err) + } + if eid, err = intFromLabel(ns.Labels, environmentIDLabel); err != nil { + return 0, 0, fmt.Errorf("couldn't get environment ID from label: %v", err) + } + return pid, eid, nil +} diff --git a/internal/k8s/validate.go b/internal/k8s/validate.go new file mode 100644 index 00000000..08a938ba --- /dev/null +++ b/internal/k8s/validate.go @@ -0,0 +1,17 @@ +package k8s + +import ( + "fmt" + + "k8s.io/apimachinery/pkg/util/validation" +) + +// ValidateLabelValue validates the given string to determine if it is a valid +// kubernetes label value. +func ValidateLabelValue(s string) error { + errs := validation.IsValidLabelValue(s) + if len(errs) > 0 { + return fmt.Errorf("invalid label value: %v", errs) + } + return nil +} diff --git a/internal/sshserver/connectionparams.go b/internal/sshserver/connectionparams.go new file mode 100644 index 00000000..31ab47ae --- /dev/null +++ b/internal/sshserver/connectionparams.go @@ -0,0 +1,51 @@ +package sshserver + +import "regexp" + +var ( + serviceRegex = regexp.MustCompile(`service=(.+)`) + containerRegex = regexp.MustCompile(`container=(.+)`) +) + +// parseConnectionParams takes the raw SSH command, and parses out any +// leading service=... and container=... arguments. It returns: +// * If a service=... argument is given, the value of that argument. If no such +// argument is given, it falls back to a default of "cli". +// * If a container=... argument is given, the value of that argument. If no +// such argument is given, it returns an empty string. +// * The remaining arguments with any leading service= or container= arguments +// removed. +// +// Notes about the logic implemented here: +// * container=... may not be specified without service=... +// * service=... must be given as the first argument to be recognised. +// * If not given in the expected order or with empty values, these arguments +// will be interpreted as regular command-line arguments. +// +// In manpage syntax: +// +// [service=... [container=...]] CMD... +// +func parseConnectionParams(args []string) (string, string, []string) { + // exit early if we have no args + if len(args) == 0 { + return "cli", "", args + } + // check for service argument + serviceMatches := serviceRegex.FindStringSubmatch(args[0]) + if len(serviceMatches) == 0 { + return "cli", "", args + } + service := serviceMatches[1] + // exit early if we are out of arguments + if len(args) < 2 { + return service, "", args[1:] + } + // check for container argument + containerMatches := containerRegex.FindStringSubmatch(args[1]) + if len(containerMatches) == 0 { + return service, "", args[1:] + } + container := containerMatches[1] + return service, container, args[2:] +} diff --git a/internal/sshserver/connectionparams_test.go b/internal/sshserver/connectionparams_test.go new file mode 100644 index 00000000..a0832d9c --- /dev/null +++ b/internal/sshserver/connectionparams_test.go @@ -0,0 +1,68 @@ +package sshserver_test + +import ( + "reflect" + "testing" + + "github.com/uselagoon/ssh-portal/internal/sshserver" +) + +type parsedParams struct { + service string + container string + args []string +} + +func TestParseConnectionParams(t *testing.T) { + var testCases = map[string]struct { + input []string + expect parsedParams + }{ + "no special args": { + input: []string{"drush", "do", "something"}, + expect: parsedParams{ + service: "cli", + container: "", + args: []string{"drush", "do", "something"}, + }, + }, + "service arg": { + input: []string{"service=mongo", "drush", "do", "something"}, + expect: parsedParams{ + service: "mongo", + container: "", + args: []string{"drush", "do", "something"}, + }, + }, + "service and container args": { + input: []string{"service=nginx", "container=php", "drush", "do", "something"}, + expect: parsedParams{ + service: "nginx", + container: "php", + args: []string{"drush", "do", "something"}, + }, + }, + "invalid order": { + input: []string{"container=php", "service=nginx", "drush", "do", "something"}, + expect: parsedParams{ + service: "cli", + container: "", + args: []string{"container=php", "service=nginx", "drush", "do", "something"}, + }, + }, + } + for name, tc := range testCases { + t.Run(name, func(tt *testing.T) { + service, container, args := sshserver.ParseConnectionParams(tc.input) + if tc.expect.service != service { + tt.Fatalf("service: expected %v, got %v", tc.expect.service, service) + } + if tc.expect.container != container { + tt.Fatalf("container: expected %v, got %v", tc.expect.container, container) + } + if !reflect.DeepEqual(tc.expect.args, args) { + tt.Fatalf("args: expected %v, got %v", tc.expect.args, args) + } + }) + } +} diff --git a/internal/sshserver/helper_test.go b/internal/sshserver/helper_test.go new file mode 100644 index 00000000..471b2a4f --- /dev/null +++ b/internal/sshserver/helper_test.go @@ -0,0 +1,7 @@ +package sshserver + +// ParseConnectionParams exposes the private parseConnectionParams for testing +// only. +func ParseConnectionParams(args []string) (string, string, []string) { + return parseConnectionParams(args) +} diff --git a/internal/sshserver/sessionhandler.go b/internal/sshserver/sessionhandler.go index 1115ca4b..fc05afb8 100644 --- a/internal/sshserver/sessionhandler.go +++ b/internal/sshserver/sessionhandler.go @@ -25,18 +25,76 @@ func sessionHandler(log *zap.Logger, c *k8s.Client) ssh.Handler { log.Warn("couldn't get session ID") return } - // check if a pty is required + // check if a pty was requested _, _, pty := s.Pty() // start the command log.Debug("starting command exec", - zap.String("session-id", sid)) - // TODO: handle the custom command parameters such as service=... - err := c.Exec(s.Context(), "cli", s.User(), s.Command(), s, s.Stderr(), pty) + zap.String("session-id", sid), + zap.Strings("raw command", s.Command()), + ) + // parse the command line arguments to extract any service or container args + service, container, cmd := parseConnectionParams(s.Command()) + // validate the service and container + if err := k8s.ValidateLabelValue(service); err != nil { + log.Debug("invalid service name", + zap.String("service", service), + zap.String("session-id", sid), + zap.Error(err)) + _, err = fmt.Fprintf(s.Stderr(), "invalid service name %s. SID: %s\r\n", + service, sid) + if err != nil { + log.Debug("couldn't write to session stream", + zap.String("session-id", sid), + zap.Error(err)) + } + return + } + if err := k8s.ValidateLabelValue(container); err != nil { + log.Debug("invalid container name", + zap.String("container", container), + zap.String("session-id", sid), + zap.Error(err)) + _, err = fmt.Fprintf(s.Stderr(), "invalid container name %s. SID: %s\r\n", + container, sid) + if err != nil { + log.Debug("couldn't write to session stream", + zap.String("session-id", sid), + zap.Error(err)) + } + return + } + // find the deployment name based on the given service name + deployment, err := c.FindDeployment(s.Context(), s.User(), service) + if err != nil { + log.Debug("couldn't find deployment for service", + zap.String("service", service), + zap.String("session-id", sid), + zap.Error(err)) + _, err = fmt.Fprintf(s.Stderr(), "unknown service %s. SID: %s\r\n", + service, sid) + if err != nil { + log.Debug("couldn't write to session stream", + zap.String("session-id", sid), + zap.Error(err)) + } + return + } + log.Info("executing command", + zap.String("namespace", s.User()), + zap.String("deployment", deployment), + zap.String("container", container), + zap.Strings("command", cmd), + zap.Bool("pty", pty), + zap.String("session-id", sid), + ) + err = c.Exec(s.Context(), s.User(), deployment, container, cmd, s, + s.Stderr(), pty) if err != nil { log.Warn("couldn't execute command", zap.String("session-id", sid), zap.Error(err)) - _, err = fmt.Fprintf(s, "couldn't execute command. SID: %s\n", sid) + _, err = fmt.Fprintf(s.Stderr(), "error executing command. SID: %s\r\n", + sid) if err != nil { log.Warn("couldn't send error to client", zap.String("session-id", sid),