diff --git a/skygear/registry.py b/skygear/registry.py index a977d1e..322e9b8 100644 --- a/skygear/registry.py +++ b/skygear/registry.py @@ -44,8 +44,8 @@ def __init__(self): 'op': {}, 'hook': {}, 'timer': {}, - 'event': {}, } + self.event_map = defaultdict(list) self.param_map = { 'op': [], 'handler': [], @@ -123,12 +123,7 @@ def register_hook(self, name, func, *args, **kwargs): log.debug("Registered hook '%s' to skygear!", name) def register_event(self, name, func, *args, **kwargs): - event_funcs = self.func_map['event'].get(name) - if event_funcs is None: - event_funcs = [] - self.func_map['event'][name] = event_funcs - - event_funcs.append(func) + self.event_map[name].append(func) self._add_param('event', { 'name': name @@ -214,6 +209,9 @@ def func_list(self): def get_func(self, kind, name): return self.func_map[kind][name] + def get_event_funcs(self, name): + return self.event_map[name] + def get_provider(self, name): return self.providers[name] diff --git a/skygear/tests/test_registry.py b/skygear/tests/test_registry.py index 7b69833..b5f6f28 100644 --- a/skygear/tests/test_registry.py +++ b/skygear/tests/test_registry.py @@ -154,9 +154,9 @@ def fn(): registry = Registry() registry.register_event('plugin:event:foo', fn) - func_map = registry.func_map['event'] + func_map = registry.event_map assert len(func_map) == 1 - assert func_map['plugin:event:foo'] == fn + assert func_map['plugin:event:foo'] == [fn] param_map = registry.param_map['event'] assert len(param_map) == 1 diff --git a/skygear/transmitter/common.py b/skygear/transmitter/common.py index 1d80734..87c6388 100644 --- a/skygear/transmitter/common.py +++ b/skygear/transmitter/common.py @@ -109,16 +109,17 @@ def call_func(self, ctx, kind, name, param): @_wrap_result def call_event_func(self, name, param): try: - event_funcs = self._registry.get_func('event', name) + event_funcs = self._registry.get_event_funcs(name) except KeyError: log.warning('Missing event func named "{}"'.format(name)) - - results = [self.event(event_func, param) for event_func in event_funcs] + return if name == 'init': - return results[0] - - return results + # Only init event support returning data + return self.event(event_funcs[0], param) + else: + for event_func in event_funcs: + self.event(event_func, param) @_wrap_result def call_provider(self, ctx, name, action, param): diff --git a/skygear/transmitter/tests/test_common.py b/skygear/transmitter/tests/test_common.py index 62bbeae..d65da48 100644 --- a/skygear/transmitter/tests/test_common.py +++ b/skygear/transmitter/tests/test_common.py @@ -102,6 +102,16 @@ def testCallFuncSkygearException(self, mocker): assert result['error']['message'] == 'Error occurred' assert result['error']['code'] == 1 + @patch('skygear.registry.Registry.get_event_funcs') + def testCallEventFunc(self, mocker): + mock1 = MagicMock() + mock2 = MagicMock() + mocker.return_value = [mock1, mock2] + self.transport.call_event_func('some-event', {}) + mocker.assert_called_once_with('some-event') + mock1.assert_called_once_with() + mock2.assert_called_once_with() + def testOpDictArg(self): mock = MagicMock(return_value={'result': 'OK'}) self.transport.op(mock, dict(named='value')) @@ -180,6 +190,11 @@ def testProvider(self): self.transport.provider(mock, 'action', {'data': 'hello'}) mock.handle_action.assert_called_with('action', {'data': 'hello'}) + def testEvent(self): + mock = MagicMock() + self.transport.event(mock, {}) + mock.assert_called_once_with() + class TestBase64Encoding(unittest.TestCase): @patch.dict(os.environ, {'SKYGEAR_CONTEXT': 'e30='})