Skip to content

Commit

Permalink
SetDialTLS should override dial func in EnableH2C
Browse files Browse the repository at this point in the history
  • Loading branch information
imroc committed Dec 1, 2022
1 parent d673df0 commit fe9bec8
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 15 deletions.
20 changes: 7 additions & 13 deletions internal/http2/go115.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,18 @@ import (
"context"
"crypto/tls"
reqtls "github.com/imroc/req/v3/pkg/tls"
"net"
)

// dialTLSWithContext uses tls.Dialer, added in Go 1.15, to open a TLS
// connection.
func (t *Transport) dialTLSWithContext(ctx context.Context, network, addr string, cfg *tls.Config) (tlsCn reqtls.Conn, err error) {
var conn net.Conn
if t.DialTLSContext != nil {
conn, err = t.DialTLSContext(ctx, network, addr)
} else {
dialer := &tls.Dialer{
Config: cfg,
}
conn, err = dialer.DialContext(ctx, network, addr)
func (t *Transport) dialTLSWithContext(ctx context.Context, network, addr string, cfg *tls.Config) (reqtls.Conn, error) {
dialer := &tls.Dialer{
Config: cfg,
}
conn, err := dialer.DialContext(ctx, network, addr)
if err != nil {
return
return nil, err
}
tlsCn = conn.(reqtls.Conn)
return
tlsCn := conn.(reqtls.Conn)
return tlsCn, nil
}
6 changes: 6 additions & 0 deletions internal/http2/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ const (
// for concurrent use by multiple goroutines.
type Transport struct {
*transport.Options

// DialTLS specifies an optional dial function for creating
// TLS connections for requests.
//
Expand Down Expand Up @@ -560,6 +561,11 @@ func (t *Transport) dialTLS(ctx context.Context) func(string, string, *tls.Confi
if t.DialTLS != nil {
return t.DialTLS
}
if t.DialTLSContext != nil {
return func(network string, addr string, cfg *tls.Config) (net.Conn, error) {
return t.DialTLSContext(ctx, network, addr)
}
}
return func(network, addr string, cfg *tls.Config) (net.Conn, error) {
tlsCn, err := t.dialTLSWithContext(ctx, network, addr, cfg)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ func (t *Transport) EnableForceHTTP2() *Transport {
func (t *Transport) EnableH2C() *Transport {
t.Options.EnableH2C = true
t.t2.AllowHTTP = true
t.t2.DialTLS = func(network, addr string, cfg *tls.Config) (net.Conn, error) {
t.DialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return net.Dial(network, addr)
}
return t
Expand All @@ -377,7 +377,7 @@ func (t *Transport) EnableH2C() *Transport {
func (t *Transport) DisableH2C() *Transport {
t.Options.EnableH2C = false
t.t2.AllowHTTP = false
t.t2.DialTLS = nil
t.t2.DialTLSContext = nil
return t
}

Expand Down

0 comments on commit fe9bec8

Please sign in to comment.