From b79fb589d38d1ac339c204ab9e2237764a4f3aa1 Mon Sep 17 00:00:00 2001 From: Philipp korber Date: Fri, 2 Sep 2016 23:37:03 +0200 Subject: [PATCH] Fixes #1099, use `.rsplit(..., 1)` instead of split+join, so that names where the seperator appears more than one time are handled correctly (also add tests for `split_suffix()`) --- blocks/bricks/recurrent/misc.py | 4 +-- tests/bricks/test_recurrent.py | 62 +++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 2 deletions(-) diff --git a/blocks/bricks/recurrent/misc.py b/blocks/bricks/recurrent/misc.py index 2b1b80c3..c9a53944 100644 --- a/blocks/bricks/recurrent/misc.py +++ b/blocks/bricks/recurrent/misc.py @@ -186,9 +186,9 @@ def suffixes(names, level): @staticmethod def split_suffix(name): # Target name with suffix to the correct layer - name_level = name.split(RECURRENTSTACK_SEPARATOR) + name_level = name.rsplit(RECURRENTSTACK_SEPARATOR, 1) if len(name_level) == 2 and name_level[-1].isdigit(): - name = RECURRENTSTACK_SEPARATOR.join(name_level[:-1]) + name = name_level[0] level = int(name_level[-1]) else: # It must be from bottom layer diff --git a/tests/bricks/test_recurrent.py b/tests/bricks/test_recurrent.py index dad2b15d..0cf3f4ad 100644 --- a/tests/bricks/test_recurrent.py +++ b/tests/bricks/test_recurrent.py @@ -435,6 +435,68 @@ def test_many_steps(self): self.do_many_steps(self.stack2, skip_connections=True) self.do_many_steps(self.stack2, skip_connections=True, low_memory=True) +class TestRecurrentStackHelperMethodes(unittest.TestCase): + # Separated from TestRecurrentStack because it doesn't depend on setUp + # and covers a different area then the other tests in TestRecurrentStack + + def test_suffix(self): + # level >= 0 !! + level1, = numpy.random.randint(1, 150, size=(1,)) + # name1 != "mask" !! + name1 = "somepart" + + test_cases = [ + ("mask", level1, "mask"), + ("{name}", 0, "{name}"), + ("{name}", level1, "{name}{sep}{level}") + ] + + for _name, level, _expected_result in test_cases: + name = _name.format(name=name1, level=level1, sep=RECURRENTSTACK_SEPARATOR) + expected_result = _expected_result.format(name=name1, level=level1, sep=RECURRENTSTACK_SEPARATOR) + + resut = RecurrentStack.suffix(name, level) + + assert resut == expected_result, "expected suffix(\"{}\",{}) -> \"{}\" got \"{}\"".format(name, level, + expected_result, + resut) + + def test_split_suffix(self): + # generate some numbers + level1, level2 = numpy.random.randint(1, 150, size=(2,)) + name1 = "somepart" + + # test cases like (, , ) + # name, level, level2 and sep will be provided + test_cases = [ + # case layer == 0 + ("{name}", "{name}", 0), + # case empty name part + ("{sep}{level}", "", level1), + # normal case + ("{name}{sep}{level}","{name}",level1), + # case nested recurrent stacks + ("{name}{sep}{level}{sep}{level2}","{name}{sep}{level}", level2), + # some more edge cases... + ("{sep}{name}{sep}{level}", "{sep}{name}", level1), + ("{name}{sep}","{name}{sep}", 0), + ("{name}{sep}{name}","{name}{sep}{name}", 0), + ("{name}{sep}{level}{sep}{name}", "{name}{sep}{level}{sep}{name}", 0) + ] + + # check all test cases + for _name, _expected_name_part, expected_level in test_cases: + # fill in aktual details like the currend RECURRENTSTACK_SEPARATOR + name = _name.format(name=name1, level=level1, level2=level2, sep=RECURRENTSTACK_SEPARATOR) + expected_name_part = _expected_name_part.format(name=name1, level=level1, level2=level2, + sep=RECURRENTSTACK_SEPARATOR) + + name_part, level = RecurrentStack.split_suffix(name) + + assert name_part == expected_name_part and level == expected_level, \ + "expected split_suffex(\"{}\") -> name(\"{}\"), level({}) got name(\"{}\"), level({})".format( + name, expected_name_part, expected_level, name_part, level) + class TestGatedRecurrent(unittest.TestCase): def setUp(self):