From 30574fd0414005dfa8792a6e797023e862bdcf43 Mon Sep 17 00:00:00 2001 From: Alexey Pelykh Date: Mon, 17 Apr 2023 06:12:18 +0200 Subject: [PATCH] Support --hostnames (#1325) * Support --hostnames * Support python binaries that are not called "python" * Update log statement to use self.hostname now --------- Co-authored-by: Abhinav Singh <126065+abhinavsingh@users.noreply.github.com> --- Makefile | 37 +++++++------- README.md | 10 ++-- docs/changelog-fragments.d/1325.feature.md | 1 + proxy/common/flag.py | 24 +++------ proxy/core/acceptor/acceptor.py | 8 +-- proxy/core/listener/pool.py | 11 +++-- proxy/core/listener/tcp.py | 31 +++++++++--- proxy/core/work/fd/fd.py | 5 +- proxy/proxy.py | 10 ++-- tests/core/test_acceptor.py | 31 +++++++----- tests/core/test_listener.py | 6 +-- tests/core/test_listener_pool.py | 49 +++++++++---------- .../exceptions/test_http_proxy_auth_failed.py | 9 +++- tests/http/proxy/test_http_proxy.py | 9 +++- .../proxy/test_http_proxy_tls_interception.py | 8 +-- tests/http/test_protocol_handler.py | 16 +++--- tests/http/web/test_web_server.py | 22 +++++---- tests/integration/__init__.py | 10 ++++ tests/integration/test_integration.py | 42 ++++++++++------ tests/plugin/test_http_proxy_plugins.py | 5 +- ...ttp_proxy_plugins_with_tls_interception.py | 6 ++- tests/socks/test_handler.py | 5 +- 22 files changed, 206 insertions(+), 149 deletions(-) create mode 100644 docs/changelog-fragments.d/1325.feature.md create mode 100644 tests/integration/__init__.py diff --git a/Makefile b/Makefile index da423c5537..f111b2f9e6 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,5 @@ SHELL := /bin/bash +PYTHON ?= python NS ?= abhinavsingh IMAGE_NAME ?= proxy.py @@ -40,23 +41,23 @@ all: lib-test https-certificates: # Generate server key - python -m proxy.common.pki gen_private_key \ + $(PYTHON) -m proxy.common.pki gen_private_key \ --private-key-path $(HTTPS_KEY_FILE_PATH) - python -m proxy.common.pki remove_passphrase \ + $(PYTHON) -m proxy.common.pki remove_passphrase \ --private-key-path $(HTTPS_KEY_FILE_PATH) # Generate server certificate - python -m proxy.common.pki gen_public_key \ + $(PYTHON) -m proxy.common.pki gen_public_key \ --private-key-path $(HTTPS_KEY_FILE_PATH) \ --public-key-path $(HTTPS_CERT_FILE_PATH) sign-https-certificates: # Generate CSR request - python -m proxy.common.pki gen_csr \ + $(PYTHON) -m proxy.common.pki gen_csr \ --csr-path $(HTTPS_CSR_FILE_PATH) \ --private-key-path $(HTTPS_KEY_FILE_PATH) \ --public-key-path $(HTTPS_CERT_FILE_PATH) # Sign CSR with CA - python -m proxy.common.pki sign_csr \ + $(PYTHON) -m proxy.common.pki sign_csr \ --csr-path $(HTTPS_CSR_FILE_PATH) \ --crt-path $(HTTPS_SIGNED_CERT_FILE_PATH) \ --hostname localhost \ @@ -65,23 +66,23 @@ sign-https-certificates: ca-certificates: # Generate CA key - python -m proxy.common.pki gen_private_key \ + $(PYTHON) -m proxy.common.pki gen_private_key \ --private-key-path $(CA_KEY_FILE_PATH) - python -m proxy.common.pki remove_passphrase \ + $(PYTHON) -m proxy.common.pki remove_passphrase \ --private-key-path $(CA_KEY_FILE_PATH) # Generate CA certificate - python -m proxy.common.pki gen_public_key \ + $(PYTHON) -m proxy.common.pki gen_public_key \ --private-key-path $(CA_KEY_FILE_PATH) \ --public-key-path $(CA_CERT_FILE_PATH) # Generate key that will be used to generate domain certificates on the fly # Generated certificates are then signed with CA certificate / key generated above - python -m proxy.common.pki gen_private_key \ + $(PYTHON) -m proxy.common.pki gen_private_key \ --private-key-path $(CA_SIGNING_KEY_FILE_PATH) - python -m proxy.common.pki remove_passphrase \ + $(PYTHON) -m proxy.common.pki remove_passphrase \ --private-key-path $(CA_SIGNING_KEY_FILE_PATH) lib-check: - python check.py + $(PYTHON) check.py lib-clean: find . -name '*.pyc' -exec rm -f {} + @@ -107,10 +108,10 @@ lib-dep: pip install "setuptools>=42" lib-pre-commit: - python -m pre_commit run --hook-stage manual --all-files -v + $(PYTHON) -m pre_commit run --hook-stage manual --all-files -v lib-lint: - python -m tox -e lint + $(PYTHON) -m tox -e lint lib-flake8: tox -e lint -- flake8 --all-files @@ -119,12 +120,12 @@ lib-mypy: tox -e lint -- mypy --all-files lib-pytest: - python -m tox -e python -- -v + $(PYTHON) -m tox -e python -- -v lib-test: lib-clean lib-check lib-lint lib-pytest lib-package: lib-clean lib-check - python -m tox -e cleanup-dists,build-dists,metadata-validation + $(PYTHON) -m tox -e cleanup-dists,build-dists,metadata-validation lib-release-test: lib-package twine upload --verbose --repository-url https://test.pypi.org/legacy/ dist/* @@ -133,7 +134,7 @@ lib-release: lib-package twine upload dist/* lib-doc: - python -m tox -e build-docs && \ + $(PYTHON) -m tox -e build-docs && \ $(OPEN) .tox/build-docs/docs_out/index.html || true lib-coverage: lib-clean @@ -145,7 +146,7 @@ lib-profile: sudo py-spy record \ -o profile.svg \ -t -F -s -- \ - python -m proxy \ + $(PYTHON) -m proxy \ --hostname 127.0.0.1 \ --num-acceptors 1 \ --num-workers 1 \ @@ -161,7 +162,7 @@ lib-speedscope: -o profile.speedscope.json \ -f speedscope \ -t -F -s -- \ - python -m proxy \ + $(PYTHON) -m proxy \ --hostname 127.0.0.1 \ --num-acceptors 1 \ --num-workers 1 \ diff --git a/README.md b/README.md index 7d16b00a5d..90141417e6 100644 --- a/README.md +++ b/README.md @@ -213,7 +213,8 @@ - `--enable-reverse-proxy --plugins proxy.plugin.ReverseProxyPlugin` - Plugin API is currently in *development phase*. Expect breaking changes. See [Deploying proxy.py in production](#deploying-proxypy-in-production) on how to ensure reliability across code changes. -- Can listen on multiple ports +- Can listen on multiple addresses and ports + - Use `--hostnames` flag to provide additional addresses - Use `--ports` flag to provide additional ports - Optionally, use `--port` flag to override default port `8899` - Capable of serving multiple protocols over the same port @@ -2335,8 +2336,9 @@ usage: -m [-h] [--tunnel-hostname TUNNEL_HOSTNAME] [--tunnel-port TUNNEL_PORT] [--tunnel-remote-port TUNNEL_REMOTE_PORT] [--threadless] [--threaded] [--num-workers NUM_WORKERS] [--enable-events] [--local-executor LOCAL_EXECUTOR] [--backlog BACKLOG] - [--hostname HOSTNAME] [--port PORT] [--ports PORTS [PORTS ...]] - [--port-file PORT_FILE] [--unix-socket-path UNIX_SOCKET_PATH] + [--hostname HOSTNAME] [--hostnames HOSTNAMES [HOSTNAMES ...]] + [--port PORT] [--ports PORTS [PORTS ...]] [--port-file PORT_FILE] + [--unix-socket-path UNIX_SOCKET_PATH] [--num-acceptors NUM_ACCEPTORS] [--version] [--log-level LOG_LEVEL] [--log-file LOG_FILE] [--log-format LOG_FORMAT] [--open-file-limit OPEN_FILE_LIMIT] @@ -2405,6 +2407,8 @@ options: --backlog BACKLOG Default: 100. Maximum number of pending connections to proxy server. --hostname HOSTNAME Default: 127.0.0.1. Server IP address. + --hostnames HOSTNAMES [HOSTNAMES ...] + Default: None. Additional IP addresses to listen on. --port PORT Default: 8899. Server port. To listen on more ports, pass them using --ports flag. --ports PORTS [PORTS ...] diff --git a/docs/changelog-fragments.d/1325.feature.md b/docs/changelog-fragments.d/1325.feature.md new file mode 100644 index 0000000000..65d6af632a --- /dev/null +++ b/docs/changelog-fragments.d/1325.feature.md @@ -0,0 +1 @@ +Support `--hostnames` to specify multiple IP addresses to listen on. diff --git a/proxy/common/flag.py b/proxy/common/flag.py index f8395a6f62..cffb97223e 100644 --- a/proxy/common/flag.py +++ b/proxy/common/flag.py @@ -11,7 +11,6 @@ import os import sys import base64 -import socket import argparse import ipaddress import itertools @@ -25,7 +24,7 @@ from .plugins import Plugins from .version import __version__ from .constants import ( - COMMA, IS_WINDOWS, PLUGIN_PAC_FILE, PLUGIN_DASHBOARD, PLUGIN_HTTP_PROXY, + COMMA, PLUGIN_PAC_FILE, PLUGIN_DASHBOARD, PLUGIN_HTTP_PROXY, PLUGIN_PROXY_AUTH, PLUGIN_WEB_SERVER, DEFAULT_NUM_WORKERS, PLUGIN_REVERSE_PROXY, DEFAULT_NUM_ACCEPTORS, PLUGIN_INSPECT_TRAFFIC, DEFAULT_DISABLE_HEADERS, PY2_DEPRECATION_MESSAGE, DEFAULT_DEVTOOLS_WS_PATH, @@ -291,24 +290,15 @@ def initialize( IpAddress, opts.get('hostname', ipaddress.ip_address(args.hostname)), ) + hostnames: List[List[str]] = opts.get('hostnames', args.hostnames) + args.hostnames = [ + ipaddress.ip_address(hostname) for hostname in list( + itertools.chain.from_iterable([] if hostnames is None else hostnames), + ) + ] args.unix_socket_path = opts.get( 'unix_socket_path', args.unix_socket_path, ) - # AF_UNIX is not available on Windows - # See https://bugs.python.org/issue33408 - if not IS_WINDOWS: - args.family = socket.AF_UNIX if args.unix_socket_path else ( - socket.AF_INET6 if args.hostname.version == 6 else socket.AF_INET - ) - else: - # FIXME: Not true for tests, as this value will be a mock. - # - # It's a problem only on Windows. Instead of a proper - # fix in the tests, simply commenting this line of assertion - # for now. - # - # assert args.unix_socket_path is None - args.family = socket.AF_INET6 if args.hostname.version == 6 else socket.AF_INET args.port = cast(int, opts.get('port', args.port)) ports: List[List[int]] = opts.get('ports', args.ports) args.ports = [ diff --git a/proxy/core/acceptor/acceptor.py b/proxy/core/acceptor/acceptor.py index e6db855ee2..2d43e0c3e0 100644 --- a/proxy/core/acceptor/acceptor.py +++ b/proxy/core/acceptor/acceptor.py @@ -186,12 +186,8 @@ def _recv_and_setup_socks(self) -> None: # dynamically accept from new fds. for _ in range(self.fd_queue.recv()): fileno = recv_handle(self.fd_queue) - # TODO: Convert to socks i.e. list of fds - self.socks[fileno] = socket.fromfd( - fileno, - family=self.flags.family, - type=socket.SOCK_STREAM, - ) + sock = socket.socket(fileno=socket.dup(fileno)) # type: ignore[attr-defined] + self.socks[fileno] = sock self.fd_queue.close() def _start_local(self) -> None: diff --git a/proxy/core/listener/pool.py b/proxy/core/listener/pool.py index b362ae558c..f9befa9c17 100644 --- a/proxy/core/listener/pool.py +++ b/proxy/core/listener/pool.py @@ -9,6 +9,7 @@ :license: BSD, see LICENSE for more details. """ import argparse +import itertools from typing import TYPE_CHECKING, Any, List, Type from .tcp import TcpSocketListener @@ -37,10 +38,12 @@ def __exit__(self, *args: Any) -> None: def setup(self) -> None: if self.flags.unix_socket_path: self.add(UnixSocketListener) - else: - self.add(TcpSocketListener) - for port in self.flags.ports: - self.add(TcpSocketListener, port=port) + hostnames = {self.flags.hostname, *self.flags.hostnames} + ports = set(self.flags.ports) + if not self.flags.unix_socket_path: + ports.add(self.flags.port) + for hostname, port in itertools.product(hostnames, ports): + self.add(TcpSocketListener, hostname=hostname, port=port) def shutdown(self) -> None: for listener in self.pool: diff --git a/proxy/core/listener/tcp.py b/proxy/core/listener/tcp.py index f841183fd8..b6dc15e8ef 100644 --- a/proxy/core/listener/tcp.py +++ b/proxy/core/listener/tcp.py @@ -10,7 +10,8 @@ """ import socket import logging -from typing import Any, Optional +import ipaddress +from typing import Any, Union, Optional from .base import BaseListener from ...common.flag import flags @@ -26,6 +27,15 @@ help='Default: 127.0.0.1. Server IP address.', ) +flags.add_argument( + '--hostnames', + action='append', + nargs='+', + type=str, + default=None, + help='Default: None. Additional IP addresses to listen on.', +) + flags.add_argument( '--port', type=int, @@ -37,6 +47,7 @@ '--ports', action='append', nargs='+', + type=int, default=None, help='Default: None. Additional ports to listen on.', ) @@ -54,9 +65,14 @@ class TcpSocketListener(BaseListener): """Tcp listener.""" - def __init__(self, *args: Any, port: Optional[int] = None, **kwargs: Any) -> None: - # Port if passed will be used, otherwise - # flag port value will be used. + def __init__( + self, + hostname: Union[ipaddress.IPv4Address, ipaddress.IPv6Address], + port: int, + *args: Any, + **kwargs: Any, + ) -> None: + self.hostname = hostname self.port = port # Set after binding to a port. # @@ -66,19 +82,18 @@ def __init__(self, *args: Any, port: Optional[int] = None, **kwargs: Any) -> Non def listen(self) -> socket.socket: sock = socket.socket( - socket.AF_INET6 if self.flags.hostname.version == 6 else socket.AF_INET, + socket.AF_INET6 if self.hostname.version == 6 else socket.AF_INET, socket.SOCK_STREAM, ) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) # s.setsockopt(socket.SOL_TCP, socket.TCP_FASTOPEN, 5) - port = self.port if self.port is not None else self.flags.port - sock.bind((str(self.flags.hostname), port)) + sock.bind((str(self.hostname), self.port)) sock.listen(self.flags.backlog) sock.setblocking(False) self._port = sock.getsockname()[1] logger.info( 'Listening on %s:%s' % - (self.flags.hostname, self._port), + (self.hostname, self._port), ) return sock diff --git a/proxy/core/work/fd/fd.py b/proxy/core/work/fd/fd.py index 577e5c7c37..730019123c 100644 --- a/proxy/core/work/fd/fd.py +++ b/proxy/core/work/fd/fd.py @@ -30,10 +30,7 @@ def work(self, *args: Any) -> None: fileno: int = args[0] addr: Optional[HostPort] = args[1] conn: Optional[TcpOrTlsSocket] = args[2] - conn = conn or socket.fromfd( - fileno, family=socket.AF_INET if self.flags.hostname.version == 4 else socket.AF_INET6, - type=socket.SOCK_STREAM, - ) + conn = conn or socket.socket(fileno=socket.dup(fileno)) # type: ignore[attr-defined] uid = '%s-%s-%s' % (self.iid, self._total, fileno) self.works[fileno] = self.create(uid, conn, addr) self.works[fileno].publish_event( diff --git a/proxy/proxy.py b/proxy/proxy.py index d9d9f89798..62ecc52e48 100644 --- a/proxy/proxy.py +++ b/proxy/proxy.py @@ -211,16 +211,18 @@ def setup(self) -> None: )._port # --ports flag can also use 0 as value for ephemeral port selection. # Here, we override flags.ports to reflect actual listening ports. - ports = [] - offset = 1 if self.flags.unix_socket_path or self.flags.port else 0 + ports = set() + offset = 1 if self.flags.unix_socket_path else 0 for index in range(offset, offset + len(self.flags.ports)): - ports.append( + ports.add( cast( 'TcpSocketListener', self.listeners.pool[index], )._port, ) - self.flags.ports = ports + if self.flags.port in ports: + ports.remove(self.flags.port) + self.flags.ports = list(ports) # Write ports to port file self._write_port_file() # Setup EventManager diff --git a/tests/core/test_acceptor.py b/tests/core/test_acceptor.py index acfe0eab21..e52eaffca3 100644 --- a/tests/core/test_acceptor.py +++ b/tests/core/test_acceptor.py @@ -8,7 +8,6 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. """ -import socket import selectors import multiprocessing @@ -41,19 +40,23 @@ def setUp(self) -> None: ) @mock.patch('selectors.DefaultSelector') - @mock.patch('socket.fromfd') + @mock.patch('socket.socket') + @mock.patch('socket.dup') @mock.patch('proxy.core.acceptor.acceptor.recv_handle') def test_continues_when_no_events( self, mock_recv_handle: mock.Mock, - mock_fromfd: mock.Mock, + mock_socket_dup: mock.Mock, + mock_socket: mock.Mock, mock_selector: mock.Mock, ) -> None: fileno = 10 + mock_socket_dup.side_effect = lambda fd: fd conn = mock.MagicMock() addr = mock.MagicMock() - sock = mock_fromfd.return_value - mock_fromfd.return_value.accept.return_value = (conn, addr) + sock = mock.MagicMock() + sock.accept.return_value = (conn, addr) + mock_socket.side_effect = lambda **kwargs: sock if kwargs.get('fileno') == fileno else mock.DEFAULT mock_recv_handle.return_value = fileno selector = mock_selector.return_value @@ -66,20 +69,24 @@ def test_continues_when_no_events( @mock.patch('threading.Thread') @mock.patch('selectors.DefaultSelector') - @mock.patch('socket.fromfd') + @mock.patch('socket.dup') + @mock.patch('socket.socket') @mock.patch('proxy.core.acceptor.acceptor.recv_handle') def test_accepts_client_from_server_socket( self, mock_recv_handle: mock.Mock, - mock_fromfd: mock.Mock, + mock_socket: mock.Mock, + mock_socket_dup: mock.Mock, mock_selector: mock.Mock, mock_thread: mock.Mock, ) -> None: fileno = 10 + mock_socket_dup.side_effect = lambda fd: fd conn = mock.MagicMock() addr = mock.MagicMock() - sock = mock_fromfd.return_value - mock_fromfd.return_value.accept.return_value = (conn, addr) + sock = mock.MagicMock() + sock.accept.return_value = (conn, addr) + mock_socket.side_effect = lambda **kwargs: sock if kwargs.get('fileno') == fileno else mock.DEFAULT mock_recv_handle.return_value = fileno self.pipe[1].recv.return_value = 1 @@ -100,10 +107,8 @@ def test_accepts_client_from_server_socket( ) selector.unregister.assert_called_with(fileno) mock_recv_handle.assert_called_with(self.pipe[1]) - mock_fromfd.assert_called_with( - fileno, - family=socket.AF_INET, - type=socket.SOCK_STREAM, + mock_socket.assert_called_with( + fileno=fileno, ) self.flags.work_klass.assert_called_with( self.work_klass.create.return_value, diff --git a/tests/core/test_listener.py b/tests/core/test_listener.py index 4f3adf241a..5dea87851e 100644 --- a/tests/core/test_listener.py +++ b/tests/core/test_listener.py @@ -26,8 +26,8 @@ class TestListener(unittest.TestCase): @mock.patch('socket.socket') def test_setup_and_teardown(self, mock_socket: mock.Mock) -> None: sock = mock_socket.return_value - flags = FlagParser.initialize(port=0) - with TcpSocketListener(flags=flags) as listener: + flags = FlagParser.initialize() + with TcpSocketListener(flags=flags, hostname=flags.hostname, port=flags.port) as listener: mock_socket.assert_called_with( socket.AF_INET6 if flags.hostname.version == 6 else socket.AF_INET, socket.SOCK_STREAM, @@ -42,7 +42,7 @@ def test_setup_and_teardown(self, mock_socket: mock.Mock) -> None: (socket.IPPROTO_TCP, socket.TCP_NODELAY, 1), ) sock.bind.assert_called_with( - (str(flags.hostname), 0), + (str(flags.hostname), flags.port), ) sock.listen.assert_called_with(flags.backlog) sock.setblocking.assert_called_with(False) diff --git a/tests/core/test_listener_pool.py b/tests/core/test_listener_pool.py index afa47985db..dd776b585f 100644 --- a/tests/core/test_listener_pool.py +++ b/tests/core/test_listener_pool.py @@ -10,6 +10,8 @@ """ import os import tempfile +import ipaddress +import itertools import pytest import unittest @@ -29,9 +31,9 @@ def test_setup_and_teardown( mock_unix_listener: mock.Mock, mock_tcp_listener: mock.Mock, ) -> None: - flags = FlagParser.initialize(port=0) + flags = FlagParser.initialize() with ListenerPool(flags=flags) as pool: - mock_tcp_listener.assert_called_once_with(flags=flags) + mock_tcp_listener.assert_called_once_with(flags=flags, hostname=flags.hostname, port=flags.port) mock_unix_listener.assert_not_called() mock_tcp_listener.return_value.setup.assert_called_once() self.assertEqual(pool.pool[0], mock_tcp_listener.return_value) @@ -60,36 +62,33 @@ def test_unix_socket_listener( @mock.patch('proxy.core.listener.pool.TcpSocketListener') @mock.patch('proxy.core.listener.pool.UnixSocketListener') - def test_multi_listener_on_ports( + def test_multi_listener( self, mock_unix_listener: mock.Mock, mock_tcp_listener: mock.Mock, ) -> None: flags = FlagParser.initialize( - ['--ports', '9000', '--ports', '9001'], - port=0, + ['--hostnames', '127.0.0.2', '--ports', '9000', '--ports', '9001'], ) with ListenerPool(flags=flags) as pool: mock_unix_listener.assert_not_called() - self.assertEqual(len(pool.pool), 3) - self.assertEqual(mock_tcp_listener.call_count, 3) - self.assertEqual( - mock_tcp_listener.call_args_list[0][1]['flags'], - flags, + self.assertEqual(len(pool.pool), 6) + self.assertEqual(mock_tcp_listener.call_count, 6) + self.assertSetEqual( + { + ( + mock_tcp_listener.call_args_list[call][1]['hostname'], + mock_tcp_listener.call_args_list[call][1]['port'], + ) for call in range(6) + }, + set( + itertools.product( + [ipaddress.IPv4Address('127.0.0.1'), ipaddress.IPv4Address('127.0.0.2')], + [8899, 9000, 9001], + ), + ), ) - self.assertEqual( - mock_tcp_listener.call_args_list[1][1]['flags'], - flags, - ) - self.assertEqual( - mock_tcp_listener.call_args_list[1][1]['port'], - 9000, - ) - self.assertEqual( - mock_tcp_listener.call_args_list[2][1]['flags'], - flags, - ) - self.assertEqual( - mock_tcp_listener.call_args_list[2][1]['port'], - 9001, + self.assertListEqual( + [mock_tcp_listener.call_args_list[call][1]['flags'] for call in range(6)], + [flags, flags, flags, flags, flags, flags], ) diff --git a/tests/http/exceptions/test_http_proxy_auth_failed.py b/tests/http/exceptions/test_http_proxy_auth_failed.py index 9dfbe2142e..7b35ae740c 100644 --- a/tests/http/exceptions/test_http_proxy_auth_failed.py +++ b/tests/http/exceptions/test_http_proxy_auth_failed.py @@ -25,18 +25,23 @@ class TestHttpProxyAuthFailed(Assertions): @pytest.fixture(autouse=True) # type: ignore[misc] def _setUp(self, mocker: MockerFixture) -> None: - self.mock_fromfd = mocker.patch('socket.fromfd') + self.mock_socket = mocker.patch('socket.socket') + self.mock_socket_dup = mocker.patch('socket.dup') self.mock_selector = mocker.patch('selectors.DefaultSelector') self.mock_server_conn = mocker.patch( 'proxy.http.proxy.server.TcpServerConnection', ) self.fileno = 10 + self.mock_socket_dup.side_effect = lambda fd: fd + self._addr = ('127.0.0.1', 54382) self.flags = FlagParser.initialize( ["--basic-auth", "user:pass"], threaded=True, ) - self._conn = self.mock_fromfd.return_value + self._conn = mocker.MagicMock() + self.mock_socket.side_effect = \ + lambda **kwargs: self._conn if kwargs.get('fileno') == self.fileno else mocker.DEFAULT self.protocol_handler = HttpProtocolHandler( HttpClientConnection(self._conn, self._addr), flags=self.flags, diff --git a/tests/http/proxy/test_http_proxy.py b/tests/http/proxy/test_http_proxy.py index 68f6bfd999..80ad7a9a38 100644 --- a/tests/http/proxy/test_http_proxy.py +++ b/tests/http/proxy/test_http_proxy.py @@ -30,9 +30,12 @@ def _setUp(self, mocker: MockerFixture) -> None: 'proxy.http.proxy.server.TcpServerConnection', ) self.mock_selector = mocker.patch('selectors.DefaultSelector') - self.mock_fromfd = mocker.patch('socket.fromfd') + self.mock_socket = mocker.patch('socket.socket') + self.mock_socket_dup = mocker.patch('socket.dup') self.fileno = 10 + self.mock_socket_dup.side_effect = lambda fd: fd + self._addr = ('127.0.0.1', 54382) self.flags = FlagParser.initialize(threaded=True) self.plugin = mocker.MagicMock() @@ -40,7 +43,9 @@ def _setUp(self, mocker: MockerFixture) -> None: b'HttpProtocolHandlerPlugin': [HttpProxyPlugin], b'HttpProxyBasePlugin': [self.plugin], } - self._conn = self.mock_fromfd.return_value + self._conn = mocker.MagicMock() + self.mock_socket.side_effect = \ + lambda **kwargs: self._conn if kwargs.get('fileno') == self.fileno else mocker.DEFAULT self.protocol_handler = HttpProtocolHandler( HttpClientConnection(self._conn, self._addr), flags=self.flags, diff --git a/tests/http/proxy/test_http_proxy_tls_interception.py b/tests/http/proxy/test_http_proxy_tls_interception.py index b3734ca23f..654bbc5fcd 100644 --- a/tests/http/proxy/test_http_proxy_tls_interception.py +++ b/tests/http/proxy/test_http_proxy_tls_interception.py @@ -39,7 +39,7 @@ async def test_e2e(self, mocker: MockerFixture) -> None: host, port = uuid.uuid4().hex, 443 netloc = '{0}:{1}'.format(host, port) - self.mock_fromfd = mocker.patch('socket.fromfd') + self.mock_socket_dup = mocker.patch('socket.dup') self.mock_selector = mocker.patch('selectors.DefaultSelector') self.mock_sign_csr = mocker.patch('proxy.http.proxy.server.sign_csr') self.mock_gen_csr = mocker.patch('proxy.http.proxy.server.gen_csr') @@ -53,6 +53,9 @@ async def test_e2e(self, mocker: MockerFixture) -> None: self.mock_gen_csr.return_value = True self.mock_gen_public_key.return_value = True + self.fileno = 10 + self.mock_socket_dup.side_effect = lambda fd: fd + # Used for server side wrapping self.mock_ssl_context = mocker.patch('ssl.create_default_context') upstream_tls_sock = mock.MagicMock(spec=ssl.SSLSocket) @@ -82,7 +85,6 @@ def mock_connection() -> Any: type(self.mock_server_conn.return_value).closed = \ mock.PropertyMock(return_value=False) - self.fileno = 10 self._addr = ('127.0.0.1', 54382) self.flags = FlagParser.initialize( ca_cert_file='ca-cert.pem', @@ -101,7 +103,7 @@ def mock_connection() -> Any: b'HttpProtocolHandlerPlugin': [self.plugin, HttpProxyPlugin], b'HttpProxyBasePlugin': [self.proxy_plugin], } - self._conn = self.mock_fromfd.return_value + self._conn = mock.MagicMock(spec=socket.socket) self.protocol_handler = HttpProtocolHandler( HttpClientConnection(self._conn, self._addr), flags=self.flags, diff --git a/tests/http/test_protocol_handler.py b/tests/http/test_protocol_handler.py index 0ba0b44284..62829695ae 100644 --- a/tests/http/test_protocol_handler.py +++ b/tests/http/test_protocol_handler.py @@ -52,12 +52,13 @@ class TestHttpProtocolHandlerWithoutServerMock(Assertions): @pytest.fixture(autouse=True) # type: ignore[misc] def _setUp(self, mocker: MockerFixture) -> None: - self.mock_fromfd = mocker.patch('socket.fromfd') + self.mock_socket = mocker.patch('socket.socket') + self.mock_socket_dup = mocker.patch('socket.dup', side_effect=lambda fd: fd) self.mock_selector = mocker.patch('selectors.DefaultSelector') self.fileno = 10 self._addr = ('127.0.0.1', 54382) - self._conn = self.mock_fromfd.return_value + self._conn = self.mock_socket.return_value self.http_server_port = 65535 self.flags = FlagParser.initialize(threaded=True) @@ -88,7 +89,7 @@ async def test_proxy_connection_failed(self) -> None: @pytest.mark.asyncio # type: ignore[misc] async def test_proxy_authentication_failed(self) -> None: - self._conn = self.mock_fromfd.return_value + self._conn = self.mock_socket.return_value mock_selector_for_client_read(self) flags = FlagParser.initialize( auth_code=base64.b64encode(b'user:pass'), @@ -147,7 +148,8 @@ class TestHttpProtocolHandler(Assertions): @pytest.fixture(autouse=True) # type: ignore[misc] def _setUp(self, mocker: MockerFixture) -> None: - self.mock_fromfd = mocker.patch('socket.fromfd') + self.mock_socket = mocker.patch('socket.socket') + self.mock_socket_dup = mocker.patch('socket.dup', side_effect=lambda fd: fd) self.mock_selector = mocker.patch('selectors.DefaultSelector') self.mock_server_connection = mocker.patch( 'proxy.http.proxy.server.TcpServerConnection', @@ -155,7 +157,7 @@ def _setUp(self, mocker: MockerFixture) -> None: self.fileno = 10 self._addr = ('127.0.0.1', 54382) - self._conn = self.mock_fromfd.return_value + self._conn = self.mock_socket.return_value self.http_server_port = 65535 self.flags = FlagParser.initialize(threaded=True) @@ -311,7 +313,7 @@ def has_buffer() -> bool: @pytest.mark.asyncio # type: ignore[misc] async def test_authenticated_proxy_http_get(self) -> None: - self._conn = self.mock_fromfd.return_value + self._conn = self.mock_socket.return_value mock_selector_for_client_read(self) server = self.mock_server_connection.return_value @@ -363,7 +365,7 @@ async def test_authenticated_proxy_http_tunnel(self) -> None: server = self.mock_server_connection.return_value server.connect.return_value = True server.buffer_size.return_value = 0 - self._conn = self.mock_fromfd.return_value + self._conn = self.mock_socket.return_value self.mock_selector_for_client_read_and_server_write(server) flags = FlagParser.initialize( diff --git a/tests/http/web/test_web_server.py b/tests/http/web/test_web_server.py index e5bbabb39b..71baac4fd6 100644 --- a/tests/http/web/test_web_server.py +++ b/tests/http/web/test_web_server.py @@ -41,10 +41,11 @@ def test_on_client_connection_called_on_teardown(mocker: MockerFixture) -> None: plugin = mocker.MagicMock() - mock_fromfd = mocker.patch('socket.fromfd') + mock_socket_dup = mocker.patch('socket.dup') + mock_socket_dup.side_effect = lambda fd: fd flags = FlagParser.initialize(threaded=True) flags.plugins = {b'HttpProtocolHandlerPlugin': [plugin]} - _conn = mock_fromfd.return_value + _conn = mocker.MagicMock() _addr = ('127.0.0.1', 54382) protocol_handler = HttpProtocolHandler( HttpClientConnection(_conn, _addr), @@ -141,11 +142,12 @@ class TestWebServerPluginWithPacFilePlugin(Assertions): ], ) # type: ignore[misc] def _setUp(self, request: Any, mocker: MockerFixture) -> None: - self.mock_fromfd = mocker.patch('socket.fromfd') + self.mock_socket = mocker.patch('socket.socket') + self.mock_socket_dup = mocker.patch('socket.dup', side_effect=lambda fd: fd) self.mock_selector = mocker.patch('selectors.DefaultSelector') self.fileno = 10 self._addr = ('127.0.0.1', 54382) - self._conn = self.mock_fromfd.return_value + self._conn = self.mock_socket.return_value self.pac_file = request.param if isinstance(self.pac_file, str): with open(self.pac_file, 'rb') as f: @@ -195,11 +197,12 @@ class TestStaticWebServerPlugin(Assertions): @pytest.fixture(autouse=True) # type: ignore[misc] def _setUp(self, mocker: MockerFixture) -> None: - self.mock_fromfd = mocker.patch('socket.fromfd') + self.mock_socket = mocker.patch('socket.socket') + self.mock_socket_dup = mocker.patch('socket.dup', side_effect=lambda fd: fd) self.mock_selector = mocker.patch('selectors.DefaultSelector') self.fileno = 10 self._addr = ('127.0.0.1', 54382) - self._conn = self.mock_fromfd.return_value + self._conn = self.mock_socket.return_value # Setup a static directory self.static_server_dir = os.path.join(tempfile.gettempdir(), 'static') self.index_file_path = os.path.join( @@ -316,11 +319,12 @@ class TestWebServerPlugin(Assertions): @pytest.fixture(autouse=True) # type: ignore[misc] def _setUp(self, mocker: MockerFixture) -> None: - self.mock_fromfd = mocker.patch('socket.fromfd') + self.mock_socket = mocker.patch('socket.socket') + self.mock_socket_dup = mocker.patch('socket.dup', side_effect=lambda fd: fd) self.mock_selector = mocker.patch('selectors.DefaultSelector') self.fileno = 10 self._addr = ('127.0.0.1', 54382) - self._conn = self.mock_fromfd.return_value + self._conn = self.mock_socket.return_value self.flags = FlagParser.initialize(threaded=True) self.flags.plugins = Plugins.load([ bytes_(PLUGIN_HTTP_PROXY), @@ -334,7 +338,7 @@ def _setUp(self, mocker: MockerFixture) -> None: @pytest.mark.asyncio # type: ignore[misc] async def test_default_web_server_returns_404(self) -> None: - self._conn = self.mock_fromfd.return_value + self._conn = self.mock_socket.return_value self.mock_selector.return_value.select.return_value = [ ( selectors.SelectorKey( diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000000..232621f0b5 --- /dev/null +++ b/tests/integration/__init__.py @@ -0,0 +1,10 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" diff --git a/tests/integration/test_integration.py b/tests/integration/test_integration.py index ca0012b6a6..1a14374423 100644 --- a/tests/integration/test_integration.py +++ b/tests/integration/test_integration.py @@ -11,6 +11,7 @@ Test the simplest proxy use scenario for smoke. """ import os +import sys import time import tempfile import subprocess @@ -18,7 +19,7 @@ from typing import Any, List, Generator from pathlib import Path from subprocess import Popen -from subprocess import check_output as _check_output +from subprocess import run as _run import pytest @@ -30,9 +31,9 @@ os.makedirs(CERT_DIR, exist_ok=True) -def check_output(args: List[Any]) -> bytes: # pragma: no cover +def run(args: List[Any], **kwargs: Any) -> None: args = args if not IS_WINDOWS else ['powershell'] + args - return _check_output(args) + _run(args, check=True, stderr=subprocess.STDOUT, **kwargs) def _https_server_flags() -> str: @@ -128,29 +129,34 @@ def _tls_interception_flags(ca_cert_suffix: str = '') -> str: @pytest.fixture(scope='session', autouse=not IS_WINDOWS) # type: ignore[misc] def _gen_https_certificates(request: Any) -> None: - check_output([ + run([ 'make', 'https-certificates', + '-e', 'PYTHON="%s"' % (sys.executable,), '-e', 'CERT_DIR=%s/' % (str(CERT_DIR)), ]) - check_output([ + run([ 'make', 'sign-https-certificates', + '-e', 'PYTHON="%s"' % (sys.executable,), '-e', 'CERT_DIR=%s/' % (str(CERT_DIR)), ]) @pytest.fixture(scope='session', autouse=not IS_WINDOWS) # type: ignore[misc] def _gen_ca_certificates(request: Any) -> None: - check_output([ + run([ 'make', 'ca-certificates', + '-e', 'PYTHON="%s"' % (sys.executable,), '-e', 'CERT_DIR=%s/' % (str(CERT_DIR)), ]) - check_output([ + run([ 'make', 'ca-certificates', + '-e', 'PYTHON="%s"' % (sys.executable,), '-e', 'CA_CERT_SUFFIX=-chunk', '-e', 'CERT_DIR=%s/' % (str(CERT_DIR)), ]) - check_output([ + run([ 'make', 'ca-certificates', + '-e', 'PYTHON="%s"' % (sys.executable,), '-e', 'CA_CERT_SUFFIX=-post', '-e', 'CERT_DIR=%s/' % (str(CERT_DIR)), ]) @@ -175,7 +181,7 @@ def proxy_py_subprocess(request: Any) -> Generator[int, None, None]: ca_cert_dir = TEMP_DIR / ('certificates-%s' % run_id) os.makedirs(ca_cert_dir, exist_ok=True) proxy_cmd = ( - 'python', '-m', 'proxy', + sys.executable, '-m', 'proxy', '--hostname', '127.0.0.1', '--port', '0', '--port-file', str(port_file), @@ -190,8 +196,12 @@ def proxy_py_subprocess(request: Any) -> Generator[int, None, None]: ) + tuple(request.param.split()) proxy_proc = Popen(proxy_cmd, stderr=subprocess.STDOUT) # Needed because port file might not be available immediately - while not port_file.exists(): + retries = 0 + while not port_file.exists() and retries < 8: time.sleep(1) + retries += 1 + if not port_file.exists(): + raise RuntimeError('proxy.py failed to boot up') try: yield int(port_file.read_text()) finally: @@ -218,7 +228,9 @@ def test_integration(proxy_py_subprocess: int) -> None: """An acceptance test using ``curl`` through proxy.py.""" this_test_module = Path(__file__) shell_script_test = this_test_module.with_suffix('.sh') - check_output([str(shell_script_test), str(proxy_py_subprocess)]) + print('shell_script_test %s' % shell_script_test) + print('proxy_py_subprocess %s' % proxy_py_subprocess) + run([str(shell_script_test), str(proxy_py_subprocess)], stdout=sys.stdout.buffer) @pytest.mark.smoke # type: ignore[misc] @@ -236,7 +248,7 @@ def test_https_integration(proxy_py_subprocess: int) -> None: this_test_module = Path(__file__) shell_script_test = this_test_module.with_suffix('.sh') # "1" means use-https scheme for requests to instance - check_output([str(shell_script_test), str(proxy_py_subprocess), '1']) + run([str(shell_script_test), str(proxy_py_subprocess), '1']) @pytest.mark.smoke # type: ignore[misc] @@ -252,7 +264,7 @@ def test_https_integration(proxy_py_subprocess: int) -> None: def test_integration_with_interception_flags(proxy_py_subprocess: int) -> None: """An acceptance test for TLS interception using ``curl`` through proxy.py.""" shell_script_test = Path(__file__).parent / 'test_interception.sh' - check_output([ + run([ str(shell_script_test), str(proxy_py_subprocess), str(CERT_DIR), @@ -273,7 +285,7 @@ def test_modify_chunk_response_integration(proxy_py_subprocess: int) -> None: """An acceptance test for :py:class:`~proxy.plugin.ModifyChunkResponsePlugin` interception using ``curl`` through proxy.py.""" shell_script_test = Path(__file__).parent / 'test_modify_chunk_response.sh' - check_output([ + run([ str(shell_script_test), str(proxy_py_subprocess), str(CERT_DIR), @@ -294,7 +306,7 @@ def test_modify_post_response_integration(proxy_py_subprocess: int) -> None: """An acceptance test for :py:class:`~proxy.plugin.ModifyPostDataPlugin` interception using ``curl`` through proxy.py.""" shell_script_test = Path(__file__).parent / 'test_modify_post_data.sh' - check_output([ + run([ str(shell_script_test), str(proxy_py_subprocess), str(CERT_DIR), diff --git a/tests/plugin/test_http_proxy_plugins.py b/tests/plugin/test_http_proxy_plugins.py index 259c8bb846..94c72cc540 100644 --- a/tests/plugin/test_http_proxy_plugins.py +++ b/tests/plugin/test_http_proxy_plugins.py @@ -40,7 +40,8 @@ class TestHttpProxyPluginExamples(Assertions): @pytest.fixture(autouse=True) # type: ignore[misc] def _setUp(self, request: Any, mocker: MockerFixture) -> None: - self.mock_fromfd = mocker.patch('socket.fromfd') + self.mock_socket = mocker.patch('socket.socket') + self.mock_socket_dup = mocker.patch('socket.dup', side_effect=lambda fd: fd) self.mock_selector = mocker.patch('selectors.DefaultSelector') self.mock_server_conn = mocker.patch( 'proxy.http.proxy.server.TcpServerConnection', @@ -66,7 +67,7 @@ def _setUp(self, request: Any, mocker: MockerFixture) -> None: b'HttpProtocolHandlerPlugin': [HttpProxyPlugin], b'HttpProxyBasePlugin': [plugin], } - self._conn = self.mock_fromfd.return_value + self._conn = self.mock_socket.return_value self.protocol_handler = HttpProtocolHandler( HttpClientConnection(self._conn, self._addr), flags=self.flags, diff --git a/tests/plugin/test_http_proxy_plugins_with_tls_interception.py b/tests/plugin/test_http_proxy_plugins_with_tls_interception.py index 64dd9440f6..3d8d6a28f4 100644 --- a/tests/plugin/test_http_proxy_plugins_with_tls_interception.py +++ b/tests/plugin/test_http_proxy_plugins_with_tls_interception.py @@ -35,7 +35,7 @@ class TestHttpProxyPluginExamplesWithTlsInterception(Assertions): @pytest.fixture(autouse=True) # type: ignore[misc] def _setUp(self, request: Any, mocker: MockerFixture) -> None: - self.mock_fromfd = mocker.patch('socket.fromfd') + self.mock_socket_dup = mocker.patch('socket.dup') self.mock_selector = mocker.patch('selectors.DefaultSelector') self.mock_sign_csr = mocker.patch('proxy.http.proxy.server.sign_csr') self.mock_gen_csr = mocker.patch('proxy.http.proxy.server.gen_csr') @@ -53,6 +53,8 @@ def _setUp(self, request: Any, mocker: MockerFixture) -> None: self.mock_gen_public_key.return_value = True self.fileno = 10 + self.mock_socket_dup.side_effect = lambda fd: fd + self._addr = ('127.0.0.1', 54382) self.flags = FlagParser.initialize( ca_cert_file='ca-cert.pem', @@ -69,7 +71,7 @@ def _setUp(self, request: Any, mocker: MockerFixture) -> None: b'HttpProxyBasePlugin': [plugin], } self._conn = mocker.MagicMock(spec=socket.socket) - self.mock_fromfd.return_value = self._conn + self.protocol_handler = HttpProtocolHandler( HttpClientConnection(self._conn, self._addr), flags=self.flags, ) diff --git a/tests/socks/test_handler.py b/tests/socks/test_handler.py index cd5364954b..fb9ca1fcd9 100644 --- a/tests/socks/test_handler.py +++ b/tests/socks/test_handler.py @@ -21,12 +21,13 @@ class TestHttpProtocolHandlerWithoutServerMock(Assertions): @pytest.fixture(autouse=True) # type: ignore[misc] def _setUp(self, mocker: MockerFixture) -> None: - self.mock_fromfd = mocker.patch('socket.fromfd') + self.mock_socket = mocker.patch('socket.socket') + self.mock_socket_dup = mocker.patch('socket.dup', side_effect=lambda fd: fd) self.mock_selector = mocker.patch('selectors.DefaultSelector') self.fileno = 10 self._addr = ('127.0.0.1', 54382) - self._conn = self.mock_fromfd.return_value + self._conn = self.mock_socket.return_value self.flags = FlagParser.initialize(threaded=True)