@@ -27,17 +27,27 @@ type Service struct {
27
27
handshake HandshakeConfig
28
28
handshakeForServerName map [string ]HandshakeConfig
29
29
strictMode bool
30
+ wildcardSNI WildcardSNI
30
31
handler N.TCPConnectionHandlerEx
31
32
logger logger.ContextLogger
32
33
}
33
34
35
+ type WildcardSNI int
36
+
37
+ const (
38
+ WildcardSNIOff WildcardSNI = iota
39
+ WildcardSNIAuthed
40
+ WildcardSNIAll
41
+ )
42
+
34
43
type ServiceConfig struct {
35
44
Version int
36
45
Password string // for protocol version 2
37
46
Users []User // for protocol version 3
38
47
Handshake HandshakeConfig
39
48
HandshakeForServerName map [string ]HandshakeConfig // for protocol version 2/3
40
49
StrictMode bool // for protocol version 3
50
+ WildcardSNI WildcardSNI // for protocol version 3
41
51
Handler N.TCPConnectionHandlerEx
42
52
Logger logger.ContextLogger
43
53
}
@@ -60,6 +70,7 @@ func NewService(config ServiceConfig) (*Service, error) {
60
70
handshake : config .Handshake ,
61
71
handshakeForServerName : config .HandshakeForServerName ,
62
72
strictMode : config .StrictMode ,
73
+ wildcardSNI : config .WildcardSNI ,
63
74
handler : config .Handler ,
64
75
logger : config .Logger ,
65
76
}
@@ -84,16 +95,6 @@ func NewService(config ServiceConfig) (*Service, error) {
84
95
return service , nil
85
96
}
86
97
87
- func (s * Service ) selectHandshake (clientHelloFrame * buf.Buffer ) HandshakeConfig {
88
- serverName , err := extractServerName (clientHelloFrame .Bytes ())
89
- if err == nil {
90
- if customHandshake , found := s .handshakeForServerName [serverName ]; found {
91
- return customHandshake
92
- }
93
- }
94
- return s .handshake
95
- }
96
-
97
98
func (s * Service ) NewConnection (ctx context.Context , conn net.Conn , source M.Socksaddr , destination M.Socksaddr , onClose N.CloseHandlerFunc ) error {
98
99
switch s .version {
99
100
default :
@@ -127,8 +128,17 @@ func (s *Service) NewConnection(ctx context.Context, conn net.Conn, source M.Soc
127
128
if err != nil {
128
129
return E .Cause (err , "read client handshake" )
129
130
}
130
-
131
- handshakeConfig := s .selectHandshake (clientHelloFrame )
131
+ serverName , err := extractServerName (clientHelloFrame .Bytes ())
132
+ var handshakeConfig HandshakeConfig
133
+ if err == nil {
134
+ if customHandshake , found := s .handshakeForServerName [serverName ]; found {
135
+ handshakeConfig = customHandshake
136
+ } else {
137
+ handshakeConfig = s .handshake
138
+ }
139
+ } else {
140
+ handshakeConfig = s .handshake
141
+ }
132
142
handshakeConn , err := handshakeConfig .Dialer .DialContext (ctx , N .NetworkTCP , handshakeConfig .Server )
133
143
if err != nil {
134
144
return E .Cause (err , "server handshake" )
@@ -154,27 +164,55 @@ func (s *Service) NewConnection(ctx context.Context, conn net.Conn, source M.Soc
154
164
if err != nil {
155
165
return E .Cause (err , "read client handshake" )
156
166
}
157
-
158
- handshakeConfig := s .selectHandshake (clientHelloFrame )
159
- handshakeConn , err := handshakeConfig .Dialer .DialContext (ctx , N .NetworkTCP , handshakeConfig .Server )
167
+ serverName , err := extractServerName (clientHelloFrame .Bytes ())
160
168
if err != nil {
161
- return E .Cause (err , "server handshake " )
169
+ return E .Cause (err , "extract server name " )
162
170
}
163
-
164
- _ , err = handshakeConn .Write (clientHelloFrame .Bytes ())
165
- if err != nil {
166
- clientHelloFrame .Release ()
167
- return E .Cause (err , "write client handshake" )
171
+ var (
172
+ handshakeConfig HandshakeConfig
173
+ isCustom bool
174
+ )
175
+ if customHandshake , found := s .handshakeForServerName [serverName ]; found {
176
+ handshakeConfig = customHandshake
177
+ isCustom = true
178
+ } else {
179
+ handshakeConfig = s .handshake
180
+ if s .wildcardSNI != WildcardSNIOff {
181
+ handshakeConfig .Server = M.Socksaddr {
182
+ Fqdn : serverName ,
183
+ Port : s .handshake .Server .Port ,
184
+ }
185
+ }
168
186
}
187
+ var handshakeConn net.Conn
169
188
user , err := verifyClientHello (clientHelloFrame .Bytes (), s .users )
170
189
if err != nil {
171
190
s .logger .WarnContext (ctx , E .Cause (err , "client hello verify failed" ))
191
+ if s .wildcardSNI == WildcardSNIAll || isCustom {
192
+ handshakeConn , err = handshakeConfig .Dialer .DialContext (ctx , N .NetworkTCP , handshakeConfig .Server )
193
+ } else {
194
+ handshakeConn , err = s .handshake .Dialer .DialContext (ctx , N .NetworkTCP , s .handshake .Server )
195
+ }
196
+ if err != nil {
197
+ return E .Cause (err , "server handshake" )
198
+ }
172
199
return bufio .CopyConn (ctx , conn , handshakeConn )
173
200
}
174
201
if user .Name != "" {
175
202
ctx = auth .ContextWithUser (ctx , user .Name )
176
203
}
177
204
s .logger .TraceContext (ctx , "client hello verify success" )
205
+
206
+ handshakeConn , err = handshakeConfig .Dialer .DialContext (ctx , N .NetworkTCP , handshakeConfig .Server )
207
+ if err != nil {
208
+ return E .Cause (err , "server handshake" )
209
+ }
210
+
211
+ _ , err = handshakeConn .Write (clientHelloFrame .Bytes ())
212
+ if err != nil {
213
+ clientHelloFrame .Release ()
214
+ return E .Cause (err , "write client handshake" )
215
+ }
178
216
clientHelloFrame .Release ()
179
217
180
218
var serverHelloFrame * buf.Buffer
0 commit comments