diff --git a/pyhsslms/hsslms.py b/pyhsslms/hsslms.py old mode 100644 new mode 100755 index 50b417c..4cbc5d4 --- a/pyhsslms/hsslms.py +++ b/pyhsslms/hsslms.py @@ -7,7 +7,7 @@ # in RFC 8554. # # -# Copyright (c) 2020, Vigil Security, LLC +# Copyright (c) 2020-2021, Vigil Security, LLC # All rights reserved. # # Redistribution and use, with or without modification, are permitted @@ -78,7 +78,7 @@ def usage(name): print(" -w LMOTS_TYPE, --lmots LMOTS_TYPE") print(" Winternitz number") print(" -a HASH_ALG, --alg HASH_ALG") - print(" Hash algorithm (only sha256)") + print(" Hash algorithm (sha256 or shake)") print(" ") print("optional command arguments:") print(" -h, --help") @@ -131,24 +131,56 @@ def main(): type=int, choices=[1, 2, 4, 8], metavar='LMOTS_TYPE', help='Winternitz number') parser.add_argument('-a', '--alg', dest='alg', default='sha256', - type=str, choices=['sha256'], - metavar='HASH_ALG', help='Hash algorithm (only sha256)') + type=str, choices=['sha256', 'shake'], + metavar='HASH_ALG', help='Hash algorithm (sha256 or shake)') + parser.add_argument('-t', '--trunc', dest='trunc', default='32', + type=str, choices=[32, 24], + metavar='TRUNC', help='Hash algorithm truncation size') args = parser.parse_args(sys.argv[3:]) - lms_dict = { - 5: pyhsslms.lms_sha256_m32_h5, - 10: pyhsslms.lms_sha256_m32_h10, - 15: pyhsslms.lms_sha256_m32_h15, - 20: pyhsslms.lms_sha256_m32_h20, - 25: pyhsslms.lms_sha256_m32_h25, } - lmots_dict = { - 1: pyhsslms.lmots_sha256_n32_w1, - 2: pyhsslms.lmots_sha256_n32_w2, - 4: pyhsslms.lmots_sha256_n32_w4, - 8: pyhsslms.lmots_sha256_n32_w8, } levels = args.levels - lms_type = lms_dict[args.lms] - lmots_type = lmots_dict[args.lmots] + if args.alg == 'sha256': + if args.trunc == 32: + if args.lms == 5: lms_type = pyhsslms.lms_sha256_m32_h5 + if args.lms == 10: lms_type = pyhsslms.lms_sha256_m32_h10 + if args.lms == 15: lms_type = pyhsslms.lms_sha256_m32_h15 + if args.lms == 20: lms_type = pyhsslms.lms_sha256_m32_h20 + if args.lms == 25: lms_type = pyhsslms.lms_sha256_m32_h25 + if args.lmots == 1: lmots_type = pyhsslms.lmots_sha256_n32_w1 + if args.lmots == 2: lmots_type = pyhsslms.lmots_sha256_n32_w2 + if args.lmots == 4: lmots_type = pyhsslms.lmots_sha256_n32_w4 + if args.lmots == 8: lmots_type = pyhsslms.lmots_sha256_n32_w8 + else: # args.trunc == 24 + if args.lms == 5: lms_type = pyhsslms.lms_sha256_m24_h5 + if args.lms == 10: lms_type = pyhsslms.lms_sha256_m24_h10 + if args.lms == 15: lms_type = pyhsslms.lms_sha256_m24_h15 + if args.lms == 20: lms_type = pyhsslms.lms_sha256_m24_h20 + if args.lms == 25: lms_type = pyhsslms.lms_sha256_m24_h25 + if args.lmots == 1: lmots_type = pyhsslms.lmots_sha256_n24_w1 + if args.lmots == 2: lmots_type = pyhsslms.lmots_sha256_n24_w2 + if args.lmots == 4: lmots_type = pyhsslms.lmots_sha256_n24_w4 + if args.lmots == 8: lmots_type = pyhsslms.lmots_sha256_n24_w8 + else: # args.alg == 'shake' + if args.trunc == 32: + if args.lms == 5: lms_type = pyhsslms.lms_shake_m32_h5 + if args.lms == 10: lms_type = pyhsslms.lms_shake_m32_h10 + if args.lms == 15: lms_type = pyhsslms.lms_shake_m32_h15 + if args.lms == 20: lms_type = pyhsslms.lms_shake_m32_h20 + if args.lms == 25: lms_type = pyhsslms.lms_shake_m32_h25 + if args.lmots == 1: lmots_type = pyhsslms.lmots_shake_n32_w1 + if args.lmots == 2: lmots_type = pyhsslms.lmots_shake_n32_w2 + if args.lmots == 4: lmots_type = pyhsslms.lmots_shake_n32_w4 + if args.lmots == 8: lmots_type = pyhsslms.lmots_shake_n32_w8 + else: # args.trunc == 24 + if args.lms == 5: lms_type = pyhsslms.lms_shake_m24_h5 + if args.lms == 10: lms_type = pyhsslms.lms_shake_m24_h10 + if args.lms == 15: lms_type = pyhsslms.lms_shake_m24_h15 + if args.lms == 20: lms_type = pyhsslms.lms_shake_m24_h20 + if args.lms == 25: lms_type = pyhsslms.lms_shake_m24_h25 + if args.lmots == 1: lmots_type = pyhsslms.lmots_shake_n24_w1 + if args.lmots == 2: lmots_type = pyhsslms.lmots_shake_n24_w2 + if args.lmots == 4: lmots_type = pyhsslms.lmots_shake_n24_w4 + if args.lmots == 8: lmots_type = pyhsslms.lmots_shake_n24_w8 pyhsslms.HssLmsPrivateKey.genkey(keyname, levels=levels, lms_type=lms_type, lmots_type=lmots_type)