diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt index 7e48ab6d557..6dcaea89a30 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt @@ -1184,7 +1184,7 @@ class UserSessionScope internal constructor( internal val mlsMigrationManager: MLSMigrationManager = MLSMigrationManagerImpl( kaliumConfigs, - featureSupport, + isMLSEnabled, incrementalSyncRepository, lazy { clientRepository }, lazy { users.timestampKeyRepository }, @@ -1632,7 +1632,8 @@ class UserSessionScope internal constructor( private val oneOnOneProtocolSelector: OneOnOneProtocolSelector get() = OneOnOneProtocolSelectorImpl( - userRepository + userRepository, + userConfigRepository ) private val acmeCertificatesSyncWorker: ACMECertificatesSyncWorker by lazy { diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrationManager.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrationManager.kt index 44cf79afca8..0a546e6b4d1 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrationManager.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrationManager.kt @@ -25,7 +25,7 @@ import com.wire.kalium.logic.data.sync.IncrementalSyncRepository import com.wire.kalium.logic.data.sync.IncrementalSyncStatus import com.wire.kalium.logic.feature.TimestampKeyRepository import com.wire.kalium.logic.feature.TimestampKeys -import com.wire.kalium.logic.featureFlags.FeatureSupport +import com.wire.kalium.logic.feature.user.IsMLSEnabledUseCase import com.wire.kalium.logic.featureFlags.KaliumConfigs import com.wire.kalium.logic.functional.Either import com.wire.kalium.logic.functional.flatMap @@ -50,7 +50,7 @@ internal interface MLSMigrationManager @Suppress("LongParameterList") internal class MLSMigrationManagerImpl( private val kaliumConfigs: KaliumConfigs, - private val featureSupport: FeatureSupport, + private val isMLSEnabledUseCase: IsMLSEnabledUseCase, private val incrementalSyncRepository: IncrementalSyncRepository, private val clientRepository: Lazy, private val timestampKeyRepository: Lazy, @@ -73,7 +73,7 @@ internal class MLSMigrationManagerImpl( incrementalSyncRepository.incrementalSyncState.collect { syncState -> ensureActive() if (syncState is IncrementalSyncStatus.Live && - featureSupport.isMLSSupported && + isMLSEnabledUseCase() && clientRepository.value.hasRegisteredMLSClient().getOrElse(false) ) { updateMigration() diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrationWorker.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrationWorker.kt index 2e8a3ab82d0..3350536d6b9 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrationWorker.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrationWorker.kt @@ -42,7 +42,7 @@ internal class MLSMigrationWorkerImpl( override suspend fun runMigration() = syncMigrationConfigurations().flatMap { userConfigRepository.getMigrationConfiguration().getOrNull()?.let { configuration -> - if (configuration.hasMigrationStarted()) { + if (configuration.status.toBoolean() && configuration.hasMigrationStarted()) { kaliumLogger.i("Running proteus to MLS migration") mlsMigrator.migrateProteusConversations().flatMap { if (configuration.hasMigrationEnded()) { @@ -57,7 +57,6 @@ internal class MLSMigrationWorkerImpl( } } ?: Either.Right(Unit) } - private suspend fun syncMigrationConfigurations(): Either = featureConfigRepository.getFeatureConfigs().flatMap { configurations -> mlsConfigHandler.handle(configurations.mlsModel, duringSlowSync = false) diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/protocol/OneOnOneProtocolSelector.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/protocol/OneOnOneProtocolSelector.kt index eb878a7cc09..d0608b8abdd 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/protocol/OneOnOneProtocolSelector.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/protocol/OneOnOneProtocolSelector.kt @@ -18,18 +18,21 @@ package com.wire.kalium.logic.feature.protocol import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.configuration.UserConfigRepository import com.wire.kalium.logic.data.user.SupportedProtocol import com.wire.kalium.logic.data.user.UserId import com.wire.kalium.logic.data.user.UserRepository import com.wire.kalium.logic.functional.Either import com.wire.kalium.logic.functional.flatMap +import com.wire.kalium.logic.functional.fold internal interface OneOnOneProtocolSelector { suspend fun getProtocolForUser(userId: UserId): Either } internal class OneOnOneProtocolSelectorImpl( - private val userRepository: UserRepository + private val userRepository: UserRepository, + private val userConfigRepository: UserConfigRepository ) : OneOnOneProtocolSelector { override suspend fun getProtocolForUser(userId: UserId): Either = userRepository.userById(userId).flatMap { otherUser -> @@ -40,8 +43,11 @@ internal class OneOnOneProtocolSelectorImpl( val selfUserProtocols = selfUser.supportedProtocols.orEmpty() val otherUserProtocols = otherUser.supportedProtocols.orEmpty() - - val commonProtocols = selfUserProtocols.intersect(otherUserProtocols) + val commonProtocols = userConfigRepository.getDefaultProtocol().fold({ + selfUserProtocols.intersect(otherUserProtocols) + }, { + selfUserProtocols.intersect(listOf(it).toSet()).intersect(otherUserProtocols) + }) return when { commonProtocols.contains(SupportedProtocol.MLS) -> Either.Right(SupportedProtocol.MLS) diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrationManagerTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrationManagerTest.kt index 35bcdb0ca00..0b7395bde90 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrationManagerTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrationManagerTest.kt @@ -23,7 +23,7 @@ import com.wire.kalium.logic.data.sync.IncrementalSyncRepository import com.wire.kalium.logic.data.sync.IncrementalSyncStatus import com.wire.kalium.logic.feature.TimestampKeyRepository import com.wire.kalium.logic.feature.TimestampKeys -import com.wire.kalium.logic.featureFlags.FeatureSupport +import com.wire.kalium.logic.feature.user.IsMLSEnabledUseCase import com.wire.kalium.logic.featureFlags.KaliumConfigs import com.wire.kalium.logic.functional.Either import com.wire.kalium.logic.test_util.TestKaliumDispatcher @@ -121,7 +121,7 @@ class MLSMigrationManagerTest { val clientRepository = mock(ClientRepository::class) @Mock - val featureSupport = mock(FeatureSupport::class) + val isMLSEnabledUseCase = mock(IsMLSEnabledUseCase::class) @Mock val timestampKeyRepository = mock(TimestampKeyRepository::class) @@ -149,8 +149,9 @@ class MLSMigrationManagerTest { fun withIsMLSSupported(supported: Boolean) = apply { every { - featureSupport.isMLSSupported + isMLSEnabledUseCase() }.returns(supported) + } suspend fun withHasRegisteredMLSClient(result: Boolean) = apply { @@ -161,7 +162,7 @@ class MLSMigrationManagerTest { fun arrange() = this to MLSMigrationManagerImpl( kaliumConfigs, - featureSupport, + isMLSEnabledUseCase, incrementalSyncRepository, lazy { clientRepository }, lazy { timestampKeyRepository }, diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrationWorkerTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrationWorkerTest.kt new file mode 100644 index 00000000000..3a6c9754066 --- /dev/null +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/mlsmigration/MLSMigrationWorkerTest.kt @@ -0,0 +1,282 @@ +/* + * Wire + * Copyright (C) 2024 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.feature.mlsmigration + +import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.StorageFailure +import com.wire.kalium.logic.configuration.UserConfigRepository +import com.wire.kalium.logic.data.featureConfig.FeatureConfigRepository +import com.wire.kalium.logic.data.featureConfig.FeatureConfigTest +import com.wire.kalium.logic.data.featureConfig.MLSMigrationModel +import com.wire.kalium.logic.data.featureConfig.MLSModel +import com.wire.kalium.logic.data.featureConfig.Status +import com.wire.kalium.logic.data.mls.SupportedCipherSuite +import com.wire.kalium.logic.data.user.SupportedProtocol +import com.wire.kalium.logic.feature.featureConfig.handler.MLSConfigHandler +import com.wire.kalium.logic.feature.featureConfig.handler.MLSMigrationConfigHandler +import com.wire.kalium.logic.feature.mlsmigration.MLSMigrationWorkerTest.Arrangement.Companion.MIGRATION_CONFIG +import com.wire.kalium.logic.feature.mlsmigration.MLSMigrationWorkerTest.Arrangement.Companion.NOT_FOUND_FAILURE +import com.wire.kalium.logic.feature.mlsmigration.MLSMigrationWorkerTest.Arrangement.Companion.TEST_FAILURE +import com.wire.kalium.logic.feature.user.UpdateSupportedProtocolsAndResolveOneOnOnesUseCase +import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.functional.left +import com.wire.kalium.logic.functional.right +import com.wire.kalium.logic.util.shouldFail +import com.wire.kalium.logic.util.shouldSucceed +import io.mockative.Mock +import io.mockative.any +import io.mockative.classOf +import io.mockative.coEvery +import io.mockative.coVerify +import io.mockative.every +import io.mockative.mock +import io.mockative.once +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.test.runTest +import kotlinx.datetime.Instant +import kotlin.test.Test + +class MLSMigrationWorkerTest { + @Test + fun givenGettingMigrationConfigurationFails_whenRunningMigration_workerReturnsNoFailure() = runTest { + // given + val (arrangement, mlsMigrationWorker) = Arrangement().withGetMLSMigrationConfigurationsReturns(NOT_FOUND_FAILURE).arrange() + + // when + val result = mlsMigrationWorker.runMigration() + + // then + result.shouldSucceed() + + coVerify { arrangement.userConfigRepository.getMigrationConfiguration() }.wasInvoked(once) + coVerify { arrangement.mlsMigrator.migrateProteusConversations() }.wasNotInvoked() + } + + @Test + fun givenMigrationIsDisabled_whenRunningMigration_workerReturnsNoFailure() = runTest { + // given + val (arrangement, mlsMigrationWorker) = Arrangement() + .withGetMLSMigrationConfigurationsReturns(MIGRATION_CONFIG.copy(status = Status.DISABLED).right()).arrange() + + // when + val result = mlsMigrationWorker.runMigration() + + // then + result.shouldSucceed() + + coVerify { arrangement.userConfigRepository.getMigrationConfiguration() }.wasInvoked(once) + coVerify { arrangement.mlsMigrator.migrateProteusConversations() }.wasNotInvoked() + } + + @Test + fun givenMigrationIsEnabledButNotStarted_whenRunningMigration_workerReturnsNoFailure() = runTest { + // given + val (arrangement, mlsMigrationWorker) = Arrangement().withGetMLSMigrationConfigurationsReturns( + MIGRATION_CONFIG.copy(startTime = Instant.DISTANT_FUTURE, status = Status.ENABLED).right() + ).arrange() + + // when + val result = mlsMigrationWorker.runMigration() + + // then + result.shouldSucceed() + + coVerify { arrangement.userConfigRepository.getMigrationConfiguration() }.wasInvoked(once) + coVerify { arrangement.mlsMigrator.migrateProteusConversations() }.wasNotInvoked() + } + + @Test + fun givenMigrationIsDisabledButStarted_whenRunningMigration_workerReturnsNoFailure() = runTest { + // given + val (arrangement, mlsMigrationWorker) = Arrangement().withGetMLSMigrationConfigurationsReturns( + MIGRATION_CONFIG.copy(startTime = Instant.DISTANT_PAST, status = Status.DISABLED).right() + ).arrange() + + // when + val result = mlsMigrationWorker.runMigration() + + // then + result.shouldSucceed() + + coVerify { arrangement.userConfigRepository.getMigrationConfiguration() }.wasInvoked(once) + coVerify { arrangement.mlsMigrator.migrateProteusConversations() }.wasNotInvoked() + } + + @Test + fun givenMigrationIsEnabledAndStartedAndProteusMigrationFails_whenRunningMigration_thenWorkerShouldFail() = runTest { + // given + val (arrangement, mlsMigrationWorker) = Arrangement().withGetMLSMigrationConfigurationsReturns( + MIGRATION_CONFIG.copy(startTime = Instant.DISTANT_PAST, status = Status.ENABLED).right() + ).withMigrateProteusConversationsReturn(TEST_FAILURE).arrange() + + // when + val result = mlsMigrationWorker.runMigration() + + // then + result.shouldFail() + + coVerify { arrangement.userConfigRepository.getMigrationConfiguration() }.wasInvoked(once) + coVerify { arrangement.mlsMigrator.migrateProteusConversations() }.wasInvoked(once) + coVerify { arrangement.mlsMigrator.finaliseAllProteusConversations() }.wasNotInvoked() + coVerify { arrangement.mlsMigrator.finaliseProteusConversations() }.wasNotInvoked() + } + + @Test + fun givenProteusMigrationSucceedAndMigrationHasNotEnded_whenRunningMigration_thenWorkerShouldSucceed() = runTest { + // given + val (arrangement, mlsMigrationWorker) = Arrangement().withGetMLSMigrationConfigurationsReturns( + MIGRATION_CONFIG.copy(startTime = Instant.DISTANT_PAST, endTime = Instant.DISTANT_FUTURE, status = Status.ENABLED).right() + ).withMigrateProteusConversationsReturn(Unit.right()).withFinaliseProteusConversations(Unit.right()).arrange() + + // when + val result = mlsMigrationWorker.runMigration() + + // then + result.shouldSucceed() + + coVerify { arrangement.userConfigRepository.getMigrationConfiguration() }.wasInvoked(once) + coVerify { arrangement.mlsMigrator.migrateProteusConversations() }.wasInvoked(once) + coVerify { arrangement.mlsMigrator.finaliseAllProteusConversations() }.wasNotInvoked() + coVerify { arrangement.mlsMigrator.finaliseProteusConversations() }.wasInvoked(once) + } + + @Test + fun givenProteusMigrationSucceedAndMigrationHasNotEndedAndFinaliseProteusConversationsFails_whenRunningMigration_thenWorkerShouldFail() = + runTest { + // given + val (arrangement, mlsMigrationWorker) = Arrangement().withGetMLSMigrationConfigurationsReturns( + MIGRATION_CONFIG.copy(startTime = Instant.DISTANT_PAST, endTime = Instant.DISTANT_FUTURE, status = Status.ENABLED).right() + ).withMigrateProteusConversationsReturn(Unit.right()).withFinaliseProteusConversations(TEST_FAILURE).arrange() + + // when + val result = mlsMigrationWorker.runMigration() + + // then + result.shouldFail() + + coVerify { arrangement.userConfigRepository.getMigrationConfiguration() }.wasInvoked(once) + coVerify { arrangement.mlsMigrator.migrateProteusConversations() }.wasInvoked(once) + coVerify { arrangement.mlsMigrator.finaliseAllProteusConversations() }.wasNotInvoked() + coVerify { arrangement.mlsMigrator.finaliseProteusConversations() }.wasInvoked(once) + } + + @Test + fun givenProteusMigrationSucceedAndMigrationHasEnded_whenRunningMigration_thenWorkerShouldSucceed() = runTest { + // given + val (arrangement, mlsMigrationWorker) = Arrangement().withGetMLSMigrationConfigurationsReturns( + MIGRATION_CONFIG.copy(startTime = Instant.DISTANT_PAST, endTime = Instant.DISTANT_PAST, status = Status.ENABLED).right() + ).withMigrateProteusConversationsReturn(Unit.right()).withFinaliseAllProteusConversations(Unit.right()).arrange() + + // when + val result = mlsMigrationWorker.runMigration() + + // then + result.shouldSucceed() + + coVerify { arrangement.userConfigRepository.getMigrationConfiguration() }.wasInvoked(once) + coVerify { arrangement.mlsMigrator.migrateProteusConversations() }.wasInvoked(once) + coVerify { arrangement.mlsMigrator.finaliseAllProteusConversations() }.wasInvoked(once) + coVerify { arrangement.mlsMigrator.finaliseProteusConversations() }.wasNotInvoked() + } + + @Test + fun givenProteusMigrationSucceedAndMigrationHasEndedAndFinaliseAllProteusConversationsFails_whenRunningMigration_thenWorkerShouldFail() = + runTest { + // given + val (arrangement, mlsMigrationWorker) = Arrangement().withGetMLSMigrationConfigurationsReturns( + MIGRATION_CONFIG.copy(startTime = Instant.DISTANT_PAST, endTime = Instant.DISTANT_PAST, status = Status.ENABLED) + .right() + ).withMigrateProteusConversationsReturn(Unit.right()).withFinaliseAllProteusConversations(TEST_FAILURE).arrange() + + // when + val result = mlsMigrationWorker.runMigration() + + // then + result.shouldFail() + + coVerify { arrangement.userConfigRepository.getMigrationConfiguration() }.wasInvoked(once) + coVerify { arrangement.mlsMigrator.migrateProteusConversations() }.wasInvoked(once) + coVerify { arrangement.mlsMigrator.finaliseAllProteusConversations() }.wasInvoked(once) + coVerify { arrangement.mlsMigrator.finaliseProteusConversations() }.wasNotInvoked() + } + + private class Arrangement { + @Mock + val userConfigRepository: UserConfigRepository = mock(classOf()) + + @Mock + val featureConfigRepository: FeatureConfigRepository = mock(classOf()) + + @Mock + val updateSupportedProtocolsAndResolveOneOnOnes = mock(classOf()) + + @Mock + val mlsMigrator: MLSMigrator = mock(classOf()) + + val mlsConfigHandler = MLSConfigHandler(userConfigRepository, updateSupportedProtocolsAndResolveOneOnOnes) + + val mlsMigrationConfigHandler = MLSMigrationConfigHandler(userConfigRepository, updateSupportedProtocolsAndResolveOneOnOnes) + + suspend fun withGetMLSMigrationConfigurationsReturns(result: Either) = apply { + coEvery { userConfigRepository.getMigrationConfiguration() }.returns(result) + } + + suspend fun withMigrateProteusConversationsReturn(result: Either) = apply { + coEvery { mlsMigrator.migrateProteusConversations() }.returns(result) + } + + suspend fun withFinaliseAllProteusConversations(result: Either) = apply { + coEvery { mlsMigrator.finaliseAllProteusConversations() }.returns(result) + } + + suspend fun withFinaliseProteusConversations(result: Either) = apply { + coEvery { mlsMigrator.finaliseProteusConversations() }.returns(result) + } + + init { + runBlocking { + coEvery { featureConfigRepository.getFeatureConfigs() }.returns(FeatureConfigTest.newModel().right()) + every { userConfigRepository.setMLSEnabled(any()) }.returns(Unit.right()) + coEvery { userConfigRepository.getSupportedProtocols() }.returns(NOT_FOUND_FAILURE) + every { userConfigRepository.setDefaultProtocol(any()) }.returns(Unit.right()) + coEvery { userConfigRepository.setSupportedProtocols(any>()) }.returns(Unit.right()) + coEvery { userConfigRepository.setSupportedCipherSuite(any()) }.returns(Unit.right()) + coEvery { userConfigRepository.setMigrationConfiguration(any()) }.returns(Unit.right()) + } + } + + fun arrange() = this to MLSMigrationWorkerImpl( + userConfigRepository, featureConfigRepository, mlsConfigHandler, mlsMigrationConfigHandler, mlsMigrator + ) + + companion object { + val TEST_FAILURE = CoreFailure.Unknown(Throwable("Testing!")).left() + val NOT_FOUND_FAILURE = StorageFailure.DataNotFound.left() + val MLS_CONFIG = MLSModel( + defaultProtocol = SupportedProtocol.MLS, + supportedProtocols = setOf(SupportedProtocol.PROTEUS), + status = Status.ENABLED, + supportedCipherSuite = null + ) + + val MIGRATION_CONFIG = MLSMigrationModel( + startTime = null, endTime = null, status = Status.ENABLED + ) + } + } +} diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/protocol/OneOnOneProtocolSelectorTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/protocol/OneOnOneProtocolSelectorTest.kt index d546676932a..134e4af7086 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/protocol/OneOnOneProtocolSelectorTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/protocol/OneOnOneProtocolSelectorTest.kt @@ -22,6 +22,10 @@ import com.wire.kalium.logic.StorageFailure import com.wire.kalium.logic.data.user.SupportedProtocol import com.wire.kalium.logic.framework.TestUser import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.functional.left +import com.wire.kalium.logic.functional.right +import com.wire.kalium.logic.util.arrangement.repository.UserConfigRepositoryArrangement +import com.wire.kalium.logic.util.arrangement.repository.UserConfigRepositoryArrangementImpl import com.wire.kalium.logic.util.arrangement.repository.UserRepositoryArrangement import com.wire.kalium.logic.util.arrangement.repository.UserRepositoryArrangementImpl import com.wire.kalium.logic.util.shouldFail @@ -39,9 +43,11 @@ class OneOnOneProtocolSelectorTest { @Test fun givenSelfUserIsNull_thenShouldReturnFailure() = runTest { + val failure = StorageFailure.DataNotFound val (_, oneOnOneProtocolSelector) = arrange { withUserByIdReturning(Either.Right(TestUser.OTHER)) withSelfUserReturning(null) + withGetDefaultProtocolReturning(failure.left()) } oneOnOneProtocolSelector.getProtocolForUser(TestUser.USER_ID) @@ -56,7 +62,8 @@ class OneOnOneProtocolSelectorTest { val failure = StorageFailure.DataNotFound val (_, oneOnOneProtocolSelector) = arrange { withSelfUserReturning(TestUser.SELF) - withUserByIdReturning(Either.Left(failure)) + withUserByIdReturning(failure.left()) + withGetDefaultProtocolReturning(failure.left()) } oneOnOneProtocolSelector.getProtocolForUser(TestUser.USER_ID) @@ -71,6 +78,7 @@ class OneOnOneProtocolSelectorTest { val (arrangement, oneOnOneProtocolSelector) = arrange { withSelfUserReturning(TestUser.SELF) withUserByIdReturning(Either.Left(failure)) + withGetDefaultProtocolReturning(failure.left()) } val otherUserId = TestUser.USER_ID oneOnOneProtocolSelector.getProtocolForUser(TestUser.USER_ID) @@ -82,10 +90,12 @@ class OneOnOneProtocolSelectorTest { @Test fun givenBothUsersSupportProteusAndMLS_thenShouldPreferMLS() = runTest { + val failure = StorageFailure.DataNotFound val supportedProtocols = setOf(SupportedProtocol.MLS, SupportedProtocol.PROTEUS) val (arrangement, oneOnOneProtocolSelector) = arrange { withSelfUserReturning(TestUser.SELF.copy(supportedProtocols = supportedProtocols)) withUserByIdReturning(Either.Right(TestUser.OTHER.copy(supportedProtocols = supportedProtocols))) + withGetDefaultProtocolReturning(failure.left()) } oneOnOneProtocolSelector.getProtocolForUser(TestUser.USER_ID) @@ -97,9 +107,11 @@ class OneOnOneProtocolSelectorTest { @Test fun givenBothUsersSupportProteusAndOnlyOneSupportsMLS_thenShouldPreferProteus() = runTest { val bothProtocols = setOf(SupportedProtocol.MLS, SupportedProtocol.PROTEUS) + val failure = StorageFailure.DataNotFound val (_, oneOnOneProtocolSelector) = arrange { withSelfUserReturning(TestUser.SELF.copy(supportedProtocols = bothProtocols)) withUserByIdReturning(Either.Right(TestUser.OTHER.copy(supportedProtocols = setOf(SupportedProtocol.PROTEUS)))) + withGetDefaultProtocolReturning(failure.left()) } oneOnOneProtocolSelector.getProtocolForUser(TestUser.USER_ID) @@ -110,10 +122,12 @@ class OneOnOneProtocolSelectorTest { @Test fun givenBothUsersSupportMLS_thenShouldPreferMLS() = runTest { + val failure = StorageFailure.DataNotFound val mlsSet = setOf(SupportedProtocol.MLS) val (_, oneOnOneProtocolSelector) = arrange { withSelfUserReturning(TestUser.SELF.copy(supportedProtocols = mlsSet)) withUserByIdReturning(Either.Right(TestUser.OTHER.copy(supportedProtocols = mlsSet))) + withGetDefaultProtocolReturning(failure.left()) } oneOnOneProtocolSelector.getProtocolForUser(TestUser.USER_ID) @@ -124,9 +138,53 @@ class OneOnOneProtocolSelectorTest { @Test fun givenUsersHaveNoProtocolInCommon_thenShouldReturnNoCommonProtocol() = runTest { + val failure = StorageFailure.DataNotFound val (_, oneOnOneProtocolSelector) = arrange { withSelfUserReturning(TestUser.SELF.copy(supportedProtocols = setOf(SupportedProtocol.MLS))) withUserByIdReturning(Either.Right(TestUser.OTHER.copy(supportedProtocols = setOf(SupportedProtocol.PROTEUS)))) + withGetDefaultProtocolReturning(failure.left()) + } + + oneOnOneProtocolSelector.getProtocolForUser(TestUser.USER_ID) + .shouldFail { + assertIs(it) + } + } + + @Test + fun givenUsersHaveProtocolInCommonButDiffersWithDefaultProtocol_thenShouldReturnNoCommonProtocol() = runTest { + val (_, oneOnOneProtocolSelector) = arrange { + withSelfUserReturning(TestUser.SELF.copy(supportedProtocols = setOf(SupportedProtocol.MLS))) + withUserByIdReturning(Either.Right(TestUser.OTHER.copy(supportedProtocols = setOf(SupportedProtocol.MLS)))) + withGetDefaultProtocolReturning(SupportedProtocol.PROTEUS.right()) + } + + oneOnOneProtocolSelector.getProtocolForUser(TestUser.USER_ID) + .shouldFail { + assertIs(it) + } + } + + @Test + fun givenSelfUserSupportsDefaultProtocolButOtherUserDoesnt_thenShouldReturnNoCommonProtocol() = runTest { + val (_, oneOnOneProtocolSelector) = arrange { + withSelfUserReturning(TestUser.SELF.copy(supportedProtocols = setOf(SupportedProtocol.MLS, SupportedProtocol.PROTEUS))) + withUserByIdReturning(Either.Right(TestUser.OTHER.copy(supportedProtocols = setOf(SupportedProtocol.MLS)))) + withGetDefaultProtocolReturning(SupportedProtocol.PROTEUS.right()) + } + + oneOnOneProtocolSelector.getProtocolForUser(TestUser.USER_ID) + .shouldFail { + assertIs(it) + } + } + + @Test + fun givenSelfUserDoesntSupportsDefaultProtocolButOtherUserDoes_thenShouldReturnNoCommonProtocol() = runTest { + val (_, oneOnOneProtocolSelector) = arrange { + withSelfUserReturning(TestUser.SELF.copy(supportedProtocols = setOf(SupportedProtocol.MLS))) + withUserByIdReturning(Either.Right(TestUser.OTHER.copy(supportedProtocols = setOf(SupportedProtocol.MLS, SupportedProtocol.PROTEUS)))) + withGetDefaultProtocolReturning(SupportedProtocol.PROTEUS.right()) } oneOnOneProtocolSelector.getProtocolForUser(TestUser.USER_ID) @@ -135,16 +193,30 @@ class OneOnOneProtocolSelectorTest { } } + @Test + fun givenUsersHaveProtocolInCommonIncludingDefaultProtocol_thenShouldReturnDefaultProtocolAsCommonProtocol() = runTest { + val (_, oneOnOneProtocolSelector) = arrange { + withSelfUserReturning(TestUser.SELF.copy(supportedProtocols = setOf(SupportedProtocol.MLS, SupportedProtocol.PROTEUS))) + withUserByIdReturning(Either.Right(TestUser.OTHER.copy(supportedProtocols = setOf(SupportedProtocol.MLS)))) + withGetDefaultProtocolReturning(SupportedProtocol.MLS.right()) + } + + oneOnOneProtocolSelector.getProtocolForUser(TestUser.USER_ID) + .shouldSucceed() { + assertEquals(SupportedProtocol.MLS, it) + } + } + private class Arrangement(private val configure: suspend Arrangement.() -> Unit) : - UserRepositoryArrangement by UserRepositoryArrangementImpl() { - fun arrange(): Pair = run { - runBlocking { configure() } - this@Arrangement to OneOnOneProtocolSelectorImpl(userRepository) + UserRepositoryArrangement by UserRepositoryArrangementImpl(), + UserConfigRepositoryArrangement by UserConfigRepositoryArrangementImpl() { + suspend fun arrange(): Pair = run { + configure() + this@Arrangement to OneOnOneProtocolSelectorImpl(userRepository, userConfigRepository) } } private companion object { - fun arrange(configure: suspend Arrangement.() -> Unit) = Arrangement(configure).arrange() + fun arrange(configure: suspend Arrangement.() -> Unit) = runBlocking { Arrangement(configure).arrange() } } - } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/UserConfigRepositoryArrangement.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/UserConfigRepositoryArrangement.kt index 1a6f7dc27f1..6d09e3884b8 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/UserConfigRepositoryArrangement.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/arrangement/repository/UserConfigRepositoryArrangement.kt @@ -35,6 +35,7 @@ internal interface UserConfigRepositoryArrangement { suspend fun withGetSupportedProtocolsReturning(result: Either>) suspend fun withSetSupportedProtocolsSuccessful() fun withSetDefaultProtocolSuccessful() + fun withGetDefaultProtocolReturning(result: Either) fun withSetMLSEnabledSuccessful() suspend fun withSetMigrationConfigurationSuccessful() suspend fun withGetMigrationConfigurationReturning(result: Either) @@ -64,6 +65,10 @@ internal class UserConfigRepositoryArrangementImpl : UserConfigRepositoryArrange }.returns(Either.Right(Unit)) } + override fun withGetDefaultProtocolReturning(result: Either) { + every { userConfigRepository.getDefaultProtocol() }.returns(result) + } + override fun withSetMLSEnabledSuccessful() { every { userConfigRepository.setMLSEnabled(any())