Skip to content

Commit

Permalink
feat(e2ei): respect E2EI during login and MLS client creation (WPB-58…
Browse files Browse the repository at this point in the history
…51) (#2633)

Co-authored-by: Mojtaba Chenani <[email protected]>
  • Loading branch information
AndroidBob and mchenani authored Feb 6, 2024
1 parent 1889535 commit 7c6148c
Show file tree
Hide file tree
Showing 34 changed files with 608 additions and 67 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Wire
* Copyright (C) 2024 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.wire.android.di

import com.wire.kalium.logic.CoreLogic
import com.wire.kalium.logic.data.user.UserId
import dagger.assisted.Assisted
import dagger.assisted.AssistedFactory
import dagger.assisted.AssistedInject

class ObserveIfE2EIRequiredDuringLoginUseCaseProvider @AssistedInject constructor(
@KaliumCoreLogic private val coreLogic: CoreLogic,
@Assisted
private val userId: UserId
) {
suspend fun observeIfE2EIIsRequiredDuringLogin() = coreLogic.getSessionScope(userId).observeIfE2EIRequiredDuringLogin()

@AssistedFactory
interface Factory {
fun create(userId: UserId): ObserveIfE2EIRequiredDuringLoginUseCaseProvider
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.feature.asset.DeleteAssetUseCase
import com.wire.kalium.logic.feature.asset.GetAssetSizeLimitUseCase
import com.wire.kalium.logic.feature.asset.GetAvatarAssetUseCase
import com.wire.kalium.logic.feature.client.FinalizeMLSClientAfterE2EIEnrollment
import com.wire.kalium.logic.feature.conversation.GetAllContactsNotInConversationUseCase
import com.wire.kalium.logic.feature.e2ei.usecase.EnrollE2EIUseCase
import com.wire.kalium.logic.feature.e2ei.usecase.GetE2eiCertificateUseCase
Expand Down Expand Up @@ -117,6 +118,11 @@ class UserModule {
fun provideEnrollE2EIUseCase(userScope: UserScope): EnrollE2EIUseCase =
userScope.enrollE2EI

@ViewModelScoped
@Provides
fun provideFinalizeMLSClientAfterE2EIEnrollmentUseCase(userScope: UserScope): FinalizeMLSClientAfterE2EIEnrollment =
userScope.finalizeMLSClientAfterE2EIEnrollment

@ViewModelScoped
@Provides
fun provideObserveTypingIndicatorEnabled(userScope: UserScope): ObserveTypingIndicatorEnabledUseCase =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,14 @@ class GetE2EICertificateUseCase @Inject constructor(
private lateinit var initialEnrollmentResult: E2EIEnrollmentResult.Initialized
lateinit var enrollmentResultHandler: (Either<E2EIFailure, E2EIEnrollmentResult>) -> Unit

operator fun invoke(context: Context, enrollmentResultHandler: (Either<CoreFailure, E2EIEnrollmentResult>) -> Unit) {
operator fun invoke(
context: Context,
isNewClient: Boolean,
enrollmentResultHandler: (Either<CoreFailure, E2EIEnrollmentResult>) -> Unit
) {
this.enrollmentResultHandler = enrollmentResultHandler
scope.launch {
enrollE2EI.initialEnrollment().fold({
enrollE2EI.initialEnrollment(isNewClientRegistration = isNewClient).fold({
enrollmentResultHandler(Either.Left(it))
}, {
if (it is E2EIEnrollmentResult.Initialized) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import androidx.activity.result.ActivityResultRegistry
import androidx.activity.result.contract.ActivityResultContracts
import com.wire.android.appLogger
import com.wire.android.util.deeplink.DeepLinkProcessor
import com.wire.android.util.removeQueryParams
import kotlinx.serialization.json.JsonObject
import net.openid.appauth.AppAuthConfiguration
import net.openid.appauth.AuthState
Expand All @@ -44,6 +45,7 @@ import net.openid.appauth.browser.VersionedBrowserMatcher
import net.openid.appauth.connectivity.ConnectionBuilder
import org.json.JSONObject
import java.net.HttpURLConnection
import java.net.URI
import java.net.URL
import java.security.MessageDigest
import java.security.SecureRandom
Expand Down Expand Up @@ -119,7 +121,7 @@ class OAuthUseCase(context: Context, private val authUrl: String, private val cl
handleActivityResult(result, resultHandler)
}
AuthorizationServiceConfiguration.fetchFromUrl(
Uri.parse(authUrl.plus(IDP_CONFIGURATION_PATH)),
Uri.parse(URI(authUrl).removeQueryParams().toString().plus(IDP_CONFIGURATION_PATH)),
{ configuration, ex ->
if (ex == null) {
authServiceConfig = configuration!!
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class MigrateClientsDataUseCase @Inject constructor(
private val scalaUserDBProvider: ScalaUserDatabaseProvider,
private val userDataStoreProvider: UserDataStoreProvider
) {
@Suppress("ReturnCount")
@Suppress("ReturnCount", "ComplexMethod")
suspend operator fun invoke(userId: UserId, isFederated: Boolean): Either<CoreFailure, Unit> =
scalaUserDBProvider.clientDAO(userId.value).flatMap { clientDAO ->
val clientId = clientDAO.clientInfo()?.clientId?.let { ClientId(it) }
Expand Down Expand Up @@ -103,6 +103,19 @@ class MigrateClientsDataUseCase @Inject constructor(
userDataStoreProvider.getOrCreate(userId).setInitialSyncCompleted()
}
}

is RegisterClientResult.E2EICertificateRequired ->
withTimeoutOrNull(SYNC_START_TIMEOUT) {
syncManager.waitUntilStartedOrFailure()
}.let {
it ?: Either.Left(NetworkFailure.NoNetworkConnection(null))
}.flatMap {
syncManager.waitUntilLiveOrFailure()
.onSuccess {
userDataStoreProvider.getOrCreate(userId).setInitialSyncCompleted()
TODO() // TODO: ask question about this!
}
}
}
}
}
Expand Down
7 changes: 4 additions & 3 deletions app/src/main/kotlin/com/wire/android/ui/WireActivity.kt
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ import com.wire.android.ui.common.topappbar.CommonTopAppBar
import com.wire.android.ui.common.topappbar.CommonTopAppBarViewModel
import com.wire.android.ui.common.visbility.rememberVisibilityState
import com.wire.android.ui.destinations.ConversationScreenDestination
import com.wire.android.ui.destinations.E2EIEnrollmentScreenDestination
import com.wire.android.ui.destinations.E2eiCertificateDetailsScreenDestination
import com.wire.android.ui.destinations.HomeScreenDestination
import com.wire.android.ui.destinations.ImportMediaScreenDestination
Expand Down Expand Up @@ -166,9 +167,9 @@ class WireActivity : AppCompatActivity() {
val startDestination = when (viewModel.initialAppState) {
InitialAppState.NOT_MIGRATED -> MigrationScreenDestination
InitialAppState.NOT_LOGGED_IN -> WelcomeScreenDestination
InitialAppState.LOGGED_IN -> HomeScreenDestination
}

InitialAppState.ENROLL_E2EI -> E2EIEnrollmentScreenDestination
InitialAppState.LOGGED_IN -> HomeScreenDestination
}
appLogger.i("$TAG composable content")
setComposableContent(startDestination) {
appLogger.i("$TAG splash hide")
Expand Down
27 changes: 26 additions & 1 deletion app/src/main/kotlin/com/wire/android/ui/WireActivityViewModel.kt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import com.wire.android.appLogger
import com.wire.android.datastore.GlobalDataStore
import com.wire.android.di.AuthServerConfigProvider
import com.wire.android.di.KaliumCoreLogic
import com.wire.android.di.ObserveIfE2EIRequiredDuringLoginUseCaseProvider
import com.wire.android.di.ObserveScreenshotCensoringConfigUseCaseProvider
import com.wire.android.di.ObserveSyncStateUseCaseProvider
import com.wire.android.feature.AccountSwitchUseCase
Expand Down Expand Up @@ -111,6 +112,7 @@ class WireActivityViewModel @Inject constructor(
private val currentScreenManager: CurrentScreenManager,
private val observeScreenshotCensoringConfigUseCaseProviderFactory: ObserveScreenshotCensoringConfigUseCaseProvider.Factory,
private val globalDataStore: GlobalDataStore,
private val observeIfE2EIRequiredDuringLoginUseCaseProviderFactory: ObserveIfE2EIRequiredDuringLoginUseCaseProvider.Factory
) : ViewModel() {

var globalAppState: GlobalAppState by mutableStateOf(GlobalAppState())
Expand Down Expand Up @@ -143,12 +145,16 @@ class WireActivityViewModel @Inject constructor(
private val _observeSyncFlowState: MutableStateFlow<SyncState?> = MutableStateFlow(null)
val observeSyncFlowState: StateFlow<SyncState?> = _observeSyncFlowState

private val _observeE2EIState: MutableStateFlow<Boolean?> = MutableStateFlow(null)
private val observeE2EIState: StateFlow<Boolean?> = _observeE2EIState

init {
observeSyncState()
observeUpdateAppState()
observeNewClientState()
observeScreenshotCensoringConfigState()
observeAppThemeState()
observerE2EIState()
}

private fun observeAppThemeState() {
Expand All @@ -161,6 +167,18 @@ class WireActivityViewModel @Inject constructor(
}
}

fun observerE2EIState() {
viewModelScope.launch(dispatchers.io()) {
observeUserId
.flatMapLatest {
it?.let { observeIfE2EIRequiredDuringLoginUseCaseProviderFactory.create(it).observeIfE2EIIsRequiredDuringLogin() }
?: flowOf(null)
}
.distinctUntilChanged()
.collect { _observeE2EIState.emit(it) }
}
}

private fun observeSyncState() {
viewModelScope.launch(dispatchers.io()) {
observeUserId
Expand Down Expand Up @@ -234,6 +252,7 @@ class WireActivityViewModel @Inject constructor(
get() = when {
shouldMigrate() -> InitialAppState.NOT_MIGRATED
shouldLogIn() -> InitialAppState.NOT_LOGGED_IN
blockedByE2EI() -> InitialAppState.ENROLL_E2EI
else -> InitialAppState.LOGGED_IN
}

Expand Down Expand Up @@ -264,8 +283,10 @@ class WireActivityViewModel @Inject constructor(
// to handle the deepLinks above user needs to be Logged in
// do nothing, already handled by initialAppState
}

result is DeepLinkResult.JoinConversation ->
onConversationInviteDeepLink(result.code, result.key, result.domain, onOpenConversation)

result != null -> onResult(result)
result is DeepLinkResult.Unknown -> appLogger.e("unknown deeplink result $result")
}
Expand Down Expand Up @@ -413,6 +434,10 @@ class WireActivityViewModel @Inject constructor(

fun shouldLogIn(): Boolean = !hasValidCurrentSession()

fun blockedByE2EI(): Boolean {
return observeE2EIState.value == true
}

private fun hasValidCurrentSession(): Boolean = runBlocking {
// TODO: the usage of currentSessionFlow is a temporary solution, it should be replaced with a proper solution
currentSessionFlow().first().let {
Expand Down Expand Up @@ -532,5 +557,5 @@ data class GlobalAppState(
)

enum class InitialAppState {
NOT_MIGRATED, NOT_LOGGED_IN, LOGGED_IN
NOT_MIGRATED, NOT_LOGGED_IN, LOGGED_IN, ENROLL_E2EI
}
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,11 @@ class CreateAccountCodeViewModel @Inject constructor(
is RegisterClientResult.Success -> {
onSuccess()
}

is RegisterClientResult.E2EICertificateRequired -> {
// TODO
onSuccess()
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,20 @@ data class Device(
mlsPublicKeys = client.mlsPublicKeys,
e2eiCertificateStatus = e2eiCertificateStatus
)

fun updateFromClient(client: Client): Device = copy(
name = client.displayName(),
clientId = client.id,
registrationTime = client.registrationTime?.toIsoDateTimeString(),
lastActiveInWholeWeeks = client.lastActiveInWholeWeeks(),
isValid = client.isValid,
isVerifiedProteus = client.isVerified,
mlsPublicKeys = client.mlsPublicKeys,
)

fun updateE2EICertificateStatus(e2eiCertificateStatus: CertificateStatus): Device = copy(
e2eiCertificateStatus = e2eiCertificateStatus
)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ import com.wire.android.ui.common.textfield.clearAutofillTree
import com.wire.android.ui.common.topappbar.NavigationIconType
import com.wire.android.ui.common.topappbar.WireCenterAlignedTopAppBar
import com.wire.android.ui.common.visbility.rememberVisibilityState
import com.wire.android.ui.destinations.E2EIEnrollmentScreenDestination
import com.wire.android.ui.destinations.HomeScreenDestination
import com.wire.android.ui.destinations.InitialSyncScreenDestination
import com.wire.android.ui.destinations.RemoveDeviceScreenDestination
Expand All @@ -81,11 +82,14 @@ fun RegisterDeviceScreen(navigator: Navigator) {
is RegisterDeviceFlowState.Success -> {
navigator.navigate(
NavigationCommand(
destination = if (flowState.initialSyncCompleted) HomeScreenDestination else InitialSyncScreenDestination,
destination = if (flowState.isE2EIRequired) E2EIEnrollmentScreenDestination
else if (flowState.initialSyncCompleted) HomeScreenDestination
else InitialSyncScreenDestination,
backStackMode = BackStackMode.CLEAR_WHOLE
)
)
}

is RegisterDeviceFlowState.TooManyDevices -> navigator.navigate(NavigationCommand(RemoveDeviceScreenDestination))
else ->
RegisterDeviceContent(
Expand Down Expand Up @@ -189,6 +193,7 @@ private fun PasswordTextField(state: RegisterDeviceState, onPasswordChange: (Tex
state = when (state.flowState) {
is RegisterDeviceFlowState.Error.InvalidCredentialsError ->
WireTextFieldState.Error(stringResource(id = R.string.remove_device_invalid_password))

else -> WireTextFieldState.Default
},
imeAction = ImeAction.Done,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,26 @@ package com.wire.android.ui.authentication.devices.register

import androidx.compose.ui.text.input.TextFieldValue
import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.data.conversation.ClientId
import com.wire.kalium.logic.data.user.UserId

data class RegisterDeviceState(
val password: TextFieldValue = TextFieldValue(""),
val continueEnabled: Boolean = false,
val flowState: RegisterDeviceFlowState = RegisterDeviceFlowState.Default
)

sealed class RegisterDeviceFlowState {
object Default : RegisterDeviceFlowState()
object Loading : RegisterDeviceFlowState()
object TooManyDevices : RegisterDeviceFlowState()
data class Success(val initialSyncCompleted: Boolean) : RegisterDeviceFlowState()
data class Success(
val initialSyncCompleted: Boolean,
val isE2EIRequired: Boolean,
val clientId: ClientId,
val userId: UserId? = null
) : RegisterDeviceFlowState()

sealed class Error : RegisterDeviceFlowState() {
object InvalidCredentialsError : Error()
data class GenericError(val coreFailure: CoreFailure) : Error()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,25 @@ class RegisterDeviceViewModel @Inject constructor(
)) {
is RegisterClientResult.Failure.TooManyClients ->
updateFlowState(RegisterDeviceFlowState.TooManyDevices)

is RegisterClientResult.Success ->
updateFlowState(RegisterDeviceFlowState.Success(userDataStore.initialSyncCompleted.first()))
updateFlowState(
RegisterDeviceFlowState.Success(
userDataStore.initialSyncCompleted.first(),
false,
registerDeviceResult.client.id
)
)

is RegisterClientResult.E2EICertificateRequired ->
updateFlowState(
RegisterDeviceFlowState.Success(
userDataStore.initialSyncCompleted.first(),
true,
registerDeviceResult.client.id,
registerDeviceResult.userId
)
)

is RegisterClientResult.Failure.Generic -> state = state.copy(
continueEnabled = true,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ import com.wire.android.ui.common.divider.WireDivider
import com.wire.android.ui.common.rememberTopBarElevationState
import com.wire.android.ui.common.textfield.clearAutofillTree
import com.wire.android.ui.common.visbility.rememberVisibilityState
import com.wire.android.ui.destinations.E2EIEnrollmentScreenDestination
import com.wire.android.ui.destinations.HomeScreenDestination
import com.wire.android.ui.destinations.InitialSyncScreenDestination
import com.wire.android.util.dialogErrorStrings
Expand All @@ -73,9 +74,11 @@ fun RemoveDeviceScreen(navigator: Navigator) {
val state: RemoveDeviceState = viewModel.state
val clearSessionState: ClearSessionState = clearSessionViewModel.state

fun navigateAfterSuccess(initialSyncCompleted: Boolean) = navigator.navigate(
fun navigateAfterSuccess(initialSyncCompleted: Boolean, isE2EIRequired: Boolean) = navigator.navigate(
NavigationCommand(
destination = if (initialSyncCompleted) HomeScreenDestination else InitialSyncScreenDestination,
destination = if (isE2EIRequired) E2EIEnrollmentScreenDestination
else if (initialSyncCompleted) HomeScreenDestination
else InitialSyncScreenDestination,
backStackMode = BackStackMode.CLEAR_WHOLE
)
)
Expand All @@ -84,9 +87,9 @@ fun RemoveDeviceScreen(navigator: Navigator) {
RemoveDeviceContent(
state = state,
clearSessionState = clearSessionState,
onItemClicked = { viewModel.onItemClicked(it) { navigateAfterSuccess(it) } },
onItemClicked = { viewModel.onItemClicked(it, ::navigateAfterSuccess) },
onPasswordChange = viewModel::onPasswordChange,
onRemoveConfirm = { viewModel.onRemoveConfirmed { navigateAfterSuccess(it) } },
onRemoveConfirm = { viewModel.onRemoveConfirmed(::navigateAfterSuccess) },
onDialogDismiss = viewModel::onDialogDismissed,
onErrorDialogDismiss = viewModel::clearDeleteClientError,
onBackButtonClicked = clearSessionViewModel::onBackButtonClicked,
Expand Down
Loading

0 comments on commit 7c6148c

Please sign in to comment.