Skip to content

Commit

Permalink
Enable TLS interception support even when proxy pool plugin is enabled
Browse files Browse the repository at this point in the history
  • Loading branch information
abhinavsingh committed Apr 13, 2024
1 parent 7026c13 commit d578ae7
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 21 deletions.
11 changes: 11 additions & 0 deletions proxy/core/connection/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,17 @@ def __init__(self, host: str, port: int) -> None:
self._conn: Optional[TcpOrTlsSocket] = None
self.addr: HostPort = (host, port)
self.closed = True
self._proxy = False

def is_secure(self) -> bool:
return isinstance(self._conn, ssl.SSLSocket)

def mark_as_proxy(self) -> None:
self._proxy = True

@property
def is_proxy(self) -> bool:
return self._proxy

@property
def connection(self) -> TcpOrTlsSocket:
Expand Down
8 changes: 7 additions & 1 deletion proxy/http/proxy/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ...core.event import EventQueue
from ..descriptors import DescriptorsHandlerMixin
from ...common.utils import tls_interception_enabled

from ...core.connection import TcpServerConnection

if TYPE_CHECKING: # pragma: no cover
from ...common.types import HostPort
Expand Down Expand Up @@ -69,6 +69,12 @@ def resolve_dns(self, host: str, port: int) -> Tuple[Optional[str], Optional['Ho
"""
return None, None

def upstream_connection(
self,
request: HttpParser,
) -> Optional[TcpServerConnection]:
return None

# No longer abstract since 2.4.0
#
# @abstractmethod
Expand Down
91 changes: 71 additions & 20 deletions proxy/http/proxy/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import logging
import threading
import subprocess
from typing import Any, Dict, List, Union, Optional, cast
from typing import Any, Dict, List, Union, Optional

from .plugin import HttpProxyBasePlugin
from ..parser import HttpParser, httpParserTypes, httpParserStates
Expand Down Expand Up @@ -487,6 +487,14 @@ def on_request_complete(self) -> Union[socket.socket, bool]:
# Connect to upstream
if do_connect:
self.connect_upstream()
else:
# If a plugin asked us not to connect to upstream
# check if any plugin is managing an upstream connection.
for plugin in self.plugins.values():
up = plugin.upstream_connection(self.request)
if up is not None:
self.upstream = up
break

# Invoke plugin.handle_client_request
for plugin in self.plugins.values():
Expand Down Expand Up @@ -756,54 +764,97 @@ def intercept(self) -> Union[socket.socket, bool]:
return self.client.connection

def wrap_server(self) -> bool:
assert self.upstream is not None
assert isinstance(self.upstream.connection, socket.socket)
assert self.upstream is not None and self.request.host
return self._wrap_server(
self.upstream,
host=self.request.host,
ca_file=self.flags.ca_file,
)

@staticmethod
def _wrap_server(
upstream: TcpServerConnection,
host: bytes,
ca_file: Optional[str] = None,
) -> bool:
assert isinstance(upstream.connection, socket.socket)
do_close = False
if upstream.is_proxy:
# Don't wrap upstream if its part of proxy chain
return do_close
try:
self.upstream.wrap(
text_(self.request.host),
self.flags.ca_file,
upstream.wrap(
text_(host),
ca_file,
as_non_blocking=True,
)
except ssl.SSLCertVerificationError: # Server raised certificate verification error
# When --disable-interception-on-ssl-cert-verification-error flag is on,
# we will cache such upstream hosts and avoid intercepting them for future
# requests.
logger.warning(
'ssl.SSLCertVerificationError: ' +
'Server raised cert verification error for upstream: {0}'.format(
self.upstream.addr[0],
"ssl.SSLCertVerificationError: "
+ "Server raised cert verification error for upstream: {0}".format(
upstream.addr[0],
),
)
do_close = True
except ssl.SSLError as e:
if e.reason == 'SSLV3_ALERT_HANDSHAKE_FAILURE':
logger.warning(
'{0}: '.format(e.reason) +
'Server raised handshake alert failure for upstream: {0}'.format(
self.upstream.addr[0],
"{0}: ".format(e.reason)
+ "Server raised handshake alert failure for upstream: {0}".format(
upstream.addr[0],
),
)
else:
logger.exception(
'SSLError when wrapping client for upstream: {0}'.format(
self.upstream.addr[0],
), exc_info=e,
"SSLError when wrapping client for upstream: {0}".format(
upstream.addr[0],
),
exc_info=e,
)
do_close = True
if not do_close:
assert isinstance(self.upstream.connection, ssl.SSLSocket)
assert isinstance(upstream.connection, ssl.SSLSocket)
return do_close

def wrap_client(self) -> bool:
assert self.upstream is not None and self.flags.ca_signing_key_file is not None
assert isinstance(self.upstream.connection, ssl.SSLSocket)
certificate: Optional[Dict[str, Any]] = None
if isinstance(self.upstream.connection, ssl.SSLSocket):
certificate = self.upstream.connection.getpeercert()
else:
assert self.upstream.is_proxy and self.request.host and self.request.port
if self.flags.enable_conn_pool:
assert self.upstream_conn_pool
with self.lock:
_, upstream = self.upstream_conn_pool.acquire(
(text_(self.request.host), self.request.port),
)
else:
_, upstream = True, TcpServerConnection(
text_(self.request.host),
self.request.port,
)
# Connect with overridden upstream IP and source address
# if any of the plugin returned a non-null value.
upstream.connect()
upstream.connection.setblocking(False)
do_close = self._wrap_server(
upstream,
host=self.request.host,
ca_file=self.flags.ca_file,
)
if do_close:
return do_close
assert isinstance(upstream.connection, ssl.SSLSocket)
certificate = upstream.connection.getpeercert()
assert certificate
do_close = False
try:
# TODO: Perform async certificate generation
generated_cert = self.generate_upstream_certificate(
cast(Dict[str, Any], self.upstream.connection.getpeercert()),
)
generated_cert = self.generate_upstream_certificate(certificate)
self.client.wrap(self.flags.ca_signing_key_file, generated_cert)
except subprocess.TimeoutExpired as e: # Popen communicate timeout
logger.exception(
Expand Down
6 changes: 6 additions & 0 deletions proxy/plugin/proxy_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import ipaddress
from typing import Any, Dict, List, Optional

from proxy.core.connection import TcpServerConnection

from ..http import Url, httpHeaders, httpMethods
from ..core.base import TcpUpstreamConnectionHandler
from ..http.proxy import HttpProxyBasePlugin
Expand Down Expand Up @@ -78,6 +80,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
def handle_upstream_data(self, raw: memoryview) -> None:
self.client.queue(raw)

def upstream_connection(self, request: HttpParser) -> Optional[TcpServerConnection]:
return self.upstream

def before_upstream_connection(
self, request: HttpParser,
) -> Optional[HttpParser]:
Expand Down Expand Up @@ -107,6 +112,7 @@ def before_upstream_connection(
logger.debug('Using endpoint: {0}:{1}'.format(*endpoint_tuple))
self.initialize_upstream(*endpoint_tuple)
assert self.upstream
self.upstream.mark_as_proxy()
try:
self.upstream.connect()
except TimeoutError:
Expand Down

0 comments on commit d578ae7

Please sign in to comment.