Skip to content

Commit 523d05e

Browse files
committed
Add support for wildcard-sni
1 parent 2d5886b commit 523d05e

File tree

3 files changed

+62
-24
lines changed

3 files changed

+62
-24
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module github.com/sagernet/sing-shadowtls
33
go 1.20
44

55
require (
6-
github.com/sagernet/sing v0.6.0
6+
github.com/sagernet/sing v0.6.3
77
golang.org/x/crypto v0.32.0
88
golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56
99
)

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
22
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
3-
github.com/sagernet/sing v0.6.0 h1:jT55zAXrG7H3x+s/FlrC15xQy3LcmuZ2GGA9+8IJdt0=
4-
github.com/sagernet/sing v0.6.0/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
3+
github.com/sagernet/sing v0.6.3 h1:J1spMc6LMlqUvRjWjvNMAcbvACDneqxB9zxfLuS0UTE=
4+
github.com/sagernet/sing v0.6.3/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
55
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
66
golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc=
77
golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc=

service.go

Lines changed: 59 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,27 @@ type Service struct {
2727
handshake HandshakeConfig
2828
handshakeForServerName map[string]HandshakeConfig
2929
strictMode bool
30+
wildcardSNI WildcardSNI
3031
handler N.TCPConnectionHandlerEx
3132
logger logger.ContextLogger
3233
}
3334

35+
type WildcardSNI int
36+
37+
const (
38+
WildcardSNIOff WildcardSNI = iota
39+
WildcardSNIAuthed
40+
WildcardSNIAll
41+
)
42+
3443
type ServiceConfig struct {
3544
Version int
3645
Password string // for protocol version 2
3746
Users []User // for protocol version 3
3847
Handshake HandshakeConfig
3948
HandshakeForServerName map[string]HandshakeConfig // for protocol version 2/3
4049
StrictMode bool // for protocol version 3
50+
WildcardSNI WildcardSNI // for protocol version 3
4151
Handler N.TCPConnectionHandlerEx
4252
Logger logger.ContextLogger
4353
}
@@ -60,6 +70,7 @@ func NewService(config ServiceConfig) (*Service, error) {
6070
handshake: config.Handshake,
6171
handshakeForServerName: config.HandshakeForServerName,
6272
strictMode: config.StrictMode,
73+
wildcardSNI: config.WildcardSNI,
6374
handler: config.Handler,
6475
logger: config.Logger,
6576
}
@@ -84,16 +95,6 @@ func NewService(config ServiceConfig) (*Service, error) {
8495
return service, nil
8596
}
8697

87-
func (s *Service) selectHandshake(clientHelloFrame *buf.Buffer) HandshakeConfig {
88-
serverName, err := extractServerName(clientHelloFrame.Bytes())
89-
if err == nil {
90-
if customHandshake, found := s.handshakeForServerName[serverName]; found {
91-
return customHandshake
92-
}
93-
}
94-
return s.handshake
95-
}
96-
9798
func (s *Service) NewConnection(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) error {
9899
switch s.version {
99100
default:
@@ -127,8 +128,17 @@ func (s *Service) NewConnection(ctx context.Context, conn net.Conn, source M.Soc
127128
if err != nil {
128129
return E.Cause(err, "read client handshake")
129130
}
130-
131-
handshakeConfig := s.selectHandshake(clientHelloFrame)
131+
serverName, err := extractServerName(clientHelloFrame.Bytes())
132+
var handshakeConfig HandshakeConfig
133+
if err == nil {
134+
if customHandshake, found := s.handshakeForServerName[serverName]; found {
135+
handshakeConfig = customHandshake
136+
} else {
137+
handshakeConfig = s.handshake
138+
}
139+
} else {
140+
handshakeConfig = s.handshake
141+
}
132142
handshakeConn, err := handshakeConfig.Dialer.DialContext(ctx, N.NetworkTCP, handshakeConfig.Server)
133143
if err != nil {
134144
return E.Cause(err, "server handshake")
@@ -154,27 +164,55 @@ func (s *Service) NewConnection(ctx context.Context, conn net.Conn, source M.Soc
154164
if err != nil {
155165
return E.Cause(err, "read client handshake")
156166
}
157-
158-
handshakeConfig := s.selectHandshake(clientHelloFrame)
159-
handshakeConn, err := handshakeConfig.Dialer.DialContext(ctx, N.NetworkTCP, handshakeConfig.Server)
167+
serverName, err := extractServerName(clientHelloFrame.Bytes())
160168
if err != nil {
161-
return E.Cause(err, "server handshake")
169+
return E.Cause(err, "extract server name")
162170
}
163-
164-
_, err = handshakeConn.Write(clientHelloFrame.Bytes())
165-
if err != nil {
166-
clientHelloFrame.Release()
167-
return E.Cause(err, "write client handshake")
171+
var (
172+
handshakeConfig HandshakeConfig
173+
isCustom bool
174+
)
175+
if customHandshake, found := s.handshakeForServerName[serverName]; found {
176+
handshakeConfig = customHandshake
177+
isCustom = true
178+
} else {
179+
handshakeConfig = s.handshake
180+
if s.wildcardSNI != WildcardSNIOff {
181+
handshakeConfig.Server = M.Socksaddr{
182+
Fqdn: serverName,
183+
Port: s.handshake.Server.Port,
184+
}
185+
}
168186
}
187+
var handshakeConn net.Conn
169188
user, err := verifyClientHello(clientHelloFrame.Bytes(), s.users)
170189
if err != nil {
171190
s.logger.WarnContext(ctx, E.Cause(err, "client hello verify failed"))
191+
if s.wildcardSNI == WildcardSNIAll || isCustom {
192+
handshakeConn, err = handshakeConfig.Dialer.DialContext(ctx, N.NetworkTCP, handshakeConfig.Server)
193+
} else {
194+
handshakeConn, err = s.handshake.Dialer.DialContext(ctx, N.NetworkTCP, s.handshake.Server)
195+
}
196+
if err != nil {
197+
return E.Cause(err, "server handshake")
198+
}
172199
return bufio.CopyConn(ctx, conn, handshakeConn)
173200
}
174201
if user.Name != "" {
175202
ctx = auth.ContextWithUser(ctx, user.Name)
176203
}
177204
s.logger.TraceContext(ctx, "client hello verify success")
205+
206+
handshakeConn, err = handshakeConfig.Dialer.DialContext(ctx, N.NetworkTCP, handshakeConfig.Server)
207+
if err != nil {
208+
return E.Cause(err, "server handshake")
209+
}
210+
211+
_, err = handshakeConn.Write(clientHelloFrame.Bytes())
212+
if err != nil {
213+
clientHelloFrame.Release()
214+
return E.Cause(err, "write client handshake")
215+
}
178216
clientHelloFrame.Release()
179217

180218
var serverHelloFrame *buf.Buffer

0 commit comments

Comments
 (0)