Skip to content

Commit 254fc65

Browse files
committed
Fix race issue
1 parent a1f80e0 commit 254fc65

File tree

1 file changed

+47
-11
lines changed

1 file changed

+47
-11
lines changed

xhttp/serve.go

+47-11
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@ package xhttp
33

44
import (
55
"context"
6+
"errors"
67
"log"
78
"net"
89
"net/http"
9-
"time"
10-
1110
"oss.terrastruct.com/util-go/xcontext"
11+
"sync"
12+
13+
"sync/atomic"
14+
"time"
1215
)
1316

1417
func NewServer(log *log.Logger, h http.Handler) *http.Server {
@@ -22,23 +25,56 @@ func NewServer(log *log.Logger, h http.Handler) *http.Server {
2225
}
2326
}
2427

28+
type safeServer struct {
29+
*http.Server
30+
running int32
31+
mu sync.Mutex
32+
}
33+
34+
func newSafeServer(s *http.Server) *safeServer {
35+
return &safeServer{
36+
Server: s,
37+
}
38+
}
39+
func (s *safeServer) ListenAndServe(l net.Listener) error {
40+
s.mu.Lock()
41+
defer s.mu.Unlock()
42+
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
43+
return errors.New("server is already running")
44+
}
45+
defer atomic.StoreInt32(&s.running, 0)
46+
return s.Serve(l)
47+
}
48+
func (s *safeServer) Shutdown(ctx context.Context) error {
49+
s.mu.Lock()
50+
defer s.mu.Unlock()
51+
if atomic.LoadInt32(&s.running) == 0 {
52+
return nil
53+
}
54+
return s.Server.Shutdown(ctx)
55+
}
2556
func Serve(ctx context.Context, shutdownTimeout time.Duration, s *http.Server, l net.Listener) error {
2657
s.BaseContext = func(net.Listener) context.Context {
2758
return ctx
2859
}
29-
30-
done := make(chan error, 1)
60+
ss := newSafeServer(s)
61+
serverClosed := make(chan struct{})
62+
var serverError error
3163
go func() {
32-
done <- s.Serve(l)
64+
serverError = ss.ListenAndServe(l)
65+
close(serverClosed)
3366
}()
34-
3567
select {
36-
case err := <-done:
37-
return err
68+
case <-serverClosed:
69+
return serverError
3870
case <-ctx.Done():
39-
ctx = xcontext.WithoutCancel(ctx)
40-
ctx, cancel := context.WithTimeout(ctx, shutdownTimeout)
71+
shutdownCtx, cancel := context.WithTimeout(xcontext.WithoutCancel(ctx), shutdownTimeout)
4172
defer cancel()
42-
return s.Shutdown(ctx)
73+
err := ss.Shutdown(shutdownCtx)
74+
<-serverClosed // Wait for server to exit
75+
if err != nil {
76+
return err
77+
}
78+
return serverError
4379
}
4480
}

0 commit comments

Comments
 (0)