Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: CRL proxy [WPB-8793] #2800

Merged
merged 4 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,22 @@ import kotlinx.datetime.Instant
data class E2EISettings(
val isRequired: Boolean,
val discoverUrl: String?,
val gracePeriodEnd: Instant?
val gracePeriodEnd: Instant?,
val shouldUseProxy: Boolean,
val crlProxy: String?,
) {

fun toEntity() = E2EISettingsEntity(
isRequired, discoverUrl, gracePeriodEnd?.toEpochMilliseconds()
isRequired, discoverUrl, gracePeriodEnd?.toEpochMilliseconds(), shouldUseProxy, crlProxy
)

companion object {
fun fromEntity(entity: E2EISettingsEntity) = E2EISettings(
entity.status,
entity.discoverUrl,
entity.gracePeriodEndMs?.let { Instant.fromEpochMilliseconds(it) },
entity.shouldUseProxy == true,
mchenani marked this conversation as resolved.
Show resolved Hide resolved
entity.crlProxy
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
package com.wire.kalium.logic.data.e2ei

import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.configuration.UserConfigRepository
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.functional.getOrNull
import com.wire.kalium.logic.functional.map
import com.wire.kalium.logic.wrapApiRequest
import com.wire.kalium.network.api.base.unbound.acme.ACMEApi
import com.wire.kalium.persistence.config.CRLUrlExpirationList
Expand All @@ -39,36 +42,37 @@ internal interface CertificateRevocationListRepository {

internal class CertificateRevocationListRepositoryDataSource(
private val acmeApi: ACMEApi,
private val metadataDAO: MetadataDAO
private val metadataDAO: MetadataDAO,
private val userConfigRepository: UserConfigRepository
) : CertificateRevocationListRepository {
override suspend fun getCRLs(): CRLUrlExpirationList? =
metadataDAO.getSerializable(CRL_LIST_KEY, CRLUrlExpirationList.serializer())

override suspend fun addOrUpdateCRL(url: String, timestamp: ULong) {
val newCRLUrls = metadataDAO.getSerializable(CRL_LIST_KEY, CRLUrlExpirationList.serializer())
?.let { crlExpirationList ->
val crlWithExpiration = crlExpirationList.cRLWithExpirationList.find {
it.url == url
}
crlWithExpiration?.let { item ->
crlExpirationList.cRLWithExpirationList.map { current ->
if (current.url == url) {
return@map item.copy(expiration = timestamp)
} else {
return@map current
}
?.let { crlExpirationList ->
val crlWithExpiration = crlExpirationList.cRLWithExpirationList.find {
it.url == url
}
crlWithExpiration?.let { item ->
crlExpirationList.cRLWithExpirationList.map { current ->
if (current.url == url) {
return@map item.copy(expiration = timestamp)
} else {
return@map current
}
} ?: run {
// add new CRL
crlExpirationList.cRLWithExpirationList.plus(
CRLWithExpiration(url, timestamp)
)
}

} ?: run {
// add new CRL
listOf(CRLWithExpiration(url, timestamp))
}
// add new CRL
crlExpirationList.cRLWithExpirationList.plus(
CRLWithExpiration(url, timestamp)
)
}

} ?: run {
// add new CRL
listOf(CRLWithExpiration(url, timestamp))
}
metadataDAO.putSerializable(
CRL_LIST_KEY,
CRLUrlExpirationList(newCRLUrls),
Expand All @@ -78,7 +82,11 @@ internal class CertificateRevocationListRepositoryDataSource(

override suspend fun getClientDomainCRL(url: String): Either<CoreFailure, ByteArray> =
wrapApiRequest {
acmeApi.getClientDomainCRL(url)
val proxyUrl = userConfigRepository.getE2EISettings()
.map { if (!it.shouldUseProxy || it.crlProxy.isNullOrBlank()) null else it.crlProxy }
.getOrNull()

acmeApi.getClientDomainCRL(url, proxyUrl)
}

companion object {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,18 @@ class FeatureConfigMapperImpl : FeatureConfigMapper {
E2EIModel(
E2EIConfigModel(
data.config.url,
data.config.verificationExpirationSeconds
data.config.verificationExpirationSeconds,
data.config.shouldUseProxy == true,
data.config.crlProxy
),
fromDTO(data.status)
)
} ?: E2EIModel(
E2EIConfigModel(
null,
0
0,
false,
null
),
Status.DISABLED
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,5 +100,7 @@ data class E2EIModel(

data class E2EIConfigModel(
val discoverUrl: String?,
val verificationExpirationSeconds: Long
val verificationExpirationSeconds: Long,
val shouldUseProxy: Boolean,
val crlProxy: String?,
)
Original file line number Diff line number Diff line change
Expand Up @@ -1560,7 +1560,8 @@ class UserSessionScope internal constructor(
private val certificateRevocationListRepository: CertificateRevocationListRepository
get() = CertificateRevocationListRepositoryDataSource(
acmeApi = globalScope.unboundNetworkContainer.acmeApi,
metadataDAO = userStorage.database.metadataDAO
metadataDAO = userStorage.database.metadataDAO,
userConfigRepository = userConfigRepository
)

private val proteusPreKeyRefiller: ProteusPreKeyRefiller
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,12 @@ class E2EIConfigHandler(private val userConfigRepository: UserConfigRepository)
val newSettings = E2EISettings(
isRequired = e2eiConfig.status == Status.ENABLED,
discoverUrl = e2eiConfig.config.discoverUrl,
gracePeriodEnd = gracePeriodEnd
gracePeriodEnd = gracePeriodEnd,
shouldUseProxy = e2eiConfig.config.shouldUseProxy,
crlProxy = e2eiConfig.config.crlProxy
)

if (currentSettings?.isRequired == newSettings.isRequired && currentSettings?.discoverUrl == newSettings.discoverUrl) {
if (currentSettings?.isRequired == newSettings.isRequired && currentSettings.discoverUrl == newSettings.discoverUrl) {
// that settings were already handled,
// no need to re-write as it will reset gracePeriod
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,22 @@
*/
package com.wire.kalium.logic.data.e2ei

import com.wire.kalium.logic.StorageFailure
import com.wire.kalium.logic.configuration.E2EISettings
import com.wire.kalium.logic.configuration.UserConfigRepository
import com.wire.kalium.logic.data.e2ei.CertificateRevocationListRepositoryDataSource.Companion.CRL_LIST_KEY
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.functional.right
import com.wire.kalium.network.api.base.unbound.acme.ACMEApi
import com.wire.kalium.network.utils.NetworkResponse
import com.wire.kalium.persistence.config.CRLUrlExpirationList
import com.wire.kalium.persistence.config.CRLWithExpiration
import com.wire.kalium.persistence.dao.MetadataDAO
import io.ktor.utils.io.core.toByteArray
import io.mockative.Mock
import io.mockative.any
import io.mockative.classOf
import io.mockative.eq
import io.mockative.given
import io.mockative.mock
import io.mockative.once
Expand Down Expand Up @@ -106,6 +115,54 @@ class CertificateRevocationListRepositoryTest {
}.wasInvoked(once)
}

@Test
fun givenCRLUrlProxyRequired_whenClientDomainCRLRequested_thenProxyIsApplied() = runTest {
val (arrangement, crlRepository) = Arrangement()
.withClientDomainCRL()
.withE2EISettings(E2EI_SETTINGS.copy(shouldUseProxy = true, crlProxy = DUMMY_URL).right())
.arrange()

crlRepository.getClientDomainCRL(DUMMY_URL2)

verify(arrangement.userConfigRepository).coroutine { getE2EISettings() }.wasInvoked(once)

verify(arrangement.acmeApi).coroutine {
getClientDomainCRL(DUMMY_URL2, DUMMY_URL)
}.wasInvoked(once)
}

@Test
fun givenCRLUrlProxyRequiredButEmpty_whenClientDomainCRLRequested_thenProxyIsNotApplied() = runTest {
val (arrangement, crlRepository) = Arrangement()
.withClientDomainCRL()
.withE2EISettings(E2EI_SETTINGS.copy(shouldUseProxy = true, crlProxy = "").right())
.arrange()

crlRepository.getClientDomainCRL(DUMMY_URL2)

verify(arrangement.userConfigRepository).coroutine { getE2EISettings() }.wasInvoked(once)

verify(arrangement.acmeApi).coroutine {
getClientDomainCRL(DUMMY_URL2, null)
}.wasInvoked(once)
}

@Test
fun givenCRLUrlProxyNotRequired_whenClientDomainCRLRequested_thenProxyIsNotApplied() = runTest {
val (arrangement, crlRepository) = Arrangement()
.withClientDomainCRL()
.withE2EISettings(E2EI_SETTINGS.copy(shouldUseProxy = false, crlProxy = DUMMY_URL).right())
.arrange()

crlRepository.getClientDomainCRL(DUMMY_URL2)

verify(arrangement.userConfigRepository).coroutine { getE2EISettings() }.wasInvoked(once)

verify(arrangement.acmeApi).coroutine {
getClientDomainCRL(DUMMY_URL2, null)
}.wasInvoked(once)
}

private class Arrangement {

@Mock
Expand All @@ -114,7 +171,10 @@ class CertificateRevocationListRepositoryTest {
@Mock
val metadataDAO = mock(classOf<MetadataDAO>())

fun arrange() = this to CertificateRevocationListRepositoryDataSource(acmeApi, metadataDAO)
@Mock
val userConfigRepository = mock(classOf<UserConfigRepository>())

fun arrange() = this to CertificateRevocationListRepositoryDataSource(acmeApi, metadataDAO, userConfigRepository)

suspend fun withEmptyList() = apply {
given(metadataDAO).coroutine {
Expand Down Expand Up @@ -142,12 +202,35 @@ class CertificateRevocationListRepositoryTest {
)
}.thenReturn(CRLUrlExpirationList(listOf(CRLWithExpiration(DUMMY_URL, TIMESTAMP))))
}

suspend fun withE2EISettings(result: Either<StorageFailure, E2EISettings> = E2EI_SETTINGS.right()) = apply {
given(userConfigRepository).function(userConfigRepository::getE2EISettings)
.whenInvoked()
.thenReturn(result)
}

suspend fun withClientDomainCRL() = apply {
given(acmeApi).suspendFunction(acmeApi::getClientDomainCRL)
.whenInvokedWith(any(), any<String?>())
.thenReturn(NetworkResponse.Success("some_response".toByteArray(), mapOf(), 200))
}.apply {
given(acmeApi).suspendFunction(acmeApi::getClientDomainCRL)
.whenInvokedWith(any(), eq(null))
.thenReturn(NetworkResponse.Success("some_response".toByteArray(), mapOf(), 200))
}
}

companion object {
private const val DUMMY_URL = "https://dummy.url"
private const val DUMMY_URL2 = "https://dummy-2.url"
private val TIMESTAMP = 1234567890.toULong()
private val TIMESTAMP2 = 5453222.toULong()
private val E2EI_SETTINGS = E2EISettings(
isRequired = true,
discoverUrl = "discoverUrl",
gracePeriodEnd = null,
shouldUseProxy = false,
crlProxy = null
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1024,7 +1024,7 @@ class E2EIRepositoryTest {
@Test
fun givenE2EIIsDisabled_whenCallingDiscoveryUrl_thenItFailWithDisabled() {
val (arrangement, e2eiRepository) = Arrangement()
.withGettingE2EISettingsReturns(Either.Right(E2EISettings(false, null, Instant.DISTANT_FUTURE)))
.withGettingE2EISettingsReturns(Either.Right(E2EISettings(false, null, Instant.DISTANT_FUTURE, false, null)))
.arrange()

e2eiRepository.discoveryUrl().shouldFail {
Expand All @@ -1039,7 +1039,7 @@ class E2EIRepositoryTest {
@Test
fun givenE2EIIsEnabledAndDiscoveryUrlIsNull_whenCallingDiscoveryUrl_thenItFailWithMissingDiscoveryUrl() {
val (arrangement, e2eiRepository) = Arrangement()
.withGettingE2EISettingsReturns(Either.Right(E2EISettings(true, null, Instant.DISTANT_FUTURE)))
.withGettingE2EISettingsReturns(Either.Right(E2EISettings(true, null, Instant.DISTANT_FUTURE, false, null)))
.arrange()

e2eiRepository.discoveryUrl().shouldFail {
Expand All @@ -1054,7 +1054,7 @@ class E2EIRepositoryTest {
@Test
fun givenE2EIIsEnabledAndDiscoveryUrlIsNotNull_whenCallingDiscoveryUrl_thenItSucceed() {
val (arrangement, e2eiRepository) = Arrangement()
.withGettingE2EISettingsReturns(Either.Right(E2EISettings(true, RANDOM_URL, Instant.DISTANT_FUTURE)))
.withGettingE2EISettingsReturns(Either.Right(E2EISettings(true, RANDOM_URL, Instant.DISTANT_FUTURE, false, null)))
.arrange()

e2eiRepository.discoveryUrl().shouldSucceed {
Expand Down Expand Up @@ -1445,7 +1445,7 @@ class E2EIRepositoryTest {
val HEADERS = mapOf(NONCE_HEADER_KEY to RANDOM_NONCE.value, LOCATION_HEADER_KEY to RANDOM_URL)

val E2EI_TEAM_SETTINGS = E2EISettings(
true, RANDOM_URL, DateTimeUtil.currentInstant()
true, RANDOM_URL, DateTimeUtil.currentInstant(), false, null
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ class FeatureConfigMapperTest {
), FeatureFlagStatusDTO.ENABLED
),
FeatureConfigData.E2EI(
E2EIConfigDTO("url", 1_000_000L),
E2EIConfigDTO("url", null, false, 1_000_000L),
FeatureFlagStatusDTO.ENABLED
),
FeatureConfigData.MLSMigration(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class FeatureConfigRepositoryTest {
supportedCipherSuite = null
),
E2EIModel(
E2EIConfigModel("url", 1000000L),
E2EIConfigModel("url", 1000000L, false, null),
Status.ENABLED
),
MLSMigrationModel(
Expand Down Expand Up @@ -165,7 +165,7 @@ class FeatureConfigRepositoryTest {
), FeatureFlagStatusDTO.ENABLED
),
FeatureConfigData.E2EI(
E2EIConfigDTO("url", 1000000L),
E2EIConfigDTO("url", null, false, 1000000L),
FeatureFlagStatusDTO.ENABLED
),
FeatureConfigData.MLSMigration(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ object FeatureConfigTest {
status = Status.ENABLED,
supportedCipherSuite = null
),
e2EIModel: E2EIModel = E2EIModel(E2EIConfigModel("url", 10000L), Status.ENABLED),
e2EIModel: E2EIModel = E2EIModel(E2EIConfigModel("url", 10000L, false, null), Status.ENABLED),
mlsMigrationModel: MLSMigrationModel? = MLSMigrationModel(
Instant.DISTANT_FUTURE,
Instant.DISTANT_FUTURE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ class ObserveE2EIRequiredUseCaseTest {
}

companion object {
private val MLS_E2EI_SETTING = E2EISettings(true, "some_url", null)
private val MLS_E2EI_SETTING = E2EISettings(true, "some_url", null, false, null)
private val VALID_CERTIFICATE = E2eiCertificate(
userHandle = "userHandle",
serialNumber = "serialNumber",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ class RegisterMLSClientUseCaseTest {
const val REFILL_AMOUNT = 100
val RANDOM_URL = "https://random.rn"
val E2EI_TEAM_SETTINGS = E2EISettings(
true, RANDOM_URL, DateTimeUtil.currentInstant()
true, RANDOM_URL, DateTimeUtil.currentInstant(), false, null
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ class SyncFeatureConfigsUseCaseTest {

@Test
fun givenE2EIIsDisabled_whenSyncing_thenItShouldBeStoredAsDisabled() = runTest {
val e2EIModel = E2EIModel(E2EIConfigModel("url", 10_000L), Status.DISABLED)
val e2EIModel = E2EIModel(E2EIConfigModel("url", 10_000L, false, null), Status.DISABLED)
val expectedGracePeriodEnd = DateTimeUtil.currentInstant().plus(10_000.toDuration(DurationUnit.SECONDS))
val (arrangement, syncFeatureConfigsUseCase) = Arrangement()
.withRemoteFeatureConfigsSucceeding(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ data class SelfDeletingMessagesConfigDTO(
data class E2EIConfigDTO(
@SerialName("acmeDiscoveryUrl")
val url: String?,
@SerialName("crlProxy")
val crlProxy: String?,
@SerialName("useProxyOnMobile")
val shouldUseProxy: Boolean?,
@SerialName("verificationExpiration")
val verificationExpirationSeconds: Long
)
Expand Down
Loading
Loading