Skip to content

Commit

Permalink
Register multiple event handlers
Browse files Browse the repository at this point in the history
Refs #209.
  • Loading branch information
Steven-Chan authored Jun 1, 2018
2 parents 9d8a6a1 + a4c8e5a commit feccd57
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 10 deletions.
10 changes: 5 additions & 5 deletions skygear/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def __init__(self):
'op': {},
'hook': {},
'timer': {},
'event': {},
}
self.event_map = defaultdict(list)
self.param_map = {
'op': [],
'handler': [],
Expand Down Expand Up @@ -126,11 +126,8 @@ def register_hook(self, name, func, *args, **kwargs):
log.debug("Registered hook '%s' to skygear!", name)

def register_event(self, name, func, *args, **kwargs):
if name in self.func_map['event']:
log.warning("Replacing previously registered event handler '%s'",
name)
self.event_map[name].append(func)

self.func_map['event'][name] = func
self._add_param('event', {
'name': name
})
Expand Down Expand Up @@ -215,6 +212,9 @@ def func_list(self):
def get_func(self, kind, name):
return self.func_map[kind][name]

def get_event_funcs(self, name):
return self.event_map[name]

def get_provider(self, name):
return self.providers[name]

Expand Down
4 changes: 2 additions & 2 deletions skygear/tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,9 @@ def fn():
registry = Registry()
registry.register_event('plugin:event:foo', fn)

func_map = registry.func_map['event']
func_map = registry.event_map
assert len(func_map) == 1
assert func_map['plugin:event:foo'] == fn
assert func_map['plugin:event:foo'] == [fn]

param_map = registry.param_map['event']
assert len(param_map) == 1
Expand Down
13 changes: 10 additions & 3 deletions skygear/transmitter/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,17 @@ def call_func(self, ctx, kind, name, param):
@_wrap_result
def call_event_func(self, name, param):
try:
event_func = self._registry.get_func('event', name)
return self.event(event_func, param)
except KeyError as e:
event_funcs = self._registry.get_event_funcs(name)
except KeyError:
log.warning('Missing event func named "{}"'.format(name))
return

if name == 'init':
# Only init event support returning data
return self.event(event_funcs[0], param)
else:
for event_func in event_funcs:
self.event(event_func, param)

@_wrap_result
def call_provider(self, ctx, name, action, param):
Expand Down
15 changes: 15 additions & 0 deletions skygear/transmitter/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,16 @@ def testCallFuncSkygearException(self, mocker):
assert result['error']['message'] == 'Error occurred'
assert result['error']['code'] == 1

@patch('skygear.registry.Registry.get_event_funcs')
def testCallEventFunc(self, mocker):
mock1 = MagicMock()
mock2 = MagicMock()
mocker.return_value = [mock1, mock2]
self.transport.call_event_func('some-event', {})
mocker.assert_called_once_with('some-event')
mock1.assert_called_once_with()
mock2.assert_called_once_with()

def testOpDictArg(self):
mock = MagicMock(return_value={'result': 'OK'})
self.transport.op(mock, dict(named='value'))
Expand Down Expand Up @@ -180,6 +190,11 @@ def testProvider(self):
self.transport.provider(mock, 'action', {'data': 'hello'})
mock.handle_action.assert_called_with('action', {'data': 'hello'})

def testEvent(self):
mock = MagicMock()
self.transport.event(mock, {})
mock.assert_called_once_with()


class TestBase64Encoding(unittest.TestCase):
@patch.dict(os.environ, {'SKYGEAR_CONTEXT': 'e30='})
Expand Down

0 comments on commit feccd57

Please sign in to comment.