Skip to content

Commit 00e0ebe

Browse files
committed
fix: Resolve race condition in HTTP server handling
1 parent 55b3812 commit 00e0ebe

File tree

3 files changed

+70
-23
lines changed

3 files changed

+70
-23
lines changed

.github/workflows/ci.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ jobs:
99
runs-on: ubuntu-latest
1010
steps:
1111
- uses: actions/checkout@v3
12-
- uses: actions/setup-go@v3
12+
- uses: actions/setup-go@v4
1313
with:
1414
go-version-file: ./go.mod
1515
cache: true

diff/diff.go

+16-15
Original file line numberDiff line numberDiff line change
@@ -169,31 +169,32 @@ func formatRunes(s string) string {
169169
//
170170
// Here's an example that you can play with to better understand the behaviour:
171171
//
172-
// err = diff.TestdataJSON(filepath.Join("testdata", t.Name()), "change me")
173-
// if err != nil {
174-
// t.Fatal(err)
175-
// }
172+
// err = diff.TestdataJSON(filepath.Join("testdata", t.Name()), "change me")
173+
// if err != nil {
174+
// t.Fatal(err)
175+
// }
176176
//
177177
// Normally you want to use t.Name() as path for clarity but you can pass in any string.
178178
// e.g. a single test could persist two json objects into testdata with:
179179
//
180-
// err = diff.TestdataJSON(filepath.Join("testdata", t.Name(), "1"), "change me 1")
181-
// if err != nil {
182-
// t.Fatal(err)
183-
// }
184-
// err = diff.TestdataJSON(filepath.Join("testdata", t.Name(), "2"), "change me 2")
185-
// if err != nil {
186-
// t.Fatal(err)
187-
// }
180+
// err = diff.TestdataJSON(filepath.Join("testdata", t.Name(), "1"), "change me 1")
181+
// if err != nil {
182+
// t.Fatal(err)
183+
// }
184+
// err = diff.TestdataJSON(filepath.Join("testdata", t.Name(), "2"), "change me 2")
185+
// if err != nil {
186+
// t.Fatal(err)
187+
// }
188188
//
189189
// These would persist in testdata/${t.Name()}/1.exp.json and testdata/${t.Name()}/2.exp.json
190190
//
191191
// It uses Files under the hood.
192192
//
193193
// note: testdata is the canonical Go directory for such persistent test only files.
194-
// It is unfortunately poorly documented. See https://pkg.go.dev/cmd/go/internal/test
195-
// So normally you'd want path to be filepath.Join("testdata", t.Name()).
196-
// This is also the reason this function is named "TestdataJSON".
194+
//
195+
// It is unfortunately poorly documented. See https://pkg.go.dev/cmd/go/internal/test
196+
// So normally you'd want path to be filepath.Join("testdata", t.Name()).
197+
// This is also the reason this function is named "TestdataJSON".
197198
func TestdataJSON(path string, got interface{}) error {
198199
gotb := xjson.Marshal(got)
199200
gotb = append(gotb, '\n')

xhttp/serve.go

+53-7
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@ package xhttp
33

44
import (
55
"context"
6+
"errors"
67
"log"
78
"net"
89
"net/http"
10+
"sync"
11+
"sync/atomic"
912
"time"
1013

1114
"oss.terrastruct.com/util-go/xcontext"
@@ -22,23 +25,66 @@ func NewServer(log *log.Logger, h http.Handler) *http.Server {
2225
}
2326
}
2427

28+
type safeServer struct {
29+
*http.Server
30+
running int32
31+
mu sync.Mutex
32+
}
33+
34+
func newSafeServer(s *http.Server) *safeServer {
35+
return &safeServer{
36+
Server: s,
37+
}
38+
}
39+
40+
func (s *safeServer) ListenAndServe(l net.Listener) error {
41+
s.mu.Lock()
42+
defer s.mu.Unlock()
43+
44+
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
45+
return errors.New("server is already running")
46+
}
47+
defer atomic.StoreInt32(&s.running, 0)
48+
49+
return s.Serve(l)
50+
}
51+
52+
func (s *safeServer) Shutdown(ctx context.Context) error {
53+
s.mu.Lock()
54+
defer s.mu.Unlock()
55+
56+
if atomic.LoadInt32(&s.running) == 0 {
57+
return nil
58+
}
59+
60+
return s.Server.Shutdown(ctx)
61+
}
62+
2563
func Serve(ctx context.Context, shutdownTimeout time.Duration, s *http.Server, l net.Listener) error {
2664
s.BaseContext = func(net.Listener) context.Context {
2765
return ctx
2866
}
2967

30-
done := make(chan error, 1)
68+
ss := newSafeServer(s)
69+
70+
serverClosed := make(chan struct{})
71+
var serverError error
3172
go func() {
32-
done <- s.Serve(l)
73+
serverError = ss.ListenAndServe(l)
74+
close(serverClosed)
3375
}()
3476

3577
select {
36-
case err := <-done:
37-
return err
78+
case <-serverClosed:
79+
return serverError
3880
case <-ctx.Done():
39-
ctx = xcontext.WithoutCancel(ctx)
40-
ctx, cancel := context.WithTimeout(ctx, shutdownTimeout)
81+
shutdownCtx, cancel := context.WithTimeout(xcontext.WithoutCancel(ctx), shutdownTimeout)
4182
defer cancel()
42-
return s.Shutdown(ctx)
83+
err := ss.Shutdown(shutdownCtx)
84+
<-serverClosed // Wait for server to exit
85+
if err != nil {
86+
return err
87+
}
88+
return serverError
4389
}
4490
}

0 commit comments

Comments
 (0)