Skip to content

Commit

Permalink
fix: Check file extension instead of mimeType [WPB-10605] (#2950)
Browse files Browse the repository at this point in the history
* fix: Check file extension instead of mimeType

* Review updates
  • Loading branch information
borichellow authored Aug 16, 2024
1 parent 375ca80 commit 9db15f2
Show file tree
Hide file tree
Showing 9 changed files with 264 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ import com.wire.kalium.logic.feature.applock.AppLockTeamFeatureConfigObserver
import com.wire.kalium.logic.feature.applock.AppLockTeamFeatureConfigObserverImpl
import com.wire.kalium.logic.feature.applock.MarkTeamAppLockStatusAsNotifiedUseCase
import com.wire.kalium.logic.feature.applock.MarkTeamAppLockStatusAsNotifiedUseCaseImpl
import com.wire.kalium.logic.feature.asset.ValidateAssetMimeTypeUseCase
import com.wire.kalium.logic.feature.asset.ValidateAssetMimeTypeUseCaseImpl
import com.wire.kalium.logic.feature.asset.ValidateAssetFileTypeUseCase
import com.wire.kalium.logic.feature.asset.ValidateAssetFileTypeUseCaseImpl
import com.wire.kalium.logic.feature.auth.AuthenticationScope
import com.wire.kalium.logic.feature.auth.AuthenticationScopeProvider
import com.wire.kalium.logic.feature.auth.ClearUserDataUseCase
Expand Down Expand Up @@ -1796,7 +1796,7 @@ class UserSessionScope internal constructor(

private val clearUserData: ClearUserDataUseCase get() = ClearUserDataUseCaseImpl(userStorage)

private val validateAssetMimeType: ValidateAssetMimeTypeUseCase get() = ValidateAssetMimeTypeUseCaseImpl()
private val validateAssetMimeType: ValidateAssetFileTypeUseCase get() = ValidateAssetFileTypeUseCaseImpl()

val logout: LogoutUseCase
get() = LogoutUseCaseImpl(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ internal class ScheduleNewAssetMessageUseCaseImpl(
private val selfDeleteTimer: ObserveSelfDeletionTimerSettingsForConversationUseCase,
private val scope: CoroutineScope,
private val observeFileSharingStatus: ObserveFileSharingStatusUseCase,
private val validateAssetMimeTypeUseCase: ValidateAssetMimeTypeUseCase,
private val validateAssetFileUseCase: ValidateAssetFileTypeUseCase,
private val dispatcher: KaliumDispatcher,
) : ScheduleNewAssetMessageUseCase {

Expand All @@ -133,7 +133,7 @@ internal class ScheduleNewAssetMessageUseCaseImpl(
FileSharingStatus.Value.EnabledAll -> { /* no-op*/
}

is FileSharingStatus.Value.EnabledSome -> if (!validateAssetMimeTypeUseCase(assetMimeType, it.state.allowedType)) {
is FileSharingStatus.Value.EnabledSome -> if (!validateAssetFileUseCase(assetName, it.state.allowedType)) {
kaliumLogger.e("The asset message trying to be processed has invalid content data")
return ScheduleNewAssetMessageResult.Failure.RestrictedFileType
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Wire
* Copyright (C) 2024 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.wire.kalium.logic.feature.asset

/**
* Returns true if the file extension is present in file name and is allowed and false otherwise.
* @param fileName the file name (with extension) to validate.
* @param allowedExtension the list of allowed extension.
*/
interface ValidateAssetFileTypeUseCase {
operator fun invoke(fileName: String?, allowedExtension: List<String>): Boolean
}

internal class ValidateAssetFileTypeUseCaseImpl : ValidateAssetFileTypeUseCase {
override operator fun invoke(fileName: String?, allowedExtension: List<String>): Boolean {
if (fileName == null) return false

val split = fileName.split(".")
return if (split.size < 2) {
false
} else {
val allowedExtensionLowerCase = allowedExtension.map { it.lowercase() }
val extensions = split.subList(1, split.size).map { it.lowercase() }
extensions.all { it.isNotEmpty() && allowedExtensionLowerCase.contains(it) }
}
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ import com.wire.kalium.logic.feature.asset.UpdateAssetMessageDownloadStatusUseCa
import com.wire.kalium.logic.feature.asset.UpdateAssetMessageDownloadStatusUseCaseImpl
import com.wire.kalium.logic.feature.asset.UpdateAssetMessageUploadStatusUseCase
import com.wire.kalium.logic.feature.asset.UpdateAssetMessageUploadStatusUseCaseImpl
import com.wire.kalium.logic.feature.asset.ValidateAssetMimeTypeUseCase
import com.wire.kalium.logic.feature.asset.ValidateAssetMimeTypeUseCaseImpl
import com.wire.kalium.logic.feature.asset.ValidateAssetFileTypeUseCase
import com.wire.kalium.logic.feature.asset.ValidateAssetFileTypeUseCaseImpl
import com.wire.kalium.logic.feature.message.composite.SendButtonActionMessageUseCase
import com.wire.kalium.logic.feature.message.ephemeral.DeleteEphemeralMessageForSelfUserAsReceiverUseCaseImpl
import com.wire.kalium.logic.feature.message.ephemeral.DeleteEphemeralMessageForSelfUserAsSenderUseCaseImpl
Expand Down Expand Up @@ -144,8 +144,8 @@ class MessageScope internal constructor(
protoContentMapper = protoContentMapper
)

private val validateAssetMimeTypeUseCase: ValidateAssetMimeTypeUseCase
get() = ValidateAssetMimeTypeUseCaseImpl()
private val validateAssetMimeTypeUseCase: ValidateAssetFileTypeUseCase
get() = ValidateAssetFileTypeUseCaseImpl()

private val messageContentEncoder = MessageContentEncoder()
private val messageSendingInterceptor: MessageSendingInterceptor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import com.wire.kalium.logic.data.message.Message
import com.wire.kalium.logic.data.message.MessageContent
import com.wire.kalium.logic.data.message.MessageRepository
import com.wire.kalium.logic.data.message.PersistMessageUseCase
import com.wire.kalium.logic.feature.asset.ValidateAssetMimeTypeUseCase
import com.wire.kalium.logic.feature.asset.ValidateAssetFileTypeUseCase
import com.wire.kalium.logic.functional.onFailure
import com.wire.kalium.logic.functional.onSuccess
import com.wire.kalium.logic.kaliumLogger
Expand All @@ -38,7 +38,7 @@ internal class AssetMessageHandlerImpl(
private val messageRepository: MessageRepository,
private val persistMessage: PersistMessageUseCase,
private val userConfigRepository: UserConfigRepository,
private val validateAssetMimeTypeUseCase: ValidateAssetMimeTypeUseCase
private val validateAssetMimeTypeUseCase: ValidateAssetFileTypeUseCase
) : AssetMessageHandler {

override suspend fun handle(message: Message.Regular) {
Expand All @@ -53,7 +53,7 @@ internal class AssetMessageHandlerImpl(
FileSharingStatus.Value.EnabledAll -> true

is FileSharingStatus.Value.EnabledSome -> validateAssetMimeTypeUseCase(
messageContent.value.mimeType,
messageContent.value.name,
it.state.allowedType
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -628,15 +628,15 @@ class ScheduleNewAssetMessageUseCaseTest {

verify(arrangement.validateAssetMimeTypeUseCase)
.function(arrangement.validateAssetMimeTypeUseCase::invoke)
.with(eq("text/plain"), eq(listOf("png")))
.with(eq("some-asset.txt"), eq(listOf("png")))
.wasInvoked(exactly = once)
}

@Test
fun givenAssetMimeTypeRestrictedAndFileAllowed_whenSending_thenReturnSendTheFile() = runTest(testDispatcher.default) {
// Given
val assetToSend = mockedLongAssetData()
val assetName = "some-asset.txt"
val assetName = "some-asset.png"
val inputDataPath = fakeKaliumFileSystem.providePersistentAssetPath(assetName)
val expectedAssetId = dummyUploadedAssetId
val expectedAssetSha256 = SHA256Key("some-asset-sha-256".toByteArray())
Expand Down Expand Up @@ -669,7 +669,7 @@ class ScheduleNewAssetMessageUseCaseTest {

verify(arrangement.validateAssetMimeTypeUseCase)
.function(arrangement.validateAssetMimeTypeUseCase::invoke)
.with(eq("image/png"), eq(listOf("png")))
.with(eq("some-asset.png"), eq(listOf("png")))
.wasInvoked(exactly = once)
}

Expand Down Expand Up @@ -706,7 +706,7 @@ class ScheduleNewAssetMessageUseCaseTest {
private val messageRepository: MessageRepository = mock(MessageRepository::class)

@Mock
val validateAssetMimeTypeUseCase: ValidateAssetMimeTypeUseCase = mock(ValidateAssetMimeTypeUseCase::class)
val validateAssetMimeTypeUseCase: ValidateAssetFileTypeUseCase = mock(ValidateAssetFileTypeUseCase::class)

@Mock
val observerFileSharingStatusUseCase: ObserveFileSharingStatusUseCase = mock(ObserveFileSharingStatusUseCase::class)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Wire
* Copyright (C) 2024 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.wire.kalium.logic.feature.asset

import kotlinx.coroutines.test.runTest
import kotlin.test.Test
import kotlin.test.assertFalse
import kotlin.test.assertTrue

class ValidateAssetFileTypeUseCaseTest {

@Test
fun givenRegularFileNameWithAllowedExtension_whenInvoke_thenBeApproved() = runTest {
val (_, validate) = arrange {}

val result = validate("name.txt", listOf("txt", "jpg"))

assertTrue(result)
}

@Test
fun givenRegularFileNameWithNOTAllowedExtension_whenInvoke_thenBeRestricted() = runTest {
val (_, validate) = arrange {}

val result = validate("name.php", listOf("txt", "jpg"))

assertFalse(result)
}

@Test
fun givenRegularFileNameWithoutExtension_whenInvoke_thenBeRestricted() = runTest {
val (_, validate) = arrange {}

val result = validate("name", listOf("txt", "jpg"))

assertFalse(result)
}

@Test
fun givenNullFileName_whenInvoke_thenBeRestricted() = runTest {
val (_, validate) = arrange {}

val result = validate(null, listOf("txt", "jpg"))

assertFalse(result)
}

@Test
fun givenRegularFileNameWithFewExtensions_whenInvoke_thenEachExtensionIsChecked() = runTest {
val (_, validate) = arrange {}

val result1 = validate("name.php.txt", listOf("txt", "jpg"))
val result2 = validate("name.txt.php", listOf("txt", "jpg"))
val result3 = validate("name..txt.jpg", listOf("txt", "jpg"))
val result4 = validate("name.txt.php.txt.jpg", listOf("txt", "jpg"))

assertFalse(result1)
assertFalse(result2)
assertFalse(result3)
assertFalse(result4)
}

private fun arrange(block: Arrangement.() -> Unit) = Arrangement(block).arrange()

private class Arrangement(
private val block: Arrangement.() -> Unit
) {
fun arrange() = block().run {
this@Arrangement to ValidateAssetFileTypeUseCaseImpl()
}
}
}
Loading

0 comments on commit 9db15f2

Please sign in to comment.