Skip to content

Commit

Permalink
fix: mls 1on1 race condition [WPB-15395] (#3237) (#3239)
Browse files Browse the repository at this point in the history
* fix: race condition during 1on1 mls creation, more logs

* scope init fix

* test fix

* added ignore on welcome message when conv already exist, reverted wipe on member join

* added warning

Co-authored-by: Jakub Żerko <[email protected]>
  • Loading branch information
github-actions[bot] and Garzas authored Jan 24, 2025
1 parent 7456f22 commit 7f85d16
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ internal class JoinExistingMLSConversationUseCaseImpl(
if (!featureSupport.isMLSSupported ||
!clientRepository.hasRegisteredMLSClient().getOrElse(false)
) {
kaliumLogger.d("Skip re-join existing MLS conversation, since MLS is not supported.")
kaliumLogger.d("$TAG: Skip re-join existing MLS conversation, since MLS is not supported.")
Either.Right(Unit)
} else {
conversationRepository.getConversationById(conversationId).fold({
Expand Down Expand Up @@ -115,7 +115,7 @@ internal class JoinExistingMLSConversationUseCaseImpl(
}
}
} else if (failure.kaliumException.isMlsMissingGroupInfo()) {
kaliumLogger.w("conversation has no group info, ignoring...")
kaliumLogger.w("$TAG: conversation has no group info, ignoring...")
Either.Right(Unit)
} else {
Either.Left(failure)
Expand All @@ -135,6 +135,7 @@ internal class JoinExistingMLSConversationUseCaseImpl(
protocol.epoch != 0UL -> {
// TODO(refactor): don't use conversationAPI directly
// we could use mlsConversationRepository to solve this
kaliumLogger.d("$TAG: Joining group by external commit ${conversation.id.toLogString()}")
wrapApiRequest {
conversationApi.fetchGroupInfo(conversation.id.toApi())
}.flatMap { groupInfo ->
Expand Down Expand Up @@ -185,6 +186,7 @@ internal class JoinExistingMLSConversationUseCaseImpl(
}

type == Conversation.Type.SELF -> {
kaliumLogger.d("$TAG: Establish Self MLS Conversation ${conversation.id.toLogString()}")
mlsConversationRepository.establishMLSGroup(
protocol.groupId,
emptyList()
Expand All @@ -203,6 +205,7 @@ internal class JoinExistingMLSConversationUseCaseImpl(
}

type == Conversation.Type.ONE_ON_ONE -> {
kaliumLogger.d("$TAG: Establish 1on1 MLS Conversation ${conversation.id.toLogString()}")
conversationRepository.getConversationMembers(conversation.id).flatMap { members ->
mlsConversationRepository.establishMLSGroup(
protocol.groupId,
Expand All @@ -226,4 +229,8 @@ internal class JoinExistingMLSConversationUseCaseImpl(
else -> Either.Right(Unit)
}
}

companion object {
private const val TAG = "[JoinExistingMLSConversationUseCase]"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ internal class MLSConversationDataSource(
private suspend fun sendCommitBundle(groupID: GroupID, bundle: CommitBundle): Either<CoreFailure, Unit> {
return mlsClientProvider.getMLSClient().flatMap { mlsClient ->
wrapApiRequest {
kaliumLogger.d("Sending commit bundle for ${groupID.toLogString()}")
mlsMessageApi.sendCommitBundle(mlsCommitBundleMapper.toDTO(bundle))
}.flatMap { response ->
processCommitBundleEvents(response.events)
Expand Down Expand Up @@ -376,6 +377,7 @@ internal class MLSConversationDataSource(
}

private suspend fun processCommitBundleEvents(events: List<EventContentDTO>) {
kaliumLogger.d("Processing commit bundle events")
events.forEach { eventContentDTO ->
val event =
MapperProvider.eventMapper(selfUserId).fromEventContentDTO(
Expand Down Expand Up @@ -454,7 +456,8 @@ internal class MLSConversationDataSource(
retryOnStaleMessage = true,
allowPartialMemberList = false,
cipherSuite = cipherSuite
).map { Unit }
)
.map { Unit }

private suspend fun internalAddMemberToMLSGroup(
groupID: GroupID,
Expand All @@ -464,7 +467,7 @@ internal class MLSConversationDataSource(
allowPartialMemberList: Boolean = false,
): Either<CoreFailure, MLSAdditionResult> = withContext(serialDispatcher) {
commitPendingProposals(groupID).flatMap {
kaliumLogger.d("adding ${userIdList.count()} users to MLS group")
kaliumLogger.d("adding ${userIdList.count()} users to MLS group ${groupID.toLogString()}")
produceAndSendCommitWithRetryAndResult(groupID, retryOnStaleMessage = retryOnStaleMessage) {
keyPackageRepository.claimKeyPackages(userIdList, cipherSuite).flatMap { result ->
if (result.usersWithoutKeyPackagesAvailable.isNotEmpty() && !allowPartialMemberList) {
Expand All @@ -485,12 +488,15 @@ internal class MLSConversationDataSource(
// We are creating a group with only our self client which technically
// doesn't need be added with a commit, but our backend API requires one,
// so we create a commit by updating our key material.
kaliumLogger.d("add members to MLS Group: updating keying material for self client")
updateKeyingMaterial(idMapper.toCryptoModel(groupID))
} else {
kaliumLogger.d("add members to MLS Group: executing for groupID ${groupID.toLogString()}")
addMember(idMapper.toCryptoModel(groupID), clientKeyPackageList)
}
}.onSuccess { commitBundle ->
commitBundle?.crlNewDistributionPoints?.let { revocationList ->
kaliumLogger.d("add members to MLS Group: checking revocation list")
checkRevocationList(revocationList)
}
}.map {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1796,6 +1796,7 @@ class UserSessionScope internal constructor(
clientIdProvider,
messages.messageSender,
teamRepository,
slowSyncRepository,
userId,
selfConversationIdProvider,
persistMessage,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import com.wire.kalium.logic.data.id.SelfTeamIdProvider
import com.wire.kalium.logic.data.message.MessageRepository
import com.wire.kalium.logic.data.message.PersistMessageUseCase
import com.wire.kalium.logic.data.properties.UserPropertyRepository
import com.wire.kalium.logic.data.sync.SlowSyncRepository
import com.wire.kalium.logic.data.team.TeamRepository
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.data.user.UserRepository
Expand Down Expand Up @@ -106,6 +107,7 @@ class ConversationScope internal constructor(
private val currentClientIdProvider: CurrentClientIdProvider,
private val messageSender: MessageSender,
private val teamRepository: TeamRepository,
private val slowSyncRepository: SlowSyncRepository,
private val selfUserId: UserId,
private val selfConversationIdProvider: SelfConversationIdProvider,
private val persistMessage: PersistMessageUseCase,
Expand Down Expand Up @@ -163,6 +165,7 @@ class ConversationScope internal constructor(
oneOnOneResolver,
conversationRepository,
deleteEphemeralMessageEndDate,
slowSyncRepository,
kaliumLogger
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,18 @@ import com.wire.kalium.logger.KaliumLogger
import com.wire.kalium.logic.data.conversation.ConversationDetails
import com.wire.kalium.logic.data.conversation.ConversationRepository
import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.logic.data.sync.SlowSyncRepository
import com.wire.kalium.logic.data.sync.SlowSyncStatus
import com.wire.kalium.logic.feature.conversation.mls.OneOnOneResolver
import com.wire.kalium.logic.feature.message.ephemeral.DeleteEphemeralMessagesAfterEndDateUseCase
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.util.KaliumDispatcher
import com.wire.kalium.util.KaliumDispatcherImpl
import kotlinx.coroutines.flow.filterIsInstance
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext

/**
* Used by the UI to notify Kalium that a conversation is open.
Expand All @@ -45,12 +51,26 @@ internal class NotifyConversationIsOpenUseCaseImpl(
private val oneOnOneResolver: OneOnOneResolver,
private val conversationRepository: ConversationRepository,
private val deleteEphemeralMessageEndDate: DeleteEphemeralMessagesAfterEndDateUseCase,
private val kaliumLogger: KaliumLogger
private val slowSyncRepository: SlowSyncRepository,
private val kaliumLogger: KaliumLogger,
private val dispatcher: KaliumDispatcher = KaliumDispatcherImpl
) : NotifyConversationIsOpenUseCase {

override suspend operator fun invoke(conversationId: ConversationId) {
override suspend operator fun invoke(conversationId: ConversationId) = withContext(dispatcher.io) {
val ephemeralCleanupJob = launch {
kaliumLogger.v("$TAG: Starting ephemeral messages deletion in background")
deleteEphemeralMessageEndDate()
}

val slowSyncStatus = slowSyncRepository.slowSyncStatus.first()

if (slowSyncStatus != SlowSyncStatus.Complete) {
kaliumLogger.v("$TAG: Slow sync is not completed yet, skipping further steps")
return@withContext
}

kaliumLogger.v(
"Notifying that conversation with ID: ${conversationId.toLogString()} is open"
"$TAG: Notifying that conversation with ID: ${conversationId.toLogString()} is open"
)
val conversation = conversationRepository.observeConversationDetailsById(conversationId)
.filterIsInstance<Either.Right<ConversationDetails>>()
Expand All @@ -59,15 +79,18 @@ internal class NotifyConversationIsOpenUseCaseImpl(

if (conversation is ConversationDetails.OneOne) {
kaliumLogger.v(
"Reevaluating protocol for 1:1 conversation with ID: ${conversationId.toLogString()}"
"$TAG: Reevaluating protocol for 1:1 conversation with ID: ${conversationId.toLogString()}"
)
oneOnOneResolver.resolveOneOnOneConversationWithUser(
user = conversation.otherUser,
invalidateCurrentKnownProtocols = true
)
}

// Delete Ephemeral Messages that has passed the end date
deleteEphemeralMessageEndDate()
ephemeralCleanupJob.join()
}

companion object {
private const val TAG = "[NotifyConversationIsOpenUseCase]"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

package com.wire.kalium.logic.sync.receiver.conversation

import com.wire.kalium.logger.obfuscateId
import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.MLSFailure
import com.wire.kalium.logic.data.client.MLSClientProvider
import com.wire.kalium.logic.data.conversation.Conversation
import com.wire.kalium.logic.data.conversation.ConversationDetails
Expand All @@ -33,6 +35,7 @@ import com.wire.kalium.logic.feature.keypackage.RefillKeyPackagesResult
import com.wire.kalium.logic.feature.keypackage.RefillKeyPackagesUseCase
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.functional.flatMap
import com.wire.kalium.logic.functional.flatMapLeft
import com.wire.kalium.logic.functional.map
import com.wire.kalium.logic.functional.onFailure
import com.wire.kalium.logic.functional.onSuccess
Expand Down Expand Up @@ -61,22 +64,34 @@ internal class MLSWelcomeEventHandlerImpl(
mlsClientProvider.getMLSClient()
}
.flatMap { client ->
kaliumLogger.d("$TAG: Processing MLS welcome message")
wrapMLSRequest {
client.processWelcomeMessage(event.message.decodeBase64Bytes())
}
}.flatMap { welcomeBundle ->
welcomeBundle.crlNewDistributionPoints?.let {
kaliumLogger.d("$TAG: checking revocation list")
checkRevocationList(it)
}
kaliumLogger.d("$TAG: Marking conversation as established ${welcomeBundle.groupId.obfuscateId()}")
markConversationAsEstablished(GroupID(welcomeBundle.groupId))
}.flatMap {
kaliumLogger.d("$TAG: Resolving conversation if one-on-one ${event.conversationId.toLogString()}")
resolveConversationIfOneOnOne(event.conversationId)
}
.flatMapLeft {
if (it is MLSFailure.ConversationAlreadyExists) {
kaliumLogger.w("$TAG: Discarding welcome since the conversation already exists")
Either.Right(Unit)
} else {
Either.Left(it)
}
}
.onSuccess {
val didSucceedRefillingKeyPackages = when (val refillResult = refillKeyPackages()) {
is RefillKeyPackagesResult.Failure -> {
val exception = (refillResult.failure as? CoreFailure.Unknown)?.rootCause
kaliumLogger.w("Failed to refill key packages; Failure: ${refillResult.failure}", exception)
kaliumLogger.w("$TAG: Failed to refill key packages; Failure: ${refillResult.failure}", exception)
false
}

Expand Down Expand Up @@ -119,4 +134,8 @@ internal class MLSWelcomeEventHandlerImpl(
}
}

companion object {
private const val TAG = "[MLSWelcomeEventHandler]"
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
*/
package com.wire.kalium.logic.feature.conversation

import com.wire.kalium.logic.data.sync.SlowSyncRepository
import com.wire.kalium.logic.data.sync.SlowSyncStatus
import com.wire.kalium.logic.feature.message.ephemeral.DeleteEphemeralMessagesAfterEndDateUseCase
import com.wire.kalium.logic.framework.TestConversationDetails
import com.wire.kalium.logic.functional.Either
Expand All @@ -30,8 +32,11 @@ import io.mockative.any
import io.mockative.coEvery
import io.mockative.coVerify
import io.mockative.eq
import io.mockative.every
import io.mockative.mock
import io.mockative.once
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.flowOf
import kotlinx.coroutines.test.runTest
import kotlin.test.Test

Expand Down Expand Up @@ -102,19 +107,29 @@ class NotifyConversationIsOpenUseCaseTest {
@Mock
private val deleteEphemeralMessageEndDate = mock(DeleteEphemeralMessagesAfterEndDateUseCase::class)

@Mock
private val slowSyncRepository = mock(SlowSyncRepository::class)

suspend fun withDeleteEphemeralMessageEndDateSuccess() {
coEvery {
deleteEphemeralMessageEndDate.invoke()
}.returns(Unit)
}

init {
every {
slowSyncRepository.slowSyncStatus
}.returns(MutableStateFlow(SlowSyncStatus.Complete))
}

suspend fun arrange(): Pair<Arrangement, NotifyConversationIsOpenUseCase> = run {
configure()
this@Arrangement to NotifyConversationIsOpenUseCaseImpl(
oneOnOneResolver = oneOnOneResolver,
conversationRepository = conversationRepository,
kaliumLogger = kaliumLogger,
deleteEphemeralMessageEndDate = deleteEphemeralMessageEndDate
deleteEphemeralMessageEndDate = deleteEphemeralMessageEndDate,
slowSyncRepository = slowSyncRepository
)
}
}
Expand Down

0 comments on commit 7f85d16

Please sign in to comment.