diff --git a/cryptography/src/commonTest/kotlin/com/wire/kalium/cryptography/MLSClientTest.kt b/cryptography/src/commonTest/kotlin/com/wire/kalium/cryptography/MLSClientTest.kt index 3c06302039e..ce89b8e8931 100644 --- a/cryptography/src/commonTest/kotlin/com/wire/kalium/cryptography/MLSClientTest.kt +++ b/cryptography/src/commonTest/kotlin/com/wire/kalium/cryptography/MLSClientTest.kt @@ -33,11 +33,15 @@ class MLSClientTest : BaseMLSClientTest() { } private suspend fun createClient(user: SampleUser): MLSClient { - return createMLSClient(user.qualifiedClientId, allowedCipherSuites = ALLOWED_CIPHER_SUITES, DEFAULT_CIPHER_SUITES) + return createMLSClient( + clientId = user.qualifiedClientId, + allowedCipherSuites = ALLOWED_CIPHER_SUITES, + defaultCipherSuite = DEFAULT_CIPHER_SUITES + ) } @Test - fun givemMlsClient_whenCallingGetDefaultCipherSuite_ReturnExpectedValue() = runTest { + fun givenMlsClient_whenCallingGetDefaultCipherSuite_ReturnExpectedValue() = runTest { val mlsClient = createClient(ALICE1) assertEquals(DEFAULT_CIPHER_SUITES, mlsClient.getDefaultCipherSuite()) } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationGroupRepository.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationGroupRepository.kt index e02fa479b99..0cb58e990fa 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationGroupRepository.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationGroupRepository.kt @@ -30,6 +30,7 @@ import com.wire.kalium.logic.data.id.toApi import com.wire.kalium.logic.data.id.toDao import com.wire.kalium.logic.data.id.toModel import com.wire.kalium.logic.data.message.MessageContent.MemberChange.FailedToAdd +import com.wire.kalium.logic.data.mls.CipherSuite import com.wire.kalium.logic.data.service.ServiceId import com.wire.kalium.logic.data.user.UserId import com.wire.kalium.logic.data.user.UserRepository @@ -130,20 +131,13 @@ internal class ConversationGroupRepositoryImpl( } when (apiResult) { - is Either.Left -> { - val canRetryOnce = apiResult.value.hasUnreachableDomainsError && lastUsersAttempt is LastUsersAttempt.None - if (canRetryOnce) { - extractValidUsersForRetryableError(apiResult.value, usersList) - .flatMap { (validUsers, failedUsers, failType) -> - // edge case, in case backend goes 🍌 and returns non-matching domains - if (failedUsers.isEmpty()) Either.Left(apiResult.value) - - createGroupConversation(name, validUsers, options, LastUsersAttempt.Failed(failedUsers, failType)) - } - } else { - Either.Left(apiResult.value) - } - } + is Either.Left -> handleCreateConverstionFailure( + apiResult = apiResult, + usersList = usersList, + name = name, + options = options, + lastUsersAttempt = lastUsersAttempt + ) is Either.Right -> handleGroupConversationCreated(apiResult.value, selfTeamId, usersList, lastUsersAttempt) } @@ -210,6 +204,27 @@ internal class ConversationGroupRepositoryImpl( } } + private suspend fun handleCreateConverstionFailure( + apiResult: Either.Left, + usersList: List, + name: String?, + options: ConversationOptions, + lastUsersAttempt: LastUsersAttempt + ): Either { + val canRetryOnce = apiResult.value.hasUnreachableDomainsError && lastUsersAttempt is LastUsersAttempt.None + return if (canRetryOnce) { + extractValidUsersForRetryableError(apiResult.value, usersList) + .flatMap { (validUsers, failedUsers, failType) -> + // edge case, in case backend goes 🍌 and returns non-matching domains + if (failedUsers.isEmpty()) Either.Left(apiResult.value) + + createGroupConversation(name, validUsers, options, LastUsersAttempt.Failed(failedUsers, failType)) + } + } else { + Either.Left(apiResult.value) + } + } + override suspend fun addMembers( userIdList: List, conversationId: ConversationId @@ -224,11 +239,21 @@ internal class ConversationGroupRepositoryImpl( tryAddMembersToCloudAndStorage(userIdList, conversationId, LastUsersAttempt.None) .flatMap { // best effort approach for migrated conversations, no retries - mlsConversationRepository.addMemberToMLSGroup(GroupID(protocol.groupId), userIdList) + mlsConversationRepository.addMemberToMLSGroup( + GroupID(protocol.groupId), + userIdList, + CipherSuite.fromTag(protocol.cipherSuite.cipherSuiteTag) + ) } is ConversationEntity.ProtocolInfo.MLS -> { - tryAddMembersToMLSGroup(conversationId, protocol.groupId, userIdList, LastUsersAttempt.None) + tryAddMembersToMLSGroup( + conversationId, + protocol.groupId, + userIdList, + LastUsersAttempt.None, + cipherSuite = CipherSuite.fromTag(protocol.cipherSuite.cipherSuiteTag) + ) } } } @@ -237,14 +262,22 @@ internal class ConversationGroupRepositoryImpl( * Handle the error cases and retry for claimPackages offline and out of packages. * Handle error case and retry for sendingCommit unreachable or missing legal hold consent. */ + @Suppress("LongMethod") private suspend fun tryAddMembersToMLSGroup( conversationId: ConversationId, groupId: String, userIdList: List, lastUsersAttempt: LastUsersAttempt, + cipherSuite: CipherSuite, remainingAttempts: Int = 2 ): Either { - return when (val addingMemberResult = mlsConversationRepository.addMemberToMLSGroup(GroupID(groupId), userIdList)) { + return when ( + val addingMemberResult = mlsConversationRepository.addMemberToMLSGroup( + GroupID(groupId), + userIdList, + cipherSuite + ) + ) { is Either.Right -> handleMLSMembersNotAdded(conversationId, lastUsersAttempt) is Either.Left -> { addingMemberResult.value.handleMLSMembersFailed( @@ -253,17 +286,20 @@ internal class ConversationGroupRepositoryImpl( userIdList = userIdList, lastUsersAttempt = lastUsersAttempt, remainingAttempts = remainingAttempts, + cipherSuite = cipherSuite ) } } } + @Suppress("LongMethod") private suspend fun CoreFailure.handleMLSMembersFailed( conversationId: ConversationId, groupId: String, userIdList: List, lastUsersAttempt: LastUsersAttempt, remainingAttempts: Int, + cipherSuite: CipherSuite ): Either { return when { // claiming key packages offline or out of packages @@ -277,7 +313,8 @@ internal class ConversationGroupRepositoryImpl( failedUsers = lastUsersAttempt.failedUsers + failedUsers, failType = FailedToAdd.Type.Federation, ), - remainingAttempts = remainingAttempts - 1 + remainingAttempts = remainingAttempts - 1, + cipherSuite = cipherSuite ) } @@ -292,7 +329,8 @@ internal class ConversationGroupRepositoryImpl( failedUsers = lastUsersAttempt.failedUsers + failedUsers, failType = FailedToAdd.Type.Federation, ), - remainingAttempts = remainingAttempts - 1 + remainingAttempts = remainingAttempts - 1, + cipherSuite = cipherSuite ) } @@ -308,7 +346,8 @@ internal class ConversationGroupRepositoryImpl( failedUsers = lastUsersAttempt.failedUsers + failedUsers, failType = FailedToAdd.Type.LegalHold, ), - remainingAttempts = remainingAttempts - 1 + remainingAttempts = remainingAttempts - 1, + cipherSuite = cipherSuite ) } } @@ -479,7 +518,11 @@ internal class ConversationGroupRepositoryImpl( is ConversationEntity.ProtocolInfo.MLSCapable -> { joinExistingMLSConversation(conversationId).flatMap { - mlsConversationRepository.addMemberToMLSGroup(GroupID(protocol.groupId), listOf(selfUserId)) + mlsConversationRepository.addMemberToMLSGroup( + GroupID(protocol.groupId), + listOf(selfUserId), + CipherSuite.fromTag(protocol.cipherSuite.cipherSuiteTag) + ) } } } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt index 7408105994b..6dab9c2cb9f 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt @@ -141,7 +141,12 @@ interface MLSConversationRepository { suspend fun establishMLSSubConversationGroup(groupID: GroupID, parentId: ConversationId): Either suspend fun hasEstablishedMLSGroup(groupID: GroupID): Either - suspend fun addMemberToMLSGroup(groupID: GroupID, userIdList: List): Either + suspend fun addMemberToMLSGroup( + groupID: GroupID, + userIdList: List, + cipherSuite: CipherSuite + ): Either + suspend fun removeMembersFromMLSGroup(groupID: GroupID, userIdList: List): Either suspend fun removeClientsFromMLSGroup(groupID: GroupID, clientIdList: List): Either suspend fun leaveGroup(groupID: GroupID): Either @@ -202,7 +207,7 @@ private fun CoreFailure.getStrategy( // TODO: refactor this repository as it's doing too much. // A Repository should be a dummy class that get and set some values -@Suppress("TooManyFunctions", "LongParameterList") +@Suppress("TooManyFunctions", "LongParameterList", "LargeClass") internal class MLSConversationDataSource( private val selfUserId: UserId, private val keyPackageRepository: KeyPackageRepository, @@ -448,23 +453,29 @@ internal class MLSConversationDataSource( conversationDAO.getProposalTimers().map { it.map(conversationMapper::fromDaoModel) }.flatten() ) - override suspend fun addMemberToMLSGroup(groupID: GroupID, userIdList: List): Either = + override suspend fun addMemberToMLSGroup( + groupID: GroupID, + userIdList: List, + cipherSuite: CipherSuite + ): Either = internalAddMemberToMLSGroup( groupID = groupID, userIdList = userIdList, retryOnStaleMessage = true, - allowPartialMemberList = false + allowPartialMemberList = false, + cipherSuite = cipherSuite ).map { Unit } private suspend fun internalAddMemberToMLSGroup( groupID: GroupID, userIdList: List, retryOnStaleMessage: Boolean, + cipherSuite: CipherSuite, allowPartialMemberList: Boolean = false, ): Either = withContext(serialDispatcher) { commitPendingProposals(groupID).flatMap { produceAndSendCommitWithRetryAndResult(groupID, retryOnStaleMessage = retryOnStaleMessage) { - keyPackageRepository.claimKeyPackages(userIdList).flatMap { result -> + keyPackageRepository.claimKeyPackages(userIdList, cipherSuite).flatMap { result -> if (result.usersWithoutKeyPackagesAvailable.isNotEmpty() && !allowPartialMemberList) { Either.Left(CoreFailure.MissingKeyPackages(result.usersWithoutKeyPackagesAvailable)) } else { @@ -606,7 +617,8 @@ internal class MLSConversationDataSource( groupID = groupID, userIdList = members, retryOnStaleMessage = false, - allowPartialMemberList = allowPartialMemberList + allowPartialMemberList = allowPartialMemberList, + cipherSuite = CipherSuite.fromTag(mlsClient.getDefaultCipherSuite()) ).onFailure { wrapMLSRequest { mlsClient.wipeConversation(groupID.toCrypto()) diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/keypackage/KeyPackageRepository.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/keypackage/KeyPackageRepository.kt index a8e13f0b4a8..eefe69da366 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/keypackage/KeyPackageRepository.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/keypackage/KeyPackageRepository.kt @@ -25,6 +25,7 @@ import com.wire.kalium.logic.data.conversation.ClientId import com.wire.kalium.logic.data.conversation.mls.KeyPackageClaimResult import com.wire.kalium.logic.data.id.CurrentClientIdProvider import com.wire.kalium.logic.data.id.toApi +import com.wire.kalium.logic.data.mls.CipherSuite import com.wire.kalium.logic.data.user.UserId import com.wire.kalium.logic.functional.Either import com.wire.kalium.logic.functional.flatMap @@ -50,7 +51,10 @@ interface KeyPackageRepository { * available. If the operation fails, it will be [Either.Left] with a [CoreFailure] object indicating the reason for the failure. * If **no** KeyPackages are available, [CoreFailure.MissingKeyPackages] will be the cause. */ - suspend fun claimKeyPackages(userIds: List): Either + suspend fun claimKeyPackages( + userIds: List, + cipherSuite: CipherSuite + ): Either suspend fun uploadNewKeyPackages(clientId: ClientId, amount: Int = 100): Either @@ -61,7 +65,6 @@ interface KeyPackageRepository { suspend fun getAvailableKeyPackageCount(clientId: ClientId): Either suspend fun validKeyPackageCount(clientId: ClientId): Either - } class KeyPackageDataSource( @@ -71,13 +74,22 @@ class KeyPackageDataSource( private val selfUserId: UserId, ) : KeyPackageRepository { - override suspend fun claimKeyPackages(userIds: List): Either = + override suspend fun claimKeyPackages( + userIds: List, + cipherSuite: CipherSuite + ): Either = currentClientIdProvider().flatMap { selfClientId -> val failedUsers = mutableSetOf() val claimedKeyPackages = mutableListOf() userIds.forEach { userId -> wrapApiRequest { - keyPackageApi.claimKeyPackages(KeyPackageApi.Param.SkipOwnClient(userId.toApi(), selfClientId.value)) + keyPackageApi.claimKeyPackages( + KeyPackageApi.Param.SkipOwnClient( + userId.toApi(), + selfClientId.value, + cipherSuite = cipherSuite.tag + ) + ) }.fold({ failedUsers.add(userId) }) { if (it.keyPackages.isEmpty() && userId != selfUserId) { failedUsers.add(userId) diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrator.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrator.kt index d97abea02b8..0577055b6b3 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrator.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrator.kt @@ -138,7 +138,11 @@ internal class MLSMigratorImpl( mlsConversationRepository.establishMLSGroup(protocolInfo.groupId, emptyList()) .flatMap { conversationRepository.getConversationMembers(conversationId).flatMap { members -> - mlsConversationRepository.addMemberToMLSGroup(protocolInfo.groupId, members) + mlsConversationRepository.addMemberToMLSGroup( + protocolInfo.groupId, + members, + protocolInfo.cipherSuite + ) } } } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/client/E2EIClientProviderTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/client/E2EIClientProviderTest.kt index 19ecb449588..a117e0ee68b 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/client/E2EIClientProviderTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/client/E2EIClientProviderTest.kt @@ -172,7 +172,7 @@ class E2EIClientProviderTest { return this to e2eiClientProvider } - suspend fun withGetOrFetchMLSConfig(result: SupportedCipherSuite) { + override suspend fun withGetOrFetchMLSConfig(result: SupportedCipherSuite) { coEvery { mlsClientProvider.getOrFetchMLSConfig() }.returns(result.right()) } } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/ConversationGroupRepositoryTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/ConversationGroupRepositoryTest.kt index 23ea9e22075..acb7c4cf74f 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/ConversationGroupRepositoryTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/ConversationGroupRepositoryTest.kt @@ -32,6 +32,7 @@ import com.wire.kalium.logic.data.id.toDao import com.wire.kalium.logic.data.id.toModel import com.wire.kalium.logic.data.legalhold.ListUsersLegalHoldConsent import com.wire.kalium.logic.data.message.MessageContent +import com.wire.kalium.logic.data.mls.CipherSuite import com.wire.kalium.logic.data.service.ServiceId import com.wire.kalium.logic.data.user.UserRepository import com.wire.kalium.logic.framework.TestConversation @@ -73,9 +74,9 @@ import com.wire.kalium.persistence.dao.conversation.ConversationViewEntity import io.ktor.http.HttpStatusCode import io.mockative.Mock import io.mockative.any -import io.mockative.eq import io.mockative.coEvery import io.mockative.coVerify +import io.mockative.eq import io.mockative.matches import io.mockative.mock import io.mockative.once @@ -506,7 +507,11 @@ class ConversationGroupRepositoryTest { }.wasNotInvoked() coVerify { - arrangement.mlsConversationRepository.addMemberToMLSGroup(eq(GROUP_ID), eq(listOf(TestConversation.USER_1))) + arrangement.mlsConversationRepository.addMemberToMLSGroup( + eq(GROUP_ID), + eq(listOf(TestConversation.USER_1)), + eq(CipherSuite.fromTag(CIPHER_SUITE.cipherSuiteTag)) + ) }.wasInvoked(exactly = once) } @@ -532,7 +537,11 @@ class ConversationGroupRepositoryTest { }.wasInvoked(exactly = once) coVerify { - arrangement.mlsConversationRepository.addMemberToMLSGroup(eq(GROUP_ID), eq(listOf(TestConversation.USER_1))) + arrangement.mlsConversationRepository.addMemberToMLSGroup( + eq(GROUP_ID), + eq(listOf(TestConversation.USER_1)), + eq(CipherSuite.fromTag(CIPHER_SUITE.cipherSuiteTag)) + ) }.wasInvoked(exactly = once) } @@ -758,7 +767,11 @@ class ConversationGroupRepositoryTest { }.wasInvoked(exactly = once) coVerify { - arrangement.mlsConversationRepository.addMemberToMLSGroup(eq(GroupID(MLS_PROTOCOL_INFO.groupId)), eq(listOf(TestUser.SELF.id))) + arrangement.mlsConversationRepository.addMemberToMLSGroup( + eq(GroupID(MLS_PROTOCOL_INFO.groupId)), + eq(listOf(TestUser.SELF.id)), + eq(CipherSuite.fromTag(CIPHER_SUITE.cipherSuiteTag)) + ) }.wasInvoked(exactly = once) } @@ -801,7 +814,8 @@ class ConversationGroupRepositoryTest { coVerify { arrangement.mlsConversationRepository.addMemberToMLSGroup( eq(GroupID(MIXED_PROTOCOL_INFO.groupId)), - eq(listOf(TestUser.SELF.id)) + eq(listOf(TestUser.SELF.id)), + eq(CipherSuite.fromTag(CIPHER_SUITE.cipherSuiteTag)) ) }.wasInvoked(exactly = once) } @@ -1065,8 +1079,8 @@ class ConversationGroupRepositoryTest { coVerify { arrangement.conversationApi.addMember( matches { - it.users.size == expectedValidUsersCount && it.users.first().domain != failedDomain - }, any() + it.users.size == expectedValidUsersCount && it.users.first().domain != failedDomain + }, any() ) }.wasInvoked(exactly = once) @@ -1253,15 +1267,21 @@ class ConversationGroupRepositoryTest { val expectedFullUserIdsForRequestCount = 2 val expectedValidUsersWithKeyPackagesCount = 1 coVerify { - arrangement.mlsConversationRepository.addMemberToMLSGroup(any(), matches { - it.size == expectedFullUserIdsForRequestCount - }) + arrangement.mlsConversationRepository.addMemberToMLSGroup( + any(), matches { + it.size == expectedFullUserIdsForRequestCount + }, + eq(CipherSuite.fromTag(CIPHER_SUITE.cipherSuiteTag)) + ) }.wasInvoked(exactly = once) coVerify { - arrangement.mlsConversationRepository.addMemberToMLSGroup(any(), matches { - it.size == expectedValidUsersWithKeyPackagesCount && it.first() == TestConversation.USER_1 - }) + arrangement.mlsConversationRepository.addMemberToMLSGroup( + any(), matches { + it.size == expectedValidUsersWithKeyPackagesCount && it.first() == TestConversation.USER_1 + }, + eq(CipherSuite.fromTag(CIPHER_SUITE.cipherSuiteTag)) + ) }.wasInvoked(exactly = once) coVerify { @@ -1296,15 +1316,21 @@ class ConversationGroupRepositoryTest { val expectedFullUserIdsForRequestCount = 2 val expectedValidUsersWithKeyPackagesCount = 1 coVerify { - arrangement.mlsConversationRepository.addMemberToMLSGroup(any(), matches { - it.size == expectedFullUserIdsForRequestCount - }) + arrangement.mlsConversationRepository.addMemberToMLSGroup( + any(), matches { + it.size == expectedFullUserIdsForRequestCount + }, + eq(CipherSuite.fromTag(CIPHER_SUITE.cipherSuiteTag)) + ) }.wasInvoked(exactly = once) coVerify { - arrangement.mlsConversationRepository.addMemberToMLSGroup(any(), matches { - it.size == expectedValidUsersWithKeyPackagesCount && it.first() == TestConversation.USER_1 - }) + arrangement.mlsConversationRepository.addMemberToMLSGroup( + any(), matches { + it.size == expectedValidUsersWithKeyPackagesCount && it.first() == TestConversation.USER_1 + }, + eq(CipherSuite.fromTag(CIPHER_SUITE.cipherSuiteTag)) + ) }.wasInvoked(exactly = once) coVerify { @@ -1337,21 +1363,30 @@ class ConversationGroupRepositoryTest { // then val initialCountUsers = expectedInitialUsers.size coVerify { - arrangement.mlsConversationRepository.addMemberToMLSGroup(any(), matches { - it.size == initialCountUsers - }) + arrangement.mlsConversationRepository.addMemberToMLSGroup( + any(), matches { + it.size == initialCountUsers + }, + eq(CipherSuite.fromTag(CIPHER_SUITE.cipherSuiteTag)) + ) }.wasInvoked(exactly = once) coVerify { - arrangement.mlsConversationRepository.addMemberToMLSGroup(any(), matches { - it.size == initialCountUsers - 1 // removed 1 failed users with key packages - }) + arrangement.mlsConversationRepository.addMemberToMLSGroup( + any(), matches { + it.size == initialCountUsers - 1 // removed 1 failed users with key packages + }, + eq(CipherSuite.fromTag(CIPHER_SUITE.cipherSuiteTag)) + ) }.wasInvoked(exactly = once) coVerify { - arrangement.mlsConversationRepository.addMemberToMLSGroup(any(), matches { - it.size == initialCountUsers - 2 // removed 1 failed user with commit bundle federated error - }) + arrangement.mlsConversationRepository.addMemberToMLSGroup( + any(), matches { + it.size == initialCountUsers - 2 // removed 1 failed user with commit bundle federated error + }, + eq(CipherSuite.fromTag(CIPHER_SUITE.cipherSuiteTag)) + ) }.wasInvoked(exactly = once) coVerify { @@ -1540,7 +1575,7 @@ class ConversationGroupRepositoryTest { * Mocks a sequence of [NetworkResponse]s for [ConversationApi.createNewConversation]. */ suspend fun withCreateNewConversationAPIResponses(result: Array>): Arrangement = apply { - coEvery{conversationApi.createNewConversation(any())} + coEvery { conversationApi.createNewConversation(any()) } .thenReturnSequentially(*result) } @@ -1605,80 +1640,80 @@ class ConversationGroupRepositoryTest { coEvery { conversationApi.addMember(any(), any()) }.returns( - NetworkResponse.Success( - TestConversation.ADD_MEMBER_TO_CONVERSATION_SUCCESSFUL_RESPONSE, - mapOf(), - HttpStatusCode.OK.value - ) + NetworkResponse.Success( + TestConversation.ADD_MEMBER_TO_CONVERSATION_SUCCESSFUL_RESPONSE, + mapOf(), + HttpStatusCode.OK.value ) + ) } suspend fun withAddServiceAPISucceedChanged() = apply { coEvery { conversationApi.addService(any(), any()) }.returns( - NetworkResponse.Success( - TestConversation.ADD_SERVICE_TO_CONVERSATION_SUCCESSFUL_RESPONSE, - mapOf(), - HttpStatusCode.OK.value - ) + NetworkResponse.Success( + TestConversation.ADD_SERVICE_TO_CONVERSATION_SUCCESSFUL_RESPONSE, + mapOf(), + HttpStatusCode.OK.value ) + ) } suspend fun withAddMemberAPISucceedUnchanged() = apply { coEvery { conversationApi.addMember(any(), any()) }.returns( - NetworkResponse.Success( - ConversationMemberAddedResponse.Unchanged, - mapOf(), - HttpStatusCode.OK.value - ) + NetworkResponse.Success( + ConversationMemberAddedResponse.Unchanged, + mapOf(), + HttpStatusCode.OK.value ) + ) } suspend fun withAddMemberAPIFailed() = apply { coEvery { conversationApi.addMember(any(), any()) }.returns( - NetworkResponse.Error( - KaliumException.ServerError(ErrorResponse(500, "error_message", "error_label")) - ) + NetworkResponse.Error( + KaliumException.ServerError(ErrorResponse(500, "error_message", "error_label")) ) + ) } suspend fun withDeleteMemberAPISucceedChanged() = apply { coEvery { conversationApi.removeMember(any(), any()) }.returns( - NetworkResponse.Success( - TestConversation.REMOVE_MEMBER_FROM_CONVERSATION_SUCCESSFUL_RESPONSE, - mapOf(), - HttpStatusCode.OK.value - ) + NetworkResponse.Success( + TestConversation.REMOVE_MEMBER_FROM_CONVERSATION_SUCCESSFUL_RESPONSE, + mapOf(), + HttpStatusCode.OK.value ) + ) } suspend fun withDeleteMemberAPISucceedUnchanged() = apply { coEvery { conversationApi.removeMember(any(), any()) }.returns( - NetworkResponse.Success( - ConversationMemberRemovedResponse.Unchanged, - mapOf(), - HttpStatusCode.OK.value - ) + NetworkResponse.Success( + ConversationMemberRemovedResponse.Unchanged, + mapOf(), + HttpStatusCode.OK.value ) + ) } suspend fun withDeleteMemberAPIFailed() = apply { coEvery { conversationApi.removeMember(any(), any()) }.returns( - NetworkResponse.Error( - KaliumException.ServerError(ErrorResponse(500, "error_message", "error_label")) - ) + NetworkResponse.Error( + KaliumException.ServerError(ErrorResponse(500, "error_message", "error_label")) ) + ) } suspend fun withFetchUsersIfUnknownByIdsSuccessful() = apply { @@ -1707,7 +1742,7 @@ class ConversationGroupRepositoryTest { suspend fun withSuccessfulAddMemberToMLSGroup() = apply { coEvery { - mlsConversationRepository.addMemberToMLSGroup(any(), any()) + mlsConversationRepository.addMemberToMLSGroup(any(), any(), any()) }.returns(Either.Right(Unit)) } @@ -1716,7 +1751,7 @@ class ConversationGroupRepositoryTest { */ suspend fun withAddingMemberToMlsGroupResults(vararg results: Either) = apply { coEvery { - mlsConversationRepository.addMemberToMLSGroup(any(), any()) + mlsConversationRepository.addMemberToMLSGroup(any(), any(), any()) }.thenReturnSequentially(*results) } @@ -1732,22 +1767,22 @@ class ConversationGroupRepositoryTest { coEvery { conversationApi.generateGuestRoomLink(any(), any()) }.returns( - NetworkResponse.Success( - result, - mapOf(), - HttpStatusCode.OK.value - ) + NetworkResponse.Success( + result, + mapOf(), + HttpStatusCode.OK.value ) + ) } suspend fun withFailedCallToGenerateGuestRoomLinkApi() = apply { coEvery { conversationApi.generateGuestRoomLink(any(), any()) }.returns( - NetworkResponse.Error( - KaliumException.ServerError(ErrorResponse(500, "error_message", "error_label")) - ) + NetworkResponse.Error( + KaliumException.ServerError(ErrorResponse(500, "error_message", "error_label")) ) + ) } suspend fun withDeleteGuestLink() = apply { @@ -1760,22 +1795,22 @@ class ConversationGroupRepositoryTest { coEvery { conversationApi.revokeGuestRoomLink(any()) }.returns( - NetworkResponse.Success( - Unit, - mapOf(), - HttpStatusCode.OK.value - ) + NetworkResponse.Success( + Unit, + mapOf(), + HttpStatusCode.OK.value ) + ) } suspend fun withFailedCallToRevokeGuestRoomLinkApi() = apply { coEvery { conversationApi.revokeGuestRoomLink(any()) }.returns( - NetworkResponse.Error( - KaliumException.ServerError(ErrorResponse(500, "error_message", "error_label")) - ) + NetworkResponse.Error( + KaliumException.ServerError(ErrorResponse(500, "error_message", "error_label")) ) + ) } suspend fun withSuccessfulFetchOfGuestRoomLink( @@ -1805,22 +1840,22 @@ class ConversationGroupRepositoryTest { coEvery { conversationApi.updateMessageTimer(any(), any()) }.returns( - NetworkResponse.Success( - event, - emptyMap(), - HttpStatusCode.NoContent.value - ) + NetworkResponse.Success( + event, + emptyMap(), + HttpStatusCode.NoContent.value ) + ) } suspend fun withUpdateMessageTimerAPIFailed() = apply { coEvery { conversationApi.updateMessageTimer(any(), any()) }.returns( - NetworkResponse.Error( - KaliumException.ServerError(ErrorResponse(500, "error_message", "error_label")) - ) + NetworkResponse.Error( + KaliumException.ServerError(ErrorResponse(500, "error_message", "error_label")) ) + ) } suspend fun withSuccessfulHandleMessageTimerUpdateEvent() = apply { @@ -1864,7 +1899,8 @@ class ConversationGroupRepositoryTest { fun arrange() = this to conversationGroupRepository } - companion object { + private companion object { + val CIPHER_SUITE = ConversationEntity.CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 private const val RAW_GROUP_ID = "mlsGroupId" val GROUP_ID = GroupID(RAW_GROUP_ID) val PROTEUS_PROTOCOL_INFO = ConversationEntity.ProtocolInfo.Proteus @@ -1874,7 +1910,7 @@ class ConversationGroupRepositoryTest { groupState = ConversationEntity.GroupState.ESTABLISHED, 0UL, Instant.parse("2021-03-30T15:36:00.000Z"), - cipherSuite = ConversationEntity.CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 + cipherSuite = CIPHER_SUITE ) val MIXED_PROTOCOL_INFO = ConversationEntity.ProtocolInfo .Mixed( diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepositoryTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepositoryTest.kt index afdce3f788f..e7ba9d93a8d 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepositoryTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepositoryTest.kt @@ -34,6 +34,7 @@ import com.wire.kalium.logic.CoreFailure import com.wire.kalium.logic.E2EIFailure import com.wire.kalium.logic.StorageFailure import com.wire.kalium.logic.data.client.MLSClientProvider +import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.CIPHER_SUITE import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.COMMIT_BUNDLE import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.CRYPTO_CLIENT_ID import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.E2EI_CONVERSATION_CLIENT_INFO_ENTITY @@ -303,7 +304,7 @@ class MLSConversationRepositoryTest { .withSendCommitBundleSuccessful() .arrange() - mlsConversationRepository.addMemberToMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_ID1)) + mlsConversationRepository.addMemberToMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_ID1), CIPHER_SUITE) coVerify { arrangement.checkRevocationList.invoke(any()) @@ -387,7 +388,12 @@ class MLSConversationRepositoryTest { result.shouldSucceed() coVerify { - arrangement.keyPackageRepository.claimKeyPackages(matches { it.containsAll(listOf(TestConversation.USER_1)) }) + arrangement.keyPackageRepository.claimKeyPackages( + matches { + it.containsAll(listOf(TestConversation.USER_1)) + }, + eq(CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519) + ) }.wasInvoked(once) } @@ -421,7 +427,7 @@ class MLSConversationRepositoryTest { .withSendCommitBundleSuccessful() .arrange() - val result = mlsConversationRepository.addMemberToMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_ID1)) + val result = mlsConversationRepository.addMemberToMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_ID1), CIPHER_SUITE) result.shouldSucceed() coVerify { @@ -447,7 +453,7 @@ class MLSConversationRepositoryTest { .withSendCommitBundleSuccessful(events = listOf(Arrangement.MEMBER_JOIN_EVENT)) .arrange() - val result = mlsConversationRepository.addMemberToMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_ID1)) + val result = mlsConversationRepository.addMemberToMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_ID1), CIPHER_SUITE) result.shouldSucceed() coVerify { @@ -468,7 +474,7 @@ class MLSConversationRepositoryTest { .withSendCommitBundleSuccessful() .arrange() - val result = mlsConversationRepository.addMemberToMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_ID1)) + val result = mlsConversationRepository.addMemberToMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_ID1), CIPHER_SUITE) result.shouldSucceed() coVerify { @@ -487,7 +493,7 @@ class MLSConversationRepositoryTest { .withWaitUntilLiveSuccessful() .arrange() - val result = mlsConversationRepository.addMemberToMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_ID1)) + val result = mlsConversationRepository.addMemberToMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_ID1), CIPHER_SUITE) result.shouldSucceed() coVerify { @@ -515,7 +521,7 @@ class MLSConversationRepositoryTest { .withWaitUntilLiveSuccessful() .arrange() - val result = mlsConversationRepository.addMemberToMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_ID1)) + val result = mlsConversationRepository.addMemberToMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_ID1), CIPHER_SUITE) result.shouldSucceed() coVerify { @@ -542,7 +548,7 @@ class MLSConversationRepositoryTest { .withClearProposalTimerSuccessful() .arrange() - val result = mlsConversationRepository.addMemberToMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_ID1)) + val result = mlsConversationRepository.addMemberToMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_ID1), CIPHER_SUITE) result.shouldFail() coVerify { @@ -562,7 +568,7 @@ class MLSConversationRepositoryTest { .withWaitUntilLiveSuccessful() .arrange() - val result = mlsConversationRepository.addMemberToMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_ID1)) + val result = mlsConversationRepository.addMemberToMLSGroup(Arrangement.GROUP_ID, listOf(TestConversation.USER_ID1), CIPHER_SUITE) result.shouldFail() coVerify { @@ -1414,6 +1420,8 @@ class MLSConversationRepositoryTest { @Test fun givenSuccessfulResponses_whenCallingEstablishMLSSubConversationGroup_thenGroupIsCreatedAndCommitBundleIsSentAndAccepted() = runTest { + val defaultCipherSuite = CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 + val (arrangement, mlsConversationRepository) = Arrangement(testKaliumDispatcher) .withCommitPendingProposalsReturningNothing() .withClaimKeyPackagesSuccessful() @@ -1423,6 +1431,7 @@ class MLSConversationRepositoryTest { .withKeyForCipherSuite() .withUpdateKeyingMaterialSuccessful() .withSendCommitBundleSuccessful() + .withGetDefaultCipherSuite(defaultCipherSuite) .arrange() val result = mlsConversationRepository.establishMLSSubConversationGroup(Arrangement.GROUP_ID, TestConversation.ID) @@ -1586,7 +1595,7 @@ class MLSConversationRepositoryTest { usersWithoutKeyPackages: Set = setOf() ) = apply { coEvery { - keyPackageRepository.claimKeyPackages(any()) + keyPackageRepository.claimKeyPackages(any(), any()) }.returns(Either.Right(KeyPackageClaimResult(keyPackages, usersWithoutKeyPackages))) } @@ -1775,6 +1784,7 @@ class MLSConversationRepositoryTest { } companion object { + val CIPHER_SUITE = CipherSuite.MLS_128_DHKEMP256_AES128GCM_SHA256_P256 val TEST_FAILURE = Either.Left(CoreFailure.Unknown(Throwable("an error"))) const val EPOCH = 5UL const val RAW_GROUP_ID = "groupId" diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/NewConversationMembersRepositoryTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/NewConversationMembersRepositoryTest.kt index bc26642d9e1..8e568085c21 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/NewConversationMembersRepositoryTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/NewConversationMembersRepositoryTest.kt @@ -99,13 +99,14 @@ class NewConversationMembersRepositoryTest { } private companion object { + const val GROUP_NAME = "Group Name" val CONVERSATION_RESPONSE = ConversationResponse( "creator", ConversationMembersResponse( ConversationMemberDTO.Self(TestUser.SELF.id.toApi(), "wire_member"), listOf(ConversationMemberDTO.Other(TestUser.OTHER.id.toApi(), "wire_member")) ), - ConversationGroupRepositoryTest.GROUP_NAME, + GROUP_NAME, TestConversation.NETWORK_ID, null, 0UL, diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/keypackage/KeyPackageRepositoryTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/keypackage/KeyPackageRepositoryTest.kt index 1d0187c7080..6aae5ccd319 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/keypackage/KeyPackageRepositoryTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/keypackage/KeyPackageRepositoryTest.kt @@ -25,6 +25,8 @@ import com.wire.kalium.logic.data.conversation.ClientId import com.wire.kalium.logic.data.id.CurrentClientIdProvider import com.wire.kalium.logic.data.id.PlainId import com.wire.kalium.logic.data.id.toApi +import com.wire.kalium.logic.data.keypackage.KeyPackageRepositoryTest.Arrangement.Companion.CIPHER_SUITE +import com.wire.kalium.logic.data.mls.CipherSuite import com.wire.kalium.logic.data.user.UserId import com.wire.kalium.logic.functional.Either import com.wire.kalium.logic.util.shouldFail @@ -39,9 +41,9 @@ import com.wire.kalium.network.utils.NetworkResponse import io.ktor.util.encodeBase64 import io.mockative.Mock import io.mockative.any -import io.mockative.eq import io.mockative.coEvery import io.mockative.coVerify +import io.mockative.eq import io.mockative.mock import io.mockative.once import kotlinx.coroutines.test.runTest @@ -87,7 +89,7 @@ class KeyPackageRepositoryTest { .withClaimKeyPackagesSuccessful(Arrangement.USER_ID) .arrange() - val result = keyPackageRepository.claimKeyPackages(listOf(Arrangement.USER_ID)) + val result = keyPackageRepository.claimKeyPackages(listOf(Arrangement.USER_ID), CIPHER_SUITE) result.shouldSucceed { keyPackageResult -> assertEquals(listOf(Arrangement.CLAIMED_KEY_PACKAGES.keyPackages[0]), keyPackageResult.successfullyFetchedKeyPackages) @@ -104,7 +106,7 @@ class KeyPackageRepositoryTest { .withClaimKeyPackagesSuccessfulWithEmptyResponse(userWithout) .arrange() - val result = keyPackageRepository.claimKeyPackages(listOf(userWith, userWithout)) + val result = keyPackageRepository.claimKeyPackages(listOf(userWith, userWithout), CIPHER_SUITE) result.shouldSucceed { keyPackageResult -> assertEquals( @@ -132,7 +134,7 @@ class KeyPackageRepositoryTest { } .arrange() - val result = keyPackageRepository.claimKeyPackages(usersWithout.toList()) + val result = keyPackageRepository.claimKeyPackages(usersWithout.toList(), CIPHER_SUITE) result.shouldFail { failure -> assertIs(failure) @@ -147,7 +149,7 @@ class KeyPackageRepositoryTest { .withClaimKeyPackagesSuccessfulWithEmptyResponse(Arrangement.USER_ID) .arrange() - val result = keyPackageRepository.claimKeyPackages(listOf(Arrangement.USER_ID)) + val result = keyPackageRepository.claimKeyPackages(listOf(Arrangement.USER_ID), CIPHER_SUITE) result.shouldFail { failure -> assertEquals(CoreFailure.MissingKeyPackages(setOf(Arrangement.USER_ID)), failure) @@ -162,7 +164,7 @@ class KeyPackageRepositoryTest { .withClaimKeyPackagesSuccessfulWithEmptyResponse(Arrangement.SELF_USER_ID) .arrange() - val result = keyPackageRepository.claimKeyPackages(listOf(Arrangement.SELF_USER_ID)) + val result = keyPackageRepository.claimKeyPackages(listOf(Arrangement.SELF_USER_ID), CIPHER_SUITE) result.shouldSucceed { keyPackages -> assertEquals(emptyList(), keyPackages.successfullyFetchedKeyPackages) @@ -212,13 +214,23 @@ class KeyPackageRepositoryTest { suspend fun withClaimKeyPackagesSuccessful(userId: UserId) = apply { coEvery { - keyPackageApi.claimKeyPackages(eq(KeyPackageApi.Param.SkipOwnClient(userId.toApi(), SELF_CLIENT_ID.value))) + keyPackageApi.claimKeyPackages( + eq(KeyPackageApi.Param.SkipOwnClient(userId.toApi(), SELF_CLIENT_ID.value, CIPHER_SUITE.tag)) + ) }.returns(NetworkResponse.Success(CLAIMED_KEY_PACKAGES, mapOf(), 200)) } suspend fun withClaimKeyPackagesSuccessfulWithEmptyResponse(userId: UserId) = apply { coEvery { - keyPackageApi.claimKeyPackages(eq(KeyPackageApi.Param.SkipOwnClient(userId.toApi(), SELF_CLIENT_ID.value))) + keyPackageApi.claimKeyPackages( + eq( + KeyPackageApi.Param.SkipOwnClient( + userId.toApi(), + SELF_CLIENT_ID.value, + CIPHER_SUITE.tag + ) + ) + ) }.returns(NetworkResponse.Success(EMPTY_CLAIMED_KEY_PACKAGES, mapOf(), 200)) } @@ -226,6 +238,7 @@ class KeyPackageRepositoryTest { internal companion object { const val KEY_PACKAGE_COUNT = 100 + val CIPHER_SUITE = CipherSuite.MLS_256_DHKEMP384_AES256GCM_SHA384_P384 val KEY_PACKAGE_COUNT_DTO = KeyPackageCountDTO(KEY_PACKAGE_COUNT) val SELF_CLIENT_ID: ClientId = PlainId("client_self") val OTHER_CLIENT_ID: ClientId = PlainId("client_other") diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigratorTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigratorTest.kt index d851e9e5a0a..8dcd95904ec 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigratorTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigratorTest.kt @@ -30,6 +30,7 @@ import com.wire.kalium.logic.data.message.SystemMessageInserter import com.wire.kalium.logic.data.mls.CipherSuite import com.wire.kalium.logic.data.user.UserId import com.wire.kalium.logic.data.user.UserRepository +import com.wire.kalium.logic.feature.mlsmigration.MLSMigratorTest.Arrangement.Companion.CIPHER_SUITE import com.wire.kalium.logic.framework.TestConversation import com.wire.kalium.logic.framework.TestTeam import com.wire.kalium.logic.framework.TestUser @@ -82,7 +83,11 @@ class MLSMigratorTest { } coVerify { - arrangement.mlsConversationRepository.addMemberToMLSGroup(eq(Arrangement.MIXED_PROTOCOL_INFO.groupId), eq(Arrangement.MEMBERS)) + arrangement.mlsConversationRepository.addMemberToMLSGroup( + eq(Arrangement.MIXED_PROTOCOL_INFO.groupId), + eq(Arrangement.MEMBERS), + eq(CIPHER_SUITE) + ) } } @@ -115,7 +120,11 @@ class MLSMigratorTest { } coVerify { - arrangement.mlsConversationRepository.addMemberToMLSGroup(eq(Arrangement.MIXED_PROTOCOL_INFO.groupId), eq(Arrangement.MEMBERS)) + arrangement.mlsConversationRepository.addMemberToMLSGroup( + eq(Arrangement.MIXED_PROTOCOL_INFO.groupId), + eq(Arrangement.MEMBERS), + eq(CIPHER_SUITE) + ) } coVerify { @@ -246,6 +255,7 @@ class MLSMigratorTest { conversationRepository.fetchConversation(any()) }.returns(Either.Right(Unit)) } + suspend fun withUpdateProtocolReturns(result: Either = Either.Right(true)) = apply { coEvery { conversationRepository.updateProtocolRemotely(any(), any()) @@ -266,7 +276,7 @@ class MLSMigratorTest { suspend fun withAddMembersSucceeds() = apply { coEvery { - mlsConversationRepository.addMemberToMLSGroup(any(), any()) + mlsConversationRepository.addMemberToMLSGroup(any(), any(), any()) }.returns(Either.Right(Unit)) } @@ -297,6 +307,7 @@ class MLSMigratorTest { } companion object { + val CIPHER_SUITE = CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 val MLS_STALE_MESSAGE_ERROR = KaliumException.InvalidRequestError( ErrorResponse(409, "", "mls-stale-message") ) @@ -307,14 +318,14 @@ class MLSMigratorTest { Conversation.ProtocolInfo.MLSCapable.GroupState.PENDING_JOIN, 0UL, Instant.parse("2021-03-30T15:36:00.000Z"), - cipherSuite = CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 + cipherSuite = CIPHER_SUITE ) val MLS_PROTOCOL_INFO = Conversation.ProtocolInfo.MLS( TestConversation.GROUP_ID, Conversation.ProtocolInfo.MLSCapable.GroupState.PENDING_JOIN, 0UL, Instant.parse("2021-03-30T15:36:00.000Z"), - cipherSuite = CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 + cipherSuite = CIPHER_SUITE ) } } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/provider/E2EIClientProviderArrangement.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/provider/E2EIClientProviderArrangement.kt index 9d61f3a4d0e..94eba24bd57 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/provider/E2EIClientProviderArrangement.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/provider/E2EIClientProviderArrangement.kt @@ -22,9 +22,11 @@ import com.wire.kalium.cryptography.E2EIClient import com.wire.kalium.cryptography.MLSClient import com.wire.kalium.logic.data.client.MLSClientProvider import com.wire.kalium.logic.data.id.CurrentClientIdProvider +import com.wire.kalium.logic.data.mls.SupportedCipherSuite import com.wire.kalium.logic.data.user.SelfUser import com.wire.kalium.logic.data.user.UserRepository import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.functional.right import io.mockative.Mock import io.mockative.any import io.mockative.coEvery @@ -62,6 +64,8 @@ interface E2EIClientProviderArrangement { suspend fun withE2EIEnabled(isEnabled: Boolean) suspend fun withSelfUser(selfUser: SelfUser?) + + suspend fun withGetOrFetchMLSConfig(result: SupportedCipherSuite) } class E2EIClientProviderArrangementImpl : E2EIClientProviderArrangement { @@ -113,4 +117,10 @@ class E2EIClientProviderArrangementImpl : E2EIClientProviderArrangement { }.returns(selfUser) } + override suspend fun withGetOrFetchMLSConfig(result: SupportedCipherSuite) { + coEvery { + mlsClientProvider.getOrFetchMLSConfig() + }.returns(result.right()) + } + } diff --git a/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/keypackage/KeyPackageApi.kt b/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/keypackage/KeyPackageApi.kt index b90ae85b826..b5383fef2ba 100644 --- a/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/keypackage/KeyPackageApi.kt +++ b/network/src/commonMain/kotlin/com/wire/kalium/network/api/base/authenticated/keypackage/KeyPackageApi.kt @@ -23,18 +23,31 @@ import com.wire.kalium.network.utils.NetworkResponse interface KeyPackageApi { - sealed class Param(open val user: UserId) { + sealed class Param { + + abstract val user: UserId + abstract val cipherSuite: Int + abstract val selfClientId: String? /** * @param user user ID to claim key packages from. * @param selfClientId to skip selfClient key package. */ - data class SkipOwnClient(override val user: UserId, val selfClientId: String) : Param(user) + data class SkipOwnClient( + override val user: UserId, + override val selfClientId: String, + override val cipherSuite: Int + ) : Param() /** * @param user user ID to claim key packages from. */ - data class IncludeOwnClient(override val user: UserId) : Param(user) + data class IncludeOwnClient( + override val user: UserId, + override val cipherSuite: Int, + ) : Param() { + override val selfClientId: String? = null + } } /** diff --git a/network/src/commonMain/kotlin/com/wire/kalium/network/api/v5/authenticated/KeyPackageApiV5.kt b/network/src/commonMain/kotlin/com/wire/kalium/network/api/v5/authenticated/KeyPackageApiV5.kt index a7cc33d2c30..844a0419abd 100644 --- a/network/src/commonMain/kotlin/com/wire/kalium/network/api/v5/authenticated/KeyPackageApiV5.kt +++ b/network/src/commonMain/kotlin/com/wire/kalium/network/api/v5/authenticated/KeyPackageApiV5.kt @@ -47,9 +47,10 @@ internal open class KeyPackageApiV5 internal constructor( wrapFederationResponse(response, delegatedHandler = { handleUnsuccessfulResponse(response) }) }) { httpClient.post("$PATH_KEY_PACKAGES/$PATH_CLAIM/${param.user.domain}/${param.user.value}") { - if (param is KeyPackageApi.Param.SkipOwnClient) { - parameter(QUERY_SKIP_OWN, param.selfClientId) + param.selfClientId?.let { + parameter(QUERY_SKIP_OWN, it) } + parameter(QUERY_CIPHER_SUITE, param.cipherSuite) } } @@ -84,5 +85,6 @@ internal open class KeyPackageApiV5 internal constructor( const val PATH_SELF = "self" const val PATH_COUNT = "count" const val QUERY_SKIP_OWN = "skip_own" + const val QUERY_CIPHER_SUITE = "ciphersuite" } } diff --git a/network/src/commonTest/kotlin/com/wire/kalium/api/v5/KeyPackageApiV5Test.kt b/network/src/commonTest/kotlin/com/wire/kalium/api/v5/KeyPackageApiV5Test.kt index 73037752e15..8e4b7ff8ef8 100644 --- a/network/src/commonTest/kotlin/com/wire/kalium/api/v5/KeyPackageApiV5Test.kt +++ b/network/src/commonTest/kotlin/com/wire/kalium/api/v5/KeyPackageApiV5Test.kt @@ -74,11 +74,12 @@ internal class KeyPackageApiV5Test : ApiTest() { assertion = { assertPost() assertPathEqual(KEY_PACKAGE_CLAIM_PATH) + assertQueryParameter("cipherSuite", "$cipherSuite") } ) val keyPackageApi: KeyPackageApi = KeyPackageApiV5(networkClient) - val response = keyPackageApi.claimKeyPackages(KeyPackageApi.Param.IncludeOwnClient(VALID_USER_ID)) + val response = keyPackageApi.claimKeyPackages(KeyPackageApi.Param.IncludeOwnClient(VALID_USER_ID, cipherSuite)) assertTrue(response.isSuccessful()) assertEquals(response.value, VALID_CLAIM_KEY_PACKAGES_RESPONSE.serializableData) } @@ -92,5 +93,6 @@ internal class KeyPackageApiV5Test : ApiTest() { const val KEY_PACKAGE_COUNT_PATH = "/mls/key-packages/self/$VALID_CLIENT_ID/count" const val KEY_PACKAGE_UPLOAD_PATH = "/mls/key-packages/self/$VALID_CLIENT_ID" val KEY_PACKAGE_CLAIM_PATH = "/mls/key-packages/claim/${VALID_USER_ID.domain}/${VALID_USER_ID.value}" + val cipherSuite = 2 } } diff --git a/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Conversations.sq b/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Conversations.sq index 834663577df..a2d302359b7 100644 --- a/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Conversations.sq +++ b/persistence/src/commonMain/db_user/com/wire/kalium/persistence/Conversations.sq @@ -109,6 +109,11 @@ UPDATE Conversation SET mls_group_state = ? WHERE mls_group_id = ?; +updateMlsGroupStateAndCipherSuite: +UPDATE Conversation +SET mls_group_state = :mls_group_state, mls_cipher_suite = :mls_cipher_suite +WHERE mls_group_id = :mls_group_id; + updateConversationNotificationsDateWithTheLastMessage: UPDATE Conversation SET last_notified_date = ( diff --git a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAO.kt b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAO.kt index 70b2868135f..04184c5d23b 100644 --- a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAO.kt +++ b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAO.kt @@ -36,6 +36,11 @@ interface ConversationDAO { suspend fun insertConversations(conversationEntities: List) suspend fun updateConversation(conversationEntity: ConversationEntity) suspend fun updateConversationGroupState(groupState: ConversationEntity.GroupState, groupId: String) + suspend fun updateMlsGroupStateAndCipherSuite( + groupState: ConversationEntity.GroupState, + cipherSuite: ConversationEntity.CipherSuite, + groupId: String + ) suspend fun updateConversationModifiedDate(qualifiedID: QualifiedIDEntity, date: Instant) suspend fun updateConversationNotificationDate(qualifiedID: QualifiedIDEntity) suspend fun updateConversationReadDate(conversationID: QualifiedIDEntity, date: Instant) diff --git a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAOImpl.kt b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAOImpl.kt index 6e96747ae7b..f3cd7d22ae8 100644 --- a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAOImpl.kt +++ b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationDAOImpl.kt @@ -144,6 +144,14 @@ internal class ConversationDAOImpl internal constructor( conversationQueries.updateConversationGroupState(groupState, groupId) } + override suspend fun updateMlsGroupStateAndCipherSuite( + groupState: ConversationEntity.GroupState, + cipherSuite: ConversationEntity.CipherSuite, + groupId: String + ) = withContext(coroutineContext) { + conversationQueries.updateMlsGroupStateAndCipherSuite(groupState, cipherSuite, groupId) + } + override suspend fun updateConversationModifiedDate(qualifiedID: QualifiedIDEntity, date: Instant) = withContext(coroutineContext) { conversationQueries.updateConversationModifiedDate(date, qualifiedID) } diff --git a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/ConversationDAOTest.kt b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/ConversationDAOTest.kt index ca31990a62b..6de2f15a360 100644 --- a/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/ConversationDAOTest.kt +++ b/persistence/src/commonTest/kotlin/com/wire/kalium/persistence/dao/ConversationDAOTest.kt @@ -266,6 +266,23 @@ class ConversationDAOTest : BaseDatabaseTest() { ) } + @Test + fun givenExistingConversation_ThenConversationGroupStateCanBeUpdatedToEstablished() = runTest { + conversationDAO.insertConversation(conversationEntity2) + conversationDAO.updateMlsGroupStateAndCipherSuite( + ConversationEntity.GroupState.PENDING_WELCOME_MESSAGE, + ConversationEntity.CipherSuite.MLS_256_DHKEMP521_AES256GCM_SHA512_P521, + (conversationEntity2.protocolInfo as ConversationEntity.ProtocolInfo.MLS).groupId, + + ) + val result = conversationDAO.getConversationByQualifiedID(conversationEntity2.id) + assertEquals( + (result?.protocolInfo as ConversationEntity.ProtocolInfo.MLS).groupState, ConversationEntity.GroupState.PENDING_WELCOME_MESSAGE + ) + assertEquals( + (result?.protocolInfo as ConversationEntity.ProtocolInfo.MLS).cipherSuite, ConversationEntity.CipherSuite.MLS_256_DHKEMP521_AES256GCM_SHA512_P521 + ) + } @Test fun givenExistingConversation_ThenConversationIsUpdatedOnInsert() = runTest { conversationDAO.insertConversation(conversationEntity1)