From fe9bec84ecf5e0938d537f5a6f1335f82f01c2b4 Mon Sep 17 00:00:00 2001 From: roc Date: Thu, 1 Dec 2022 19:45:07 +0800 Subject: [PATCH] SetDialTLS should override dial func in EnableH2C --- internal/http2/go115.go | 20 +++++++------------- internal/http2/transport.go | 6 ++++++ transport.go | 4 ++-- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/internal/http2/go115.go b/internal/http2/go115.go index 629d8613..a3a4dfc5 100644 --- a/internal/http2/go115.go +++ b/internal/http2/go115.go @@ -11,24 +11,18 @@ import ( "context" "crypto/tls" reqtls "github.com/imroc/req/v3/pkg/tls" - "net" ) // dialTLSWithContext uses tls.Dialer, added in Go 1.15, to open a TLS // connection. -func (t *Transport) dialTLSWithContext(ctx context.Context, network, addr string, cfg *tls.Config) (tlsCn reqtls.Conn, err error) { - var conn net.Conn - if t.DialTLSContext != nil { - conn, err = t.DialTLSContext(ctx, network, addr) - } else { - dialer := &tls.Dialer{ - Config: cfg, - } - conn, err = dialer.DialContext(ctx, network, addr) +func (t *Transport) dialTLSWithContext(ctx context.Context, network, addr string, cfg *tls.Config) (reqtls.Conn, error) { + dialer := &tls.Dialer{ + Config: cfg, } + conn, err := dialer.DialContext(ctx, network, addr) if err != nil { - return + return nil, err } - tlsCn = conn.(reqtls.Conn) - return + tlsCn := conn.(reqtls.Conn) + return tlsCn, nil } diff --git a/internal/http2/transport.go b/internal/http2/transport.go index 8b9d8a84..f6b3efe5 100644 --- a/internal/http2/transport.go +++ b/internal/http2/transport.go @@ -73,6 +73,7 @@ const ( // for concurrent use by multiple goroutines. type Transport struct { *transport.Options + // DialTLS specifies an optional dial function for creating // TLS connections for requests. // @@ -560,6 +561,11 @@ func (t *Transport) dialTLS(ctx context.Context) func(string, string, *tls.Confi if t.DialTLS != nil { return t.DialTLS } + if t.DialTLSContext != nil { + return func(network string, addr string, cfg *tls.Config) (net.Conn, error) { + return t.DialTLSContext(ctx, network, addr) + } + } return func(network, addr string, cfg *tls.Config) (net.Conn, error) { tlsCn, err := t.dialTLSWithContext(ctx, network, addr, cfg) if err != nil { diff --git a/transport.go b/transport.go index 7d14f86d..e6446484 100644 --- a/transport.go +++ b/transport.go @@ -367,7 +367,7 @@ func (t *Transport) EnableForceHTTP2() *Transport { func (t *Transport) EnableH2C() *Transport { t.Options.EnableH2C = true t.t2.AllowHTTP = true - t.t2.DialTLS = func(network, addr string, cfg *tls.Config) (net.Conn, error) { + t.DialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) { return net.Dial(network, addr) } return t @@ -377,7 +377,7 @@ func (t *Transport) EnableH2C() *Transport { func (t *Transport) DisableH2C() *Transport { t.Options.EnableH2C = false t.t2.AllowHTTP = false - t.t2.DialTLS = nil + t.t2.DialTLSContext = nil return t }