Skip to content

Commit

Permalink
Support plugins defined as inner classes (#1318)
Browse files Browse the repository at this point in the history
* Support plugins defined as inner classes

* Prefer __qualname__ over __name__ for classes

---------

Co-authored-by: Abhinav Singh <[email protected]>
  • Loading branch information
alexey-pelykh and abhinavsingh authored Mar 14, 2023
1 parent 93f6fd6 commit f3d19ff
Show file tree
Hide file tree
Showing 16 changed files with 122 additions and 35 deletions.
1 change: 1 addition & 0 deletions docs/changelog-fragments.d/1318.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Support plugins defined as inner classes
52 changes: 38 additions & 14 deletions proxy/common/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import logging
import importlib
import itertools
from types import ModuleType
from typing import Any, Dict, List, Tuple, Union, Optional

from .utils import text_, bytes_
Expand Down Expand Up @@ -75,31 +76,54 @@ def load(
# this plugin_ is implementing
base_klass = None
for k in mro:
if bytes_(k.__name__) in p:
if bytes_(k.__qualname__) in p:
base_klass = k
break
if base_klass is None:
raise ValueError('%s is NOT a valid plugin' % text_(plugin_))
if klass not in p[bytes_(base_klass.__name__)]:
p[bytes_(base_klass.__name__)].append(klass)
logger.info('Loaded plugin %s.%s', module_name, klass.__name__)
if klass not in p[bytes_(base_klass.__qualname__)]:
p[bytes_(base_klass.__qualname__)].append(klass)
logger.info('Loaded plugin %s.%s', module_name, klass.__qualname__)
# print(p)
return p

@staticmethod
def importer(plugin: Union[bytes, type]) -> Tuple[type, str]:
"""Import and returns the plugin."""
if isinstance(plugin, type):
return (plugin, '__main__')
if inspect.isclass(plugin):
return (plugin, plugin.__module__ or '__main__')
raise ValueError('%s is not a valid reference to a plugin class' % text_(plugin))
plugin_ = text_(plugin.strip())
assert plugin_ != ''
module_name, klass_name = plugin_.rsplit(text_(DOT), 1)
klass = getattr(
importlib.import_module(
module_name.replace(
os.path.sep, text_(DOT),
),
),
klass_name,
)
path = plugin_.split(text_(DOT))
klass = None

def locate_klass(klass_module_name: str, klass_path: List[str]) -> Union[type, None]:
klass_module_name = klass_module_name.replace(os.path.sep, text_(DOT))
try:
klass_module = importlib.import_module(klass_module_name)
except ModuleNotFoundError:
return None
klass_container: Union[ModuleType, type] = klass_module
for klass_path_part in klass_path:
try:
klass_container = getattr(klass_container, klass_path_part)
except AttributeError:
return None
if not isinstance(klass_container, type) or not inspect.isclass(klass_container):
return None
return klass_container

module_name = None
for module_name_parts in range(len(path) - 1, 0, -1):
module_name = '.'.join(path[0:module_name_parts])
klass = locate_klass(module_name, path[module_name_parts:])
if klass:
break
if klass is None:
module_name = '__main__'
klass = locate_klass(module_name, path)
if klass is None or module_name is None:
raise ValueError('%s is not resolvable as a plugin class' % text_(plugin))
return (klass, module_name)
2 changes: 1 addition & 1 deletion proxy/core/acceptor/acceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def _work(self, conn: socket.socket, addr: Optional[HostPort]) -> None:
conn,
addr,
event_queue=self.event_queue,
publisher_id=self.__class__.__name__,
publisher_id=self.__class__.__qualname__,
)
# TODO: Move me into target method
logger.debug( # pragma: no cover
Expand Down
2 changes: 1 addition & 1 deletion proxy/core/work/fd/fd.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def work(self, *args: Any) -> None:
self.works[fileno].publish_event(
event_name=eventNames.WORK_STARTED,
event_payload={'fileno': fileno, 'addr': addr},
publisher_id=self.__class__.__name__,
publisher_id=self.__class__.__qualname__,
)
try:
self.works[fileno].initialize()
Expand Down
2 changes: 1 addition & 1 deletion proxy/core/work/work.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def shutdown(self) -> None:
self.publish_event(
event_name=eventNames.WORK_FINISHED,
event_payload={},
publisher_id=self.__class__.__name__,
publisher_id=self.__class__.__qualname__,
)

def run(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion proxy/http/exception/http_request_rejected.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
self.reason: Optional[bytes] = reason
self.headers: Optional[Dict[bytes, bytes]] = headers
self.body: Optional[bytes] = body
klass_name = self.__class__.__name__
klass_name = self.__class__.__qualname__
super().__init__(
message='%s %r' % (klass_name, reason)
if reason
Expand Down
2 changes: 1 addition & 1 deletion proxy/http/exception/proxy_auth_failed.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class ProxyAuthenticationFailed(HttpProtocolException):
incoming request doesn't present necessary credentials."""

def __init__(self, **kwargs: Any) -> None:
super().__init__(self.__class__.__name__, **kwargs)
super().__init__(self.__class__.__qualname__, **kwargs)

def response(self, _request: 'HttpParser') -> memoryview:
return PROXY_AUTH_FAILED_RESPONSE_PKT
2 changes: 1 addition & 1 deletion proxy/http/exception/proxy_conn_failed.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, host: str, port: int, reason: str, **kwargs: Any):
self.host: str = host
self.port: int = port
self.reason: str = reason
super().__init__('%s %s' % (self.__class__.__name__, reason), **kwargs)
super().__init__('%s %s' % (self.__class__.__qualname__, reason), **kwargs)

def response(self, _request: 'HttpParser') -> memoryview:
return BAD_GATEWAY_RESPONSE_PKT
2 changes: 1 addition & 1 deletion proxy/http/proxy/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def name(self) -> str:
Defaults to name of the class. This helps plugin developers to directly
access a specific plugin by its name."""
return self.__class__.__name__ # pragma: no cover
return self.__class__.__qualname__ # pragma: no cover

def resolve_dns(self, host: str, port: int) -> Tuple[Optional[str], Optional['HostPort']]:
"""Resolve upstream server host to an IP address.
Expand Down
8 changes: 4 additions & 4 deletions proxy/http/proxy/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,7 @@ def emit_request_complete(self) -> None:
if self.request.method == httpMethods.POST
else None,
},
publisher_id=self.__class__.__name__,
publisher_id=self.__class__.__qualname__,
)

def emit_response_events(self, chunk_size: int) -> None:
Expand Down Expand Up @@ -911,7 +911,7 @@ def emit_response_headers_complete(self) -> None:
for k, v in self.response.headers.items()
},
},
publisher_id=self.__class__.__name__,
publisher_id=self.__class__.__qualname__,
)

