diff --git a/cffi_src/factory.go b/cffi_src/factory.go index e973e27..3f53cd8 100644 --- a/cffi_src/factory.go +++ b/cffi_src/factory.go @@ -309,6 +309,10 @@ func getTlsClient(requestInput RequestInput, sessionId string, withSession bool) options = append(options, tls_client.WithDisableIPV6()) } + if requestInput.DisableIPV4 { + options = append(options, tls_client.WithDisableIPV4()) + } + if requestInput.TransportOptions != nil { transportOptions := &tls_client.TransportOptions{ DisableKeepAlives: requestInput.TransportOptions.DisableKeepAlives, @@ -320,7 +324,7 @@ func getTlsClient(requestInput RequestInput, sessionId string, withSession bool) WriteBufferSize: requestInput.TransportOptions.WriteBufferSize, ReadBufferSize: requestInput.TransportOptions.ReadBufferSize, IdleConnTimeout: requestInput.TransportOptions.IdleConnTimeout, - // RootCAs: requestInput.TransportOptions.RootCAs, + //RootCAs: requestInput.TransportOptions.RootCAs, } options = append(options, tls_client.WithTransportOptions(transportOptions)) diff --git a/cffi_src/types.go b/cffi_src/types.go index 1c9ec01..f712356 100644 --- a/cffi_src/types.go +++ b/cffi_src/types.go @@ -63,6 +63,7 @@ type RequestInput struct { IsByteResponse bool `json:"isByteResponse"` IsRotatingProxy bool `json:"isRotatingProxy"` DisableIPV6 bool `json:"disableIPV6"` + DisableIPV4 bool `json:"disableIPV4"` LocalAddress *string `json:"localAddress"` ServerNameOverwrite *string `json:"serverNameOverwrite"` ProxyUrl *string `json:"proxyUrl"` diff --git a/client.go b/client.go index 23386b6..486530f 100644 --- a/client.go +++ b/client.go @@ -150,7 +150,7 @@ func buildFromConfig(logger Logger, config *httpClientConfig) (*http.Client, ban clientProfile := config.clientProfile - transport, err := newRoundTripper(clientProfile, config.transportOptions, config.serverNameOverwrite, config.insecureSkipVerify, config.withRandomTlsExtensionOrder, config.forceHttp1, config.certificatePins, config.badPinHandler, config.disableIPV6, bandwidthTracker, dialer) + transport, err := newRoundTripper(clientProfile, config.transportOptions, config.serverNameOverwrite, config.insecureSkipVerify, config.withRandomTlsExtensionOrder, config.forceHttp1, config.certificatePins, config.badPinHandler, config.disableIPV6, config.disableIPV4, bandwidthTracker, dialer) if err != nil { return nil, nil, clientProfile, err } @@ -242,7 +242,7 @@ func (c *httpClient) applyProxy() error { dialer = proxyDialer } - transport, err := newRoundTripper(c.config.clientProfile, c.config.transportOptions, c.config.serverNameOverwrite, c.config.insecureSkipVerify, c.config.withRandomTlsExtensionOrder, c.config.forceHttp1, c.config.certificatePins, c.config.badPinHandler, c.config.disableIPV6, c.bandwidthTracker, dialer) + transport, err := newRoundTripper(c.config.clientProfile, c.config.transportOptions, c.config.serverNameOverwrite, c.config.insecureSkipVerify, c.config.withRandomTlsExtensionOrder, c.config.forceHttp1, c.config.certificatePins, c.config.badPinHandler, c.config.disableIPV6, c.config.disableIPV4, c.bandwidthTracker, dialer) if err != nil { return err } diff --git a/client_options.go b/client_options.go index 7f03424..5e5f3f6 100644 --- a/client_options.go +++ b/client_options.go @@ -57,6 +57,8 @@ type httpClientConfig struct { // Establish a connection to origin server via ipv4 only disableIPV6 bool + // Establish a connection to origin server via ipv6 only + disableIPV4 bool dialer net.Dialer enabledBandwidthTracker bool @@ -243,6 +245,13 @@ func WithDisableIPV6() HttpClientOption { } } +// WithDisableIPV4 configures a dialer to use tcp6 network argument +func WithDisableIPV4() HttpClientOption { + return func(config *httpClientConfig) { + config.disableIPV4 = true + } +} + // WithBandwidthTracker configures a client to track the bandwidth used by the client. func WithBandwidthTracker() HttpClientOption { return func(config *httpClientConfig) { diff --git a/roundtripper.go b/roundtripper.go index 9f1c586..479a2c7 100644 --- a/roundtripper.go +++ b/roundtripper.go @@ -50,6 +50,7 @@ type roundTripper struct { transportOptions *TransportOptions withRandomTlsExtensionOrder bool disableIPV6 bool + disableIPV4 bool } func (rt *roundTripper) CloseIdleConnections() { @@ -129,6 +130,10 @@ func (rt *roundTripper) dialTLS(ctx context.Context, network, addr string) (net. network = "tcp4" } + if network == "tcp" && rt.disableIPV4 { + network = "tcp6" + } + rawConn, err := rt.dialer.DialContext(ctx, network, addr) if err != nil { return nil, err @@ -311,7 +316,7 @@ func (rt *roundTripper) getDialTLSAddr(req *http.Request) string { return net.JoinHostPort(req.URL.Host, "443") } -func newRoundTripper(clientProfile profiles.ClientProfile, transportOptions *TransportOptions, serverNameOverwrite string, insecureSkipVerify bool, withRandomTlsExtensionOrder bool, forceHttp1 bool, certificatePins map[string][]string, badPinHandlerFunc BadPinHandlerFunc, disableIPV6 bool, bandwidthTracker bandwidth.BandwidthTracker, dialer ...proxy.ContextDialer) (http.RoundTripper, error) { +func newRoundTripper(clientProfile profiles.ClientProfile, transportOptions *TransportOptions, serverNameOverwrite string, insecureSkipVerify bool, withRandomTlsExtensionOrder bool, forceHttp1 bool, certificatePins map[string][]string, badPinHandlerFunc BadPinHandlerFunc, disableIPV6 bool, disableIPV4 bool, bandwidthTracker bandwidth.BandwidthTracker, dialer ...proxy.ContextDialer) (http.RoundTripper, error) { pinner, err := NewCertificatePinner(certificatePins) if err != nil { return nil, fmt.Errorf("can not instantiate certificate pinner: %w", err) @@ -345,6 +350,7 @@ func newRoundTripper(clientProfile profiles.ClientProfile, transportOptions *Tra cachedTransports: make(map[string]http.RoundTripper), cachedConnections: make(map[string]net.Conn), disableIPV6: disableIPV6, + disableIPV4: disableIPV4, bandwidthTracker: bandwidthTracker, }