Skip to content

Commit

Permalink
fix(e2ei): set ciphersuites when replacing KeyPackages (WPB-10238) (#…
Browse files Browse the repository at this point in the history
…2917)

* fix(e2ei): set CS when replacing KeyPackages

* update imports after renaming functions
  • Loading branch information
mchenani authored Jul 30, 2024
1 parent f712fc2 commit 57a5a7b
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,9 @@ internal class MLSConversationDataSource(
}
if (!isNewClient) {
kaliumLogger.w("enrollment for existing client: upload new keypackages and drop old ones")
keyPackageRepository.replaceKeyPackages(clientId, rotateBundle.newKeyPackages).flatMapLeft {
keyPackageRepository
.replaceKeyPackages(clientId, rotateBundle.newKeyPackages, CipherSuite.fromTag(mlsClient.getDefaultCipherSuite()))
.flatMapLeft {
return E2EIFailure.RotationAndMigration(it).left()
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ interface KeyPackageRepository {

suspend fun uploadKeyPackages(clientId: ClientId, keyPackages: List<ByteArray>): Either<CoreFailure, Unit>

suspend fun replaceKeyPackages(clientId: ClientId, keyPackages: List<ByteArray>): Either<CoreFailure, Unit>
suspend fun replaceKeyPackages(clientId: ClientId, keyPackages: List<ByteArray>, cipherSuite: CipherSuite): Either<CoreFailure, Unit>

suspend fun getAvailableKeyPackageCount(clientId: ClientId): Either<NetworkFailure, KeyPackageCountDTO>

Expand Down Expand Up @@ -124,10 +124,11 @@ class KeyPackageDataSource(

override suspend fun replaceKeyPackages(
clientId: ClientId,
keyPackages: List<ByteArray>
keyPackages: List<ByteArray>,
cipherSuite: CipherSuite
): Either<CoreFailure, Unit> =
wrapApiRequest {
keyPackageApi.replaceKeyPackages(clientId.value, keyPackages.map { it.encodeBase64() })
keyPackageApi.replaceKeyPackages(clientId.value, keyPackages.map { it.encodeBase64() }, cipherSuite.tag)
}

override suspend fun validKeyPackageCount(clientId: ClientId): Either<CoreFailure, Int> =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1269,6 +1269,7 @@ class MLSConversationRepositoryTest {
fun givenSuccessResponse_whenRotatingKeysAndMigratingConversation_thenReturnsSuccess() = runTest {
val (arrangement, mlsConversationRepository) = Arrangement()
.withGetMLSClientSuccessful()
.withGetDefaultCipherSuiteSuccessful()
.withRotateAllSuccessful()
.withSendCommitBundleSuccessful()
.withKeyPackageLimits(10)
Expand Down Expand Up @@ -1305,6 +1306,7 @@ class MLSConversationRepositoryTest {
fun givenNewDistributionsCRL_whenRotatingKeys_thenCheckRevocationList() = runTest {
val (arrangement, mlsConversationRepository) = Arrangement()
.withGetMLSClientSuccessful()
.withGetDefaultCipherSuiteSuccessful()
.withRotateAllSuccessful(ROTATE_BUNDLE.copy(crlNewDistributionPoints = listOf("url")))
.withSendCommitBundleSuccessful()
.withKeyPackageLimits(10)
Expand Down Expand Up @@ -1332,6 +1334,7 @@ class MLSConversationRepositoryTest {
fun givenReplacingKeypackagesFailed_whenRotatingKeysAndMigratingConversation_thenReturnsFailure() = runTest {
val (arrangement, mlsConversationRepository) = Arrangement()
.withGetMLSClientSuccessful()
.withGetDefaultCipherSuiteSuccessful()
.withRotateAllSuccessful()
.withKeyPackageLimits(10)
.withReplaceKeyPackagesReturning(TEST_FAILURE)
Expand Down Expand Up @@ -1363,6 +1366,7 @@ class MLSConversationRepositoryTest {
fun givenSendingCommitBundlesFails_whenRotatingKeysAndMigratingConversation_thenReturnsFailure() = runTest {
val (arrangement, mlsConversationRepository) = Arrangement()
.withGetMLSClientSuccessful()
.withGetDefaultCipherSuiteSuccessful()
.withRotateAllSuccessful()
.withKeyPackageLimits(10)
.withReplaceKeyPackagesReturning(Either.Right(Unit))
Expand Down Expand Up @@ -1758,7 +1762,7 @@ class MLSConversationRepositoryTest {
fun withReplaceKeyPackagesReturning(result: Either<CoreFailure, Unit>) = apply {
given(keyPackageRepository)
.suspendFunction(keyPackageRepository::replaceKeyPackages)
.whenInvokedWith(anything(), anything())
.whenInvokedWith(anything(), anything(), anything())
.thenReturn(result)
}

Expand All @@ -1776,6 +1780,13 @@ class MLSConversationRepositoryTest {
.then { Either.Right(mlsClient) }
}

fun withGetDefaultCipherSuiteSuccessful() = apply {
given(mlsClient)
.function(mlsClient::getDefaultCipherSuite)
.whenInvoked()
.then { CIPHER_SUITE.tag.toUShort() }
}

fun withGetExternalSenderKeySuccessful() = apply {
given(mlsClient)
.suspendFunction(mlsClient::getExternalSenders)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,11 @@ interface KeyPackageApi {
* @param keyPackages list of key packages
*
*/
suspend fun replaceKeyPackages(clientId: String, keyPackages: List<KeyPackage>): NetworkResponse<Unit>
suspend fun replaceKeyPackages(
clientId: String,
keyPackages: List<KeyPackage>,
cipherSuite: Int
): NetworkResponse<Unit>

/**
* Get the number of available key packages for the self client
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ internal open class KeyPackageApiV0 internal constructor() : KeyPackageApi {

override suspend fun replaceKeyPackages(
clientId: String,
keyPackages: List<KeyPackage>
keyPackages: List<KeyPackage>,
cipherSuite: Int
): NetworkResponse<Unit> = NetworkResponse.Error(
APINotSupported("MLS: replaceKeyPackages api is only available on API V5")
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import com.wire.kalium.network.utils.NetworkResponse
import com.wire.kalium.network.utils.handleUnsuccessfulResponse
import com.wire.kalium.network.utils.wrapFederationResponse
import com.wire.kalium.network.utils.wrapKaliumResponse
import com.wire.kalium.util.int.toHexString
import io.ktor.client.request.get
import io.ktor.client.request.parameter
import io.ktor.client.request.post
Expand Down Expand Up @@ -67,12 +68,14 @@ internal open class KeyPackageApiV5 internal constructor(

override suspend fun replaceKeyPackages(
clientId: String,
keyPackages: List<KeyPackage>
keyPackages: List<KeyPackage>,
cipherSuite: Int
): NetworkResponse<Unit> =
wrapKaliumResponse {
kaliumLogger.v("Keypackages Count to replace: ${keyPackages.size}")
httpClient.put("$PATH_KEY_PACKAGES/$PATH_SELF/$clientId") {
setBody(KeyPackageList(keyPackages))
parameter(QUERY_CIPHER_SUITES, cipherSuite.toHexString())
}
}

Expand All @@ -86,5 +89,6 @@ internal open class KeyPackageApiV5 internal constructor(
const val PATH_COUNT = "count"
const val QUERY_SKIP_OWN = "skip_own"
const val QUERY_CIPHER_SUITE = "ciphersuite"
const val QUERY_CIPHER_SUITES = "ciphersuites"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,8 @@ fun Int.toByteArray(): ByteArray {
this.toByte()
)
}

@Suppress("MagicNumber")
fun Int.toHexString(minDigits: Int = 4): String {
return "0x" + this.toString(16).padStart(minDigits, '0')
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@
* along with this program. If not, see http://www.gnu.org/licenses/.
*/

package com.wire.kalium.util.string
package com.wire.kalium.util

import com.wire.kalium.util.int.toByteArray
import com.wire.kalium.util.int.toHexString
import com.wire.kalium.util.long.toByteArray
import com.wire.kalium.util.string.toHexString
import kotlin.test.Test
import kotlin.test.assertEquals

class NumberByteArrayTest {
class IntExtTests {

@Test
fun givenMaxLongValue_whenConvertingToByteArray_HexStringIsEqualToTheExpected() {
Expand Down Expand Up @@ -67,4 +69,10 @@ class NumberByteArrayTest {
assertEquals("00000002540BE400", result.toHexString().uppercase())
}

@Test
fun givenAnInteger_whenConvertingToHex_HexValueIsAsExpected(){
val given = 2
val expected= "0x000$given"
assertEquals(expected, given.toHexString())
}
}

0 comments on commit 57a5a7b

Please sign in to comment.