Skip to content

Commit 620373b

Browse files
committed
fix: Resolve race condition in HTTP server handling
1 parent 55b3812 commit 620373b

File tree

1 file changed

+92
-0
lines changed

1 file changed

+92
-0
lines changed

xhttp/server.go

+92
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
// Package xhttp implements http helpers.
2+
package xhttp
3+
4+
import (
5+
"context"
6+
"errors"
7+
"log"
8+
"net"
9+
"net/http"
10+
"sync"
11+
"sync/atomic"
12+
"time"
13+
14+
"oss.terrastruct.com/util-go/xcontext"
15+
)
16+
17+
func NewServer(log *log.Logger, h http.Handler) *http.Server {
18+
return &http.Server{
19+
MaxHeaderBytes: 1 << 18, // 262,144B
20+
ReadTimeout: time.Minute,
21+
WriteTimeout: time.Minute,
22+
IdleTimeout: time.Hour,
23+
ErrorLog: log,
24+
Handler: http.MaxBytesHandler(h, 1<<20), // 1,048,576B
25+
}
26+
}
27+
28+
type safeServer struct {
29+
*http.Server
30+
running int32
31+
mu sync.Mutex
32+
}
33+
34+
func newSafeServer(s *http.Server) *safeServer {
35+
return &safeServer{
36+
Server: s,
37+
}
38+
}
39+
40+
func (s *safeServer) ListenAndServe(l net.Listener) error {
41+
s.mu.Lock()
42+
defer s.mu.Unlock()
43+
44+
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
45+
return errors.New("server is already running")
46+
}
47+
defer atomic.StoreInt32(&s.running, 0)
48+
49+
return s.Serve(l)
50+
}
51+
52+
func (s *safeServer) Shutdown(ctx context.Context) error {
53+
s.mu.Lock()
54+
defer s.mu.Unlock()
55+
56+
if atomic.LoadInt32(&s.running) == 0 {
57+
return nil
58+
}
59+
60+
return s.Server.Shutdown(ctx)
61+
}
62+
63+
func Serve(ctx context.Context, shutdownTimeout time.Duration, s *http.Server, l net.Listener) error {
64+
s.BaseContext = func(net.Listener) context.Context {
65+
return ctx
66+
}
67+
68+
ss := newSafeServer(s)
69+
70+
serverClosed := make(chan struct{})
71+
var serverError error
72+
go func() {
73+
serverError = ss.ListenAndServe(l)
74+
close(serverClosed)
75+
}()
76+
77+
select {
78+
case <-serverClosed:
79+
return serverError
80+
case <-ctx.Done():
81+
shutdownCtx, cancel := context.WithTimeout(xcontext.WithoutCancel(ctx), shutdownTimeout)
82+
defer cancel()
83+
84+
err := ss.Shutdown(shutdownCtx)
85+
<-serverClosed // Wait for server to exit
86+
if err != nil {
87+
return err
88+
}
89+
return serverError
90+
}
91+
92+
}

0 commit comments

Comments
 (0)