Skip to content

Commit 4b76d4e

Browse files
authored
feat: listen and shutdown (#4)
* feat: ws handler shutdown * fix
1 parent e39a360 commit 4b76d4e

File tree

5 files changed

+105
-49
lines changed

5 files changed

+105
-49
lines changed

client/wsd.go

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -273,11 +273,3 @@ func (wc *Dialer) DialContext(ctx context.Context, options ...ConnectOption) (ne
273273
func (wc *Dialer) Dial(options ...ConnectOption) (net.Conn, error) {
274274
return wc.DialContext(context.Background(), options...)
275275
}
276-
277-
func (wc *Dialer) DialTCP(options ...ConnectOption) (net.Conn, error) {
278-
return wc.Dial(options...)
279-
}
280-
281-
func (wc *Dialer) DialContextTCP(ctx context.Context, options ...ConnectOption) (net.Conn, error) {
282-
return wc.DialContext(ctx, options...)
283-
}

server/healthy.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77
"log"
88
"net/http"
99
"os"
10-
"sync/atomic"
1110
"time"
1211
)
1312

@@ -17,9 +16,11 @@ var (
1716
jwtToken = os.Getenv("JWT_TOKEN")
1817
)
1918

20-
var ActiveNum int32 = 0
19+
type ActiveNum interface {
20+
ActiveNum() int64
21+
}
2122

22-
func HealthyCheck(ctx context.Context) {
23+
func HealthyCheck(ctx context.Context, handler ActiveNum) {
2324
shutdownDuration, _ := time.ParseDuration(interval)
2425
ticker := time.NewTicker(1 * time.Minute)
2526
defer ticker.Stop()
@@ -28,7 +29,7 @@ func HealthyCheck(ctx context.Context) {
2829
for {
2930
select {
3031
case <-ticker.C:
31-
if atomic.LoadInt32(&ActiveNum) == 0 {
32+
if handler.ActiveNum() == 0 {
3233
zeroDuration += 1 * time.Minute
3334
if zeroDuration >= shutdownDuration {
3435
sendShutdownRequest()

server/main.go

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@ package main
33
import (
44
"context"
55
"log"
6+
"net"
67
"os"
78
"os/signal"
89
"sync"
910
"syscall"
11+
"time"
1012
)
1113

1214
var (
@@ -16,29 +18,39 @@ var (
1618
)
1719

1820
func main() {
19-
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
20-
defer cancel()
21+
if listen == "" || target == "" {
22+
log.Fatalf("LISTEN or TARGET is not set")
23+
}
24+
25+
listener, err := net.Listen("tcp", listen)
26+
if err != nil {
27+
log.Fatalf("Failed to listen: %v", err)
28+
}
29+
30+
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
31+
defer stop()
2132

2233
var wg sync.WaitGroup
2334

35+
handler := NewHandler(target)
36+
2437
if flag == "true" {
2538
wg.Add(1)
2639
go func() {
2740
defer wg.Done()
28-
HealthyCheck(ctx)
41+
HealthyCheck(ctx, handler)
2942
}()
3043
}
3144

32-
if listen == "" || target == "" {
33-
panic("LISTEN or TARGET is not set")
34-
}
45+
server := NewServer(
46+
"/",
47+
handler,
48+
WithListener(listener),
49+
)
3550

3651
go func() {
37-
err := NewServer(
38-
listen,
39-
"/",
40-
NewHandler(target),
41-
).Serve()
52+
log.Println("Started ws server")
53+
err := server.Serve()
4254
if err != nil {
4355
log.Fatalf("Failed to ws serve: %v", err)
4456
}
@@ -47,4 +59,11 @@ func main() {
4759
<-ctx.Done()
4860
log.Println("Shutting down gracefully")
4961
wg.Wait()
62+
63+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
64+
defer cancel()
65+
err = server.Shutdown(ctx)
66+
if err != nil {
67+
log.Fatalf("Failed to shutdown gracefully: %v", err)
68+
}
5069
}

server/ws.go

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
)
1010

1111
type Server struct {
12+
listener net.Listener
1213
listenErr error
1314
shutdowned chan struct{}
1415
onListened chan struct{}
@@ -21,9 +22,20 @@ type Server struct {
2122

2223
type ServerOption func(*Server)
2324

24-
func NewServer(listenAddr, path string, wsHandler *Handler, opts ...ServerOption) *Server {
25+
func WithListener(listener net.Listener) ServerOption {
26+
return func(ps *Server) {
27+
ps.listener = listener
28+
}
29+
}
30+
31+
func WithListenAddr(listenAddr string) ServerOption {
32+
return func(ps *Server) {
33+
ps.listenAddr = listenAddr
34+
}
35+
}
36+
37+
func NewServer(path string, wsHandler *Handler, opts ...ServerOption) *Server {
2538
ps := &Server{
26-
listenAddr: listenAddr,
2739
wsHandler: wsHandler,
2840
path: path,
2941
onListened: make(chan struct{}),
@@ -43,25 +55,13 @@ func (ps *Server) closeOnListened() {
4355
})
4456
}
4557

46-
func (ps *Server) OnListened() <-chan struct{} {
47-
return ps.onListened
48-
}
49-
50-
func (ps *Server) ListenErr() error {
58+
func (ps *Server) WaitListen() error {
59+
<-ps.onListened
5160
return ps.listenErr
5261
}
5362

54-
func (ps *Server) Shutdowned() <-chan struct{} {
55-
return ps.shutdowned
56-
}
57-
58-
func (ps *Server) ShutdownedBool() bool {
59-
select {
60-
case <-ps.shutdowned:
61-
return true
62-
default:
63-
return false
64-
}
63+
func (ps *Server) WaitShutdown() {
64+
<-ps.shutdowned
6565
}
6666

6767
func (ps *Server) Serve() error {
@@ -73,12 +73,19 @@ func (ps *Server) Serve() error {
7373
return ps.listenAndServe(server)
7474
}
7575

76-
func (ps *Server) listenAndServe(server *http.Server) error {
76+
func (ps *Server) getListener() (net.Listener, error) {
77+
if ps.listener != nil {
78+
return ps.listener, nil
79+
}
7780
addr := ps.listenAddr
7881
if addr == "" {
7982
addr = ":http"
8083
}
81-
ln, err := net.Listen("tcp", addr)
84+
return net.Listen("tcp", addr)
85+
}
86+
87+
func (ps *Server) listenAndServe(server *http.Server) error {
88+
ln, err := ps.getListener()
8289
if err != nil {
8390
ps.listenErr = err
8491
return err
@@ -100,17 +107,20 @@ func (ps *Server) Server() *http.Server {
100107
ReadHeaderTimeout: time.Second * 5,
101108
MaxHeaderBytes: 16 * 1024,
102109
}
110+
ps.server.RegisterOnShutdown(func() {
111+
ps.wsHandler.Close()
112+
})
103113
}
104114
return ps.server
105115
}
106116

107117
func (ps *Server) Close() error {
108-
timeoutCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
109-
defer cancel()
110-
return ps.Shutdown(timeoutCtx)
118+
defer ps.closeOnListened()
119+
return ps.server.Close()
111120
}
112121

113122
func (ps *Server) Shutdown(ctx context.Context) error {
114-
ps.closeOnListened()
123+
defer ps.closeOnListened()
124+
defer ps.wsHandler.Wait()
115125
return ps.server.Shutdown(ctx)
116126
}

server/wsh.go

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,12 @@ type GetTargetFunc func(req *http.Request) (string, []string, error)
4242
type Handler struct {
4343
bufferPool *sync.Pool
4444
wsServer *websocket.Server
45+
closeChan chan struct{}
4546
defaultTargetAddr string
47+
connectionsWg sync.WaitGroup
48+
activeNum int64
4649
bufferSize int
50+
closeOnce sync.Once
4751
}
4852

4953
type HandlerOption func(*Handler)
@@ -65,6 +69,7 @@ func checkOrigin(config *websocket.Config, req *http.Request) (err error) {
6569
func NewHandler(targetAddr string, opts ...HandlerOption) *Handler {
6670
h := &Handler{
6771
defaultTargetAddr: targetAddr,
72+
closeChan: make(chan struct{}),
6873
}
6974

7075
for _, opt := range opts {
@@ -95,10 +100,26 @@ func (h *Handler) putBuffer(buffer *[]byte) {
95100
}
96101
}
97102

103+
func (h *Handler) ActiveNum() int64 {
104+
return atomic.LoadInt64(&h.activeNum)
105+
}
106+
107+
func (h *Handler) addActiveNum() {
108+
atomic.AddInt64(&h.activeNum, 1)
109+
}
110+
111+
func (h *Handler) subActiveNum() {
112+
atomic.AddInt64(&h.activeNum, -1)
113+
}
114+
98115
func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
99-
atomic.AddInt32(&ActiveNum, 1)
116+
h.connectionsWg.Add(1)
117+
defer h.connectionsWg.Done()
118+
119+
h.addActiveNum()
120+
defer h.subActiveNum()
121+
100122
h.wsServer.ServeHTTP(w, req)
101-
atomic.AddInt32(&ActiveNum, -1)
102123
}
103124

104125
var pingCodec = websocket.Codec{
@@ -127,6 +148,9 @@ func (h *Handler) handleWebSocket(ws *websocket.Conn) {
127148
}
128149
_ = ws.Close()
129150
return
151+
case <-h.closeChan:
152+
_ = ws.Close()
153+
return
130154
case <-exit:
131155
return
132156
}
@@ -197,3 +221,13 @@ func CopyBufferWithWriteTimeout(dst deadlineWriter, src io.Reader, buf []byte, t
197221
}
198222
return written, err
199223
}
224+
225+
func (h *Handler) Close() {
226+
h.closeOnce.Do(func() {
227+
close(h.closeChan)
228+
})
229+
}
230+
231+
func (h *Handler) Wait() {
232+
h.connectionsWg.Wait()
233+
}

0 commit comments

Comments
 (0)