Skip to content

Commit

Permalink
✨ Add SignalDecryption plugin to automatically decrypt data (#292)
Browse files Browse the repository at this point in the history
  • Loading branch information
guiyanakuang authored Feb 4, 2024
1 parent 0f78e35 commit f305c15
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ fun base64Decode(string: String): ByteArray {
return Base64.getDecoder().decode(string)
}

fun base64mimeEncode(bytes: ByteArray): String {
return Base64.getMimeEncoder().encodeToString(bytes)
}

fun base64mimeDecode(string: String): ByteArray {
return Base64.getMimeDecoder().decode(string)
}

fun encryptData(key: SecretKey, data: ByteArray): ByteArray {
val cipher = Cipher.getInstance("AES/CBC/PKCS5Padding")
val ivBytes = ByteArray(cipher.blockSize)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package com.clipevery.net

import com.clipevery.exception.StandardErrorCode
import com.clipevery.net.exception.signalExceptionHandler
import com.clipevery.net.plugin.SignalDecryption
import com.clipevery.routing.syncRouting
import com.clipevery.serializer.IdentityKeySerializer
import com.clipevery.serializer.PreKeyBundleSerializer
Expand All @@ -16,6 +17,7 @@ import io.ktor.server.netty.Netty
import io.ktor.server.netty.NettyApplicationEngine
import io.ktor.server.plugins.contentnegotiation.ContentNegotiation
import io.ktor.server.plugins.statuspages.StatusPages
import io.ktor.server.request.contentType
import io.ktor.server.request.httpMethod
import io.ktor.server.request.uri
import io.ktor.server.routing.routing
Expand Down Expand Up @@ -47,10 +49,12 @@ class DesktopClipServer(private val clientHandlerManager :ClientHandlerManager):
failResponse(call, StandardErrorCode.UNKNOWN_ERROR.toErrorCode())
}
signalExceptionHandler()
}
install(SignalDecryption) {

}
intercept(ApplicationCallPipeline.Setup) {
logger.info {"Received request: ${call.request.httpMethod.value} ${call.request.uri}" }
proceed()
logger.info {"Received request: ${call.request.httpMethod.value} ${call.request.uri} ${call.request.contentType()}" }
}
routing {
syncRouting()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package com.clipevery.net.plugin

import com.clipevery.Dependencies
import com.clipevery.utils.base64mimeDecode
import io.ktor.server.application.ApplicationPlugin
import io.ktor.server.application.createApplicationPlugin
import io.ktor.server.application.hooks.ReceiveRequestBytes
import io.ktor.util.KtorDsl
import io.ktor.utils.io.core.readBytes
import io.ktor.utils.io.writer
import org.signal.libsignal.protocol.SessionCipher
import org.signal.libsignal.protocol.SignalProtocolAddress
import org.signal.libsignal.protocol.message.SignalMessage
import org.signal.libsignal.protocol.state.SignalProtocolStore
import java.nio.ByteBuffer

val SignalDecryption: ApplicationPlugin<SignalDecryptionConfig> = createApplicationPlugin(
"SignalDecryption",
::SignalDecryptionConfig
) {

val signalProtocolStore: SignalProtocolStore = pluginConfig.signalProtocolStore

on(ReceiveRequestBytes) { call, body ->
val headers = call.request.headers
headers["appInstanceId"]?.let { appInstanceId ->
headers["signal"]?.let { signal ->
if (signal == "1") {
return@on application.writer {
val base64Content = body.readRemaining().readBytes()
val originalString = String(base64Content, Charsets.UTF_8)
val base64String = originalString.substring(1, originalString.length - 1)
val encryptedContent = base64mimeDecode(base64String)
val signalProtocolAddress = SignalProtocolAddress(appInstanceId, 1)
val signalMessage = SignalMessage(encryptedContent)
val sessionCipher = SessionCipher(signalProtocolStore, signalProtocolAddress)
val decrypt = sessionCipher.decrypt(signalMessage)
channel.writeFully(ByteBuffer.wrap(decrypt))
}.channel
}
}
}
return@on body
}
}

@KtorDsl
class SignalDecryptionConfig {

val signalProtocolStore: SignalProtocolStore = Dependencies.koinApplication.koin.get()
}
Original file line number Diff line number Diff line change
@@ -1,26 +1,16 @@
package com.clipevery.utils

import com.clipevery.app.AppInfo
import com.clipevery.exception.ErrorCode
import com.clipevery.exception.ErrorType
import com.clipevery.exception.StandardErrorCode
import io.ktor.http.ContentType
import io.ktor.http.HttpHeaders
import io.ktor.http.HttpStatusCode
import io.ktor.server.application.ApplicationCall
import io.ktor.server.request.receive
import io.ktor.server.response.header
import io.ktor.server.response.respond
import io.ktor.server.response.respondBytes
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.Serializable
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.decodeFromStream
import org.signal.libsignal.protocol.SessionCipher
import org.signal.libsignal.protocol.SignalProtocolAddress
import org.signal.libsignal.protocol.message.SignalMessage
import org.signal.libsignal.protocol.state.SignalProtocolStore


suspend inline fun successResponse(call: ApplicationCall) {
Expand Down Expand Up @@ -48,44 +38,10 @@ suspend inline fun failResponse(call: ApplicationCall, errorCode: ErrorCode, mes
ErrorType.INTERNAL_ERROR -> HttpStatusCode.InternalServerError
ErrorType.USER_ERROR -> HttpStatusCode.UnprocessableEntity
}
val failMessage = FailResponse(code, errorCode.name)
val failMessage = FailResponse(code, message ?: errorCode.name)
failResponse(call, failMessage, status)
}

@ExperimentalSerializationApi
suspend inline fun <reified T : Any> decodeReceive(call: ApplicationCall,
signalProtocolStore: SignalProtocolStore): T {
val bytes = call.receive<ByteArray>()
val appInstanceId = call.request.headers["appInstanceId"]
if (appInstanceId == null) {
failResponse(call, StandardErrorCode.NOT_FOUND_APP_INSTANCE_ID.toErrorCode())
}
val signalProtocolAddress = SignalProtocolAddress(appInstanceId!!, 1)

val signalMessage = SignalMessage(bytes)

val sessionCipher = SessionCipher(signalProtocolStore, signalProtocolAddress)

val decrypt = sessionCipher.decrypt(signalMessage)

return Json.decodeFromStream(decrypt.inputStream())
}

suspend inline fun <reified T : Any> encodeResponse(call: ApplicationCall,
appInfo: AppInfo,
signalProtocolStore: SignalProtocolStore,
message: T) {
call.response.header("appInstanceId", appInfo.appInstanceId)

val signalProtocolAddress = SignalProtocolAddress(appInfo.appInstanceId, 1)

val sessionCipher = SessionCipher(signalProtocolStore, signalProtocolAddress)

val encrypt = sessionCipher.encrypt(Json.encodeToString(message).encodeToByteArray())

call.respond(status = HttpStatusCode.OK, message = encrypt)
}

suspend inline fun getAppInstanceId(call: ApplicationCall): String? {
val appInstanceId = call.request.headers["appInstanceId"]
if (appInstanceId == null) {
Expand Down

0 comments on commit f305c15

Please sign in to comment.