Skip to content

Commit f5c8c3f

Browse files
authored
Merge pull request #470 from uselagoon/nats-update
feat: avoid deprecated NATS API
2 parents 11395a4 + bb058b4 commit f5c8c3f

File tree

6 files changed

+83
-27
lines changed

6 files changed

+83
-27
lines changed

cmd/ssh-portal/serve.go

+10-6
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ func (cmd *ServeCmd) Run(log *slog.Logger) error {
3636
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM)
3737
defer stop()
3838
// get nats server connection
39-
nconn, err := nats.Connect(cmd.NATSServer,
39+
nc, err := nats.Connect(cmd.NATSServer,
4040
nats.Name("ssh-portal"),
4141
// exit on connection close
4242
nats.ClosedHandler(func(_ *nats.Conn) {
@@ -52,10 +52,6 @@ func (cmd *ServeCmd) Run(log *slog.Logger) error {
5252
if err != nil {
5353
return fmt.Errorf("couldn't connect to NATS server: %v", err)
5454
}
55-
nc, err := nats.NewEncodedConn(nconn, "json")
56-
if err != nil {
57-
return fmt.Errorf("couldn't get encoded conn: %v", err)
58-
}
5955
defer nc.Close()
6056
// start listening on TCP port
6157
l, err := net.Listen("tcp", fmt.Sprintf(":%d", cmd.SSHServerPort))
@@ -83,7 +79,15 @@ func (cmd *ServeCmd) Run(log *slog.Logger) error {
8379
eg.Go(func() error {
8480
// start serving SSH connection requests
8581
return sshserver.Serve(
86-
ctx, log, nc, l, c, hostkeys, cmd.LogAccessEnabled, cmd.Banner)
82+
ctx,
83+
log,
84+
nc,
85+
l,
86+
c,
87+
hostkeys,
88+
cmd.LogAccessEnabled,
89+
cmd.Banner,
90+
)
8791
})
8892
return eg.Wait()
8993
}

internal/sshportalapi/server.go

+7-8
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ func ServeNATS(
5050
wg := sync.WaitGroup{}
5151
wg.Add(1)
5252
// connect to NATS server
53-
nconn, err := nats.Connect(natsURL,
53+
nc, err := nats.Connect(natsURL,
5454
nats.Name("ssh-portal-api"),
5555
// synchronise exiting ServeNATS()
5656
nats.ClosedHandler(func(_ *nats.Conn) {
@@ -67,14 +67,13 @@ func ServeNATS(
6767
if err != nil {
6868
return fmt.Errorf("couldn't connect to NATS server: %v", err)
6969
}
70-
nc, err := nats.NewEncodedConn(nconn, "json")
71-
if err != nil {
72-
return fmt.Errorf("couldn't get encoded conn: %v", err)
73-
}
7470
defer nc.Close()
75-
// set up request/response callback for sshportal
76-
_, err = nc.QueueSubscribe(bus.SubjectSSHAccessQuery, queue,
77-
sshportal(ctx, log, nc, p, l, k))
71+
// configure callback
72+
_, err = nc.QueueSubscribe(
73+
bus.SubjectSSHAccessQuery,
74+
queue,
75+
sshportal(ctx, log, nc, p, l, k),
76+
)
7877
if err != nil {
7978
return fmt.Errorf("couldn't subscribe to queue: %v", err)
8079
}

internal/sshportalapi/sshportal.go

+21-7
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package sshportalapi
22

33
import (
44
"context"
5+
"encoding/json"
56
"errors"
67
"log/slog"
78
"time"
@@ -23,20 +24,30 @@ var (
2324
})
2425
)
2526

27+
var (
28+
falseResponse = []byte(`false`)
29+
trueResponse = []byte(`true`)
30+
)
31+
2632
func sshportal(
2733
ctx context.Context,
2834
log *slog.Logger,
29-
c *nats.EncodedConn,
35+
c *nats.Conn,
3036
p *rbac.Permission,
3137
l LagoonDBService,
3238
k KeycloakService,
33-
) nats.Handler {
34-
return func(_, replySubject string, query *bus.SSHAccessQuery) {
39+
) nats.MsgHandler {
40+
return func(msg *nats.Msg) {
3541
var realmRoles, userGroups []string
3642
// set up tracing and update metrics
3743
ctx, span := otel.Tracer(pkgName).Start(ctx, bus.SubjectSSHAccessQuery)
3844
defer span.End()
3945
requestsCounter.Inc()
46+
var query bus.SSHAccessQuery
47+
if err := json.Unmarshal(msg.Data, &query); err != nil {
48+
log.Warn("couldn't unmarshal query", slog.Any("query", msg.Data))
49+
return
50+
}
4051
log := log.With(slog.Any("query", query))
4152
// sanity check the query
4253
if query.SSHFingerprint == "" || query.NamespaceName == "" {
@@ -48,7 +59,7 @@ func sshportal(
4859
if err != nil {
4960
if errors.Is(err, lagoondb.ErrNoResult) {
5061
log.Warn("unknown namespace name", slog.Any("error", err))
51-
if err = c.Publish(replySubject, false); err != nil {
62+
if err = c.Publish(msg.Reply, falseResponse); err != nil {
5263
log.Error("couldn't publish reply", slog.Any("error", err))
5364
}
5465
return
@@ -65,7 +76,7 @@ func sshportal(
6576
log.Warn("ID mismatch in environment identification",
6677
slog.Any("env", env),
6778
slog.Any("error", err))
68-
if err = c.Publish(replySubject, false); err != nil {
79+
if err = c.Publish(msg.Reply, falseResponse); err != nil {
6980
log.Error("couldn't publish reply", slog.Any("error", err))
7081
}
7182
return
@@ -75,7 +86,7 @@ func sshportal(
7586
if err != nil {
7687
if errors.Is(err, lagoondb.ErrNoResult) {
7788
log.Debug("unknown SSH Fingerprint", slog.Any("error", err))
78-
if err = c.Publish(replySubject, false); err != nil {
89+
if err = c.Publish(msg.Reply, falseResponse); err != nil {
7990
log.Error("couldn't publish reply", slog.Any("error", err))
8091
}
8192
return
@@ -115,10 +126,13 @@ func sshportal(
115126
ok := p.UserCanSSHToEnvironment(
116127
ctx, env, realmRoles, userGroups, groupNameProjectIDsMap)
117128
var logMsg string
129+
var response []byte
118130
if ok {
119131
logMsg = "SSH access authorized"
132+
response = trueResponse
120133
} else {
121134
logMsg = "SSH access not authorized"
135+
response = falseResponse
122136
}
123137
log.Info(logMsg,
124138
slog.Int("environmentID", env.ID),
@@ -127,7 +141,7 @@ func sshportal(
127141
slog.String("projectName", env.ProjectName),
128142
slog.String("userUUID", user.UUID.String()),
129143
)
130-
if err = c.Publish(replySubject, ok); err != nil {
144+
if err = c.Publish(msg.Reply, response); err != nil {
131145
log.Error("couldn't publish reply",
132146
slog.String("userUUID", user.UUID.String()),
133147
slog.Any("error", err))
+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
package sshportalapi
2+
3+
import (
4+
"encoding/json"
5+
"testing"
6+
)
7+
8+
func TestResponseMarshal(t *testing.T) {
9+
var testCases = map[string]struct {
10+
input []byte
11+
expect bool
12+
}{
13+
"true": {input: trueResponse, expect: true},
14+
"false": {input: falseResponse, expect: false},
15+
}
16+
for name, tc := range testCases {
17+
t.Run(name, func(tt *testing.T) {
18+
var value bool
19+
if err := json.Unmarshal(tc.input, &value); err != nil {
20+
tt.Fatalf("error unmarshaling data %v to bool", tc.input)
21+
}
22+
if value != tc.expect {
23+
tt.Fatalf("expected %v, got %v", tc.expect, value)
24+
}
25+
})
26+
}
27+
}

internal/sshserver/authhandler.go

+17-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package sshserver
22

33
import (
4+
"encoding/json"
45
"log/slog"
56
"time"
67

@@ -40,8 +41,11 @@ var (
4041

4142
// pubKeyAuth returns a ssh.PublicKeyHandler which queries the remote
4243
// ssh-portal-api for Lagoon SSH authorization.
43-
func pubKeyAuth(log *slog.Logger, nc *nats.EncodedConn,
44-
c *k8s.Client) ssh.PublicKeyHandler {
44+
func pubKeyAuth(
45+
log *slog.Logger,
46+
nc *nats.Conn,
47+
c *k8s.Client,
48+
) ssh.PublicKeyHandler {
4549
return func(ctx ssh.Context, key ssh.PublicKey) bool {
4650
authAttemptsTotal.Inc()
4751
log := log.With(slog.String("sessionID", ctx.SessionID()))
@@ -60,21 +64,29 @@ func pubKeyAuth(log *slog.Logger, nc *nats.EncodedConn,
6064
}
6165
// construct ssh access query
6266
fingerprint := gossh.FingerprintSHA256(pubKey)
63-
q := bus.SSHAccessQuery{
67+
queryData, err := json.Marshal(bus.SSHAccessQuery{
6468
SSHFingerprint: fingerprint,
6569
NamespaceName: ctx.User(),
6670
ProjectID: pid,
6771
EnvironmentID: eid,
6872
SessionID: ctx.SessionID(),
73+
})
74+
if err != nil {
75+
log.Warn("couldn't marshal NATS request", slog.Any("error", err))
76+
return false
6977
}
7078
// send query
71-
var ok bool
72-
err = nc.Request(bus.SubjectSSHAccessQuery, q, &ok, natsTimeout)
79+
msg, err := nc.Request(bus.SubjectSSHAccessQuery, queryData, natsTimeout)
7380
if err != nil {
7481
log.Warn("couldn't make NATS request", slog.Any("error", err))
7582
return false
7683
}
7784
// handle response
85+
var ok bool
86+
if err := json.Unmarshal(msg.Data, &ok); err != nil {
87+
log.Warn("couldn't unmarshal response", slog.Any("response", msg.Data))
88+
return false
89+
}
7890
if !ok {
7991
log.Debug("SSH access not authorized",
8092
slog.String("fingerprint", fingerprint),

internal/sshserver/serve.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ func disableSHA1Kex(_ ssh.Context) *gossh.ServerConfig {
4040
func Serve(
4141
ctx context.Context,
4242
log *slog.Logger,
43-
nc *nats.EncodedConn,
43+
nc *nats.Conn,
4444
l net.Listener,
4545
c *k8s.Client,
4646
hostKeys [][]byte,

0 commit comments

Comments
 (0)