Skip to content

Commit

Permalink
✨ Support for encrypted file transfer
Browse files Browse the repository at this point in the history
  • Loading branch information
guiyanakuang committed Jun 20, 2024
1 parent 4c4a50d commit 1c33ecf
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ import io.ktor.client.*
import io.ktor.client.plugins.*
import io.ktor.client.request.*
import io.ktor.client.statement.*
import io.ktor.http.*
import io.ktor.util.*
import io.ktor.utils.io.*
import io.ktor.utils.io.core.*
import org.signal.libsignal.protocol.SessionCipher
import org.signal.libsignal.protocol.SignalProtocolAddress
import org.signal.libsignal.protocol.message.SignalMessage
import org.signal.libsignal.protocol.state.SignalProtocolStore
import java.io.ByteArrayOutputStream

object SignalClientDecryptPlugin : HttpClientPlugin<SignalConfig, SignalClientDecryptPlugin> {

Expand All @@ -39,24 +41,57 @@ object SignalClientDecryptPlugin : HttpClientPlugin<SignalConfig, SignalClientDe
if (signal == "1") {
logger.debug { "signal client decrypt $appInstanceId" }
val byteReadChannel: ByteReadChannel = it.content
val bytes = byteReadChannel.readRemaining().readBytes()
val signalProtocolAddress = SignalProtocolAddress(appInstanceId, 1)
val signalMessage = SignalMessage(bytes)
val sessionCipher = SessionCipher(signalProtocolStore, signalProtocolAddress)
val decrypt = sessionCipher.decrypt(signalMessage)

// Create a new ByteReadChannel to contain the decrypted content
val newChannel = ByteReadChannel(decrypt)
val responseData =
HttpResponseData(
it.status,
it.requestTime,
it.headers,
it.version,
newChannel,
it.coroutineContext,
)
proceedWith(DefaultHttpResponse(it.call, responseData))
val contentType = it.call.response.contentType()

if (contentType == ContentType.Application.Json) {
val bytes = byteReadChannel.readRemaining().readBytes()
val signalProtocolAddress = SignalProtocolAddress(appInstanceId, 1)
val signalMessage = SignalMessage(bytes)
val sessionCipher = SessionCipher(signalProtocolStore, signalProtocolAddress)
val decrypt = sessionCipher.decrypt(signalMessage)

// Create a new ByteReadChannel to contain the decrypted content
val newChannel = ByteReadChannel(decrypt)
val responseData =
HttpResponseData(
it.status,
it.requestTime,
it.headers,
it.version,
newChannel,
it.coroutineContext,
)
proceedWith(DefaultHttpResponse(it.call, responseData))
} else if (contentType == ContentType.Application.OctetStream) {
val signalProtocolAddress = SignalProtocolAddress(appInstanceId, 1)
val sessionCipher = SessionCipher(signalProtocolStore, signalProtocolAddress)

val result = ByteArrayOutputStream()
while (!byteReadChannel.isClosedForRead) {
val size = byteReadChannel.readInt()
val byteArray = ByteArray(size)
var bytesRead = 0
while (bytesRead < size) {
val currentRead = byteReadChannel.readAvailable(byteArray, bytesRead, size - bytesRead)
if (currentRead == -1) break
bytesRead += currentRead
}
val signalMessage = SignalMessage(byteArray)
result.write(sessionCipher.decrypt(signalMessage))
}
val newChannel = ByteReadChannel(result.toByteArray())
val responseData =
HttpResponseData(
it.status,
it.requestTime,
it.headers,
it.version,
newChannel,
it.coroutineContext,
)
proceedWith(DefaultHttpResponse(it.call, responseData))
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,18 @@ object SignalClientEncryptPlugin : HttpClientPlugin<SignalConfig, SignalClientEn
context.headers["targetAppInstanceId"]?.let { targetAppInstanceId ->
context.headers["signal"]?.let { signal ->
if (signal == "1") {
val originalContent = context.body as? OutgoingContent.ByteArrayContent
originalContent?.let {
logger.debug { "signal client encrypt $targetAppInstanceId" }
val signalProtocolAddress = SignalProtocolAddress(targetAppInstanceId, 1)
val sessionCipher = SessionCipher(signalProtocolStore, signalProtocolAddress)
val ciphertextMessage = sessionCipher.encrypt(it.bytes())
val encryptedData = ciphertextMessage.serialize()
logger.debug { "signal client encrypt $targetAppInstanceId" }
when (context.body) {
// Current all client requests use the Json protocol
context.body = ByteArrayContent(encryptedData, contentType = ContentType.Application.Json)
proceedWith(context.body)
is OutgoingContent.ByteArrayContent -> {
val originalContent = context.body as OutgoingContent.ByteArrayContent
val signalProtocolAddress = SignalProtocolAddress(targetAppInstanceId, 1)
val sessionCipher = SessionCipher(signalProtocolStore, signalProtocolAddress)
val ciphertextMessage = sessionCipher.encrypt(originalContent.bytes())
val encryptedData = ciphertextMessage.serialize()
context.body = ByteArrayContent(encryptedData, contentType = ContentType.Application.Json)
proceedWith(context.body)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@ import io.ktor.server.response.*
import io.ktor.util.pipeline.*
import io.ktor.utils.io.*
import io.ktor.utils.io.core.*
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async
import org.signal.libsignal.protocol.SessionCipher
import org.signal.libsignal.protocol.SignalProtocolAddress
import java.nio.ByteBuffer

val SIGNAL_SERVER_ENCRYPT_PLUGIN: ApplicationPlugin<SignalConfig> =
createApplicationPlugin(
Expand All @@ -26,9 +30,9 @@ val SIGNAL_SERVER_ENCRYPT_PLUGIN: ApplicationPlugin<SignalConfig> =
headers["appInstanceId"]?.let { appInstanceId ->
headers["signal"]?.let {
logger.debug { "signal server encrypt $appInstanceId" }
val signalProtocolAddress = SignalProtocolAddress(appInstanceId, 1)
val sessionCipher = SessionCipher(signalProtocolStore, signalProtocolAddress)
transformBodyTo(body) { bytes ->
val signalProtocolAddress = SignalProtocolAddress(appInstanceId, 1)
val sessionCipher = SessionCipher(signalProtocolStore, signalProtocolAddress)
val ciphertextMessage = sessionCipher.encrypt(bytes)
ciphertextMessage.serialize()
}
Expand All @@ -39,6 +43,9 @@ val SIGNAL_SERVER_ENCRYPT_PLUGIN: ApplicationPlugin<SignalConfig> =

object EncryptResponse :
Hook<suspend EncryptResponse.Context.(ApplicationCall, OutgoingContent) -> Unit> {

val ioCoroutineDispatcher = CoroutineScope(Dispatchers.IO)

class Context(private val context: PipelineContext<Any, ApplicationCall>) {
suspend fun transformBodyTo(
body: OutgoingContent,
Expand All @@ -56,12 +63,49 @@ object EncryptResponse :
}

is OutgoingContent.WriteChannelContent -> {
val byteChannel = ByteChannel(true)
body.writeTo(byteChannel)
byteChannel.flush()
byteChannel.close()
val bytes = byteChannel.readRemaining().readBytes()
context.subject = ByteArrayContent(encrypt(bytes), contentType = body.contentType, status = body.status)
val producer: suspend ByteWriteChannel.() -> Unit = {
val encryptChannel: ByteWriteChannel = this
val originChannel = ByteChannel(true)
val byteBuffer = ByteBuffer.allocateDirect(81920)

val deferred =
ioCoroutineDispatcher.async {
while (true) {
byteBuffer.clear()
val size = originChannel.readAvailable(byteBuffer)
if (size < 0) break
if (size == 0) continue
byteBuffer.flip()
val byteArray = ByteArray(byteBuffer.remaining())
byteBuffer.get(byteArray)
var offset = 0
do {
val transformedBytes = encrypt(byteArray)
encryptChannel.writeInt(transformedBytes.size)
val availableSize =
encryptChannel.writeAvailable(
transformedBytes,
offset,
transformedBytes.size - offset,
)
offset += availableSize
} while (size > offset)
}
}

body.writeTo(originChannel)
originChannel.close()

deferred.await()
}

val content =
ChannelWriterContent(
body = producer,
contentType = body.contentType,
status = body.status,
)
context.subject = content
}

else -> {
Expand Down

0 comments on commit 1c33ecf

Please sign in to comment.