Skip to content

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
amirhosv committed Aug 8, 2024
1 parent 4dbd6c0 commit 84fc611
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 52 deletions.
33 changes: 22 additions & 11 deletions src/com/amazon/corretto/crypto/provider/ConcatenationKdfSpec.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,38 +4,49 @@

import java.security.spec.KeySpec;
import java.util.Objects;
import java.util.Optional;

/**
* Represents the inputs to ConcatenationKdf algorithms.
*
* <p>When using HMAC variants, salt must be provided. The algorithmName is the name of algorithm
* used to create SecretKeySpec.
* <p>If info or salt is not provided, empty byte arrays are used.
*
* <p>The algorithmName is the name of algorithm used to create SecretKeySpec.
*/
public class ConcatenationKdfSpec implements KeySpec {
private static final byte[] EMPTY = new byte[0];
private final byte[] secret;
private final byte[] info;
private final Optional<byte[]> salt;
private final int outputLen;
private final String algorithmName;
private final byte[] info;
private final byte[] salt;

public ConcatenationKdfSpec(
final byte[] secret,
final byte[] info,
final byte[] salt,
final int outputLen,
final String algorithmName) {
final String algorithmName,
final byte[] info,
final byte[] salt) {
this.secret = Objects.requireNonNull(secret);
if (this.secret.length == 0) {
throw new IllegalArgumentException("Secret must be byte array with non-zero length.");
}
this.info = Objects.requireNonNull(info);
this.salt = Optional.ofNullable(salt);
if (outputLen <= 0) {
throw new IllegalArgumentException("Output size must be greater than zero.");
}
this.outputLen = outputLen;
this.algorithmName = Objects.requireNonNull(algorithmName);
this.info = Objects.requireNonNull(info);
this.salt = Objects.requireNonNull(salt);
}

public ConcatenationKdfSpec(
final byte[] secret, final int outputLen, final String algorithmName) {
this(secret, outputLen, algorithmName, EMPTY, EMPTY);
}

public ConcatenationKdfSpec(
final byte[] secret, final int outputLen, final String algorithmName, final byte[] info) {
this(secret, outputLen, algorithmName, info, EMPTY);
}

public byte[] getSecret() {
Expand All @@ -50,7 +61,7 @@ public int getOutputLen() {
return outputLen;
}

public Optional<byte[]> getSalt() {
public byte[] getSalt() {
return salt;
}

Expand Down
12 changes: 3 additions & 9 deletions src/com/amazon/corretto/crypto/provider/ConcatenationKdfSpi.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class ConcatenationKdfSpi extends KdfSpi {
@Override
protected SecretKey engineGenerateSecret(final KeySpec keySpec) throws InvalidKeySpecException {
if (!(keySpec instanceof ConcatenationKdfSpec)) {
throw new InvalidKeySpecException("Expected a key spec of type GenericSpec.");
throw new InvalidKeySpecException("Expected a key spec of type ConcatenationKdfSpi.");
}
final ConcatenationKdfSpec spec = (ConcatenationKdfSpec) keySpec;

Expand All @@ -39,20 +39,14 @@ protected SecretKey engineGenerateSecret(final KeySpec keySpec) throws InvalidKe
output,
output.length);
} else {
final byte[] salt =
spec.getSalt()
.orElseThrow(
() ->
new InvalidKeySpecException(
"Salt cannot be null for HMAC variation of Concatenation KDF."));
nSskdfHmac(
digestCode,
spec.getSecret(),
spec.getSecret().length,
spec.getInfo(),
spec.getInfo().length,
salt,
salt.length,
spec.getSalt(),
spec.getSalt().length,
output,
output.length);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import static com.amazon.corretto.crypto.provider.test.TestUtil.getEntriesFromFile;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assumptions.assumeFalse;
Expand Down Expand Up @@ -57,18 +58,15 @@ public void concatenationKdfsAreNotAvailableInFipsMode() {
@Test
public void secretLengthCannotBeZero() {
assertThrows(
IllegalArgumentException.class,
() -> new ConcatenationKdfSpec(new byte[0], new byte[0], new byte[0], 1, "name"));
IllegalArgumentException.class, () -> new ConcatenationKdfSpec(new byte[0], 1, "name"));
}

@Test
public void outputLengthCannotBeZeroOrNegative() {
assertThrows(
IllegalArgumentException.class,
() -> new ConcatenationKdfSpec(new byte[1], new byte[0], new byte[0], 0, "name"));
IllegalArgumentException.class, () -> new ConcatenationKdfSpec(new byte[0], 0, "name"));
assertThrows(
IllegalArgumentException.class,
() -> new ConcatenationKdfSpec(new byte[1], new byte[0], new byte[0], -1, "name"));
IllegalArgumentException.class, () -> new ConcatenationKdfSpec(new byte[0], -1, "name"));
}

// The rest of the tests are only available in non-FIPS mode.
Expand All @@ -81,35 +79,31 @@ public void concatenationKdfExpectsConcatenationKdfSpecAsKeySpec() throws Except
InvalidKeySpecException.class, () -> skf.generateSecret(new PBEKeySpec(new char[4])));
}

