diff --git a/.github/workflows/gradle-ios-tests.yml b/.github/workflows/gradle-ios-tests.yml index 18b0e775adc..e5f2d180a21 100644 --- a/.github/workflows/gradle-ios-tests.yml +++ b/.github/workflows/gradle-ios-tests.yml @@ -1,9 +1,9 @@ name: "iOS Tests" on: - merge_group: - pull_request: - types: [ opened, synchronize ] # Don't rerun on `edited` to save time +# merge_group: +# pull_request: +# types: [ opened, synchronize ] # Don't rerun on `edited` to save time concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number }} @@ -14,7 +14,7 @@ jobs: uses: ./.github/workflows/codestyle.yml gradle-run-tests: needs: [detekt] - runs-on: macos-12 + runs-on: macos-latest steps: - name: Checkout diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/CoreFailure.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/CoreFailure.kt index b4ced5084a0..fb147a76018 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/CoreFailure.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/CoreFailure.kt @@ -198,6 +198,7 @@ interface MLSFailure : CoreFailure { data object StaleProposal : MLSFailure data object StaleCommit : MLSFailure data object InternalErrors : MLSFailure + data object Disabled : MLSFailure data class Generic(internal val exception: Exception) : MLSFailure { val rootCause: Throwable get() = exception diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/client/MLSClientProvider.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/client/MLSClientProvider.kt index ce242232fb1..3109d07ac78 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/client/MLSClientProvider.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/client/MLSClientProvider.kt @@ -28,6 +28,7 @@ import com.wire.kalium.cryptography.coreCryptoCentral import com.wire.kalium.logger.KaliumLogLevel import com.wire.kalium.logic.CoreFailure import com.wire.kalium.logic.E2EIFailure +import com.wire.kalium.logic.MLSFailure import com.wire.kalium.logic.configuration.UserConfigRepository import com.wire.kalium.logic.data.conversation.ClientId import com.wire.kalium.logic.data.featureConfig.FeatureConfigRepository @@ -130,6 +131,10 @@ class MLSClientProviderImpl( } override suspend fun getOrFetchMLSConfig(): Either { + if (!userConfigRepository.isMLSEnabled().getOrElse(true)) { + kaliumLogger.w("$TAG: Cannot fetch MLS config, MLS is disabled.") + return MLSFailure.Disabled.left() + } return userConfigRepository.getSupportedCipherSuite().flatMapLeft { featureConfigRepository.getFeatureConfigs().map { it.mlsModel.supportedCipherSuite diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationMapper.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationMapper.kt index 4c47e7fa4c4..000d60e7e37 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationMapper.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationMapper.kt @@ -427,11 +427,11 @@ internal class ConversationMapperImpl( private fun ConversationResponse.getProtocolInfo(mlsGroupState: GroupState?): ProtocolInfo { return when (protocol) { ConvProtocol.MLS -> ProtocolInfo.MLS( - groupId ?: "", - mlsGroupState ?: GroupState.PENDING_JOIN, - epoch ?: 0UL, + groupId = groupId ?: "", + groupState = mlsGroupState ?: GroupState.PENDING_JOIN, + epoch = epoch ?: 0UL, keyingMaterialLastUpdate = DateTimeUtil.currentInstant(), - ConversationEntity.CipherSuite.fromTag(mlsCipherSuiteTag) + cipherSuite = ConversationEntity.CipherSuite.fromTag(mlsCipherSuiteTag) ) ConvProtocol.MIXED -> ProtocolInfo.Mixed( diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepository.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepository.kt index b5686b8475b..905be49fe24 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepository.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepository.kt @@ -379,12 +379,18 @@ internal class ConversationDataSource internal constructor( selfUserTeamId: String?, originatedFromEvent: Boolean ): Either = wrapStorageRequest { - val isNewConversation = conversationDAO.getConversationBaseInfoByQualifiedID(conversation.id.toDao()) == null + val existingConversation = conversationDAO.getConversationBaseInfoByQualifiedID(conversation.id.toDao()) + val isNewConversation = existingConversation?.let { conversationEntity -> + (conversationEntity.protocolInfo as? ConversationEntity.ProtocolInfo.MLSCapable)?.groupState?.let { + it != ConversationEntity.GroupState.ESTABLISHED + } ?: false + } ?: true if (isNewConversation) { + val mlsGroupState = conversation.groupId?.let { mlsGroupState(idMapper.fromGroupIDEntity(it), originatedFromEvent) } conversationDAO.insertConversation( conversationMapper.fromApiModelToDaoModel( conversation, - mlsGroupState = conversation.groupId?.let { mlsGroupState(idMapper.fromGroupIDEntity(it), originatedFromEvent) }, + mlsGroupState = mlsGroupState, selfTeamIdProvider().getOrNull(), ) ) @@ -403,14 +409,12 @@ internal class ConversationDataSource internal constructor( ) = wrapStorageRequest { val conversationEntities = conversations .map { conversationResponse -> + val mlsGroupState = conversationResponse.groupId?.let { + mlsGroupState(idMapper.fromGroupIDEntity(it), originatedFromEvent) + } conversationMapper.fromApiModelToDaoModel( conversationResponse, - mlsGroupState = conversationResponse.groupId?.let { - mlsGroupState( - idMapper.fromGroupIDEntity(it), - originatedFromEvent - ) - }, + mlsGroupState = mlsGroupState, selfTeamIdProvider().getOrNull(), ) } @@ -432,9 +436,14 @@ internal class ConversationDataSource internal constructor( } } - private suspend fun mlsGroupState(groupId: GroupID, originatedFromEvent: Boolean = false): ConversationEntity.GroupState = - hasEstablishedMLSGroup(groupId).fold({ - throw IllegalStateException(it.toString()) // TODO find a more fitting exception? + private suspend fun mlsGroupState( + groupId: GroupID, + originatedFromEvent: Boolean = false + ): ConversationEntity.GroupState = hasEstablishedMLSGroup(groupId) + .fold({ failure -> + kaliumLogger.withFeatureId(CONVERSATIONS) + .w("Error checking MLS group state, setting to ${ConversationEntity.GroupState.PENDING_JOIN}") + ConversationEntity.GroupState.PENDING_JOIN }, { exists -> if (exists) { ConversationEntity.GroupState.ESTABLISHED diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt index a36555eef9c..69fabfe459f 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/conversation/MLSConversationRepository.kt @@ -691,17 +691,23 @@ internal class MLSConversationDataSource( }) override suspend fun getClientIdentity(clientId: ClientId) = - wrapStorageRequest { conversationDAO.getE2EIConversationClientInfoByClientId(clientId.value) }.flatMap { - mlsClientProvider.getMLSClient().flatMap { mlsClient -> - wrapMLSRequest { + wrapStorageRequest { conversationDAO.getE2EIConversationClientInfoByClientId(clientId.value) } + .flatMap { conversationClientInfo -> + mlsClientProvider.getMLSClient().flatMap { mlsClient -> + wrapMLSRequest { - mlsClient.getDeviceIdentities( - it.mlsGroupId, - listOf(CryptoQualifiedClientId(it.clientId, it.userId.toModel().toCrypto())) - ).firstOrNull() + mlsClient.getDeviceIdentities( + conversationClientInfo.mlsGroupId, + listOf( + CryptoQualifiedClientId( + conversationClientInfo.clientId, + conversationClientInfo.userId.toModel().toCrypto() + ) + ) + ).firstOrNull() + } } } - } override suspend fun getUserIdentity(userId: UserId) = wrapStorageRequest { diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt index 376b4fb5b11..d42449f319a 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt @@ -1674,7 +1674,8 @@ class UserSessionScope internal constructor( cachedClientIdClearer, updateSupportedProtocolsAndResolveOneOnOnes, registerMLSClientUseCase, - syncFeatureConfigsUseCase + syncFeatureConfigsUseCase, + userConfigRepository ) val conversations: ConversationScope by lazy { ConversationScope( @@ -1896,7 +1897,7 @@ class UserSessionScope internal constructor( @OptIn(DelicateKaliumApi::class) private val isAllowedToRegisterMLSClient: IsAllowedToRegisterMLSClientUseCase - get() = IsAllowedToRegisterMLSClientUseCaseImpl(featureSupport, mlsPublicKeysRepository) + get() = IsAllowedToRegisterMLSClientUseCaseImpl(featureSupport, mlsPublicKeysRepository, userConfigRepository) private val syncFeatureConfigsUseCase: SyncFeatureConfigsUseCase get() = SyncFeatureConfigsUseCaseImpl( diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/client/ClientScope.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/client/ClientScope.kt index d8e292e2708..25d9793c5f9 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/client/ClientScope.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/client/ClientScope.kt @@ -18,6 +18,7 @@ package com.wire.kalium.logic.feature.client +import com.wire.kalium.logic.configuration.UserConfigRepository import com.wire.kalium.logic.configuration.notification.NotificationTokenRepository import com.wire.kalium.logic.data.auth.verification.SecondFactorVerificationRepository import com.wire.kalium.logic.data.client.ClientRepository @@ -71,7 +72,8 @@ class ClientScope @OptIn(DelicateKaliumApi::class) internal constructor( private val cachedClientIdClearer: CachedClientIdClearer, private val updateSupportedProtocolsAndResolveOneOnOnes: UpdateSupportedProtocolsAndResolveOneOnOnesUseCase, private val registerMLSClientUseCase: RegisterMLSClientUseCase, - private val syncFeatureConfigsUseCase: SyncFeatureConfigsUseCase + private val syncFeatureConfigsUseCase: SyncFeatureConfigsUseCase, + private val userConfigRepository: UserConfigRepository ) { @OptIn(DelicateKaliumApi::class) val register: RegisterClientUseCase @@ -102,7 +104,7 @@ class ClientScope @OptIn(DelicateKaliumApi::class) internal constructor( val deregisterNativePushToken: DeregisterTokenUseCase get() = DeregisterTokenUseCaseImpl(clientRepository, notificationTokenRepository) val mlsKeyPackageCountUseCase: MLSKeyPackageCountUseCase - get() = MLSKeyPackageCountUseCaseImpl(keyPackageRepository, clientIdProvider, keyPackageLimitsProvider) + get() = MLSKeyPackageCountUseCaseImpl(keyPackageRepository, clientIdProvider, keyPackageLimitsProvider, userConfigRepository) val restartSlowSyncProcessForRecoveryUseCase: RestartSlowSyncProcessForRecoveryUseCase get() = RestartSlowSyncProcessForRecoveryUseCaseImpl(slowSyncRepository) val refillKeyPackages: RefillKeyPackagesUseCase diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/client/IsAllowedToRegisterMLSClientUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/client/IsAllowedToRegisterMLSClientUseCase.kt index 838619af9cb..e41f8f92111 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/client/IsAllowedToRegisterMLSClientUseCase.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/client/IsAllowedToRegisterMLSClientUseCase.kt @@ -18,8 +18,10 @@ package com.wire.kalium.logic.feature.client +import com.wire.kalium.logic.configuration.UserConfigRepository import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeysRepository import com.wire.kalium.logic.featureFlags.FeatureSupport +import com.wire.kalium.logic.functional.getOrElse import com.wire.kalium.logic.functional.isRight import com.wire.kalium.util.DelicateKaliumApi @@ -39,8 +41,12 @@ interface IsAllowedToRegisterMLSClientUseCase { internal class IsAllowedToRegisterMLSClientUseCaseImpl( private val featureSupport: FeatureSupport, private val mlsPublicKeysRepository: MLSPublicKeysRepository, + private val userConfigRepository: UserConfigRepository ) : IsAllowedToRegisterMLSClientUseCase { - override suspend operator fun invoke(): Boolean = - featureSupport.isMLSSupported && mlsPublicKeysRepository.getKeys().isRight() + override suspend operator fun invoke(): Boolean { + return featureSupport.isMLSSupported + && userConfigRepository.isMLSEnabled().getOrElse(false) + && mlsPublicKeysRepository.getKeys().isRight() + } } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/keypackage/MLSKeyPackageCountUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/keypackage/MLSKeyPackageCountUseCase.kt index 46647c3ed64..ab4528d6f9a 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/keypackage/MLSKeyPackageCountUseCase.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/keypackage/MLSKeyPackageCountUseCase.kt @@ -20,11 +20,13 @@ package com.wire.kalium.logic.feature.keypackage import com.wire.kalium.logic.CoreFailure import com.wire.kalium.logic.NetworkFailure +import com.wire.kalium.logic.configuration.UserConfigRepository import com.wire.kalium.logic.data.conversation.ClientId import com.wire.kalium.logic.data.keypackage.KeyPackageLimitsProvider import com.wire.kalium.logic.data.keypackage.KeyPackageRepository import com.wire.kalium.logic.data.id.CurrentClientIdProvider import com.wire.kalium.logic.functional.fold +import com.wire.kalium.logic.functional.getOrElse /** * This use case will return the current number of key packages. @@ -37,6 +39,7 @@ internal class MLSKeyPackageCountUseCaseImpl( private val keyPackageRepository: KeyPackageRepository, private val currentClientIdProvider: CurrentClientIdProvider, private val keyPackageLimitsProvider: KeyPackageLimitsProvider, + private val userConfigRepository: UserConfigRepository ) : MLSKeyPackageCountUseCase { override suspend operator fun invoke(fromAPI: Boolean): MLSKeyPackageCountResult = when (fromAPI) { @@ -47,10 +50,15 @@ internal class MLSKeyPackageCountUseCaseImpl( private suspend fun validKeyPackagesCountFromAPI() = currentClientIdProvider().fold({ MLSKeyPackageCountResult.Failure.FetchClientIdFailure(it) }, { selfClient -> - keyPackageRepository.getAvailableKeyPackageCount(selfClient).fold( - { - MLSKeyPackageCountResult.Failure.NetworkCallFailure(it) - }, { MLSKeyPackageCountResult.Success(selfClient, it.count, keyPackageLimitsProvider.needsRefill(it.count)) }) + if (userConfigRepository.isMLSEnabled().getOrElse(false)) { + keyPackageRepository.getAvailableKeyPackageCount(selfClient) + .fold( + { MLSKeyPackageCountResult.Failure.NetworkCallFailure(it) }, + { MLSKeyPackageCountResult.Success(selfClient, it.count, keyPackageLimitsProvider.needsRefill(it.count)) } + ) + } else { + MLSKeyPackageCountResult.Failure.NotEnabled + } }) private suspend fun validKeyPackagesCountFromMLSClient() = @@ -70,6 +78,7 @@ sealed class MLSKeyPackageCountResult { sealed class Failure : MLSKeyPackageCountResult() { class NetworkCallFailure(val networkFailure: NetworkFailure) : Failure() class FetchClientIdFailure(val genericFailure: CoreFailure) : Failure() + data object NotEnabled : Failure() data class Generic(val genericFailure: CoreFailure) : Failure() } } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSMessageFailureHandler.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSMessageFailureHandler.kt index 88f29b61b65..f1f5998a491 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSMessageFailureHandler.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSMessageFailureHandler.kt @@ -43,6 +43,7 @@ internal object MLSMessageFailureHandler { is MLSFailure.StaleCommit -> MLSMessageFailureResolution.Ignore is MLSFailure.MessageEpochTooOld -> MLSMessageFailureResolution.Ignore is MLSFailure.InternalErrors -> MLSMessageFailureResolution.Ignore + is MLSFailure.Disabled -> MLSMessageFailureResolution.Ignore else -> MLSMessageFailureResolution.InformUser } } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/client/MLSClientProviderTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/client/MLSClientProviderTest.kt index b42b9bba356..8a333157316 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/client/MLSClientProviderTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/client/MLSClientProviderTest.kt @@ -17,6 +17,7 @@ */ package com.wire.kalium.logic.data.client +import com.wire.kalium.logic.CoreFailure import com.wire.kalium.logic.StorageFailure import com.wire.kalium.logic.data.featureConfig.FeatureConfigTest import com.wire.kalium.logic.data.featureConfig.MLSModel @@ -32,12 +33,16 @@ import com.wire.kalium.logic.util.arrangement.repository.FeatureConfigRepository import com.wire.kalium.logic.util.arrangement.repository.FeatureConfigRepositoryArrangementImpl import com.wire.kalium.logic.util.arrangement.repository.UserConfigRepositoryArrangement import com.wire.kalium.logic.util.arrangement.repository.UserConfigRepositoryArrangementImpl +import com.wire.kalium.logic.util.shouldFail import com.wire.kalium.logic.util.shouldSucceed import com.wire.kalium.persistence.dbPassphrase.PassphraseStorage +import io.ktor.util.reflect.instanceOf import io.mockative.Mock import io.mockative.mock import io.mockative.once import io.mockative.verify +import io.mockative.verify +import kotlinx.coroutines.runBlocking import kotlinx.coroutines.test.runTest import kotlin.test.Test import kotlin.test.assertEquals @@ -61,13 +66,18 @@ class MLSClientProviderTest { val (arrangement, mlsClientProvider) = Arrangement().arrange { withGetSupportedCipherSuitesReturning(StorageFailure.DataNotFound.left()) - withGetFeatureConfigsReturning(FeatureConfigTest.newModel(mlsModel =expected).right()) + withGetFeatureConfigsReturning(FeatureConfigTest.newModel(mlsModel = expected).right()) + withGetMLSEnabledReturning(true.right()) } mlsClientProvider.getOrFetchMLSConfig().shouldSucceed { assertEquals(expected.supportedCipherSuite, it) } + verify(arrangement.userConfigRepository) + .function(arrangement.userConfigRepository::isMLSEnabled) + .wasInvoked(exactly = once) + verify(arrangement.userConfigRepository) .suspendFunction(arrangement.userConfigRepository::getSupportedCipherSuite) .wasInvoked(exactly = once) @@ -89,12 +99,18 @@ class MLSClientProviderTest { val (arrangement, mlsClientProvider) = Arrangement().arrange { withGetSupportedCipherSuitesReturning(expected.right()) + withGetMLSEnabledReturning(true.right()) + withGetFeatureConfigsReturning(FeatureConfigTest.newModel().right()) } mlsClientProvider.getOrFetchMLSConfig().shouldSucceed { assertEquals(expected, it) } + verify(arrangement.userConfigRepository) + .function(arrangement.userConfigRepository::isMLSEnabled) + .wasInvoked(exactly = once) + verify(arrangement.userConfigRepository) .suspendFunction(arrangement.userConfigRepository::getSupportedCipherSuite) .wasInvoked(exactly = once) @@ -104,6 +120,39 @@ class MLSClientProviderTest { .wasNotInvoked() } + @Test + fun givenMLSDisabledWhenGetOrFetchMLSConfigIsCalledThenDoNotCallGetSupportedCipherSuiteOrGetFeatureConfigs() = runTest { + // given + val (arrangement, mlsClientProvider) = Arrangement().arrange { + withGetMLSEnabledReturning(false.right()) + withGetSupportedCipherSuitesReturning( + SupportedCipherSuite( + supported = listOf( + CipherSuite.MLS_128_DHKEMP256_AES128GCM_SHA256_P256, + CipherSuite.MLS_256_DHKEMP384_AES256GCM_SHA384_P384 + ), + default = CipherSuite.MLS_128_DHKEMP256_AES128GCM_SHA256_P256 + ).right() + ) + } + + // when + val result = mlsClientProvider.getOrFetchMLSConfig() + + // then + result.shouldFail { //TODO check + it.instanceOf(CoreFailure.Unknown::class) + } + + verify(arrangement.userConfigRepository) + .suspendFunction(arrangement.userConfigRepository::getSupportedCipherSuite) + .wasNotInvoked() + + verify(arrangement.featureConfigRepository) + .suspendFunction(arrangement.featureConfigRepository::getFeatureConfigs) + .wasNotInvoked() + } + private class Arrangement : UserConfigRepositoryArrangement by UserConfigRepositoryArrangementImpl(), FeatureConfigRepositoryArrangement by FeatureConfigRepositoryArrangementImpl() { diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepositoryTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepositoryTest.kt index 5b69a2db008..4f52a88da63 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepositoryTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/data/conversation/ConversationRepositoryTest.kt @@ -20,6 +20,7 @@ package com.wire.kalium.logic.data.conversation import app.cash.turbine.test import com.wire.kalium.cryptography.MLSClient +import com.wire.kalium.logic.MLSFailure import com.wire.kalium.logic.StorageFailure import com.wire.kalium.logic.data.client.MLSClientProvider import com.wire.kalium.logic.data.event.Event @@ -708,16 +709,16 @@ class ConversationRepositoryTest { val shouldFetchFromArchivedConversations = false val messagePreviewEntity = MESSAGE_PREVIEW_ENTITY.copy(conversationId = conversationIdEntity) - val conversationEntity = TestConversation.VIEW_ENTITY.copy( - id = conversationIdEntity, - type = ConversationEntity.Type.GROUP, - ) + val conversationEntity = TestConversation.VIEW_ENTITY.copy( + id = conversationIdEntity, + type = ConversationEntity.Type.GROUP, + ) - val unreadMessagesCount = 5 - val conversationUnreadEventEntity = ConversationUnreadEventEntity( - conversationIdEntity, - mapOf(UnreadEventTypeEntity.MESSAGE to unreadMessagesCount) - ) + val unreadMessagesCount = 5 + val conversationUnreadEventEntity = ConversationUnreadEventEntity( + conversationIdEntity, + mapOf(UnreadEventTypeEntity.MESSAGE to unreadMessagesCount) + ) val (_, conversationRepository) = Arrangement() .withConversations(listOf(conversationEntity)) @@ -729,8 +730,8 @@ class ConversationRepositoryTest { conversationRepository.observeConversationListDetails(shouldFetchFromArchivedConversations).test { val result = awaitItem() - assertContains(result.map { it.conversation.id }, conversationId) - val conversation = result.first { it.conversation.id == conversationId } + assertContains(result.map { it.conversation.id }, conversationId) + val conversation = result.first { it.conversation.id == conversationId } assertIs(conversation) assertEquals(conversation.unreadEventCount[UnreadEventType.MESSAGE], unreadMessagesCount) @@ -739,9 +740,9 @@ class ConversationRepositoryTest { conversation.lastMessage ) - awaitComplete() - } + awaitComplete() } + } @Test fun givenArchivedConversationHasNewMessages_whenGettingConversationDetails_ThenCorrectlyGetUnreadMessageCountAndNullLastMessage() = @@ -789,21 +790,21 @@ class ConversationRepositoryTest { val conversationEntity = TestConversation.VIEW_ENTITY.copy( type = ConversationEntity.Type.GROUP, ) - val (_, conversationRepository) = Arrangement() - .withExpectedObservableConversation(conversationEntity) - .arrange() + val (_, conversationRepository) = Arrangement() + .withExpectedObservableConversation(conversationEntity) + .arrange() - // when - conversationRepository.observeConversationDetailsById(TestConversation.ID).test { - // then - val conversationDetail = awaitItem() + // when + conversationRepository.observeConversationDetailsById(TestConversation.ID).test { + // then + val conversationDetail = awaitItem() - assertIs>(conversationDetail) - assertTrue { conversationDetail.value.lastMessage == null } + assertIs>(conversationDetail) + assertTrue { conversationDetail.value.lastMessage == null } - awaitComplete() - } + awaitComplete() } + } @Test fun givenAOneToOneConversationHasNotNewMessages_whenGettingConversationDetails_ThenReturnZeroUnreadMessageCount() = @@ -837,36 +838,36 @@ class ConversationRepositoryTest { val conversationId = QualifiedID("some_value", "some_domain") val shouldFetchFromArchivedConversations = false - val conversationEntity = TestConversation.VIEW_ENTITY.copy( - id = conversationIdEntity, type = ConversationEntity.Type.ONE_ON_ONE, - otherUserId = QualifiedIDEntity("otherUser", "domain") - ) + val conversationEntity = TestConversation.VIEW_ENTITY.copy( + id = conversationIdEntity, type = ConversationEntity.Type.ONE_ON_ONE, + otherUserId = QualifiedIDEntity("otherUser", "domain") + ) - val unreadMessagesCount = 5 - val conversationUnreadEventEntity = ConversationUnreadEventEntity( - conversationIdEntity, - mapOf(UnreadEventTypeEntity.MESSAGE to unreadMessagesCount) - ) + val unreadMessagesCount = 5 + val conversationUnreadEventEntity = ConversationUnreadEventEntity( + conversationIdEntity, + mapOf(UnreadEventTypeEntity.MESSAGE to unreadMessagesCount) + ) + + val (_, conversationRepository) = Arrangement() + .withConversations(listOf(conversationEntity)) + .withLastMessages(listOf()) + .withConversationUnreadEvents(listOf(conversationUnreadEventEntity)) + .arrange() - val (_, conversationRepository) = Arrangement() - .withConversations(listOf(conversationEntity)) - .withLastMessages(listOf()) - .withConversationUnreadEvents(listOf(conversationUnreadEventEntity)) - .arrange() - // when conversationRepository.observeConversationListDetails(shouldFetchFromArchivedConversations).test { val result = awaitItem() - assertContains(result.map { it.conversation.id }, conversationId) - val conversation = result.first { it.conversation.id == conversationId } + assertContains(result.map { it.conversation.id }, conversationId) + val conversation = result.first { it.conversation.id == conversationId } - assertIs(conversation) - assertEquals(conversation.unreadEventCount[UnreadEventType.MESSAGE], unreadMessagesCount) + assertIs(conversation) + assertEquals(conversation.unreadEventCount[UnreadEventType.MESSAGE], unreadMessagesCount) - awaitComplete() - } + awaitComplete() } + } @Test fun givenAConversationDaoFailed_whenUpdatingTheConversationReadDate_thenShouldNotSucceed() = runTest { @@ -1796,6 +1797,13 @@ class ConversationRepositoryTest { .thenReturn(updated) } + suspend fun withDisabledMlsClientProvider() = apply { + given(mlsClientProvider) + .suspendFunction(mlsClientProvider::getMLSClient) + .whenInvokedWith(any()) + .thenReturn(Either.Left(MLSFailure.Disabled)) + } + fun arrange() = this to conversationRepository } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/keypackage/MLSKeyPackageCountUseCaseTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/keypackage/MLSKeyPackageCountUseCaseTest.kt index 7d21769a8d3..74ac2a59f6b 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/keypackage/MLSKeyPackageCountUseCaseTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/keypackage/MLSKeyPackageCountUseCaseTest.kt @@ -31,6 +31,9 @@ import com.wire.kalium.logic.feature.keypackage.MLSKeyPackageCountUseCaseTest.Ar import com.wire.kalium.logic.framework.TestClient import com.wire.kalium.logic.functional.Either import com.wire.kalium.network.api.base.authenticated.keypackage.KeyPackageCountDTO +import com.wire.kalium.logic.functional.right +import com.wire.kalium.logic.util.arrangement.repository.UserConfigRepositoryArrangement +import com.wire.kalium.logic.util.arrangement.repository.UserConfigRepositoryArrangementImpl import io.mockative.Mock import io.mockative.anything import io.mockative.classOf @@ -39,20 +42,21 @@ import io.mockative.given import io.mockative.mock import io.mockative.once import io.mockative.verify -import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.runBlocking import kotlinx.coroutines.test.runTest import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertIs -@OptIn(ExperimentalCoroutinesApi::class) class MLSKeyPackageCountUseCaseTest { @Test fun givenClientIdIsNotRegistered_ThenReturnGenericError() = runTest { val (arrangement, keyPackageCountUseCase) = Arrangement() .withClientId(Either.Left(CLIENT_FETCH_ERROR)) - .arrange() + .arrange{ + withGetMLSEnabledReturning(true.right()) + } val actual = keyPackageCountUseCase() @@ -71,7 +75,9 @@ class MLSKeyPackageCountUseCaseTest { .withAvailableKeyPackageCountReturn(Either.Right(KEY_PACKAGE_COUNT_DTO)) .withClientId(Either.Right(TestClient.CLIENT_ID)) .withKeyPackageLimitSucceed() - .arrange() + .arrange{ + withGetMLSEnabledReturning(true.right()) + } val actual = keyPackageCountUseCase() @@ -88,7 +94,9 @@ class MLSKeyPackageCountUseCaseTest { val (arrangement, keyPackageCountUseCase) = Arrangement() .withAvailableKeyPackageCountReturn(Either.Left(NETWORK_FAILURE)) .withClientId(Either.Right(TestClient.CLIENT_ID)) - .arrange() + .arrange{ + withGetMLSEnabledReturning(true.right()) + } val actual = keyPackageCountUseCase() @@ -100,7 +108,30 @@ class MLSKeyPackageCountUseCaseTest { assertEquals(actual.networkFailure, NETWORK_FAILURE) } - private class Arrangement { + @Test + fun givenClientID_whenCallingGetMLSEnabledReturnFalse_ThenReturnKeyPackageCountNotEnabledFailure() = runTest { + val (arrangement, keyPackageCountUseCase) = Arrangement() + .withAvailableKeyPackageCountReturn(Either.Right(KEY_PACKAGE_COUNT_DTO)) + .withClientId(Either.Right(TestClient.CLIENT_ID)) + .arrange{ + withGetMLSEnabledReturning(false.right()) + } + + val actual = keyPackageCountUseCase() + + verify(arrangement.userConfigRepository) + .function(arrangement.userConfigRepository::isMLSEnabled) + .wasInvoked(once) + + verify(arrangement.keyPackageRepository) + .suspendFunction(arrangement.keyPackageRepository::getAvailableKeyPackageCount) + .with(eq(TestClient.CLIENT_ID)) + .wasNotInvoked() + + assertIs(actual) + } + + private class Arrangement : UserConfigRepositoryArrangement by UserConfigRepositoryArrangementImpl() { @Mock val keyPackageRepository = mock(classOf()) @@ -129,9 +160,11 @@ class MLSKeyPackageCountUseCaseTest { .then { result } } - fun arrange() = this to MLSKeyPackageCountUseCaseImpl( - keyPackageRepository, currentClientIdProvider, keyPackageLimitsProvider - ) + fun arrange(block: suspend Arrangement.() -> Unit) = apply { runBlocking { block() } }.let { + this to MLSKeyPackageCountUseCaseImpl( + keyPackageRepository, currentClientIdProvider, keyPackageLimitsProvider, userConfigRepository + ) + } companion object { val NETWORK_FAILURE = NetworkFailure.NoNetworkConnection(null) diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/mls/OneOnOneMigratorArrangement.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/mls/OneOnOneMigratorArrangement.kt index 5b1cbe45133..53818e71c14 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/mls/OneOnOneMigratorArrangement.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/mls/OneOnOneMigratorArrangement.kt @@ -63,3 +63,4 @@ class OneOnOneMigratorArrangementImpl : OneOnOneMigratorArrangement { .thenReturn(result) } } +//, diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/UserConfigRepositoryArrangement.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/UserConfigRepositoryArrangement.kt index 5fb5428a080..5f6e264b906 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/UserConfigRepositoryArrangement.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/UserConfigRepositoryArrangement.kt @@ -36,6 +36,7 @@ internal interface UserConfigRepositoryArrangement { fun withSetDefaultProtocolSuccessful() fun withGetDefaultProtocolReturning(result: Either) fun withSetMLSEnabledSuccessful() + fun withGetMLSEnabledReturning(result: Either) fun withSetMigrationConfigurationSuccessful() fun withGetMigrationConfigurationReturning(result: Either) fun withSetSupportedCipherSuite(result: Either) @@ -81,6 +82,13 @@ internal class UserConfigRepositoryArrangementImpl : UserConfigRepositoryArrange .thenReturn(Either.Right(Unit)) } + override fun withGetMLSEnabledReturning(result: Either) { + given(userConfigRepository) + .function(userConfigRepository::isMLSEnabled) + .whenInvoked() + .thenReturn(result) + } + override fun withSetMigrationConfigurationSuccessful() { given(userConfigRepository) .suspendFunction(userConfigRepository::setMigrationConfiguration) diff --git a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationEntity.kt b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationEntity.kt index c1ef9973477..6f68b683f53 100644 --- a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationEntity.kt +++ b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/conversation/ConversationEntity.kt @@ -75,7 +75,7 @@ data class ConversationEntity( companion object { fun fromTag(tag: Int?): CipherSuite = - if (tag != null) values().first { type -> type.cipherSuiteTag == tag } else UNKNOWN + if (tag != null) entries.first { type -> type.cipherSuiteTag == tag } else UNKNOWN } }