Skip to content

Commit d18203f

Browse files
committed
fix: Resolve race condition in HTTP server handling
1 parent 55b3812 commit d18203f

File tree

1 file changed

+55
-7
lines changed

1 file changed

+55
-7
lines changed

xhttp/serve.go

+55-7
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@ package xhttp
33

44
import (
55
"context"
6+
"errors"
67
"log"
78
"net"
89
"net/http"
10+
"sync"
11+
"sync/atomic"
912
"time"
1013

1114
"oss.terrastruct.com/util-go/xcontext"
@@ -22,23 +25,68 @@ func NewServer(log *log.Logger, h http.Handler) *http.Server {
2225
}
2326
}
2427

28+
type safeServer struct {
29+
*http.Server
30+
running int32
31+
mu sync.Mutex
32+
}
33+
34+
func newSafeServer(s *http.Server) *safeServer {
35+
return &safeServer{
36+
Server: s,
37+
}
38+
}
39+
40+
func (s *safeServer) ListenAndServe(l net.Listener) error {
41+
s.mu.Lock()
42+
defer s.mu.Unlock()
43+
44+
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
45+
return errors.New("server is already running")
46+
}
47+
defer atomic.StoreInt32(&s.running, 0)
48+
49+
return s.Serve(l)
50+
}
51+
52+
func (s *safeServer) Shutdown(ctx context.Context) error {
53+
s.mu.Lock()
54+
defer s.mu.Unlock()
55+
56+
if atomic.LoadInt32(&s.running) == 0 {
57+
return nil
58+
}
59+
60+
return s.Server.Shutdown(ctx)
61+
}
62+
2563
func Serve(ctx context.Context, shutdownTimeout time.Duration, s *http.Server, l net.Listener) error {
2664
s.BaseContext = func(net.Listener) context.Context {
2765
return ctx
2866
}
2967

30-
done := make(chan error, 1)
68+
ss := newSafeServer(s)
69+
70+
serverClosed := make(chan struct{})
71+
var serverError error
3172
go func() {
32-
done <- s.Serve(l)
73+
serverError = ss.ListenAndServe(l)
74+
close(serverClosed)
3375
}()
3476

3577
select {
36-
case err := <-done:
37-
return err
78+
case <-serverClosed:
79+
return serverError
3880
case <-ctx.Done():
39-
ctx = xcontext.WithoutCancel(ctx)
40-
ctx, cancel := context.WithTimeout(ctx, shutdownTimeout)
81+
shutdownCtx, cancel := context.WithTimeout(xcontext.WithoutCancel(ctx), shutdownTimeout)
4182
defer cancel()
42-
return s.Shutdown(ctx)
83+
84+
err := ss.Shutdown(shutdownCtx)
85+
<-serverClosed // Wait for server to exit
86+
if err != nil {
87+
return err
88+
}
89+
return serverError
4390
}
91+
4492
}

0 commit comments

Comments
 (0)