Skip to content

Commit

Permalink
Add a way to enforce DTLS session expiration (#60)
Browse files Browse the repository at this point in the history
* Add a way to enforce DTLS session expiration

* Fix linting

* Address comments

* Apply suggestion
  • Loading branch information
akolosov-n authored Aug 14, 2024
1 parent b1558de commit 823b4b1
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 15 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023 kotlin-mbedtls contributors (https://github.com/open-coap/kotlin-mbedtls)
* Copyright (c) 2022-2024 kotlin-mbedtls contributors (https://github.com/open-coap/kotlin-mbedtls)
* SPDX-License-Identifier: Apache-2.0
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -25,7 +25,7 @@ import java.net.InetSocketAddress
class DatagramPacketWithContext(
data: ByteBuf,
recipient: InetSocketAddress?,
sender: InetSocketAddress,
sender: InetSocketAddress?,
val sessionContext: DtlsSessionContext
) : DatagramPacket(data, recipient, sender) {

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023 kotlin-mbedtls contributors (https://github.com/open-coap/kotlin-mbedtls)
* Copyright (c) 2022-2024 kotlin-mbedtls contributors (https://github.com/open-coap/kotlin-mbedtls)
* SPDX-License-Identifier: Apache-2.0
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -90,6 +90,14 @@ class DtlsChannelHandler @JvmOverloads constructor(

override fun write(ctx: ChannelHandlerContext, msg: Any, promise: ChannelPromise) {
when (msg) {
is DatagramPacketWithContext -> {
write(msg, promise, ctx)
if (msg.sessionContext.sessionExpirationHint) {
promise.toCompletableFuture().thenAccept {
dtlsServer.closeSession(msg.recipient())
}
}
}
is DatagramPacket -> write(msg, promise, ctx)
is SessionAuthenticationContext -> {
msg.map.forEach { (key, value) ->
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023 kotlin-mbedtls contributors (https://github.com/open-coap/kotlin-mbedtls)
* Copyright (c) 2022-2024 kotlin-mbedtls contributors (https://github.com/open-coap/kotlin-mbedtls)
* SPDX-License-Identifier: Apache-2.0
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -30,14 +30,21 @@ class EchoHandler : ChannelInboundHandlerAdapter() {

val sessionContext = DatagramPacketWithContext.contextFrom(msg)
val authContext = (sessionContext.authenticationContext["AUTH"] ?: "")
val dgramContent = dgram.content().toByteArray()
val goToSleep = dgramContent.toString(Charset.defaultCharset()).endsWith(":sleep")

val reply = ctx.alloc().buffer(dgram.content().readableBytes() + 20)
val reply = ctx.alloc().buffer(dgramContent.size + 20)
reply.writeBytes(echoPrefix)
reply.writeCharSequence(authContext, Charset.defaultCharset())
reply.writeBytes(dgram.content())
reply.writeBytes(dgramContent)

dgram.release()

ctx.writeAndFlush(DatagramPacket(reply, dgram.sender()))
ctx.writeAndFlush(
DatagramPacketWithContext(
reply,
dgram.sender(),
null,
sessionContext.copy(sessionExpirationHint = goToSleep)
)
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import io.netty.channel.socket.DatagramChannel
import io.netty.channel.socket.DatagramPacket
import io.netty.util.concurrent.DefaultThreadFactory
import org.assertj.core.api.Assertions.assertThatThrownBy
import org.awaitility.kotlin.await
import org.junit.jupiter.api.AfterAll
import org.junit.jupiter.api.AfterEach
import org.junit.jupiter.api.Assertions.assertEquals
Expand Down Expand Up @@ -224,4 +225,29 @@ class NettyTest {
assertEquals(0, dtlsServer.numberOfSessions)
client.close()
}

@Test
fun `server should store session if hinted to do so`() {
val client = NettyTransportAdapter.connect(clientConf, srvAddress).mapToString()

// when normal packet is sent
assertTrue(client.send("hi").await())
assertEquals("ECHO:hi", client.receive(5.seconds).await())

// then session should not be stored
assertEquals(1, dtlsServer.numberOfSessions)
assertEquals(0, sessionStore.size())

// when a packet with session expiration hint is sent
assertTrue(client.send("hi:sleep").await())
assertEquals("ECHO:hi:sleep", client.receive(5.seconds).await())

// then session must be stored
await.atMost(5.seconds).untilAsserted {
assertEquals(0, dtlsServer.numberOfSessions)
assertEquals(1, sessionStore.size())
}

client.close()
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023 kotlin-mbedtls contributors (https://github.com/open-coap/kotlin-mbedtls)
* Copyright (c) 2022-2024 kotlin-mbedtls contributors (https://github.com/open-coap/kotlin-mbedtls)
* SPDX-License-Identifier: Apache-2.0
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -118,6 +118,10 @@ class DtlsServer(
}
}

fun closeSession(addr: InetSocketAddress) {
sessions.remove(addr)?.storeAndClose()
}

fun loadSession(sessBuf: SessionWithContext?, adr: InetSocketAddress, cid: ByteArray): Boolean {
return try {
if (sessBuf == null) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023 kotlin-mbedtls contributors (https://github.com/open-coap/kotlin-mbedtls)
* Copyright (c) 2022-2024 kotlin-mbedtls contributors (https://github.com/open-coap/kotlin-mbedtls)
* SPDX-License-Identifier: Apache-2.0
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -95,8 +95,16 @@ class DtlsServerTransport private constructor(
override fun send(packet: Packet<ByteBuffer>): CompletableFuture<Boolean> = executor.supply {
val encPacket = dtlsServer.encrypt(packet.buffer, packet.peerAddress)?.let(packet::map)

when (encPacket) {
null -> completedFuture(false)
when {
encPacket == null -> completedFuture(false)
packet.sessionContext.sessionExpirationHint -> {
transport.send(encPacket).thenApply { isSuccess ->
if (isSuccess) {
dtlsServer.closeSession(packet.peerAddress)
}
isSuccess
}
}
else -> transport.send(encPacket)
}
}.thenCompose(Function.identity())
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023 kotlin-mbedtls contributors (https://github.com/open-coap/kotlin-mbedtls)
* Copyright (c) 2022-2024 kotlin-mbedtls contributors (https://github.com/open-coap/kotlin-mbedtls)
* SPDX-License-Identifier: Apache-2.0
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -24,7 +24,8 @@ data class DtlsSessionContext @JvmOverloads constructor(
val authenticationContext: AuthenticationContext = emptyMap(),
val peerCertificateSubject: String? = null,
val cid: ByteArray? = null,
val sessionStartTimestamp: Instant? = null
val sessionStartTimestamp: Instant? = null,
val sessionExpirationHint: Boolean = false
) {
companion object {
@JvmField
Expand All @@ -44,6 +45,7 @@ data class DtlsSessionContext @JvmOverloads constructor(
if (!cid.contentEquals(other.cid)) return false
} else if (other.cid != null) return false
if (sessionStartTimestamp != other.sessionStartTimestamp) return false
if (sessionExpirationHint != other.sessionExpirationHint) return false

return true
}
Expand All @@ -53,6 +55,7 @@ data class DtlsSessionContext @JvmOverloads constructor(
result = 31 * result + (peerCertificateSubject?.hashCode() ?: 0)
result = 31 * result + (cid?.contentHashCode() ?: 0)
result = 31 * result + (sessionStartTimestamp?.hashCode() ?: 0)
result = 31 * result + (sessionExpirationHint.hashCode())
return result
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,29 @@ class DtlsServerTransportTest {
client.close()
}

@Test
fun `server should store session if hinted to do so`() {
// given
server = DtlsServerTransport.create(conf, sessionStore = sessionStore)
val serverReceived = server.receive(1.seconds)
val client = DtlsTransmitter.connect(server, clientConfig).await().mapToString()

client.send("dupa")
server.send(Packet("dupa".toByteBuffer(), serverReceived.await().peerAddress))
assertEquals("dupa", client.receive(1.seconds).await())

client.send("sleep")
server.send(Packet("sleep".toByteBuffer(), serverReceived.await().peerAddress, sessionContext = DtlsSessionContext(sessionExpirationHint = true)))
assertEquals("sleep", client.receive(1.seconds).await())

await.atMost(5.seconds).untilAsserted {
assertEquals(1, sessionStore.size())
assertEquals(0, server.numberOfSessions())
}

client.close()
}

private fun <T> Transport<T>.dropReceive(drop: (Int) -> Boolean): Transport<T> {
val underlying = this
var i = 0
Expand Down

0 comments on commit 823b4b1

Please sign in to comment.