From 1f3cd1b6e4cae090524867c9ce165b75e944dce5 Mon Sep 17 00:00:00 2001 From: Jonathan Lennox Date: Mon, 21 Oct 2024 15:28:09 -0400 Subject: [PATCH] Revert "Remove blocking inside SocketPool." This reverts commit d2e792a0d6bbbeaac709e6fbeb64d6d2acaa4eae. --- .../kotlin/org/ice4j/socket/SocketPool.kt | 55 +++++++++++++++---- .../kotlin/org/ice4j/socket/SocketPoolTest.kt | 26 +++++---- 2 files changed, 58 insertions(+), 23 deletions(-) diff --git a/src/main/kotlin/org/ice4j/socket/SocketPool.kt b/src/main/kotlin/org/ice4j/socket/SocketPool.kt index 7a8a4922..9e5f181d 100644 --- a/src/main/kotlin/org/ice4j/socket/SocketPool.kt +++ b/src/main/kotlin/org/ice4j/socket/SocketPool.kt @@ -15,11 +15,12 @@ */ package org.ice4j.socket +import java.io.Closeable import java.net.DatagramSocket import java.net.DatagramSocketImpl import java.net.SocketAddress import java.nio.channels.DatagramChannel -import java.util.concurrent.atomic.AtomicLong +import java.util.concurrent.Semaphore /** A pool of datagram sockets all bound on the same port. * @@ -55,6 +56,8 @@ class SocketPool( Runtime.getRuntime().availableProcessors() } + private val semaphore = Semaphore(numSockets) + private val sockets = buildList { val multipleSockets = numSockets > 1 var bindAddr = address @@ -71,7 +74,7 @@ class SocketPool( } } - private val sockIndex = AtomicLong(0) + private val availableSockets = ArrayDeque(sockets) /** The socket on which packets will be received. */ val receiveSocket: DatagramSocket @@ -80,18 +83,48 @@ class SocketPool( // sockets, spreading load? get() = sockets.last() - /** Gets a socket on which packets can be sent, chosen from among all the available send sockets. */ - val sendSocket: DatagramSocket - get() { - if (numSockets == 1) { - return sockets.first() + interface SocketHolder : Closeable { + val socket: DatagramSocket + } + + /** Socket holder with autocloseable semantics, to ensure the socket is returned. */ + private inner class MySocketHolder : SocketHolder { + override val socket = acquireSendSocket() + private var closed = false + override fun close() { + if (!closed) { + returnSendSocket(socket) + closed = true } - return sockets[nextIndex()] } + } - private fun nextIndex(): Int { - val nextIdx = sockIndex.getAndIncrement() - return nextIdx.rem(numSockets).toInt() + /** Trivial socket holder that gives out a single unique socket. */ + private inner class TrivialSocketHolder : SocketHolder { + override val socket = sockets.first() + override fun close() {} + } + + /** Gets a send socket holder. Should be used with Kotlin [use] or Java try-with-resources. May block. */ + fun getSendSocket(): SocketHolder { + if (numSockets == 1) { + return TrivialSocketHolder() + } + return MySocketHolder() + } + + private fun acquireSendSocket(): DatagramSocket { + semaphore.acquire() + synchronized(availableSockets) { + return availableSockets.removeFirst() + } + } + + private fun returnSendSocket(socket: DatagramSocket) { + synchronized(availableSockets) { + availableSockets.addLast(socket) + } + semaphore.release() } fun close() { diff --git a/src/test/kotlin/org/ice4j/socket/SocketPoolTest.kt b/src/test/kotlin/org/ice4j/socket/SocketPoolTest.kt index abc4566f..729821ca 100644 --- a/src/test/kotlin/org/ice4j/socket/SocketPoolTest.kt +++ b/src/test/kotlin/org/ice4j/socket/SocketPoolTest.kt @@ -35,14 +35,13 @@ class SocketPoolTest : ShouldSpec() { context("Getting multiple send sockets from a pool") { val numSockets = 4 val pool = SocketPool(loopbackAny, numSockets) - val sockets = mutableListOf() + val holders = mutableListOf() should("be possible") { repeat(numSockets) { - sockets.add(pool.sendSocket) + holders.add(pool.getSendSocket()) } } - // All sockets should be distinct - sockets.toSet().size shouldBe sockets.size + holders.forEach { it.close() } pool.close() } @@ -50,15 +49,16 @@ class SocketPoolTest : ShouldSpec() { val numSockets = 4 val pool = SocketPool(loopbackAny, numSockets) val local = pool.receiveSocket.localSocketAddress - val sockets = mutableListOf() + val holders = mutableListOf() repeat(numSockets) { - sockets.add(pool.sendSocket) + holders.add(pool.getSendSocket()) } - sockets.forEachIndexed { i, it -> + holders.forEachIndexed { i, it -> val buf = i.toString().toByteArray() val packet = DatagramPacket(buf, buf.size, local) - it.send(packet) + it.socket.send(packet) } + holders.forEach { it.close() } should("be received") { for (i in 0 until numSockets) { @@ -70,7 +70,6 @@ class SocketPoolTest : ShouldSpec() { packet.socketAddress shouldBe local } } - pool.close() } context("The number of send sockets") { @@ -81,7 +80,9 @@ class SocketPoolTest : ShouldSpec() { repeat(2 * numSockets) { // This should cycle through all the available send sockets - sockets.add(pool.sendSocket) + pool.getSendSocket().use { holder -> + sockets.add(holder.socket) + } } should("be correct") { @@ -124,8 +125,9 @@ class SocketPoolTest : ShouldSpec() { private fun sendToSocket(count: Int) { for (i in 0 until count) { - val socket = pool.sendSocket - socket.send(DatagramPacket(buf, BUFFER_SIZE, destAddr)) + pool.getSendSocket().use { + it.socket.send(DatagramPacket(buf, BUFFER_SIZE, destAddr)) + } } }