From 7b8ed0128303dcb79d5998e7c173df8784037995 Mon Sep 17 00:00:00 2001 From: Cristian G Date: Wed, 5 Jul 2023 08:59:59 -0400 Subject: [PATCH] fix unit test --- .../atala/prism/apollo/derivation/HDKey.kt | 28 +++++-- .../prism/apollo/derivation/HDKeyTest.kt | 74 ++++++++++++++++--- 2 files changed, 85 insertions(+), 17 deletions(-) diff --git a/base-asymmetric-encryption/src/commonMain/kotlin/io/iohk/atala/prism/apollo/derivation/HDKey.kt b/base-asymmetric-encryption/src/commonMain/kotlin/io/iohk/atala/prism/apollo/derivation/HDKey.kt index a442cbcf4..c80ee263d 100644 --- a/base-asymmetric-encryption/src/commonMain/kotlin/io/iohk/atala/prism/apollo/derivation/HDKey.kt +++ b/base-asymmetric-encryption/src/commonMain/kotlin/io/iohk/atala/prism/apollo/derivation/HDKey.kt @@ -8,6 +8,9 @@ import io.iohk.atala.prism.apollo.utils.ECConfig import io.iohk.atala.prism.apollo.utils.ECPrivateKeyDecodingException import io.iohk.atala.prism.apollo.utils.KMMECSecp256k1PrivateKey +/** + * Represents and HDKey with its derive methods + */ class HDKey( val privateKey: ByteArray? = null, val publicKey: ByteArray? = null, @@ -18,11 +21,20 @@ class HDKey( constructor(seed: ByteArray, depth: Int, childIndex: BigInteger) : this( privateKey = seed.sliceArray(IntRange(0, 31)), - chainCode = seed.sliceArray(listOf(32)), + chainCode = seed.sliceArray(32 until seed.size), depth = depth, childIndex = childIndex - ) + ) { + require(seed.size == 64) { + "Seed expected byte length to be ${ECConfig.PRIVATE_KEY_BYTE_SIZE}" + } + } + /** + * Method to derive an HDKey by a path + * + * @param path value used to derive a key + */ fun derive(path: String): HDKey { if (!path.matches(Regex("^[mM].*"))) { throw Error("Path must start with \"m\" or \"M\"") @@ -37,7 +49,6 @@ class HDKey( if (m == null || m.size != 3) { throw Error("Invalid child index: $c") } - // TODO: Null check?? val idx = m[1].toBigInteger() if (idx >= HARDENED_OFFSET) { throw Error("Invalid index") @@ -48,9 +59,14 @@ class HDKey( return child } + /** + * Method to derive an HDKey child by index + * + * @param index value used to derive a key + */ fun deriveChild(index: BigInteger): HDKey { if (chainCode == null) { - throw Error("No chainCode set") + throw Exception("No chainCode set") } val data = if (index >= HARDENED_OFFSET) { val priv = privateKey ?: throw Error("Could not derive hardened child key") @@ -93,11 +109,11 @@ class HDKey( fun getKMMSecp256k1PrivateKey(): KMMECSecp256k1PrivateKey { privateKey?.let { return KMMECSecp256k1PrivateKey.secp256k1FromBytes(privateKey) - } ?: throw Exception("") + } ?: throw Exception("Private key not available") } private fun isValidPrivateKey(data: ByteArray): Boolean { - return (data.size != ECConfig.PRIVATE_KEY_BYTE_SIZE) + return (data.size == ECConfig.PRIVATE_KEY_BYTE_SIZE) } companion object { diff --git a/base-asymmetric-encryption/src/commonTest/kotlin/io/iohk/atala/prism/apollo/derivation/HDKeyTest.kt b/base-asymmetric-encryption/src/commonTest/kotlin/io/iohk/atala/prism/apollo/derivation/HDKeyTest.kt index 9b484dae4..8ff58fe6e 100644 --- a/base-asymmetric-encryption/src/commonTest/kotlin/io/iohk/atala/prism/apollo/derivation/HDKeyTest.kt +++ b/base-asymmetric-encryption/src/commonTest/kotlin/io/iohk/atala/prism/apollo/derivation/HDKeyTest.kt @@ -1,14 +1,28 @@ package io.iohk.atala.prism.apollo.derivation import com.ionspin.kotlin.bignum.integer.BigInteger +import com.ionspin.kotlin.bignum.integer.toBigInteger +import io.iohk.atala.prism.apollo.derivation.HDKey.Companion.HARDENED_OFFSET +import io.iohk.atala.prism.apollo.utils.KMMECSecp256k1KeyPair import kotlin.random.Random -import kotlin.test.Ignore import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertFailsWith import kotlin.test.assertNotNull class HDKeyTest { + + @Test + fun testConstructor_whenSeedIncorrectLength_thenThrowException() { + val seed = Random.Default.nextBytes(32) + val depth = 1 + val childIndex = BigInteger(HARDENED_OFFSET) + + assertFailsWith(IllegalArgumentException::class) { + HDKey(seed, depth, childIndex) + } + } + @Test fun testConstructorWithSeed_thenNonNullValues() { val seed = Random.Default.nextBytes(64) @@ -24,7 +38,7 @@ class HDKeyTest { } @Test - fun testDerive_whenIncorrectPath_thenThrowError() { + fun testDerive_whenIncorrectPath_thenThrowException() { val seed = Random.Default.nextBytes(64) val depth = 1 val childIndex = BigInteger(0) @@ -37,30 +51,68 @@ class HDKeyTest { } } - @Ignore @Test fun testDerive_thenHDDeriveOk() { - val seed = Random.Default.nextBytes(64) + val keyPair = KMMECSecp256k1KeyPair.generateSecp256k1KeyPair() + val chainCode = Random.Default.nextBytes(32) + val seed = keyPair.privateKey.getEncoded() + chainCode val depth = 1 - val childIndex = BigInteger(0) + val childIndex: BigInteger = HARDENED_OFFSET.toBigInteger() val hdKey = HDKey(seed, depth, childIndex) - val path = "m/44'/0'/0'/0/0" + val path = "m/44'/0'/0'" val hdKeyResult = hdKey.derive(path) assertNotNull(hdKeyResult.privateKey) assertNotNull(hdKeyResult.chainCode) - assertEquals(depth, hdKeyResult.depth) + assertEquals(4, hdKeyResult.depth) assertEquals(childIndex, hdKeyResult.childIndex) } @Test - fun testDeriveChild_whenIncorrectPath_thenThrowError() { - val seed = Random.Default.nextBytes(64) + fun testDeriveChild_whenNoChainCode_thenThrowException() { + val keyPair = KMMECSecp256k1KeyPair.generateSecp256k1KeyPair() val depth = 1 - val childIndex = BigInteger(HDKey.HARDENED_OFFSET) + val childIndex = BigInteger(HARDENED_OFFSET) - val hdKey = HDKey(seed, depth, childIndex) + val hdKey = HDKey( + privateKey = keyPair.privateKey.getEncoded(), + depth = depth, + childIndex = childIndex + ) + + assertFailsWith(Exception::class) { + hdKey.deriveChild(childIndex) + } + } + + @Test + fun testDeriveChild_whenPrivateKeyNotHardened_thenThrowException() { + val keyPair = KMMECSecp256k1KeyPair.generateSecp256k1KeyPair() + val depth = 1 + val childIndex = BigInteger(1) + + val hdKey = HDKey( + privateKey = keyPair.privateKey.getEncoded(), + depth = depth, + childIndex = childIndex + ) + + assertFailsWith(Exception::class) { + hdKey.deriveChild(childIndex) + } + } + + @Test + fun testDeriveChild_whenPrivateKeyNotRightLength_thenThrowException() { + val depth = 1 + val childIndex = BigInteger(1) + + val hdKey = HDKey( + privateKey = Random.Default.nextBytes(33), + depth = depth, + childIndex = childIndex + ) assertFailsWith(Exception::class) { hdKey.deriveChild(childIndex)