diff --git a/src/main/java/org/ice4j/ice/harvest/AbstractUdpListener.java b/src/main/java/org/ice4j/ice/harvest/AbstractUdpListener.java index 88d7837a..fe9ca413 100644 --- a/src/main/java/org/ice4j/ice/harvest/AbstractUdpListener.java +++ b/src/main/java/org/ice4j/ice/harvest/AbstractUdpListener.java @@ -19,7 +19,6 @@ import org.ice4j.*; import org.ice4j.attribute.*; -import org.ice4j.ice.*; import org.ice4j.message.*; import org.ice4j.socket.*; import org.ice4j.util.*; @@ -37,7 +36,7 @@ import static org.ice4j.ice.harvest.HarvestConfig.config; /** - * A class which holds a {@link DatagramSocket} and runs a thread + * A class which holds a {@link SocketPool} and runs a thread * ({@link #thread}) which perpetually reads from it. * * When a datagram from an unknown source is received, it is parsed as a STUN @@ -196,13 +195,18 @@ static String getUfrag(byte[] buf, int off, int len) */ protected final TransportAddress localAddress; + /** + * The pool of sockets available for writing. + */ + private final SocketPool socketPool; + /** * The "main" socket that this harvester reads from. */ - private final DatagramSocket socket; + private final DatagramSocket receiveSocket; /** - * The thread reading from {@link #socket}. + * The thread reading from {@link #receiveSocket}. */ private final Thread thread; @@ -236,12 +240,14 @@ protected AbstractUdpListener(TransportAddress localAddress) ); } - socket = new DatagramSocket( tempAddress ); + socketPool = new SocketPool( tempAddress, config.udpSocketPoolSize() ); + + receiveSocket = socketPool.getReceiveSocket(); Integer receiveBufferSize = config.udpReceiveBufferSize(); if (receiveBufferSize != null) { - socket.setReceiveBufferSize(receiveBufferSize); + receiveSocket.setReceiveBufferSize(receiveBufferSize); } /* Update the port number if needed. */ @@ -249,7 +255,7 @@ protected AbstractUdpListener(TransportAddress localAddress) { tempAddress = new TransportAddress( tempAddress.getAddress(), - socket.getLocalPort(), + receiveSocket.getLocalPort(), tempAddress.getTransport() ); } @@ -257,11 +263,12 @@ protected AbstractUdpListener(TransportAddress localAddress) String logMessage = "Initialized AbstractUdpListener with address " + this.localAddress; - logMessage += ". Receive buffer size " + socket.getReceiveBufferSize(); + logMessage += ". Receive buffer size " + receiveSocket.getReceiveBufferSize(); if (receiveBufferSize != null) { logMessage += " (asked for " + receiveBufferSize + ")"; } + logMessage += "; socket pool size " + socketPool.getNumSockets(); logger.info(logMessage); thread = new Thread(() -> @@ -292,11 +299,11 @@ public TransportAddress getLocalAddress() public void close() { close = true; - socket.close(); // causes socket#receive to stop blocking. + socketPool.close(); // causes socket#receive to stop blocking. } /** - * Perpetually reads datagrams from {@link #socket} and handles them + * Perpetually reads datagrams from {@link #receiveSocket} and handles them * accordingly. * * It is important that this blocks are little as possible (except on @@ -326,7 +333,7 @@ private void runInHarvesterThread() try { - socket.receive(pkt); + receiveSocket.receive(pkt); } catch (IOException ioe) { @@ -376,13 +383,13 @@ private void runInHarvesterThread() { candidateSocket.close(); } - socket.close(); + socketPool.close(); } /** * Read packets from the socket and forward them via the push API. Note that the memory model here is different * than the other case. Specifically, we: - * 1. Receive from {@link #socket} into a fixed buffer + * 1. Receive from {@link #receiveSocket} into a fixed buffer * 2. Obtain a buffer of the required size using {@link BufferPool#getBuffer} * 3. Copy the data into the buffer and either * 3.1 Call the associated {@link BufferHandler} if the packet is payload @@ -410,7 +417,7 @@ private void runInHarvesterThreadPush() try { - socket.receive(pkt); + receiveSocket.receive(pkt); receivedTime = clock.instant(); } catch (IOException ioe) @@ -467,7 +474,7 @@ private void runInHarvesterThreadPush() { candidateSocket.close(); } - socket.close(); + socketPool.close(); } private Buffer bufferFromPacket(DatagramPacket p, Instant receivedTime) @@ -478,7 +485,7 @@ private Buffer bufferFromPacket(DatagramPacket p, Instant receivedTime) System.arraycopy(p.getData(), p.getOffset(), buffer.getBuffer(), off, p.getLength()); buffer.setOffset(off); buffer.setLength(p.getLength()); - buffer.setLocalAddress(socket.getLocalSocketAddress()); + buffer.setLocalAddress(receiveSocket.getLocalSocketAddress()); buffer.setRemoteAddress(p.getSocketAddress()); buffer.setReceivedTime(receivedTime); @@ -808,14 +815,14 @@ public void receive(DatagramPacket p) /** * {@inheritDoc} * - * Delegates to the actual socket of the harvester. + * Delegates to the socket pool. */ @Override public void send(DatagramPacket p) throws IOException { p.setSocketAddress(remoteAddress); - socket.send(p); + socketPool.send(p); } } } diff --git a/src/main/kotlin/org/ice4j/ice/harvest/HarvestConfig.kt b/src/main/kotlin/org/ice4j/ice/harvest/HarvestConfig.kt index 5c554f18..6ac60132 100644 --- a/src/main/kotlin/org/ice4j/ice/harvest/HarvestConfig.kt +++ b/src/main/kotlin/org/ice4j/ice/harvest/HarvestConfig.kt @@ -41,6 +41,12 @@ class HarvestConfig { } fun udpReceiveBufferSize() = udpReceiveBufferSize + val udpSocketPoolSize: Int by config { + "ice4j.harvest.udp.socket-pool-size".from(configSource) + } + + fun udpSocketPoolSize() = udpSocketPoolSize + val useIpv6: Boolean by config { "org.ice4j.ipv6.DISABLED".from(configSource) .transformedBy { !it } diff --git a/src/main/kotlin/org/ice4j/socket/SocketPool.kt b/src/main/kotlin/org/ice4j/socket/SocketPool.kt new file mode 100644 index 00000000..624e5b23 --- /dev/null +++ b/src/main/kotlin/org/ice4j/socket/SocketPool.kt @@ -0,0 +1,114 @@ +/* + * Copyright @ 2020 - Present, 8x8 Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.ice4j.socket + +import java.net.DatagramPacket +import java.net.DatagramSocket +import java.net.DatagramSocketImpl +import java.net.SocketAddress +import java.nio.channels.DatagramChannel + +/** A pool of datagram sockets all bound on the same port. + * + * This is necessary to allow multiple threads to send packets simultaneously from the same source address, + * in JDK 15 and later, because the [DatagramChannel]-based implementation of [DatagramSocketImpl] introduced + * in that version locks the socket during a call to [DatagramSocket.send]. + * + * (The old [DatagramSocketImpl] implementation can be used by setting the system property + * `jdk.net.usePlainDatagramSocketImpl` in JDK versions 15 through 17, but was removed in versions 18 and later.) + * + * This feature may also be useful on older JDK versions on non-Linux operating systems, such as macOS, + * which block simultaneous writes through the same UDP socket at the operating system level. + * + * The sockets are opened such that packets will be _received_ on exactly one socket. + */ +class SocketPool( + /** The address to which to bind the pool of sockets. */ + address: SocketAddress, + /** The number of sockets to create for the pool. If this is set to zero (the default), the number + * will be set automatically to an appropriate value. + */ + requestedNumSockets: Int = 0 +) { + init { + require(requestedNumSockets >= 0) { "RequestedNumSockets must be >= 0" } + } + + internal class SocketAndIndex( + val socket: DatagramSocket, + var count: Int = 0 + ) + + val numSockets: Int = + if (requestedNumSockets != 0) { + requestedNumSockets + } else { + // TODO: set this to 1 in situations where pools aren't needed? + Runtime.getRuntime().availableProcessors() + } + + private val sockets = buildList { + val multipleSockets = numSockets > 1 + var bindAddr = address + for (i in 0 until numSockets) { + val sock = DatagramSocket(null) + if (multipleSockets) { + sock.reuseAddress = true + } + sock.bind(bindAddr) + if (i == 0 && multipleSockets) { + bindAddr = sock.localSocketAddress + } + add(SocketAndIndex(sock, 0)) + } + } + + /** The socket on which packets will be received. */ + val receiveSocket: DatagramSocket + // On all platforms I've tested, the last-bound socket is the one which receives packets. + // TODO: should we support Linux's flavor of SO_REUSEPORT, in which packets can be received on *all* the + // sockets, spreading load? + get() = sockets.last().socket + + fun send(packet: DatagramPacket) { + val sendSocket = getSendSocket() + sendSocket.socket.send(packet) + returnSocket(sendSocket) + } + + /** Gets a socket on which packets can be sent, chosen from among all the available send sockets. */ + internal fun getSendSocket(): SocketAndIndex { + if (numSockets == 1) { + return sockets.first() + } + synchronized(sockets) { + val min = sockets.minBy { it.count } + min.count++ + + return min + } + } + + internal fun returnSocket(socket: SocketAndIndex) { + synchronized(sockets) { + socket.count-- + } + } + + fun close() { + sockets.forEach { it.socket.close() } + } +} diff --git a/src/main/resources/reference.conf b/src/main/resources/reference.conf index 09955622..9906d4cb 100644 --- a/src/main/resources/reference.conf +++ b/src/main/resources/reference.conf @@ -57,6 +57,10 @@ ice4j { // Whether to allocate ephemeral ports for local candidates. This is the default value, and can be overridden // for Agent instances. use-dynamic-ports = true + + // The size of the socket pool to use to send packets on the "single port" harvester. 0 means the + // default (Java's reported number of available processors). 1 is equivalent to not using a socket pool. + socket-pool-size = 0 } // The list of IP addresses that are allowed to be used for host candidate allocations. When empty, any address is diff --git a/src/test/java/org/ice4j/ice/harvest/SinglePortUdpHarvesterTest.java b/src/test/java/org/ice4j/ice/harvest/SinglePortUdpHarvesterTest.java index 3535e227..d6f8aeb5 100644 --- a/src/test/java/org/ice4j/ice/harvest/SinglePortUdpHarvesterTest.java +++ b/src/test/java/org/ice4j/ice/harvest/SinglePortUdpHarvesterTest.java @@ -31,51 +31,6 @@ */ public class SinglePortUdpHarvesterTest { - /** - * Verifies that, without closing, the address used by a harvester cannot be re-used. - * - * @see https://github.com/jitsi/ice4j/issues/139 - */ - @Test - public void testRebindWithoutCloseThrows() throws Exception - { - // Setup test fixture. - final TransportAddress address = new TransportAddress( "127.0.0.1", 10000, Transport.UDP ); - SinglePortUdpHarvester firstHarvester; - try - { - firstHarvester = new SinglePortUdpHarvester( address ); - } - catch (BindException ex) - { - // This is not expected at this stage (the port is likely already in use by another process, voiding this - // test). Rethrow as a different exception than the BindException, that is expected to be thrown later in - // this test. - throw new Exception( "Test fixture is invalid.", ex ); - } - - // Execute system under test. - SinglePortUdpHarvester secondHarvester = null; - try - { - secondHarvester = new SinglePortUdpHarvester( address ); - fail("expected BindException to be thrown at this point"); - } - catch (BindException ex) - { - //expected, do nothing - } - finally - { - // Tear down - firstHarvester.close(); - if (secondHarvester != null) - { - secondHarvester.close(); - } - } - } - /** * Verifies that, after closing, the address used by a harvester can be re-used. * diff --git a/src/test/kotlin/org/ice4j/ice/harvest/HarvestConfigTest.kt b/src/test/kotlin/org/ice4j/ice/harvest/HarvestConfigTest.kt index 0fbf396a..8ed65d26 100644 --- a/src/test/kotlin/org/ice4j/ice/harvest/HarvestConfigTest.kt +++ b/src/test/kotlin/org/ice4j/ice/harvest/HarvestConfigTest.kt @@ -30,6 +30,7 @@ class HarvestConfigTest : ConfigTest() { config.useIpv6 shouldBe true config.useLinkLocalAddresses shouldBe true config.udpReceiveBufferSize shouldBe null + config.udpSocketPoolSize shouldBe 0 config.stunMappingCandidateHarvesterAddresses shouldBe emptyList() } context("Setting via legacy config (system properties)") { @@ -39,6 +40,7 @@ class HarvestConfigTest : ConfigTest() { config.useIpv6 shouldBe false config.useLinkLocalAddresses shouldBe false config.udpReceiveBufferSize shouldBe 555 + config.udpSocketPoolSize shouldBe 0 config.stunMappingCandidateHarvesterAddresses shouldBe listOf("stun1.legacy:555", "stun2.legacy") } } @@ -49,6 +51,7 @@ class HarvestConfigTest : ConfigTest() { config.useIpv6 shouldBe false config.useLinkLocalAddresses shouldBe false config.udpReceiveBufferSize shouldBe 666 + config.udpSocketPoolSize shouldBe 3 config.stunMappingCandidateHarvesterAddresses shouldBe listOf("stun1.new:666", "stun2.new") } } @@ -60,6 +63,7 @@ class HarvestConfigTest : ConfigTest() { config.useIpv6 shouldBe false config.useLinkLocalAddresses shouldBe false config.udpReceiveBufferSize shouldBe 555 + config.udpSocketPoolSize shouldBe 0 config.stunMappingCandidateHarvesterAddresses shouldBe listOf("stun1.legacy:555", "stun2.legacy") } } @@ -153,6 +157,7 @@ private val newConfigNonDefault = """ udp { receive-buffer-size = 666 use-dynamic-ports = false + socket-pool-size = 3 } mapping { stun { diff --git a/src/test/kotlin/org/ice4j/socket/SocketPoolTest.kt b/src/test/kotlin/org/ice4j/socket/SocketPoolTest.kt new file mode 100644 index 00000000..b82bcb4f --- /dev/null +++ b/src/test/kotlin/org/ice4j/socket/SocketPoolTest.kt @@ -0,0 +1,266 @@ +package org.ice4j.socket + +import io.kotest.core.spec.style.ShouldSpec +import io.kotest.core.test.Enabled +import io.kotest.core.test.TestCase +import io.kotest.matchers.comparables.shouldBeLessThan +import io.kotest.matchers.should +import io.kotest.matchers.shouldBe +import io.kotest.matchers.shouldNotBe +import io.kotest.matchers.types.beInstanceOf +import java.net.DatagramPacket +import java.net.DatagramSocket +import java.net.InetSocketAddress +import java.net.SocketAddress +import java.time.Clock +import java.time.Duration +import java.time.Instant +import java.util.concurrent.CyclicBarrier + +private val loopbackAny = InetSocketAddress("127.0.0.1", 0) +private val loopbackDiscard = InetSocketAddress("127.0.0.1", 9) + +@OptIn(io.kotest.common.ExperimentalKotest::class) +class SocketPoolTest : ShouldSpec() { + init { + context("Creating a new socket pool") { + val pool = SocketPool(loopbackAny) + should("Bind to a random port") { + val local = pool.receiveSocket.localSocketAddress + local should beInstanceOf() + (local as InetSocketAddress).port shouldNotBe 0 + } + pool.close() + } + + context("Getting multiple send sockets from a pool") { + val numSockets = 4 + val pool = SocketPool(loopbackAny, numSockets) + val sockets = mutableListOf() + should("be possible") { + repeat(numSockets) { + sockets.add(pool.getSendSocket().socket) + } + } + // All sockets should be distinct + sockets.toSet().size shouldBe sockets.size + pool.close() + } + + context("Packets sent from each of the send sockets in the pool") { + val numSockets = 4 + val pool = SocketPool(loopbackAny, numSockets) + val local = pool.receiveSocket.localSocketAddress + val sockets = mutableListOf() + repeat(numSockets) { + sockets.add(pool.getSendSocket().socket) + } + sockets.forEachIndexed { i, it -> + val buf = i.toString().toByteArray() + val packet = DatagramPacket(buf, buf.size, local) + it.send(packet) + } + + should("be received") { + for (i in 0 until numSockets) { + val buf = ByteArray(1500) + val packet = DatagramPacket(buf, buf.size) + pool.receiveSocket.soTimeout = 1 // Don't block if something's wrong + pool.receiveSocket.receive(packet) + packet.data.decodeToString(0, packet.length).toInt() shouldBe i + packet.socketAddress shouldBe local + } + } + pool.close() + } + + context("The number of send sockets") { + val numSockets = 4 + val pool = SocketPool(loopbackAny, numSockets) + + val sockets = mutableSetOf() + + repeat(2 * numSockets) { + // This should cycle through all the available send sockets + sockets.add(pool.getSendSocket().socket) + } + + should("be correct") { + sockets.size shouldBe numSockets + } + + pool.close() + } + + val disableIfOnlyOneCore: (TestCase) -> Enabled = { + if (Runtime.getRuntime().availableProcessors() > 1) { + Enabled.enabled + } else { + Enabled.disabled("Need multiple processors to run test") + } + } + + context("Sending packets from multiple threads").config(enabledOrReasonIf = disableIfOnlyOneCore) { + val poolWarmup = SocketPool(loopbackAny, 1) + sendTimeOnAllSockets(poolWarmup) + + val pool1 = SocketPool(loopbackAny, 1) + val elapsed1 = sendTimeOnAllSockets(pool1) + + // 0 means pick the default value, currently Runtime.getRuntime().availableProcessors(). + val poolN = SocketPool(loopbackAny, 0) + val elapsedN = sendTimeOnAllSockets(poolN) + + elapsedN shouldBeLessThan elapsed1 // Very weak test + } + + val enableOnlyIfPropertySet: (TestCase) -> Enabled = { + if (System.getProperty("doPerfTests") != null) { + Enabled.enabled + } else { + Enabled.disabled("Set \"doPerfTests\" property to enable SocketPool performance tests") + } + } + + context("Test sending packets from multiple threads").config(enabledOrReasonIf = enableOnlyIfPropertySet) { + testSending() + } + } + private class Sender( + private val count: Int, + private val pool: SocketPool, + private val destAddr: SocketAddress + ) : Runnable { + private val buf = ByteArray(BUFFER_SIZE) + + private fun sendToSocket(count: Int) { + for (i in 0 until count) { + pool.send(DatagramPacket(buf, BUFFER_SIZE, destAddr)) + } + } + + override fun run() { + barrier.await() + + start() + sendToSocket(count) + end() + } + + companion object { + private const val BUFFER_SIZE = 1500 + const val NUM_PACKETS = 600000 + private val clock = Clock.systemUTC() + + private var start = Instant.MAX + private var end = Instant.MIN + + val elapsed: Duration + get() = Duration.between(start, end) + + fun start() { + val now = clock.instant() + synchronized(this) { + if (start.isAfter(now)) { + start = now + } + } + } + + fun end() { + val now = clock.instant() + synchronized(this) { + if (end.isBefore(now)) { + end = now + } + } + } + + private var barrier: CyclicBarrier = CyclicBarrier(1) + + fun reset(numThreads: Int) { + barrier = CyclicBarrier(numThreads) + start = Instant.MAX + end = Instant.MIN + } + } + } + + companion object { + private fun sendTimeOnAllSockets( + pool: SocketPool, + numThreads: Int = pool.numSockets, + numPackets: Int = Sender.NUM_PACKETS + ): Duration { + val threads = mutableListOf() + Sender.reset(numThreads) + repeat(numThreads) { + val thread = Thread(Sender(numPackets / numThreads, pool, loopbackDiscard)) + threads.add(thread) + thread.start() + } + threads.forEach { it.join() } + return Sender.elapsed + } + + private fun testSendingOnce( + numSockets: Int, + numThreads: Int, + numPackets: Int = Sender.NUM_PACKETS, + warmup: Boolean = false + ) { + val pool = SocketPool(loopbackAny, numSockets) + val elapsed = sendTimeOnAllSockets(pool, numThreads, numPackets) + if (!warmup) { + println( + "Send $numPackets packets on $numSockets sockets on $numThreads threads " + + "took $elapsed" + ) + } + } + + fun testSending() { + val numProcessors = Runtime.getRuntime().availableProcessors() + + testSendingOnce(1, 1, warmup = true) + testSendingOnce(2 * numProcessors, 2 * numProcessors, warmup = true) + + testSendingOnce(1, 1) + testSendingOnce(1, numProcessors) + testSendingOnce(1, 2 * numProcessors) + testSendingOnce(1, 4 * numProcessors) + testSendingOnce(1, 8 * numProcessors) + + testSendingOnce(numProcessors, numProcessors) + testSendingOnce(numProcessors, 2 * numProcessors) + testSendingOnce(numProcessors, 4 * numProcessors) + testSendingOnce(numProcessors, 8 * numProcessors) + + testSendingOnce(2 * numProcessors, 2 * numProcessors) + testSendingOnce(2 * numProcessors, 4 * numProcessors) + testSendingOnce(2 * numProcessors, 8 * numProcessors) + + testSendingOnce(4 * numProcessors, 4 * numProcessors) + testSendingOnce(4 * numProcessors, 8 * numProcessors) + + testSendingOnce(8 * numProcessors, 8 * numProcessors) + } + + @JvmStatic + fun main(args: Array) { + if (args.size >= 2) { + val numSockets = args[0].toInt() + val numThreads = args[1].toInt() + val numPackets = if (args.size > 2) { + args[2].toInt() + } else { + Sender.NUM_PACKETS + } + testSendingOnce(numThreads = numThreads, numSockets = numSockets, numPackets = 10000, warmup = true) + testSendingOnce(numThreads = numThreads, numSockets = numSockets, numPackets = numPackets) + } else { + testSending() + } + } + } +}