Skip to content

Commit bc37424

Browse files
committed
Updates
1 parent dbcfd03 commit bc37424

File tree

3 files changed

+58
-117
lines changed

3 files changed

+58
-117
lines changed

pkg/http/sse.go

Lines changed: 1 addition & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,6 @@ func (h sseHandler) handleEventStream(w http.ResponseWriter, r *http.Request) {
167167
return
168168
}
169169

170-
// Disable proxy buffering for stream responses
171-
w.Header().Set("X-Accel-Buffering", "no")
172-
173170
responsesChan := make(chan interface{})
174171

175172
go func() {
@@ -209,72 +206,8 @@ func (h sseHandler) handleEventStream(w http.ResponseWriter, r *http.Request) {
209206
}
210207
}()
211208

212-
responder := DefaultSSEResponder{
209+
responder := sseResponder{
213210
marshaler: h.config.Marshaler,
214211
}
215212
responder.Respond(w, r, responsesChan)
216213
}
217-
218-
type OneModelResponseFunc[T any] func(r *http.Request) (response T, err error)
219-
type EventsResponseFunc[T any] func(r *http.Request, msg *message.Message) (response T, err error)
220-
type ValidFunc func(r *http.Request, msg *message.Message) (ok bool)
221-
222-
type OneModelStreamAdapter[T any] struct {
223-
responseFunc OneModelResponseFunc[T]
224-
}
225-
226-
func NewOneModelStreamAdapter[T any](
227-
responseFunc OneModelResponseFunc[T],
228-
) OneModelStreamAdapter[T] {
229-
return OneModelStreamAdapter[T]{
230-
responseFunc: responseFunc,
231-
}
232-
}
233-
234-
func (a OneModelStreamAdapter[T]) InitialStreamResponse(w http.ResponseWriter, r *http.Request) (response T, ok bool) {
235-
resp, err := a.responseFunc(r)
236-
if err != nil {
237-
w.WriteHeader(http.StatusInternalServerError)
238-
var empty T
239-
return empty, false
240-
}
241-
242-
return resp, true
243-
}
244-
245-
func (a OneModelStreamAdapter[T]) NextStreamResponse(r *http.Request, msg *message.Message) (response T, ok bool) {
246-
resp, err := a.responseFunc(r)
247-
if err != nil {
248-
var empty T
249-
return empty, false
250-
}
251-
252-
return resp, true
253-
}
254-
255-
type EventsStreamAdapter[T any] struct {
256-
responseFunc EventsResponseFunc[T]
257-
}
258-
259-
func NewEventsStreamAdapter[T any](
260-
responseFunc EventsResponseFunc[T],
261-
) EventsStreamAdapter[T] {
262-
return EventsStreamAdapter[T]{
263-
responseFunc: responseFunc,
264-
}
265-
}
266-
267-
func (a EventsStreamAdapter[T]) InitialStreamResponse(w http.ResponseWriter, r *http.Request) (response T, ok bool) {
268-
var empty T
269-
return empty, true
270-
}
271-
272-
func (a EventsStreamAdapter[T]) NextStreamResponse(r *http.Request, msg *message.Message) (response T, ok bool) {
273-
resp, err := a.responseFunc(r, msg)
274-
if err != nil {
275-
var empty T
276-
return empty, false
277-
}
278-
279-
return resp, true
280-
}

pkg/http/sse_marshaler.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
package http
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
)
8+
9+
type ServerSentEvent struct {
10+
Event string
11+
Data []byte
12+
}
13+
14+
type SSEMarshaler interface {
15+
Marshal(ctx context.Context, payload any) (ServerSentEvent, error)
16+
}
17+
18+
type JSONSSEMarshaler struct{}
19+
20+
func (j JSONSSEMarshaler) Marshal(ctx context.Context, payload any) (ServerSentEvent, error) {
21+
data, err := json.Marshal(payload)
22+
if err != nil {
23+
return ServerSentEvent{}, err
24+
}
25+
26+
return ServerSentEvent{
27+
Event: "data",
28+
Data: data,
29+
}, nil
30+
}
31+
32+
type StringSSEMarshaler struct{}
33+
34+
func (b StringSSEMarshaler) Marshal(ctx context.Context, payload any) (ServerSentEvent, error) {
35+
data := fmt.Sprint(payload)
36+
37+
return ServerSentEvent{
38+
Event: "data",
39+
Data: []byte(data),
40+
}, nil
41+
}

pkg/http/responder.go renamed to pkg/http/sse_responder.go

Lines changed: 16 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
package http
22

