Skip to content

Commit 03bd46c

Browse files
authored
Merge pull request #14 from thejoeejoee/feat-sub-http-response-status
[watermill-http] Custom HTTP response status
2 parents 92154bf + 0103e46 commit 03bd46c

File tree

3 files changed

+77
-6
lines changed

3 files changed

+77
-6
lines changed

pkg/http/context.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package http
2+
3+
import (
4+
"context"
5+
"github.com/ThreeDotsLabs/watermill/message"
6+
)
7+
8+
// ctxResponseStatusCodeKey is a context key for the http status code in the message context
9+
type ctxResponseStatusCodeKey struct{}
10+
11+
// StatusCodeFromContext returns the status code from the context.
12+
func StatusCodeFromContext(ctx context.Context, otherwise int) int {
13+
if v := ctx.Value(ctxResponseStatusCodeKey{}); v != nil {
14+
if code, ok := v.(int); ok {
15+
return code
16+
}
17+
}
18+
return otherwise
19+
}
20+
21+
// WithResponseStatusCode returns a new context with the status code.
22+
func WithResponseStatusCode(ctx context.Context, code int) context.Context {
23+
return context.WithValue(ctx, ctxResponseStatusCodeKey{}, code)
24+
}
25+
26+
// SetResponseStatusCode sets a http status code to the given message.
27+
func SetResponseStatusCode(m *message.Message, code int) *message.Message {
28+
m.SetContext(WithResponseStatusCode(m.Context(), code))
29+
return m
30+
}

pkg/http/pubsub_test.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,44 @@ func TestHttpPubSub(t *testing.T) {
9595
})
9696
}
9797

98+
func TestHttpSubStatusCode(t *testing.T) {
99+
pub, sub := createPubSub(t)
100+
101+
defer func() {
102+
require.NoError(t, pub.Close())
103+
require.NoError(t, sub.Close())
104+
}()
105+
106+
msgs, err := sub.Subscribe(context.Background(), "/test")
107+
require.NoError(t, err)
108+
109+
go func() {
110+
_ = sub.StartHTTPServer()
111+
}()
112+
113+
waitForHTTP(t, sub, time.Second*10)
114+
115+
t.Run("response with custom http status code", func(t *testing.T) {
116+
go func() {
117+
select {
118+
case <-time.After(time.Second * 10):
119+
return
120+
case msg := <-msgs:
121+
http.SetResponseStatusCode(msg, nethttp.StatusForbidden)
122+
msg.Nack()
123+
}
124+
}()
125+
126+
req, err := nethttp.NewRequest(nethttp.MethodPost, fmt.Sprintf("http://%s/test", sub.Addr()), nil)
127+
require.NoError(t, err)
128+
129+
resp, err := nethttp.DefaultClient.Do(req)
130+
require.NoError(t, err)
131+
132+
require.Equal(t, nethttp.StatusForbidden, resp.StatusCode)
133+
})
134+
}
135+
98136
func waitForHTTP(t *testing.T, sub *http.Subscriber, timeoutTime time.Duration) {
99137
timeout := time.After(timeoutTime)
100138
for {

pkg/http/subscriber.go

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -139,14 +139,17 @@ func (s *Subscriber) Subscribe(ctx context.Context, url string) (<-chan *message
139139
s.logger.Trace("Waiting for ACK", logFields)
140140
select {
141141
case <-msg.Acked():
142-
s.logger.Trace("Message acknowledged", logFields.Add(watermill.LogFields{"err": err}))
143-
w.WriteHeader(http.StatusOK)
142+
code := StatusCodeFromContext(msg.Context(), http.StatusOK)
143+
s.logger.Trace("Message acknowledged", logFields.Add(watermill.LogFields{"err": err, "http_status_code": code}))
144+
w.WriteHeader(code)
144145
case <-msg.Nacked():
145-
s.logger.Trace("Message nacked", logFields.Add(watermill.LogFields{"err": err}))
146-
w.WriteHeader(http.StatusInternalServerError)
146+
code := StatusCodeFromContext(msg.Context(), http.StatusInternalServerError)
147+
s.logger.Trace("Message nacked", logFields.Add(watermill.LogFields{"err": err, "http_status_code": code}))
148+
w.WriteHeader(code)
147149
case <-r.Context().Done():
148-
s.logger.Info("Request stopped without ACK received", logFields)
149-
w.WriteHeader(http.StatusInternalServerError)
150+
code := StatusCodeFromContext(msg.Context(), http.StatusInternalServerError)
151+
s.logger.Info("Request stopped without ACK received", logFields.Add(watermill.LogFields{"http_status_code": code}))
152+
w.WriteHeader(code)
150153
}
151154
})
152155

0 commit comments

Comments
 (0)