def emit_response_chunk_received(self, chunk_size: int) -> None:
Expand All @@ -925,7 +925,7 @@ def emit_response_chunk_received(self, chunk_size: int) -> None:
'chunk_size': chunk_size,
'encoded_chunk_size': chunk_size,
},
publisher_id=self.__class__.__name__,
publisher_id=self.__class__.__qualname__,
)

def emit_response_complete(self) -> None:
Expand All @@ -938,7 +938,7 @@ def emit_response_complete(self) -> None:
event_payload={
'encoded_response_size': self.response.total_size,
},
publisher_id=self.__class__.__name__,
publisher_id=self.__class__.__qualname__,
)

#
Expand Down
2 changes: 1 addition & 1 deletion proxy/http/server/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def name(self) -> str:
Defaults to name of the class. This helps plugin developers to directly
access a specific plugin by its name."""
return self.__class__.__name__ # pragma: no cover
return self.__class__.__qualname__ # pragma: no cover

@abstractmethod
def routes(self) -> List[Tuple[int, str]]:
Expand Down
25 changes: 25 additions & 0 deletions tests/common/my_plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# -*- coding: utf-8 -*-
"""
proxy.py
~~~~~~~~
⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
Network monitoring, controls & Application development, testing, debugging.
:copyright: (c) 2013-present by Abhinav Singh and contributors.
:license: BSD, see LICENSE for more details.
"""
from typing import Any

from proxy.http.proxy import HttpProxyPlugin


class MyHttpProxyPlugin(HttpProxyPlugin):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)


class OuterClass:

class MyHttpProxyPlugin(HttpProxyPlugin):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
39 changes: 38 additions & 1 deletion tests/common/test_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
:copyright: (c) 2013-present by Abhinav Singh and contributors.
:license: BSD, see LICENSE for more details.
"""
from typing import Dict, List
from typing import Any, Dict, List

import unittest
from unittest import mock
Expand All @@ -19,6 +19,7 @@
from proxy.common.utils import bytes_
from proxy.common.version import __version__
from proxy.common.constants import PLUGIN_HTTP_PROXY, PY2_DEPRECATION_MESSAGE
from . import my_plugins


class TestFlags(unittest.TestCase):
Expand Down Expand Up @@ -140,6 +141,42 @@ def test_unique_plugin_from_class(self) -> None:
],
})

