Skip to content

Commit 9b115ae

Browse files
committed
隐藏conn.Write接口
1 parent 5b6f27d commit 9b115ae

File tree

6 files changed

+49
-9
lines changed

6 files changed

+49
-9
lines changed

autobahn/server/autobahn-server.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ var keyPEMBlock []byte
2727
type echoHandler struct{}
2828

2929
func (e *echoHandler) OnOpen(c *greatws.Conn) {
30+
// err := c.WriteMessage(greatws.Binary, make([]byte, 1<<28))
31+
// if err != nil {
32+
// fmt.Printf("%s\n", err)
33+
// }
3034
// fmt.Printf("OnOpen: %p\n", c)
3135
}
3236

@@ -132,7 +136,7 @@ func (h *handler) echoRunStream2(w http.ResponseWriter, r *http.Request) {
132136
greatws.WithServerIgnorePong(),
133137
greatws.WithServerCallback(&echoHandler{}),
134138
greatws.WithServerEnableUTF8Check(),
135-
greatws.WithServerReadTimeout(5 * time.Second),
139+
// greatws.WithServerReadTimeout(5 * time.Second),
136140
greatws.WithServerMultiEventLoop(h.m),
137141
greatws.WithServerStreamMode(),
138142
greatws.WithServerCallbackInEventLoop(),

conn_core.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,7 @@ func (c *Conn) WriteMessage(op Opcode, writeBuf []byte) (err error) {
581581
var fw fixedwriter.FixedWriter
582582

583583
c.mu.Lock()
584-
err = frame.WriteFrame(&fw, c, writeBuf, true, rsv1, c.client, op, maskValue)
584+
err = frame.WriteFrame(&fw, connToNewConn(c), writeBuf, true, rsv1, c.client, op, maskValue)
585585
c.mu.Unlock()
586586

587587
return err
@@ -619,14 +619,14 @@ func (c *Conn) writeFragment(op Opcode, writeBuf []byte, maxFragment int /*单
619619
var fw fixedwriter.FixedWriter
620620
for len(writeBuf) > 0 {
621621
if len(writeBuf) > maxFragment {
622-
if err := frame.WriteFrame(&fw, c, writeBuf[:maxFragment], false, rsv1, c.client, op, maskValue); err != nil {
622+
if err := frame.WriteFrame(&fw, connToNewConn(c), writeBuf[:maxFragment], false, rsv1, c.client, op, maskValue); err != nil {
623623
return err
624624
}
625625
writeBuf = writeBuf[maxFragment:]
626626
op = Continuation
627627
continue
628628
}
629-
return frame.WriteFrame(&fw, c, writeBuf, true, rsv1, c.client, op, maskValue)
629+
return frame.WriteFrame(&fw, connToNewConn(c), writeBuf, true, rsv1, c.client, op, maskValue)
630630
}
631631
return nil
632632
}

conn_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ func Test_ReadMessage(t *testing.T) {
253253
// err = con.WriteMessage(Binary, []byte("hello"))
254254
maskValue := rand.Uint32()
255255
var fw fixedwriter.FixedWriter
256-
err = frame.WriteFrame(&fw, con, []byte("hello"), true, true, con.client, Binary, maskValue)
256+
err = frame.WriteFrame(&fw, connToNewConn(con), []byte("hello"), true, true, con.client, Binary, maskValue)
257257
if err != nil {
258258
t.Error(err)
259259
}
@@ -314,7 +314,7 @@ func Test_ReadMessage(t *testing.T) {
314314
// err = con.WriteMessage(Binary, []byte("hello"))
315315
maskValue := rand.Uint32()
316316
var fw fixedwriter.FixedWriter
317-
err = frame.WriteFrame(&fw, con, []byte("hello"), true, true, con.client, Ping, maskValue)
317+
err = frame.WriteFrame(&fw, connToNewConn(con), []byte("hello"), true, true, con.client, Ping, maskValue)
318318
if err != nil {
319319
t.Error(err)
320320
}
@@ -461,12 +461,12 @@ func TestFragmentFrame(t *testing.T) {
461461

462462
maskValue := rand.Uint32()
463463
var fw fixedwriter.FixedWriter
464-
err = frame.WriteFrame(&fw, con, []byte("h"), false, false, con.client, Text, maskValue)
464+
err = frame.WriteFrame(&fw, connToNewConn(con), []byte("h"), false, false, con.client, Text, maskValue)
465465
if err != nil {
466466
t.Error(err)
467467
}
468468
maskValue = rand.Uint32()
469-
err = frame.WriteFrame(&fw, con, []byte{}, true, false, con.client, Text, maskValue)
469+
err = frame.WriteFrame(&fw, connToNewConn(con), []byte{}, true, false, con.client, Text, maskValue)
470470
if err != nil {
471471
t.Error(err)
472472
}

conn_unix.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,9 @@ func (c *Conn) getPtr() int {
185185
return int(uintptr(unsafe.Pointer(c)))
186186
}
187187

188-
func (c *Conn) Write(b []byte) (n int, err error) {
188+
// Conn Write入口, 原始定义是func (c *Conn) Write() (n int, err error)
189+
// 会引起误用,所以隐藏起来, 作为一个websocket库,直接暴露tcp的write接口也不合适
190+
func connWrite(c *Conn, b []byte) (n int, err error) {
189191
if c.isClosed() {
190192
return 0, ErrClosed
191193
}

conn_write.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// Copyright 2023-2024 antlabs. All rights reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
//go:build linux || darwin || netbsd || freebsd || openbsd || dragonfly
16+
// +build linux darwin netbsd freebsd openbsd dragonfly
17+
18+
package greatws
19+
20+
import "unsafe"
21+
22+
type newConnWrite Conn
23+
24+
func connToNewConn(c *Conn) *newConnWrite {
25+
return (*newConnWrite)(unsafe.Pointer(c))
26+
}
27+
28+
func (c *newConnWrite) Write(p []byte) (n int, err error) {
29+
c2 := (*Conn)(unsafe.Pointer(c))
30+
return connWrite(c2, p)
31+
}

multi_event_loops.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,9 @@ func (m *MultiEventLoop) getEventLoop(fd int) *EventLoop {
186186
// 添加一个连接到多路事件循环
187187
func (m *MultiEventLoop) add(c *Conn) error {
188188
fd := c.getFd()
189+
if fd == -1 {
190+
return nil
191+
}
189192
index := fd % len(m.loops)
190193
m.safeConns.addConn(c)
191194
// m.loops[index].conns.Store(fd, c)

0 commit comments

Comments
 (0)