Skip to content

Commit

Permalink
fix unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
cristianIOHK committed Jul 5, 2023
1 parent e4b6ca7 commit 7b8ed01
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ import io.iohk.atala.prism.apollo.utils.ECConfig
import io.iohk.atala.prism.apollo.utils.ECPrivateKeyDecodingException
import io.iohk.atala.prism.apollo.utils.KMMECSecp256k1PrivateKey

/**
* Represents and HDKey with its derive methods
*/
class HDKey(
val privateKey: ByteArray? = null,
val publicKey: ByteArray? = null,
Expand All @@ -18,11 +21,20 @@ class HDKey(

constructor(seed: ByteArray, depth: Int, childIndex: BigInteger) : this(
privateKey = seed.sliceArray(IntRange(0, 31)),
chainCode = seed.sliceArray(listOf(32)),
chainCode = seed.sliceArray(32 until seed.size),
depth = depth,
childIndex = childIndex
)
) {
require(seed.size == 64) {
"Seed expected byte length to be ${ECConfig.PRIVATE_KEY_BYTE_SIZE}"
}
}

/**
* Method to derive an HDKey by a path
*
* @param path value used to derive a key
*/
fun derive(path: String): HDKey {
if (!path.matches(Regex("^[mM].*"))) {
throw Error("Path must start with \"m\" or \"M\"")
Expand All @@ -37,7 +49,6 @@ class HDKey(
if (m == null || m.size != 3) {
throw Error("Invalid child index: $c")
}
// TODO: Null check??
val idx = m[1].toBigInteger()
if (idx >= HARDENED_OFFSET) {
throw Error("Invalid index")
Expand All @@ -48,9 +59,14 @@ class HDKey(
return child
}

/**
* Method to derive an HDKey child by index
*
* @param index value used to derive a key
*/
fun deriveChild(index: BigInteger): HDKey {
if (chainCode == null) {
throw Error("No chainCode set")
throw Exception("No chainCode set")
}
val data = if (index >= HARDENED_OFFSET) {
val priv = privateKey ?: throw Error("Could not derive hardened child key")
Expand Down Expand Up @@ -93,11 +109,11 @@ class HDKey(
fun getKMMSecp256k1PrivateKey(): KMMECSecp256k1PrivateKey {
privateKey?.let {
return KMMECSecp256k1PrivateKey.secp256k1FromBytes(privateKey)
} ?: throw Exception("")
} ?: throw Exception("Private key not available")
}

private fun isValidPrivateKey(data: ByteArray): Boolean {
return (data.size != ECConfig.PRIVATE_KEY_BYTE_SIZE)
return (data.size == ECConfig.PRIVATE_KEY_BYTE_SIZE)
}

companion object {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,28 @@
package io.iohk.atala.prism.apollo.derivation

import com.ionspin.kotlin.bignum.integer.BigInteger
import com.ionspin.kotlin.bignum.integer.toBigInteger
import io.iohk.atala.prism.apollo.derivation.HDKey.Companion.HARDENED_OFFSET
import io.iohk.atala.prism.apollo.utils.KMMECSecp256k1KeyPair
import kotlin.random.Random
import kotlin.test.Ignore
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
import kotlin.test.assertNotNull

class HDKeyTest {

@Test
fun testConstructor_whenSeedIncorrectLength_thenThrowException() {
val seed = Random.Default.nextBytes(32)
val depth = 1
val childIndex = BigInteger(HARDENED_OFFSET)

assertFailsWith(IllegalArgumentException::class) {
HDKey(seed, depth, childIndex)
}
}

@Test
fun testConstructorWithSeed_thenNonNullValues() {
val seed = Random.Default.nextBytes(64)
Expand All @@ -24,7 +38,7 @@ class HDKeyTest {
}

@Test
fun testDerive_whenIncorrectPath_thenThrowError() {
fun testDerive_whenIncorrectPath_thenThrowException() {
val seed = Random.Default.nextBytes(64)
val depth = 1
val childIndex = BigInteger(0)
Expand All @@ -37,30 +51,68 @@ class HDKeyTest {
}
}

@Ignore
@Test
fun testDerive_thenHDDeriveOk() {
val seed = Random.Default.nextBytes(64)
val keyPair = KMMECSecp256k1KeyPair.generateSecp256k1KeyPair()
val chainCode = Random.Default.nextBytes(32)
val seed = keyPair.privateKey.getEncoded() + chainCode
val depth = 1
val childIndex = BigInteger(0)
val childIndex: BigInteger = HARDENED_OFFSET.toBigInteger()

val hdKey = HDKey(seed, depth, childIndex)
val path = "m/44'/0'/0'/0/0"
val path = "m/44'/0'/0'"

val hdKeyResult = hdKey.derive(path)
assertNotNull(hdKeyResult.privateKey)
assertNotNull(hdKeyResult.chainCode)
assertEquals(depth, hdKeyResult.depth)
assertEquals(4, hdKeyResult.depth)
assertEquals(childIndex, hdKeyResult.childIndex)
}

@Test
fun testDeriveChild_whenIncorrectPath_thenThrowError() {
val seed = Random.Default.nextBytes(64)
fun testDeriveChild_whenNoChainCode_thenThrowException() {
val keyPair = KMMECSecp256k1KeyPair.generateSecp256k1KeyPair()
val depth = 1
val childIndex = BigInteger(HDKey.HARDENED_OFFSET)
val childIndex = BigInteger(HARDENED_OFFSET)

val hdKey = HDKey(seed, depth, childIndex)
val hdKey = HDKey(
privateKey = keyPair.privateKey.getEncoded(),
depth = depth,
childIndex = childIndex
)

assertFailsWith(Exception::class) {
hdKey.deriveChild(childIndex)
}
}

@Test
fun testDeriveChild_whenPrivateKeyNotHardened_thenThrowException() {
val keyPair = KMMECSecp256k1KeyPair.generateSecp256k1KeyPair()
val depth = 1
val childIndex = BigInteger(1)

val hdKey = HDKey(
privateKey = keyPair.privateKey.getEncoded(),
depth = depth,
childIndex = childIndex
)

assertFailsWith(Exception::class) {
hdKey.deriveChild(childIndex)
}
}

@Test
fun testDeriveChild_whenPrivateKeyNotRightLength_thenThrowException() {
val depth = 1
val childIndex = BigInteger(1)

val hdKey = HDKey(
privateKey = Random.Default.nextBytes(33),
depth = depth,
childIndex = childIndex
)

assertFailsWith(Exception::class) {
hdKey.deriveChild(childIndex)
Expand Down

0 comments on commit 7b8ed01

Please sign in to comment.