Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Enabling local storage of signal secret keys #30

Merged
merged 5 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions composeApp/src/commonMain/kotlin/com/clipevery/AppConfig.kt
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
package com.clipevery

import kotlinx.serialization.Serializable

@Serializable
data class AppConfig(val bindingState: Boolean = false)
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package com.clipevery.encrypt

enum class CreateSignalProtocolState {

NEW_GENERATE,
DELETE_GENERATE,
EXISTING,

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package com.clipevery.encrypt

import java.security.SecureRandom
import java.util.Base64
import javax.crypto.Cipher
import javax.crypto.KeyGenerator
import javax.crypto.SecretKey
import javax.crypto.spec.IvParameterSpec
import javax.crypto.spec.SecretKeySpec

fun generateAESKey(): SecretKey {
val keyGen = KeyGenerator.getInstance("AES")
keyGen.init(256)
return keyGen.generateKey()
}

fun secretKeyToString(secretKey: SecretKey): String {
val encodedKey = secretKey.encoded
return Base64.getEncoder().encodeToString(encodedKey)
}

fun stringToSecretKey(encodedKey: String): SecretKey {
val decodedKey = Base64.getDecoder().decode(encodedKey)
return SecretKeySpec(decodedKey, 0, decodedKey.size, "AES")
}

fun encryptData(key: SecretKey, data: ByteArray): ByteArray {
val cipher = Cipher.getInstance("AES/CBC/PKCS5Padding")
val ivBytes = ByteArray(cipher.blockSize)
SecureRandom().nextBytes(ivBytes)
val ivSpec = IvParameterSpec(ivBytes)

cipher.init(Cipher.ENCRYPT_MODE, key, ivSpec)
val encrypted = cipher.doFinal(data)
return ivBytes + encrypted
}

fun decryptData(key: SecretKey, encryptedData: ByteArray): ByteArray {
val cipher = Cipher.getInstance("AES/CBC/PKCS5Padding")
val ivBytes = encryptedData.copyOfRange(0, 16)
val actualEncryptedData = encryptedData.copyOfRange(16, encryptedData.size)

val ivSpec = IvParameterSpec(ivBytes)
cipher.init(Cipher.DECRYPT_MODE, key, ivSpec)

return cipher.doFinal(actualEncryptedData)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package com.clipevery.encrypt

interface SignalProtocolFactory {

fun createSignalProtocol(): SignalProtocolWithState
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
package com.clipevery.encrypt

data class SignalProtocolWithState(val signalProtocol: SignalProtocol,
val state: CreateSignalProtocolState)
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ package com.clipevery.presist
import kotlin.reflect.KClass

interface OneFilePersist {
fun <T : Any> readAs(clazz: KClass<T>): T?
fun <T : Any> read(clazz: KClass<T>): T?

fun readBytes(): ByteArray?

fun <T> save(config: T)

fun saveBytes(bytes: ByteArray)
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import java.util.Properties

val logger = KotlinLogging.logger {}

fun getFactory(): AppInfoFactory {
fun getAppInfoFactory(): AppInfoFactory {
val platform = currentPlatform()
return if (platform.isMacos()) {
MacosAppInfoFactory()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package com.clipevery.encrypt

import com.clipevery.utils.readJson
import org.whispersystems.libsignal.IdentityKeyPair
import org.whispersystems.libsignal.state.PreKeyRecord
import org.whispersystems.libsignal.state.SignedPreKeyRecord
import org.whispersystems.libsignal.util.KeyHelper
import java.io.ByteArrayOutputStream
import java.io.DataInputStream
import java.io.DataOutputStream

class DesktopSignalProtocol(override val identityKeyPair: IdentityKeyPair,
override val registrationId: Int,
Expand All @@ -18,43 +20,46 @@ class DesktopSignalProtocol(override val identityKeyPair: IdentityKeyPair,
KeyHelper.generateSignedPreKey(KeyHelper.generateIdentityKeyPair(), 5))
}


data class StringEncodeSignalProtocol(val identityKeyPairStr: String,
val registrationIdStr: Int,
val preKeysStr: List<String>,
val signedPreKeyStr: String)


fun readSignalProtocol(data: String): SignalProtocol {
val stringEncodeSignalProtocol = readJson<StringEncodeSignalProtocol>(data)

val identityKeyPair = IdentityKeyPair(asciiStringToBytes(stringEncodeSignalProtocol.identityKeyPairStr))

val registrationId = stringEncodeSignalProtocol.registrationIdStr

val preKeys = stringEncodeSignalProtocol.preKeysStr.map { PreKeyRecord(asciiStringToBytes(it)) }

val signedPreKey = SignedPreKeyRecord(asciiStringToBytes(stringEncodeSignalProtocol.signedPreKeyStr))

return DesktopSignalProtocol(identityKeyPair, registrationId, preKeys, signedPreKey)
}

fun writeSignalProtocol(signalProtocol: SignalProtocol): StringEncodeSignalProtocol {
val identityKeyPairStr = bytesToAsciiString(signalProtocol.identityKeyPair.serialize())

val registrationIdStr = signalProtocol.registrationId

val preKeysStr = signalProtocol.preKeys.map { bytesToAsciiString(it.serialize()) }

val signedPreKeyStr = bytesToAsciiString(signalProtocol.signedPreKey.serialize())

return StringEncodeSignalProtocol(identityKeyPairStr, registrationIdStr, preKeysStr, signedPreKeyStr)
}

fun bytesToAsciiString(bytes: ByteArray): String {
return bytes.joinToString(separator = "") { it.toInt().toChar().toString() }
fun readSignalProtocol(data: ByteArray): SignalProtocol {
val inputStream = DataInputStream(data.inputStream())
val identityKeyPairSize = inputStream.readInt()
val identityKeyPairBytes = inputStream.readNBytes(identityKeyPairSize)
val identityKeyPair = IdentityKeyPair(identityKeyPairBytes)
val registrationId = inputStream.readInt()
val preKeysSize = inputStream.readInt()
val preKeys = buildList {
for (i in 0 until preKeysSize) {
val preKeySize = inputStream.readInt()
val preKeyBytes = inputStream.readNBytes(preKeySize)
add(PreKeyRecord(preKeyBytes))
}
}
val signedPreKeySize = inputStream.readInt()
val signedPreKeyBytes = inputStream.readNBytes(signedPreKeySize)
val signedPreKeyRecord = SignedPreKeyRecord(signedPreKeyBytes)
return DesktopSignalProtocol(identityKeyPair, registrationId, preKeys, signedPreKeyRecord)
}

fun asciiStringToBytes(str: String): ByteArray {
return str.map { it.code.toByte() }.toByteArray()
fun writeSignalProtocol(signalProtocol: SignalProtocol): ByteArray {
val byteStream = ByteArrayOutputStream()
val dataStream = DataOutputStream(byteStream)
val identityKeyPairBytes = signalProtocol.identityKeyPair.serialize()
val identityKeyPairSize = identityKeyPairBytes.size
dataStream.writeInt(identityKeyPairSize)
dataStream.write(identityKeyPairBytes)
dataStream.writeInt(signalProtocol.registrationId)
val preKeys = signalProtocol.preKeys
dataStream.writeInt(preKeys.size)
preKeys.forEach {
val preKeyBytes = it.serialize()
val preKeySize = preKeyBytes.size
dataStream.writeInt(preKeySize)
dataStream.write(preKeyBytes)
}
val signedPreKeyBytes = signalProtocol.signedPreKey.serialize()
val signedPreKeySize = signedPreKeyBytes.size
dataStream.writeInt(signedPreKeySize)
dataStream.write(signedPreKeyBytes)

return byteStream.toByteArray()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package com.clipevery.encrypt

import com.clipevery.AppInfo
import com.clipevery.macos.MacosKeychainHelper
import com.clipevery.path.getPathProvider
import com.clipevery.platform.currentPlatform
import com.clipevery.presist.DesktopOneFilePersist
import com.clipevery.windows.WindowDapiHelper
import io.github.oshai.kotlinlogging.KotlinLogging

val logger = KotlinLogging.logger {}

fun getSignalProtocolFactory(appInfo: AppInfo): SignalProtocolFactory {
val currentPlatform = currentPlatform()
return if (currentPlatform.isMacos()) {
MacosSignalProtocolFactory(appInfo)
} else if (currentPlatform.isWindows()) {
WindowsSignalProtocolFactory()
} else {
throw IllegalStateException("Unknown platform: ${currentPlatform.name}")
}
}

class MacosSignalProtocolFactory(private val appInfo: AppInfo): SignalProtocolFactory {

private val filePersist = DesktopOneFilePersist(getPathProvider().resolveUser("signal.data"))

override fun createSignalProtocol(): SignalProtocolWithState {
val file = filePersist.path.toFile()
var deleteOldSignalProtocol = false
if (file.exists()) {
logger.info { "Found signalProtocol encrypt file" }
val bytes = file.readBytes()
val password = MacosKeychainHelper.getPassword(appInfo.appName, appInfo.userName)

password?.let {
logger.info { "Found password in keychain by ${appInfo.appName} ${appInfo.userName}" }
try {
val secretKey = stringToSecretKey(it)
val decryptData = decryptData(secretKey, bytes)
return SignalProtocolWithState(
readSignalProtocol(decryptData),
CreateSignalProtocolState.EXISTING
)
} catch (e: Exception) {
logger.error(e) { "Failed to decrypt signalProtocol" }
}
}

deleteOldSignalProtocol = true
if (file.delete()) {
logger.info { "Delete signalProtocol encrypt file" }
}

} else {
logger.info { "No found signalProtocol encrypt file" }
}

logger.info { "Creating new SignalProtocol" }
val signalProtocol = DesktopSignalProtocol()
val data = writeSignalProtocol(signalProtocol)
val password = MacosKeychainHelper.getPassword(appInfo.appName, appInfo.userName)

val secretKey = password?.let {
logger.info { "Found password in keychain by ${appInfo.appName} ${appInfo.userName}" }
stringToSecretKey(it)
} ?: run {
logger.info { "Generating new password in keychain by ${appInfo.appName} ${appInfo.userName}" }
val secretKey = generateAESKey()
MacosKeychainHelper.setPassword(appInfo.appName, appInfo.userName, secretKeyToString(secretKey))
secretKey
}

val encryptData = encryptData(secretKey, data)
filePersist.saveBytes(encryptData)
return SignalProtocolWithState(signalProtocol,
if (deleteOldSignalProtocol) CreateSignalProtocolState.DELETE_GENERATE
else CreateSignalProtocolState.NEW_GENERATE)
}
}


class WindowsSignalProtocolFactory : SignalProtocolFactory {

private val filePersist = DesktopOneFilePersist(getPathProvider().resolveUser("signal.data"))

override fun createSignalProtocol(): SignalProtocolWithState {
val file = filePersist.path.toFile()
var deleteOldSignalProtocol = false
if (file.exists()) {
logger.info { "Found signalProtocol encrypt file" }
filePersist.readBytes()?.let {
try {
val decryptData = WindowDapiHelper.decryptData(it)
decryptData?.let { byteArray ->
return SignalProtocolWithState(
readSignalProtocol(byteArray),
CreateSignalProtocolState.EXISTING
)
}
} catch (e: Exception) {
logger.error(e) { "Failed to decrypt signalProtocol" }
}
}
deleteOldSignalProtocol = true
if (file.delete()) {
logger.info { "Delete signalProtocol encrypt file" }
}
} else {
logger.info { "No found signalProtocol encrypt file" }
}

logger.info { "Creating new SignalProtocol" }
val signalProtocol = DesktopSignalProtocol()
val data = writeSignalProtocol(signalProtocol)
val encryptData = WindowDapiHelper.encryptData(data)
filePersist.saveBytes(encryptData!!)
return SignalProtocolWithState(signalProtocol,
if (deleteOldSignalProtocol) CreateSignalProtocolState.DELETE_GENERATE
else CreateSignalProtocolState.NEW_GENERATE)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ import java.nio.file.Path
import kotlin.reflect.KClass

@Suppress("UNCHECKED_CAST")
class DesktopOneFilePersist(private val path: Path) : OneFilePersist {
override fun <T: Any> readAs(clazz: KClass<T>): T? {
class DesktopOneFilePersist(val path: Path) : OneFilePersist {
override fun <T: Any> read(clazz: KClass<T>): T? {
val file = path.toFile()
return if (file.exists()) {
val serializer = Json.serializersModule.serializer(clazz.java)
Expand All @@ -17,6 +17,15 @@ class DesktopOneFilePersist(private val path: Path) : OneFilePersist {
}
}

override fun readBytes(): ByteArray? {
val file = path.toFile()
return if (file.exists()) {
file.readBytes()
} else {
null
}
}

override fun <T> save(config: T) {
val kClass = config!!::class
val serializer = Json.serializersModule.serializer(kClass.java)
Expand All @@ -25,4 +34,10 @@ class DesktopOneFilePersist(private val path: Path) : OneFilePersist {
file.parentFile?.mkdirs()
file.writeText(json)
}

override fun saveBytes(bytes: ByteArray) {
val file = path.toFile()
file.parentFile?.mkdirs()
file.writeBytes(bytes)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ import com.sun.jna.platform.win32.Crypt32Util.cryptUnprotectData

object WindowDapiHelper {

fun encryptString(data: ByteArray): ByteArray? {
fun encryptData(data: ByteArray): ByteArray? {
return cryptProtectData(data)
}

fun decryptString(encryptedData: ByteArray): ByteArray? {
fun decryptData(encryptedData: ByteArray): ByteArray? {
return cryptUnprotectData(encryptedData)
}
}
Loading