Skip to content

Commit

Permalink
Merge pull request #1143 from dathinab/fix_bug_1099_correct_split_suffix
Browse files Browse the repository at this point in the history
Fixes #1099, use `.rsplit(..., 1)`
  • Loading branch information
rizar authored Sep 4, 2016
2 parents 46c03f6 + b79fb58 commit e1fedb0
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 2 deletions.
4 changes: 2 additions & 2 deletions blocks/bricks/recurrent/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,9 @@ def suffixes(names, level):
@staticmethod
def split_suffix(name):
# Target name with suffix to the correct layer
name_level = name.split(RECURRENTSTACK_SEPARATOR)
name_level = name.rsplit(RECURRENTSTACK_SEPARATOR, 1)
if len(name_level) == 2 and name_level[-1].isdigit():
name = RECURRENTSTACK_SEPARATOR.join(name_level[:-1])
name = name_level[0]
level = int(name_level[-1])
else:
# It must be from bottom layer
Expand Down
62 changes: 62 additions & 0 deletions tests/bricks/test_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,68 @@ def test_many_steps(self):
self.do_many_steps(self.stack2, skip_connections=True)
self.do_many_steps(self.stack2, skip_connections=True, low_memory=True)

class TestRecurrentStackHelperMethodes(unittest.TestCase):
# Separated from TestRecurrentStack because it doesn't depend on setUp
# and covers a different area then the other tests in TestRecurrentStack

def test_suffix(self):
# level >= 0 !!
level1, = numpy.random.randint(1, 150, size=(1,))
# name1 != "mask" !!
name1 = "somepart"

test_cases = [
("mask", level1, "mask"),
("{name}", 0, "{name}"),
("{name}", level1, "{name}{sep}{level}")
]

for _name, level, _expected_result in test_cases:
name = _name.format(name=name1, level=level1, sep=RECURRENTSTACK_SEPARATOR)
expected_result = _expected_result.format(name=name1, level=level1, sep=RECURRENTSTACK_SEPARATOR)

resut = RecurrentStack.suffix(name, level)

assert resut == expected_result, "expected suffix(\"{}\",{}) -> \"{}\" got \"{}\"".format(name, level,
expected_result,
resut)

def test_split_suffix(self):
# generate some numbers
level1, level2 = numpy.random.randint(1, 150, size=(2,))
name1 = "somepart"

# test cases like (<given_name>, <expected_name>, <expected_level>)
# name, level, level2 and sep will be provided
test_cases = [
# case layer == 0
("{name}", "{name}", 0),
# case empty name part
("{sep}{level}", "", level1),
# normal case
("{name}{sep}{level}","{name}",level1),
# case nested recurrent stacks
("{name}{sep}{level}{sep}{level2}","{name}{sep}{level}", level2),
# some more edge cases...
("{sep}{name}{sep}{level}", "{sep}{name}", level1),
("{name}{sep}","{name}{sep}", 0),
("{name}{sep}{name}","{name}{sep}{name}", 0),
("{name}{sep}{level}{sep}{name}", "{name}{sep}{level}{sep}{name}", 0)
]

# check all test cases
for _name, _expected_name_part, expected_level in test_cases:
# fill in aktual details like the currend RECURRENTSTACK_SEPARATOR
name = _name.format(name=name1, level=level1, level2=level2, sep=RECURRENTSTACK_SEPARATOR)
expected_name_part = _expected_name_part.format(name=name1, level=level1, level2=level2,
sep=RECURRENTSTACK_SEPARATOR)

name_part, level = RecurrentStack.split_suffix(name)

assert name_part == expected_name_part and level == expected_level, \
"expected split_suffex(\"{}\") -> name(\"{}\"), level({}) got name(\"{}\"), level({})".format(
name, expected_name_part, expected_level, name_part, level)


class TestGatedRecurrent(unittest.TestCase):
def setUp(self):
Expand Down

0 comments on commit e1fedb0

Please sign in to comment.