diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt index 5e026901365..7629f5d0110 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt @@ -669,7 +669,9 @@ internal class MLSConversationDataSource( } if (!isNewClient) { kaliumLogger.w("enrollment for existing client: upload new keypackages and drop old ones") - keyPackageRepository.replaceKeyPackages(clientId, rotateBundle.newKeyPackages).flatMapLeft { + keyPackageRepository + .replaceKeyPackages(clientId, rotateBundle.newKeyPackages, CipherSuite.fromTag(mlsClient.getDefaultCipherSuite())) + .flatMapLeft { return E2EIFailure.RotationAndMigration(it).left() } } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/keypackage/KeyPackageRepository.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/keypackage/KeyPackageRepository.kt index 56ad144486c..1f79a2d2a9f 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/keypackage/KeyPackageRepository.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/keypackage/KeyPackageRepository.kt @@ -61,7 +61,7 @@ interface KeyPackageRepository { suspend fun uploadKeyPackages(clientId: ClientId, keyPackages: List): Either - suspend fun replaceKeyPackages(clientId: ClientId, keyPackages: List): Either + suspend fun replaceKeyPackages(clientId: ClientId, keyPackages: List, cipherSuite: CipherSuite): Either suspend fun getAvailableKeyPackageCount(clientId: ClientId): Either @@ -124,10 +124,11 @@ class KeyPackageDataSource( override suspend fun replaceKeyPackages( clientId: ClientId, - keyPackages: List + keyPackages: List, + cipherSuite: CipherSuite ): Either = wrapApiRequest { - keyPackageApi.replaceKeyPackages(clientId.value, keyPackages.map { it.encodeBase64() }) + keyPackageApi.replaceKeyPackages(clientId.value, keyPackages.map { it.encodeBase64() }, cipherSuite.tag) } override suspend fun validKeyPackageCount(clientId: ClientId): Either = diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepositoryTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepositoryTest.kt index 28b087c472e..a81bf905943 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepositoryTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepositoryTest.kt @@ -1269,6 +1269,7 @@ class MLSConversationRepositoryTest { fun givenSuccessResponse_whenRotatingKeysAndMigratingConversation_thenReturnsSuccess() = runTest { val (arrangement, mlsConversationRepository) = Arrangement() .withGetMLSClientSuccessful() + .withGetDefaultCipherSuiteSuccessful() .withRotateAllSuccessful() .withSendCommitBundleSuccessful() .withKeyPackageLimits(10) @@ -1305,6 +1306,7 @@ class MLSConversationRepositoryTest { fun givenNewDistributionsCRL_whenRotatingKeys_thenCheckRevocationList() = runTest { val (arrangement, mlsConversationRepository) = Arrangement() .withGetMLSClientSuccessful() + .withGetDefaultCipherSuiteSuccessful() .withRotateAllSuccessful(ROTATE_BUNDLE.copy(crlNewDistributionPoints = listOf("url"))) .withSendCommitBundleSuccessful() .withKeyPackageLimits(10) @@ -1332,6 +1334,7 @@ class MLSConversationRepositoryTest { fun givenReplacingKeypackagesFailed_whenRotatingKeysAndMigratingConversation_thenReturnsFailure() = runTest { val (arrangement, mlsConversationRepository) = Arrangement() .withGetMLSClientSuccessful() + .withGetDefaultCipherSuiteSuccessful() .withRotateAllSuccessful() .withKeyPackageLimits(10) .withReplaceKeyPackagesReturning(TEST_FAILURE) @@ -1363,6 +1366,7 @@ class MLSConversationRepositoryTest { fun givenSendingCommitBundlesFails_whenRotatingKeysAndMigratingConversation_thenReturnsFailure() = runTest { val (arrangement, mlsConversationRepository) = Arrangement() .withGetMLSClientSuccessful() + .withGetDefaultCipherSuiteSuccessful() .withRotateAllSuccessful() .withKeyPackageLimits(10) .withReplaceKeyPackagesReturning(Either.Right(Unit)) @@ -1758,7 +1762,7 @@ class MLSConversationRepositoryTest { fun withReplaceKeyPackagesReturning(result: Either) = apply { given(keyPackageRepository) .suspendFunction(keyPackageRepository::replaceKeyPackages) - .whenInvokedWith(anything(), anything()) + .whenInvokedWith(anything(), anything(), anything()) .thenReturn(result) } @@ -1776,6 +1780,13 @@ class MLSConversationRepositoryTest { .then { Either.Right(mlsClient) } } + fun withGetDefaultCipherSuiteSuccessful() = apply { + given(mlsClient) + .function(mlsClient::getDefaultCipherSuite) + .whenInvoked() + .then { CIPHER_SUITE.tag.toUShort() } + } + fun withGetExternalSenderKeySuccessful() = apply { given(mlsClient) .suspendFunction(mlsClient::getExternalSenders) diff --git a/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/keypackage/KeyPackageApi.kt b/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/keypackage/KeyPackageApi.kt index b5383fef2ba..17185468264 100644 --- a/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/keypackage/KeyPackageApi.kt +++ b/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/keypackage/KeyPackageApi.kt @@ -73,7 +73,11 @@ interface KeyPackageApi { * @param keyPackages list of key packages * */ - suspend fun replaceKeyPackages(clientId: String, keyPackages: List): NetworkResponse + suspend fun replaceKeyPackages( + clientId: String, + keyPackages: List, + cipherSuite: Int + ): NetworkResponse /** * Get the number of available key packages for the self client diff --git a/network/src/commonMain/kotlin/com/wire/kalium/network/api/v0/authenticated/KeyPackageApiV0.kt b/network/src/commonMain/kotlin/com/wire/kalium/network/api/v0/authenticated/KeyPackageApiV0.kt index 64b15238100..4697c5a1d1c 100644 --- a/network/src/commonMain/kotlin/com/wire/kalium/network/api/v0/authenticated/KeyPackageApiV0.kt +++ b/network/src/commonMain/kotlin/com/wire/kalium/network/api/v0/authenticated/KeyPackageApiV0.kt @@ -41,7 +41,8 @@ internal open class KeyPackageApiV0 internal constructor() : KeyPackageApi { override suspend fun replaceKeyPackages( clientId: String, - keyPackages: List + keyPackages: List, + cipherSuite: Int ): NetworkResponse = NetworkResponse.Error( APINotSupported("MLS: replaceKeyPackages api is only available on API V5") ) diff --git a/network/src/commonMain/kotlin/com/wire/kalium/network/api/v5/authenticated/KeyPackageApiV5.kt b/network/src/commonMain/kotlin/com/wire/kalium/network/api/v5/authenticated/KeyPackageApiV5.kt index 844a0419abd..f0bf03393ed 100644 --- a/network/src/commonMain/kotlin/com/wire/kalium/network/api/v5/authenticated/KeyPackageApiV5.kt +++ b/network/src/commonMain/kotlin/com/wire/kalium/network/api/v5/authenticated/KeyPackageApiV5.kt @@ -30,6 +30,7 @@ import com.wire.kalium.network.utils.NetworkResponse import com.wire.kalium.network.utils.handleUnsuccessfulResponse import com.wire.kalium.network.utils.wrapFederationResponse import com.wire.kalium.network.utils.wrapKaliumResponse +import com.wire.kalium.util.int.toHexString import io.ktor.client.request.get import io.ktor.client.request.parameter import io.ktor.client.request.post @@ -67,12 +68,14 @@ internal open class KeyPackageApiV5 internal constructor( override suspend fun replaceKeyPackages( clientId: String, - keyPackages: List + keyPackages: List, + cipherSuite: Int ): NetworkResponse = wrapKaliumResponse { kaliumLogger.v("Keypackages Count to replace: ${keyPackages.size}") httpClient.put("$PATH_KEY_PACKAGES/$PATH_SELF/$clientId") { setBody(KeyPackageList(keyPackages)) + parameter(QUERY_CIPHER_SUITES, cipherSuite.toHexString()) } } @@ -86,5 +89,6 @@ internal open class KeyPackageApiV5 internal constructor( const val PATH_COUNT = "count" const val QUERY_SKIP_OWN = "skip_own" const val QUERY_CIPHER_SUITE = "ciphersuite" + const val QUERY_CIPHER_SUITES = "ciphersuites" } } diff --git a/util/src/commonMain/kotlin/com.wire.kalium.util/int/toByteArray.kt b/util/src/commonMain/kotlin/com.wire.kalium.util/int/IntExt.kt similarity index 87% rename from util/src/commonMain/kotlin/com.wire.kalium.util/int/toByteArray.kt rename to util/src/commonMain/kotlin/com.wire.kalium.util/int/IntExt.kt index e3c8798d27e..38da29371cd 100644 --- a/util/src/commonMain/kotlin/com.wire.kalium.util/int/toByteArray.kt +++ b/util/src/commonMain/kotlin/com.wire.kalium.util/int/IntExt.kt @@ -27,3 +27,8 @@ fun Int.toByteArray(): ByteArray { this.toByte() ) } + +@Suppress("MagicNumber") +fun Int.toHexString(minDigits: Int = 4): String { + return "0x" + this.toString(16).padStart(minDigits, '0') +} diff --git a/util/src/commonTest/kotlin/com/wire/kalium/util/string/NumberByteArrayTest.kt b/util/src/commonTest/kotlin/com/wire/kalium/util/IntExtTests.kt similarity index 87% rename from util/src/commonTest/kotlin/com/wire/kalium/util/string/NumberByteArrayTest.kt rename to util/src/commonTest/kotlin/com/wire/kalium/util/IntExtTests.kt index 2b1b5b9f220..1ad2e21cf76 100644 --- a/util/src/commonTest/kotlin/com/wire/kalium/util/string/NumberByteArrayTest.kt +++ b/util/src/commonTest/kotlin/com/wire/kalium/util/IntExtTests.kt @@ -16,14 +16,16 @@ * along with this program. If not, see http://www.gnu.org/licenses/. */ -package com.wire.kalium.util.string +package com.wire.kalium.util import com.wire.kalium.util.int.toByteArray +import com.wire.kalium.util.int.toHexString import com.wire.kalium.util.long.toByteArray +import com.wire.kalium.util.string.toHexString import kotlin.test.Test import kotlin.test.assertEquals -class NumberByteArrayTest { +class IntExtTests { @Test fun givenMaxLongValue_whenConvertingToByteArray_HexStringIsEqualToTheExpected() { @@ -67,4 +69,10 @@ class NumberByteArrayTest { assertEquals("00000002540BE400", result.toHexString().uppercase()) } + @Test + fun givenAnInteger_whenConvertingToHex_HexValueIsAsExpected(){ + val given = 2 + val expected= "0x000$given" + assertEquals(expected, given.toHexString()) + } }