From 7b621f9d2528b249e9a67c23054756c3f7bc2b2b Mon Sep 17 00:00:00 2001 From: Hartmut Goebel Date: Mon, 29 Apr 2019 12:02:07 +0200 Subject: [PATCH] Add missing test-cases for option -D/--delimiter-special. Also move insert_special_delimiter into a function so it can be tested. --- diceware/__init__.py | 24 +++++++++++++++++++----- tests/test_diceware.py | 42 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 60 insertions(+), 6 deletions(-) diff --git a/diceware/__init__.py b/diceware/__init__.py index e9002b1..685e880 100644 --- a/diceware/__init__.py +++ b/diceware/__init__.py @@ -163,6 +163,24 @@ def insert_special_char(word, specials=SPECIAL_CHARS, rnd=None): char_list[rnd.choice(range(len(char_list)))] = rnd.choice(specials) return ''.join(char_list) +def insert_special_delimiter(words, max_delimiter_chars, + specials=SPECIAL_CHARS, rnd=None): + """Insert a char out of `specials` into `word`. + + `rnd`, if passed in, will be used as a (pseudo) random number + generator. We use `.choice()` only. + + Returns the modified word. + """ + if rnd is None: + rnd = SystemRandom() + words = words[:] + lengths = list(range(1, max_delimiter_chars + 1)) + for pos in range(len(words)-1, 0, -1): + num_chars = rnd.choice(lengths) # choose number of chars to insert + deli = "".join(rnd.choice(specials) for j in range(num_chars)) + words.insert(pos, deli) + return words def get_passphrase(options=None): """Get a diceware passphrase. @@ -194,11 +212,7 @@ def get_passphrase(options=None): if options.caps: words = [x.capitalize() for x in words] if options.delimiter_special: - lengths = list(range(1, options.delimiter_special + 1)) - for pos in range(len(words)-1, 0, -1): - l = rnd.choice(lengths) - deli = "".join(rnd.choice(SPECIAL_CHARS) for j in range(l)) - words.insert(pos, deli) + words = insert_special_delimiter(words, options.delimiter_special) result = "".join(words) else: result = options.delimiter.join(words) diff --git a/tests/test_diceware.py b/tests/test_diceware.py index 319e919..1299bfe 100644 --- a/tests/test_diceware.py +++ b/tests/test_diceware.py @@ -9,7 +9,7 @@ from diceware import ( get_wordlists_dir, SPECIAL_CHARS, insert_special_char, get_passphrase, handle_options, main, __version__, print_version, get_random_sources, - get_wordlist_names + get_wordlist_names, insert_special_delimiter ) @@ -86,6 +86,15 @@ def test_handle_options_delimiter(self): options = handle_options(['-d', 'WOW']) assert options.delimiter == 'WOW' + def test_handle_options_delimiter_special(self): + # we can set number of special characters to be used as delimiter + options = handle_options([]) + assert options.delimiter_special == 0 + options = handle_options(['-D', '3']) + assert options.delimiter_special == 3 + options = handle_options(['--delimiter-special', '1']) + assert options.delimiter_special == 1 + def test_handle_options_randomsource(self): # we can choose the source of randomness source_names = get_random_sources().keys() @@ -244,6 +253,37 @@ def test_get_passphrase_delimiters(self): phrase = get_passphrase(options) assert " " in phrase + def test_get_passphrase_special_delimiter(self): + # delimiter_special overrules delemiter + options = handle_options(args=[]) + options.delimiter = " " + options.delimiter_special = 2 + phrase = get_passphrase(options) + assert " " not in phrase + + def test_insert_special_delimiter(self): + # we can insert special chars between the words. + fake_rnd = FakeRandom() + fake_rnd.nums_to_draw = [1, 2, 1, # (num of chars)-1, char-idx + 0, 1, + 2, 1, 2, 0] + words_in = ['aaa', 'bbb', 'ccc', 'ddd'] + result1 = insert_special_delimiter(words_in, 3, + specials='!$&', rnd=fake_rnd) + assert result1 == ['aaa', '$&!', 'bbb', '$', 'ccc', '&$', 'ddd'] + assert words_in == ['aaa', 'bbb', 'ccc', 'ddd'] # unchanges + + def test_insert_special_delimiter_defaults(self): + # defaults are respected + words_in = ['aaa', 'bbb'] + result1 = insert_special_delimiter(words_in, 2) + assert result1[0] == 'aaa' + assert result1[2] == 'bbb' + assert 1 <= len(result1[1]) <= 2 + assert result1[1][0] in SPECIAL_CHARS + if len(result1[1]) == 2: + result1[1][1] in SPECIAL_CHARS + def test_print_version(self, capsys): # we can print version infos print_version()