def test_plugin_from_inner_class_by_type(self) -> None:
self.flags = FlagParser.initialize(
[], plugins=[
TestFlags.MyHttpProxyPlugin,
my_plugins.MyHttpProxyPlugin,
my_plugins.OuterClass.MyHttpProxyPlugin,
],
)
self.assert_plugins({
'HttpProtocolHandlerPlugin': [
TestFlags.MyHttpProxyPlugin,
my_plugins.MyHttpProxyPlugin,
my_plugins.OuterClass.MyHttpProxyPlugin,
],
})

def test_plugin_from_inner_class_by_name(self) -> None:
self.flags = FlagParser.initialize(
[], plugins=[
b'tests.common.test_flags.TestFlags.MyHttpProxyPlugin',
b'tests.common.my_plugins.MyHttpProxyPlugin',
b'tests.common.my_plugins.OuterClass.MyHttpProxyPlugin',
],
)
self.assert_plugins({
'HttpProtocolHandlerPlugin': [
TestFlags.MyHttpProxyPlugin,
my_plugins.MyHttpProxyPlugin,
my_plugins.OuterClass.MyHttpProxyPlugin,
],
})

class MyHttpProxyPlugin(HttpProxyPlugin):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

def test_basic_auth_flag_is_base64_encoded(self) -> None:
flags = FlagParser.initialize(['--basic-auth', 'user:pass'])
self.assertEqual(flags.auth_code, b'dXNlcjpwYXNz')
Expand Down
8 changes: 4 additions & 4 deletions tests/core/test_event_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_empties_queue(self) -> None:
request_id='1234',
event_name=eventNames.WORK_STARTED,
event_payload={'hello': 'events'},
publisher_id=self.__class__.__name__,
publisher_id=self.__class__.__qualname__,
)
self.dispatcher.run_once()
with self.assertRaises(queue.Empty):
Expand All @@ -64,7 +64,7 @@ def subscribe(self, mock_time: mock.Mock) -> connection.Connection:
request_id='1234',
event_name=eventNames.WORK_STARTED,
event_payload={'hello': 'events'},
publisher_id=self.__class__.__name__,
publisher_id=self.__class__.__qualname__,
)
# consume
self.dispatcher.run_once()
Expand All @@ -79,7 +79,7 @@ def subscribe(self, mock_time: mock.Mock) -> connection.Connection:
'event_timestamp': 1234567,
'event_name': eventNames.WORK_STARTED,
'event_payload': {'hello': 'events'},
'publisher_id': self.__class__.__name__,
'publisher_id': self.__class__.__qualname__,
},
)
return relay_recv
Expand All @@ -101,7 +101,7 @@ def test_unsubscribe(self) -> None:
request_id='1234',
event_name=eventNames.WORK_STARTED,
event_payload={'hello': 'events'},
publisher_id=self.__class__.__name__,
publisher_id=self.__class__.__qualname__,
)
self.dispatcher.run_once()
with self.assertRaises(EOFError):
Expand Down
4 changes: 2 additions & 2 deletions tests/core/test_event_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_publish(self, mock_time: mock.Mock) -> None:
request_id='1234',
event_name=eventNames.WORK_STARTED,
event_payload={'hello': 'events'},
publisher_id=self.__class__.__name__,
publisher_id=self.__class__.__qualname__,
)
self.assertEqual(
evq.queue.get(), {
Expand All @@ -44,7 +44,7 @@ def test_publish(self, mock_time: mock.Mock) -> None:
'event_timestamp': 1234567,
'event_name': eventNames.WORK_STARTED,
'event_payload': {'hello': 'events'},
'publisher_id': self.__class__.__name__,
'publisher_id': self.__class__.__qualname__,
},
)

Expand Down
4 changes: 2 additions & 2 deletions tests/core/test_event_subscriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_event_subscriber(self, mock_time: mock.Mock) -> None:
request_id='1234',
event_name=eventNames.WORK_STARTED,
event_payload={'hello': 'events'},
publisher_id=self.__class__.__name__,
publisher_id=self.__class__.__qualname__,
)
self.dispatcher.run_once()
self.subscriber.unsubscribe()
Expand All @@ -69,6 +69,6 @@ def callback(self, ev: Dict[str, Any]) -> None:
'event_timestamp': 1234567,
'event_name': eventNames.WORK_STARTED,
'event_payload': {'hello': 'events'},
'publisher_id': self.__class__.__name__,
'publisher_id': self.__class__.__qualname__,
},
)

0 comments on commit f3d19ff

Please sign in to comment.