33
import (
4-
"context"
5-
"encoding/json"
64
"fmt"
75
"net/http"
86
"reflect"
@@ -11,59 +9,23 @@ import (
119
"github.com/go-chi/render"
1210
)
1311

14-
// Respond handles streaming JSON and XML responses, automatically setting the
15-
// Content-Type based on request headers. It will default to a JSON response.
16-
17-
type ServerSentEvent struct {
18-
Event string
19-
Data []byte
20-
}
21-
22-
type SSEMarshaler interface {
23-
Marshal(ctx context.Context, payload any) (ServerSentEvent, error)
24-
}
25-
26-
type JSONSSEMarshaler struct{}
27-
28-
func (j JSONSSEMarshaler) Marshal(ctx context.Context, payload any) (ServerSentEvent, error) {
29-
data, err := json.Marshal(payload)
30-
if err != nil {
31-
return ServerSentEvent{}, err
32-
}
33-
34-
return ServerSentEvent{
35-
Event: "data",
36-
Data: data,
37-
}, nil
38-
}
39-
40-
type BytesSSEMarshaler struct{}
41-
42-
func (b BytesSSEMarshaler) Marshal(ctx context.Context, payload any) (ServerSentEvent, error) {
43-
payloadStr := fmt.Sprint(payload)
44-
45-
data := strings.Join(strings.Split(payloadStr, "\n"), "\ndata: ")
46-
47-
return ServerSentEvent{
48-
Event: "data",
49-
Data: []byte(data),
50-
}, nil
51-
}
52-
53-
type DefaultSSEResponder struct {
12+
type sseResponder struct {
5413
marshaler SSEMarshaler
5514
}
5615

57-
func (d DefaultSSEResponder) Respond(w http.ResponseWriter, r *http.Request, v interface{}) {
16+
// Respond handles streaming JSON and XML responses, automatically setting the
17+
// Content-Type based on request headers.
18+
// Based on go-chi/render.
19+
func (s sseResponder) Respond(w http.ResponseWriter, r *http.Request, v interface{}) {
5820
if v != nil {
5921
switch reflect.TypeOf(v).Kind() {
6022
case reflect.Chan:
6123
switch render.GetAcceptedContentType(r) {
6224
case render.ContentTypeEventStream:
63-
d.channelEventStream(w, r, v)
25+
s.channelEventStream(w, r, v)
6426
return
6527
default:
66-
v = d.channelIntoSlice(w, r, v)
28+
v = s.channelIntoSlice(w, r, v)
6729
}
6830
}
6931
}
@@ -79,14 +41,17 @@ func (d DefaultSSEResponder) Respond(w http.ResponseWriter, r *http.Request, v i
7941
}
8042
}
8143

82-
func (d DefaultSSEResponder) channelEventStream(w http.ResponseWriter, r *http.Request, v interface{}) {
44+
func (s sseResponder) channelEventStream(w http.ResponseWriter, r *http.Request, v interface{}) {
8345
if reflect.TypeOf(v).Kind() != reflect.Chan {
8446
panic(fmt.Sprintf("render: event stream expects a channel, not %v", reflect.TypeOf(v).Kind()))
8547
}
8648

8749
w.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
8850
w.Header().Set("Cache-Control", "no-cache")
8951

52+
// Disable proxy buffering for stream responses
53+
w.Header().Set("X-Accel-Buffering", "no")
54+
9055
if r.ProtoMajor == 1 {
9156
// An endpoint MUST NOT generate an HTTP/2 message containing connection-specific header fields.
9257
// Source: RFC7540
@@ -125,7 +90,7 @@ func (d DefaultSSEResponder) channelEventStream(w http.ResponseWriter, r *http.R
12590
event, ok := v.(ServerSentEvent)
12691
if !ok {
12792
var err error
128-
event, err = d.marshaler.Marshal(ctx, v)
93+
event, err = s.marshaler.Marshal(ctx, v)
12994
if err != nil {
13095
_, _ = w.Write([]byte(fmt.Sprintf("event: error\ndata: {\"error\":\"%v\"}\n\n", err)))
13196
if f, ok := w.(http.Flusher); ok {
@@ -135,7 +100,9 @@ func (d DefaultSSEResponder) channelEventStream(w http.ResponseWriter, r *http.R
135100
}
136101
}
137102

138-
_, _ = w.Write([]byte(fmt.Sprintf("event: %s\ndata: %s\n\n", event.Event, event.Data)))
103+
data := strings.Join(strings.Split(string(event.Data), "\n"), "\ndata: ")
104+
105+
_, _ = w.Write([]byte(fmt.Sprintf("event: %s\ndata: %s\n\n", event.Event, data)))
139106
if f, ok := w.(http.Flusher); ok {
140107
f.Flush()
141108
}
@@ -144,7 +111,7 @@ func (d DefaultSSEResponder) channelEventStream(w http.ResponseWriter, r *http.R
144111
}
145112

146113
// channelIntoSlice buffers channel data into a slice.
147-
func (d DefaultSSEResponder) channelIntoSlice(w http.ResponseWriter, r *http.Request, from interface{}) interface{} {
114+
func (s sseResponder) channelIntoSlice(w http.ResponseWriter, r *http.Request, from interface{}) interface{} {
148115
ctx := r.Context()
149116

150117
var to []interface{}

0 commit comments

Comments
 (0)