@Test
public void concatenationKdfWithHmacExpectsSalt() throws Exception {
assumeFalse(TestUtil.isFips());
final SecretKeyFactory skf =
SecretKeyFactory.getInstance("ConcatenationKdfWithHmacSha256", TestUtil.NATIVE_PROVIDER);
assertThrows(
InvalidKeySpecException.class,
() ->
skf.generateSecret(
new ConcatenationKdfSpec(new byte[1], new byte[0], null, 10, "name")));
}

@Test
public void concatenationKdfWithEmptyInfoIsFine() throws Exception {
assumeFalse(TestUtil.isFips());
final SecretKeyFactory skf =
SecretKeyFactory.getInstance("ConcatenationKdfWithSha256", TestUtil.NATIVE_PROVIDER);
assertNotNull(
skf.generateSecret(new ConcatenationKdfSpec(new byte[1], new byte[0], null, 1, "name")));
final ConcatenationKdfSpec spec = new ConcatenationKdfSpec(new byte[1], 10, "name");
assertEquals(0, spec.getInfo().length);
assertNotNull(skf.generateSecret(spec));
}

@Test
public void concatenationKdfWithHmacEmptySaltIsFine() throws Exception {
public void concatenationKdfHmacWithEmptySaltIsFine() throws Exception {
assumeFalse(TestUtil.isFips());
final SecretKeyFactory skf =
SecretKeyFactory.getInstance("ConcatenationKdfWithHmacSha256", TestUtil.NATIVE_PROVIDER);
assertNotNull(
skf.generateSecret(
new ConcatenationKdfSpec(new byte[1], new byte[0], new byte[0], 1, "name")));
final ConcatenationKdfSpec spec1 = new ConcatenationKdfSpec(new byte[1], 10, "name");
assertEquals(0, spec1.getInfo().length);
assertEquals(0, spec1.getSalt().length);
assertNotNull(skf.generateSecret(spec1));

final ConcatenationKdfSpec spec2 =
new ConcatenationKdfSpec(new byte[1], 10, "name", new byte[10]);
assertEquals(10, spec2.getInfo().length);
assertEquals(0, spec2.getSalt().length);
assertNotNull(skf.generateSecret(spec2));
}

@ParameterizedTest(name = "{0}")
Expand All @@ -120,14 +114,17 @@ public void concatenationKdfKatTests(final RspTestEntry entry) throws Exception
assumeFalse("SHA1".equals(digest));
final boolean digestPrf = entry.getInstance("VARIANT").equals("DIGEST");
final byte[] expected = entry.getInstanceFromHex("EXPECT");

final ConcatenationKdfSpec spec =
new ConcatenationKdfSpec(
entry.getInstanceFromHex("SECRET"),
entry.getInstanceFromHex("INFO"),
entry.getInstanceFromHex("SALT"),
expected.length,
"SECRET_KEY");
final byte[] secret = entry.getInstanceFromHex("SECRET");
final byte[] info = entry.getInstanceFromHex("INFO");

final ConcatenationKdfSpec spec;
if (entry.contains("SALT")) {
spec =
new ConcatenationKdfSpec(
secret, expected.length, "SECRET_KEY", info, entry.getInstanceFromHex("SALT"));
} else {
spec = new ConcatenationKdfSpec(secret, expected.length, "SECRET_KEY", info);
}

final String alg = "ConcatenationKdfWith" + (digestPrf ? "" : "Hmac") + digest;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@ public String getInstance(final String field) {
return getInstance().get(field);
}

/**
* Returns true if the specific entry has the provided field.
*
* @see {@link #getInstance()}
*/
public boolean contains(final String field) {
return getInstance().containsKey(field);
}

/**
* Returns a specific entry from this test case after interpreting it as hex-encoded binary.
*
Expand Down

0 comments on commit 84fc611

Please sign in to comment.