Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: mls client initialization [WPB-15149] #3223

Merged
merged 10 commits into from
Jan 10, 2025
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ interface MLSFailure : CoreFailure {
data object StaleProposal : MLSFailure
data object StaleCommit : MLSFailure
data object InternalErrors : MLSFailure
data object Disabled : MLSFailure

data class Generic(internal val exception: Exception) : MLSFailure {
val rootCause: Throwable get() = exception
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import com.wire.kalium.cryptography.coreCryptoCentral
import com.wire.kalium.logger.KaliumLogLevel
import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.E2EIFailure
import com.wire.kalium.logic.MLSFailure
import com.wire.kalium.logic.configuration.UserConfigRepository
import com.wire.kalium.logic.data.conversation.ClientId
import com.wire.kalium.logic.data.featureConfig.FeatureConfigRepository
Expand Down Expand Up @@ -130,6 +131,10 @@ class MLSClientProviderImpl(
}

override suspend fun getOrFetchMLSConfig(): Either<CoreFailure, SupportedCipherSuite> {
if (!userConfigRepository.isMLSEnabled().getOrElse(true)) {
kaliumLogger.w("$TAG: Cannot fetch MLS config, MLS is disabled.")
return MLSFailure.Disabled.left()
}
return userConfigRepository.getSupportedCipherSuite().flatMapLeft<CoreFailure, SupportedCipherSuite> {
featureConfigRepository.getFeatureConfigs().map {
it.mlsModel.supportedCipherSuite
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -427,11 +427,11 @@ internal class ConversationMapperImpl(
private fun ConversationResponse.getProtocolInfo(mlsGroupState: GroupState?): ProtocolInfo {
return when (protocol) {
ConvProtocol.MLS -> ProtocolInfo.MLS(
groupId ?: "",
mlsGroupState ?: GroupState.PENDING_JOIN,
epoch ?: 0UL,
groupId = groupId ?: "",
groupState = mlsGroupState ?: GroupState.PENDING_JOIN,
epoch = epoch ?: 0UL,
keyingMaterialLastUpdate = DateTimeUtil.currentInstant(),
ConversationEntity.CipherSuite.fromTag(mlsCipherSuiteTag)
cipherSuite = ConversationEntity.CipherSuite.fromTag(mlsCipherSuiteTag)
)

ConvProtocol.MIXED -> ProtocolInfo.Mixed(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,12 +379,18 @@ internal class ConversationDataSource internal constructor(
selfUserTeamId: String?,
originatedFromEvent: Boolean
): Either<CoreFailure, Boolean> = wrapStorageRequest {
val isNewConversation = conversationDAO.getConversationBaseInfoByQualifiedID(conversation.id.toDao()) == null
val existingConversation = conversationDAO.getConversationBaseInfoByQualifiedID(conversation.id.toDao())
val isNewConversation = existingConversation?.let { conversationEntity ->
(conversationEntity.protocolInfo as? ConversationEntity.ProtocolInfo.MLSCapable)?.groupState?.let {
it != ConversationEntity.GroupState.ESTABLISHED
} ?: false
} ?: true
if (isNewConversation) {
val mlsGroupState = conversation.groupId?.let { mlsGroupState(idMapper.fromGroupIDEntity(it), originatedFromEvent) }
conversationDAO.insertConversation(
conversationMapper.fromApiModelToDaoModel(
conversation,
mlsGroupState = conversation.groupId?.let { mlsGroupState(idMapper.fromGroupIDEntity(it), originatedFromEvent) },
mlsGroupState = mlsGroupState,
selfTeamIdProvider().getOrNull(),
)
)
Expand All @@ -403,14 +409,12 @@ internal class ConversationDataSource internal constructor(
) = wrapStorageRequest {
val conversationEntities = conversations
.map { conversationResponse ->
val mlsGroupState = conversationResponse.groupId?.let {
mlsGroupState(idMapper.fromGroupIDEntity(it), originatedFromEvent)
}
conversationMapper.fromApiModelToDaoModel(
conversationResponse,
mlsGroupState = conversationResponse.groupId?.let {
mlsGroupState(
idMapper.fromGroupIDEntity(it),
originatedFromEvent
)
},
mlsGroupState = mlsGroupState,
selfTeamIdProvider().getOrNull(),
)
}
Expand All @@ -432,9 +436,14 @@ internal class ConversationDataSource internal constructor(
}
}

private suspend fun mlsGroupState(groupId: GroupID, originatedFromEvent: Boolean = false): ConversationEntity.GroupState =
hasEstablishedMLSGroup(groupId).fold({
throw IllegalStateException(it.toString()) // TODO find a more fitting exception?
private suspend fun mlsGroupState(
groupId: GroupID,
originatedFromEvent: Boolean = false
): ConversationEntity.GroupState = hasEstablishedMLSGroup(groupId)
.fold({ failure ->
kaliumLogger.withFeatureId(CONVERSATIONS)
.w("Error checking MLS group state, setting to ${ConversationEntity.GroupState.PENDING_JOIN}")
ConversationEntity.GroupState.PENDING_JOIN
}, { exists ->
if (exists) {
ConversationEntity.GroupState.ESTABLISHED
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -691,17 +691,23 @@ internal class MLSConversationDataSource(
})

override suspend fun getClientIdentity(clientId: ClientId) =
wrapStorageRequest { conversationDAO.getE2EIConversationClientInfoByClientId(clientId.value) }.flatMap {
mlsClientProvider.getMLSClient().flatMap { mlsClient ->
wrapMLSRequest {
wrapStorageRequest { conversationDAO.getE2EIConversationClientInfoByClientId(clientId.value) }
.flatMap { conversationClientInfo ->
mlsClientProvider.getMLSClient().flatMap { mlsClient ->
wrapMLSRequest {

mlsClient.getDeviceIdentities(
it.mlsGroupId,
listOf(CryptoQualifiedClientId(it.clientId, it.userId.toModel().toCrypto()))
).firstOrNull()
mlsClient.getDeviceIdentities(
conversationClientInfo.mlsGroupId,
listOf(
CryptoQualifiedClientId(
conversationClientInfo.clientId,
conversationClientInfo.userId.toModel().toCrypto()
)
)
).firstOrNull()
}
}
}
}

override suspend fun getUserIdentity(userId: UserId) =
wrapStorageRequest {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1674,7 +1674,8 @@ class UserSessionScope internal constructor(
cachedClientIdClearer,
updateSupportedProtocolsAndResolveOneOnOnes,
registerMLSClientUseCase,
syncFeatureConfigsUseCase
syncFeatureConfigsUseCase,
userConfigRepository
)
val conversations: ConversationScope by lazy {
ConversationScope(
Expand Down Expand Up @@ -1896,7 +1897,7 @@ class UserSessionScope internal constructor(

@OptIn(DelicateKaliumApi::class)
private val isAllowedToRegisterMLSClient: IsAllowedToRegisterMLSClientUseCase
get() = IsAllowedToRegisterMLSClientUseCaseImpl(featureSupport, mlsPublicKeysRepository)
get() = IsAllowedToRegisterMLSClientUseCaseImpl(featureSupport, mlsPublicKeysRepository, userConfigRepository)

private val syncFeatureConfigsUseCase: SyncFeatureConfigsUseCase
get() = SyncFeatureConfigsUseCaseImpl(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

package com.wire.kalium.logic.feature.client

import com.wire.kalium.logic.configuration.UserConfigRepository
import com.wire.kalium.logic.configuration.notification.NotificationTokenRepository
import com.wire.kalium.logic.data.auth.verification.SecondFactorVerificationRepository
import com.wire.kalium.logic.data.client.ClientRepository
Expand Down Expand Up @@ -71,7 +72,8 @@ class ClientScope @OptIn(DelicateKaliumApi::class) internal constructor(
private val cachedClientIdClearer: CachedClientIdClearer,
private val updateSupportedProtocolsAndResolveOneOnOnes: UpdateSupportedProtocolsAndResolveOneOnOnesUseCase,
private val registerMLSClientUseCase: RegisterMLSClientUseCase,
private val syncFeatureConfigsUseCase: SyncFeatureConfigsUseCase
private val syncFeatureConfigsUseCase: SyncFeatureConfigsUseCase,
private val userConfigRepository: UserConfigRepository
) {
@OptIn(DelicateKaliumApi::class)
val register: RegisterClientUseCase
Expand Down Expand Up @@ -102,7 +104,7 @@ class ClientScope @OptIn(DelicateKaliumApi::class) internal constructor(
val deregisterNativePushToken: DeregisterTokenUseCase
get() = DeregisterTokenUseCaseImpl(clientRepository, notificationTokenRepository)
val mlsKeyPackageCountUseCase: MLSKeyPackageCountUseCase
get() = MLSKeyPackageCountUseCaseImpl(keyPackageRepository, clientIdProvider, keyPackageLimitsProvider)
get() = MLSKeyPackageCountUseCaseImpl(keyPackageRepository, clientIdProvider, keyPackageLimitsProvider, userConfigRepository)
val restartSlowSyncProcessForRecoveryUseCase: RestartSlowSyncProcessForRecoveryUseCase
get() = RestartSlowSyncProcessForRecoveryUseCaseImpl(slowSyncRepository)
val refillKeyPackages: RefillKeyPackagesUseCase
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@

package com.wire.kalium.logic.feature.client

import com.wire.kalium.logic.configuration.UserConfigRepository
import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeysRepository
import com.wire.kalium.logic.featureFlags.FeatureSupport
import com.wire.kalium.logic.functional.getOrElse
import com.wire.kalium.logic.functional.isRight
import com.wire.kalium.util.DelicateKaliumApi

Expand All @@ -39,8 +41,12 @@ interface IsAllowedToRegisterMLSClientUseCase {
internal class IsAllowedToRegisterMLSClientUseCaseImpl(
private val featureSupport: FeatureSupport,
private val mlsPublicKeysRepository: MLSPublicKeysRepository,
private val userConfigRepository: UserConfigRepository
) : IsAllowedToRegisterMLSClientUseCase {

override suspend operator fun invoke(): Boolean =
featureSupport.isMLSSupported && mlsPublicKeysRepository.getKeys().isRight()
override suspend operator fun invoke(): Boolean {
return featureSupport.isMLSSupported
&& userConfigRepository.isMLSEnabled().getOrElse(false)
&& mlsPublicKeysRepository.getKeys().isRight()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ package com.wire.kalium.logic.feature.keypackage

import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.NetworkFailure
import com.wire.kalium.logic.configuration.UserConfigRepository
import com.wire.kalium.logic.data.conversation.ClientId
import com.wire.kalium.logic.data.keypackage.KeyPackageLimitsProvider
import com.wire.kalium.logic.data.keypackage.KeyPackageRepository
import com.wire.kalium.logic.data.id.CurrentClientIdProvider
import com.wire.kalium.logic.functional.fold
import com.wire.kalium.logic.functional.getOrElse

/**
* This use case will return the current number of key packages.
Expand All @@ -37,6 +39,7 @@ internal class MLSKeyPackageCountUseCaseImpl(
private val keyPackageRepository: KeyPackageRepository,
private val currentClientIdProvider: CurrentClientIdProvider,
private val keyPackageLimitsProvider: KeyPackageLimitsProvider,
private val userConfigRepository: UserConfigRepository
) : MLSKeyPackageCountUseCase {
override suspend operator fun invoke(fromAPI: Boolean): MLSKeyPackageCountResult =
when (fromAPI) {
Expand All @@ -47,10 +50,15 @@ internal class MLSKeyPackageCountUseCaseImpl(
private suspend fun validKeyPackagesCountFromAPI() = currentClientIdProvider().fold({
MLSKeyPackageCountResult.Failure.FetchClientIdFailure(it)
}, { selfClient ->
keyPackageRepository.getAvailableKeyPackageCount(selfClient).fold(
{
MLSKeyPackageCountResult.Failure.NetworkCallFailure(it)
}, { MLSKeyPackageCountResult.Success(selfClient, it.count, keyPackageLimitsProvider.needsRefill(it.count)) })
if (userConfigRepository.isMLSEnabled().getOrElse(false)) {
keyPackageRepository.getAvailableKeyPackageCount(selfClient)
.fold(
{ MLSKeyPackageCountResult.Failure.NetworkCallFailure(it) },
{ MLSKeyPackageCountResult.Success(selfClient, it.count, keyPackageLimitsProvider.needsRefill(it.count)) }
)
} else {
MLSKeyPackageCountResult.Failure.NotEnabled
}
})

private suspend fun validKeyPackagesCountFromMLSClient() =
Expand All @@ -70,6 +78,7 @@ sealed class MLSKeyPackageCountResult {
sealed class Failure : MLSKeyPackageCountResult() {
class NetworkCallFailure(val networkFailure: NetworkFailure) : Failure()
class FetchClientIdFailure(val genericFailure: CoreFailure) : Failure()
data object NotEnabled : Failure()
data class Generic(val genericFailure: CoreFailure) : Failure()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ internal object MLSMessageFailureHandler {
is MLSFailure.StaleCommit -> MLSMessageFailureResolution.Ignore
is MLSFailure.MessageEpochTooOld -> MLSMessageFailureResolution.Ignore
is MLSFailure.InternalErrors -> MLSMessageFailureResolution.Ignore
is MLSFailure.Disabled -> MLSMessageFailureResolution.Ignore
else -> MLSMessageFailureResolution.InformUser
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/
package com.wire.kalium.logic.data.client

import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.StorageFailure
import com.wire.kalium.logic.data.featureConfig.FeatureConfigTest
import com.wire.kalium.logic.data.featureConfig.MLSModel
Expand All @@ -32,12 +33,16 @@ import com.wire.kalium.logic.util.arrangement.repository.FeatureConfigRepository
import com.wire.kalium.logic.util.arrangement.repository.FeatureConfigRepositoryArrangementImpl
import com.wire.kalium.logic.util.arrangement.repository.UserConfigRepositoryArrangement
import com.wire.kalium.logic.util.arrangement.repository.UserConfigRepositoryArrangementImpl
import com.wire.kalium.logic.util.shouldFail
import com.wire.kalium.logic.util.shouldSucceed
import com.wire.kalium.persistence.dbPassphrase.PassphraseStorage
import io.ktor.util.reflect.instanceOf
import io.mockative.Mock
import io.mockative.mock
import io.mockative.once
import io.mockative.verify
import io.mockative.verify
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.test.runTest
import kotlin.test.Test
import kotlin.test.assertEquals
Expand All @@ -61,13 +66,18 @@ class MLSClientProviderTest {

val (arrangement, mlsClientProvider) = Arrangement().arrange {
withGetSupportedCipherSuitesReturning(StorageFailure.DataNotFound.left())
withGetFeatureConfigsReturning(FeatureConfigTest.newModel(mlsModel =expected).right())
withGetFeatureConfigsReturning(FeatureConfigTest.newModel(mlsModel = expected).right())
withGetMLSEnabledReturning(true.right())
}

mlsClientProvider.getOrFetchMLSConfig().shouldSucceed {
assertEquals(expected.supportedCipherSuite, it)
}

verify(arrangement.userConfigRepository)
.function(arrangement.userConfigRepository::isMLSEnabled)
.wasInvoked(exactly = once)

verify(arrangement.userConfigRepository)
.suspendFunction(arrangement.userConfigRepository::getSupportedCipherSuite)
.wasInvoked(exactly = once)
Expand All @@ -89,12 +99,18 @@ class MLSClientProviderTest {

val (arrangement, mlsClientProvider) = Arrangement().arrange {
withGetSupportedCipherSuitesReturning(expected.right())
withGetMLSEnabledReturning(true.right())
withGetFeatureConfigsReturning(FeatureConfigTest.newModel().right())
}

mlsClientProvider.getOrFetchMLSConfig().shouldSucceed {
assertEquals(expected, it)
}

verify(arrangement.userConfigRepository)
.function(arrangement.userConfigRepository::isMLSEnabled)
.wasInvoked(exactly = once)

verify(arrangement.userConfigRepository)
.suspendFunction(arrangement.userConfigRepository::getSupportedCipherSuite)
.wasInvoked(exactly = once)
Expand All @@ -104,6 +120,39 @@ class MLSClientProviderTest {
.wasNotInvoked()
}

@Test
fun givenMLSDisabledWhenGetOrFetchMLSConfigIsCalledThenDoNotCallGetSupportedCipherSuiteOrGetFeatureConfigs() = runTest {
// given
val (arrangement, mlsClientProvider) = Arrangement().arrange {
withGetMLSEnabledReturning(false.right())
withGetSupportedCipherSuitesReturning(
SupportedCipherSuite(
supported = listOf(
CipherSuite.MLS_128_DHKEMP256_AES128GCM_SHA256_P256,
CipherSuite.MLS_256_DHKEMP384_AES256GCM_SHA384_P384
),
default = CipherSuite.MLS_128_DHKEMP256_AES128GCM_SHA256_P256
).right()
)
}

// when
val result = mlsClientProvider.getOrFetchMLSConfig()

// then
result.shouldFail { //TODO check
it.instanceOf(CoreFailure.Unknown::class)
}

verify(arrangement.userConfigRepository)
.suspendFunction(arrangement.userConfigRepository::getSupportedCipherSuite)
.wasNotInvoked()

verify(arrangement.featureConfigRepository)
.suspendFunction(arrangement.featureConfigRepository::getFeatureConfigs)
.wasNotInvoked()
}

private class Arrangement : UserConfigRepositoryArrangement by UserConfigRepositoryArrangementImpl(),
FeatureConfigRepositoryArrangement by FeatureConfigRepositoryArrangementImpl() {

Expand Down
Loading
Loading