Skip to content

Commit

Permalink
feat: set the correct cipher suite when claiming key packages [WPB-85…
Browse files Browse the repository at this point in the history
…92] 🍒 (#2746)

* Commit with unresolved merge conflicts

* Commit with unresolved merge conflicts

* Commit with unresolved merge conflicts

* Commit with unresolved merge conflicts

* Commit with unresolved merge conflicts

* fix tests

* detekt

* Trigger CI

Signed-off-by: MohamadJaara <[email protected]>

* Trigger CI

Signed-off-by: MohamadJaara <[email protected]>

* detekt

* test

* test

* BaseProteusClientTest

* fix merge issues

* fix merge issues

* fix merge issues

* fix merge issues

* fix merge issues

* fix test

* detekt

* fix tests

* fix tests

---------

Signed-off-by: MohamadJaara <[email protected]>
Co-authored-by: Mohamad Jaara <[email protected]>
  • Loading branch information
github-actions[bot] and MohamadJaara authored May 22, 2024
1 parent c6a9c30 commit d9132de
Show file tree
Hide file tree
Showing 19 changed files with 361 additions and 153 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,15 @@ class MLSClientTest : BaseMLSClientTest() {
}

private suspend fun createClient(user: SampleUser): MLSClient {
return createMLSClient(user.qualifiedClientId, allowedCipherSuites = ALLOWED_CIPHER_SUITES, DEFAULT_CIPHER_SUITES)
return createMLSClient(
clientId = user.qualifiedClientId,
allowedCipherSuites = ALLOWED_CIPHER_SUITES,
defaultCipherSuite = DEFAULT_CIPHER_SUITES
)
}

@Test
fun givemMlsClient_whenCallingGetDefaultCipherSuite_ReturnExpectedValue() = runTest {
fun givenMlsClient_whenCallingGetDefaultCipherSuite_ReturnExpectedValue() = runTest {
val mlsClient = createClient(ALICE1)
assertEquals(DEFAULT_CIPHER_SUITES, mlsClient.getDefaultCipherSuite())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import com.wire.kalium.logic.data.id.toApi
import com.wire.kalium.logic.data.id.toDao
import com.wire.kalium.logic.data.id.toModel
import com.wire.kalium.logic.data.message.MessageContent.MemberChange.FailedToAdd
import com.wire.kalium.logic.data.mls.CipherSuite
import com.wire.kalium.logic.data.service.ServiceId
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.data.user.UserRepository
Expand Down Expand Up @@ -130,20 +131,13 @@ internal class ConversationGroupRepositoryImpl(
}

when (apiResult) {
is Either.Left -> {
val canRetryOnce = apiResult.value.hasUnreachableDomainsError && lastUsersAttempt is LastUsersAttempt.None
if (canRetryOnce) {
extractValidUsersForRetryableError(apiResult.value, usersList)
.flatMap { (validUsers, failedUsers, failType) ->
// edge case, in case backend goes 🍌 and returns non-matching domains
if (failedUsers.isEmpty()) Either.Left(apiResult.value)

createGroupConversation(name, validUsers, options, LastUsersAttempt.Failed(failedUsers, failType))
}
} else {
Either.Left(apiResult.value)
}
}
is Either.Left -> handleCreateConverstionFailure(
apiResult = apiResult,
usersList = usersList,
name = name,
options = options,
lastUsersAttempt = lastUsersAttempt
)

is Either.Right -> handleGroupConversationCreated(apiResult.value, selfTeamId, usersList, lastUsersAttempt)
}
Expand Down Expand Up @@ -210,6 +204,27 @@ internal class ConversationGroupRepositoryImpl(
}
}

private suspend fun handleCreateConverstionFailure(
apiResult: Either.Left<NetworkFailure>,
usersList: List<UserId>,
name: String?,
options: ConversationOptions,
lastUsersAttempt: LastUsersAttempt
): Either<CoreFailure, Conversation> {
val canRetryOnce = apiResult.value.hasUnreachableDomainsError && lastUsersAttempt is LastUsersAttempt.None
return if (canRetryOnce) {
extractValidUsersForRetryableError(apiResult.value, usersList)
.flatMap { (validUsers, failedUsers, failType) ->
// edge case, in case backend goes 🍌 and returns non-matching domains
if (failedUsers.isEmpty()) Either.Left(apiResult.value)

createGroupConversation(name, validUsers, options, LastUsersAttempt.Failed(failedUsers, failType))
}
} else {
Either.Left(apiResult.value)
}
}

override suspend fun addMembers(
userIdList: List<UserId>,
conversationId: ConversationId
Expand All @@ -224,11 +239,21 @@ internal class ConversationGroupRepositoryImpl(
tryAddMembersToCloudAndStorage(userIdList, conversationId, LastUsersAttempt.None)
.flatMap {
// best effort approach for migrated conversations, no retries
mlsConversationRepository.addMemberToMLSGroup(GroupID(protocol.groupId), userIdList)
mlsConversationRepository.addMemberToMLSGroup(
GroupID(protocol.groupId),
userIdList,
CipherSuite.fromTag(protocol.cipherSuite.cipherSuiteTag)
)
}

is ConversationEntity.ProtocolInfo.MLS -> {
tryAddMembersToMLSGroup(conversationId, protocol.groupId, userIdList, LastUsersAttempt.None)
tryAddMembersToMLSGroup(
conversationId,
protocol.groupId,
userIdList,
LastUsersAttempt.None,
cipherSuite = CipherSuite.fromTag(protocol.cipherSuite.cipherSuiteTag)
)
}
}
}
Expand All @@ -237,14 +262,22 @@ internal class ConversationGroupRepositoryImpl(
* Handle the error cases and retry for claimPackages offline and out of packages.
* Handle error case and retry for sendingCommit unreachable or missing legal hold consent.
*/
@Suppress("LongMethod")
private suspend fun tryAddMembersToMLSGroup(
conversationId: ConversationId,
groupId: String,
userIdList: List<UserId>,
lastUsersAttempt: LastUsersAttempt,
cipherSuite: CipherSuite,
remainingAttempts: Int = 2
): Either<CoreFailure, Unit> {
return when (val addingMemberResult = mlsConversationRepository.addMemberToMLSGroup(GroupID(groupId), userIdList)) {
return when (
val addingMemberResult = mlsConversationRepository.addMemberToMLSGroup(
GroupID(groupId),
userIdList,
cipherSuite
)
) {
is Either.Right -> handleMLSMembersNotAdded(conversationId, lastUsersAttempt)
is Either.Left -> {
addingMemberResult.value.handleMLSMembersFailed(
Expand All @@ -253,17 +286,20 @@ internal class ConversationGroupRepositoryImpl(
userIdList = userIdList,
lastUsersAttempt = lastUsersAttempt,
remainingAttempts = remainingAttempts,
cipherSuite = cipherSuite
)
}
}
}

@Suppress("LongMethod")
private suspend fun CoreFailure.handleMLSMembersFailed(
conversationId: ConversationId,
groupId: String,
userIdList: List<UserId>,
lastUsersAttempt: LastUsersAttempt,
remainingAttempts: Int,
cipherSuite: CipherSuite
): Either<CoreFailure, Unit> {
return when {
// claiming key packages offline or out of packages
Expand All @@ -277,7 +313,8 @@ internal class ConversationGroupRepositoryImpl(
failedUsers = lastUsersAttempt.failedUsers + failedUsers,
failType = FailedToAdd.Type.Federation,
),
remainingAttempts = remainingAttempts - 1
remainingAttempts = remainingAttempts - 1,
cipherSuite = cipherSuite
)
}

Expand All @@ -292,7 +329,8 @@ internal class ConversationGroupRepositoryImpl(
failedUsers = lastUsersAttempt.failedUsers + failedUsers,
failType = FailedToAdd.Type.Federation,
),
remainingAttempts = remainingAttempts - 1
remainingAttempts = remainingAttempts - 1,
cipherSuite = cipherSuite
)
}

Expand All @@ -308,7 +346,8 @@ internal class ConversationGroupRepositoryImpl(
failedUsers = lastUsersAttempt.failedUsers + failedUsers,
failType = FailedToAdd.Type.LegalHold,
),
remainingAttempts = remainingAttempts - 1
remainingAttempts = remainingAttempts - 1,
cipherSuite = cipherSuite
)
}
}
Expand Down Expand Up @@ -479,7 +518,11 @@ internal class ConversationGroupRepositoryImpl(

is ConversationEntity.ProtocolInfo.MLSCapable -> {
joinExistingMLSConversation(conversationId).flatMap {
mlsConversationRepository.addMemberToMLSGroup(GroupID(protocol.groupId), listOf(selfUserId))
mlsConversationRepository.addMemberToMLSGroup(
GroupID(protocol.groupId),
listOf(selfUserId),
CipherSuite.fromTag(protocol.cipherSuite.cipherSuiteTag)
)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,12 @@ interface MLSConversationRepository {

suspend fun establishMLSSubConversationGroup(groupID: GroupID, parentId: ConversationId): Either<CoreFailure, Unit>
suspend fun hasEstablishedMLSGroup(groupID: GroupID): Either<CoreFailure, Boolean>
suspend fun addMemberToMLSGroup(groupID: GroupID, userIdList: List<UserId>): Either<CoreFailure, Unit>
suspend fun addMemberToMLSGroup(
groupID: GroupID,
userIdList: List<UserId>,
cipherSuite: CipherSuite
): Either<CoreFailure, Unit>

suspend fun removeMembersFromMLSGroup(groupID: GroupID, userIdList: List<UserId>): Either<CoreFailure, Unit>
suspend fun removeClientsFromMLSGroup(groupID: GroupID, clientIdList: List<QualifiedClientID>): Either<CoreFailure, Unit>
suspend fun leaveGroup(groupID: GroupID): Either<CoreFailure, Unit>
Expand Down Expand Up @@ -202,7 +207,7 @@ private fun CoreFailure.getStrategy(

// TODO: refactor this repository as it's doing too much.
// A Repository should be a dummy class that get and set some values
@Suppress("TooManyFunctions", "LongParameterList")
@Suppress("TooManyFunctions", "LongParameterList", "LargeClass")
internal class MLSConversationDataSource(
private val selfUserId: UserId,
private val keyPackageRepository: KeyPackageRepository,
Expand Down Expand Up @@ -448,23 +453,29 @@ internal class MLSConversationDataSource(
conversationDAO.getProposalTimers().map { it.map(conversationMapper::fromDaoModel) }.flatten()
)

override suspend fun addMemberToMLSGroup(groupID: GroupID, userIdList: List<UserId>): Either<CoreFailure, Unit> =
override suspend fun addMemberToMLSGroup(
groupID: GroupID,
userIdList: List<UserId>,
cipherSuite: CipherSuite
): Either<CoreFailure, Unit> =
internalAddMemberToMLSGroup(
groupID = groupID,
userIdList = userIdList,
retryOnStaleMessage = true,
allowPartialMemberList = false
allowPartialMemberList = false,
cipherSuite = cipherSuite
).map { Unit }

private suspend fun internalAddMemberToMLSGroup(
groupID: GroupID,
userIdList: List<UserId>,
retryOnStaleMessage: Boolean,
cipherSuite: CipherSuite,
allowPartialMemberList: Boolean = false,
): Either<CoreFailure, MLSAdditionResult> = withContext(serialDispatcher) {
commitPendingProposals(groupID).flatMap {
produceAndSendCommitWithRetryAndResult(groupID, retryOnStaleMessage = retryOnStaleMessage) {
keyPackageRepository.claimKeyPackages(userIdList).flatMap { result ->
keyPackageRepository.claimKeyPackages(userIdList, cipherSuite).flatMap { result ->
if (result.usersWithoutKeyPackagesAvailable.isNotEmpty() && !allowPartialMemberList) {
Either.Left(CoreFailure.MissingKeyPackages(result.usersWithoutKeyPackagesAvailable))
} else {
Expand Down Expand Up @@ -606,7 +617,8 @@ internal class MLSConversationDataSource(
groupID = groupID,
userIdList = members,
retryOnStaleMessage = false,
allowPartialMemberList = allowPartialMemberList
allowPartialMemberList = allowPartialMemberList,
cipherSuite = CipherSuite.fromTag(mlsClient.getDefaultCipherSuite())
).onFailure {
wrapMLSRequest {
mlsClient.wipeConversation(groupID.toCrypto())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import com.wire.kalium.logic.data.conversation.ClientId
import com.wire.kalium.logic.data.conversation.mls.KeyPackageClaimResult
import com.wire.kalium.logic.data.id.CurrentClientIdProvider
import com.wire.kalium.logic.data.id.toApi
import com.wire.kalium.logic.data.mls.CipherSuite
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.functional.flatMap
Expand All @@ -50,7 +51,10 @@ interface KeyPackageRepository {
* available. If the operation fails, it will be [Either.Left] with a [CoreFailure] object indicating the reason for the failure.
* If **no** KeyPackages are available, [CoreFailure.MissingKeyPackages] will be the cause.
*/
suspend fun claimKeyPackages(userIds: List<UserId>): Either<CoreFailure, KeyPackageClaimResult>
suspend fun claimKeyPackages(
userIds: List<UserId>,
cipherSuite: CipherSuite
): Either<CoreFailure, KeyPackageClaimResult>

suspend fun uploadNewKeyPackages(clientId: ClientId, amount: Int = 100): Either<CoreFailure, Unit>

Expand All @@ -61,7 +65,6 @@ interface KeyPackageRepository {
suspend fun getAvailableKeyPackageCount(clientId: ClientId): Either<NetworkFailure, KeyPackageCountDTO>

suspend fun validKeyPackageCount(clientId: ClientId): Either<CoreFailure, Int>

}

class KeyPackageDataSource(
Expand All @@ -71,13 +74,22 @@ class KeyPackageDataSource(
private val selfUserId: UserId,
) : KeyPackageRepository {

override suspend fun claimKeyPackages(userIds: List<UserId>): Either<CoreFailure, KeyPackageClaimResult> =
override suspend fun claimKeyPackages(
userIds: List<UserId>,
cipherSuite: CipherSuite
): Either<CoreFailure, KeyPackageClaimResult> =
currentClientIdProvider().flatMap { selfClientId ->
val failedUsers = mutableSetOf<UserId>()
val claimedKeyPackages = mutableListOf<KeyPackageDTO>()
userIds.forEach { userId ->
wrapApiRequest {
keyPackageApi.claimKeyPackages(KeyPackageApi.Param.SkipOwnClient(userId.toApi(), selfClientId.value))
keyPackageApi.claimKeyPackages(
KeyPackageApi.Param.SkipOwnClient(
userId.toApi(),
selfClientId.value,
cipherSuite = cipherSuite.tag
)
)
}.fold({ failedUsers.add(userId) }) {
if (it.keyPackages.isEmpty() && userId != selfUserId) {
failedUsers.add(userId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,11 @@ internal class MLSMigratorImpl(
mlsConversationRepository.establishMLSGroup(protocolInfo.groupId, emptyList())
.flatMap {
conversationRepository.getConversationMembers(conversationId).flatMap { members ->
mlsConversationRepository.addMemberToMLSGroup(protocolInfo.groupId, members)
mlsConversationRepository.addMemberToMLSGroup(
protocolInfo.groupId,
members,
protocolInfo.cipherSuite
)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ class E2EIClientProviderTest {
return this to e2eiClientProvider
}

suspend fun withGetOrFetchMLSConfig(result: SupportedCipherSuite) {
override suspend fun withGetOrFetchMLSConfig(result: SupportedCipherSuite) {
coEvery { mlsClientProvider.getOrFetchMLSConfig() }.returns(result.right())
}
}
Expand Down
Loading

0 comments on commit d9132de

Please sign in to comment.