From a90e1698690b22f3443adb8dba6e2b0a9d414a79 Mon Sep 17 00:00:00 2001 From: Stefan Bratanov Date: Wed, 27 Sep 2023 22:42:19 +0800 Subject: [PATCH 01/13] [Yamux] Don't send frame if send buffer is not empty (#332) --- .../io/libp2p/etc/types/ByteArrayExt.kt | 2 +- .../io/libp2p/mux/yamux/YamuxHandler.kt | 21 +++++--- .../io/libp2p/mux/yamux/YamuxHandlerTest.kt | 50 +++++++++++++++++++ 3 files changed, 66 insertions(+), 7 deletions(-) diff --git a/libp2p/src/main/kotlin/io/libp2p/etc/types/ByteArrayExt.kt b/libp2p/src/main/kotlin/io/libp2p/etc/types/ByteArrayExt.kt index cd1515c6e..8bcf79edc 100644 --- a/libp2p/src/main/kotlin/io/libp2p/etc/types/ByteArrayExt.kt +++ b/libp2p/src/main/kotlin/io/libp2p/etc/types/ByteArrayExt.kt @@ -13,7 +13,7 @@ fun String.fromHex() = operator fun ByteArray.compareTo(other: ByteArray): Int { if (size != other.size) return size - other.size - for (i in 0 until size) { + for (i in indices) { if (this[i] != other[i]) return this[i].toInt().and(0xFF) - other[i].toInt().and(0xFF) } return 0 diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt index 79fc1049f..6bf1cfa24 100644 --- a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt +++ b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt @@ -46,12 +46,8 @@ open class YamuxHandler( bufferedData.add(data) } - fun bufferedBytes(): Int { - return bufferedData.sumOf { it.readableBytes() } - } - fun flush(windowSize: AtomicInteger) { - while (!bufferedData.isEmpty() && windowSize.get() > 0) { + while (!isEmpty() && windowSize.get() > 0) { val data = bufferedData.removeFirst() val length = data.readableBytes() windowSize.addAndGet(-length) @@ -60,6 +56,14 @@ open class YamuxHandler( } } + fun bufferedBytes(): Int { + return bufferedData.sumOf { it.readableBytes() } + } + + fun isEmpty(): Boolean { + return bufferedData.isEmpty() + } + fun close() { bufferedData.forEach { releaseMessage(it) } bufferedData.clear() @@ -186,7 +190,7 @@ open class YamuxHandler( ) { data.sliceMaxSize(maxFrameDataLength) .forEach { slicedData -> - if (windowSize.get() > 0) { + if (windowSize.get() > 0 && sendBufferIsEmpty(child.id)) { val length = slicedData.readableBytes() windowSize.addAndGet(-length) val frame = YamuxFrame(child.id, YamuxType.DATA, 0, length.toLong(), slicedData) @@ -212,6 +216,11 @@ open class YamuxHandler( } } + private fun sendBufferIsEmpty(id: MuxId): Boolean { + val sendBuffer = sendBuffers[id] ?: return true + return sendBuffer.isEmpty() + } + override fun onLocalOpen(child: MuxChannel) { onStreamCreate(child.id) getChannelHandlerContext().writeAndFlush(YamuxFrame(child.id, YamuxType.DATA, YamuxFlags.SYN, 0)) diff --git a/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt b/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt index f6d203929..7f66d1c6f 100644 --- a/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt @@ -235,6 +235,56 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { assertThat(frame.flags).isEqualTo(YamuxFlags.RST) } + @Test + fun `frames are sent in order when send buffer is used`() { + val handler = openStreamLocal() + val streamId = readFrameOrThrow().streamId + + val createMessage: (String) -> ByteBuf = + { it.toByteArray().toByteBuf(allocateBuf()) } + + val sendWindowUpdate: (Long) -> Unit = { + ech.writeInbound( + YamuxFrame( + streamId.toMuxId(), + YamuxType.WINDOW_UPDATE, + YamuxFlags.ACK, + it + ) + ) + } + + // approximately every 5 messages window size will be depleted + val messagesToSend = 500 + val customWindowSize = 14 + sendWindowUpdate(-INITIAL_WINDOW_SIZE.toLong() + customWindowSize) + + val range = 1..messagesToSend + + // 100 window updates should be sent to ensure buffer is flushed and all messages are sent so will send them at random times + val windowUpdatesIndices = (range).shuffled().take(100).toSet() + + for (i in range) { + if (i in windowUpdatesIndices) { + sendWindowUpdate(customWindowSize.toLong()) + } + handler.ctx.writeAndFlush(createMessage(i.toString())) + } + + // verify the order of messages + for (i in range) { + val frame = readYamuxFrame() + assertThat(frame).overridingErrorMessage( + "Expected to send %s messages but it sent only %s", + messagesToSend, + messagesToSend - i + ).isNotNull() + assertThat(frame!!.data).isNotNull() + val data = String(frame.data!!.readAllBytesAndRelease()) + assertThat(data).isEqualTo(i.toString()) + } + } + @Test fun `test ping`() { val id: Long = YamuxId.SESSION_STREAM_ID From 48a66ab7c994280162bea5f63887c1792d6e3d72 Mon Sep 17 00:00:00 2001 From: Stefan Bratanov Date: Thu, 28 Sep 2023 12:47:50 +0800 Subject: [PATCH 02/13] Fix unit test consistency (#333) --- .../test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt b/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt index 7f66d1c6f..098df2f47 100644 --- a/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt @@ -261,8 +261,11 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { val range = 1..messagesToSend - // 100 window updates should be sent to ensure buffer is flushed and all messages are sent so will send them at random times - val windowUpdatesIndices = (range).shuffled().take(100).toSet() + // 100 window updates should be sent to ensure buffer is flushed and all messages are sent + // so will send them at random times ensuring maxBufferedConnectionWrites can never be reached + val windowUpdatesIndices = (range).chunked(100).flatMap { + it.shuffled().take(20) + } for (i in range) { if (i in windowUpdatesIndices) { From b902c1e5a3c25ebe0d33a650802efd811c84230b Mon Sep 17 00:00:00 2001 From: Anton Nashatyrev Date: Thu, 5 Oct 2023 20:17:04 +0300 Subject: [PATCH 03/13] Fix .gitattributes (#335) * Change the eol attribute to CRLF for *.bat files * Treat *.png as binary (I believe these are the only binary files at the moment) * For some reason the gradle.bat needs to be recommitted (no tool detect any changes in it) --- .gitattributes | 4 +- gradlew.bat | 178 ++++++++++++++++++++++++------------------------- 2 files changed, 92 insertions(+), 90 deletions(-) diff --git a/.gitattributes b/.gitattributes index 07764a78d..e63673cab 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1 +1,3 @@ -* text eol=lf \ No newline at end of file +* text eol=lf +*.bat text eol=crlf +*.png binary diff --git a/gradlew.bat b/gradlew.bat index ac1b06f93..107acd32c 100644 --- a/gradlew.bat +++ b/gradlew.bat @@ -1,89 +1,89 @@ -@rem -@rem Copyright 2015 the original author or authors. -@rem -@rem Licensed under the Apache License, Version 2.0 (the "License"); -@rem you may not use this file except in compliance with the License. -@rem You may obtain a copy of the License at -@rem -@rem https://www.apache.org/licenses/LICENSE-2.0 -@rem -@rem Unless required by applicable law or agreed to in writing, software -@rem distributed under the License is distributed on an "AS IS" BASIS, -@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -@rem See the License for the specific language governing permissions and -@rem limitations under the License. -@rem - -@if "%DEBUG%" == "" @echo off -@rem ########################################################################## -@rem -@rem Gradle startup script for Windows -@rem -@rem ########################################################################## - -@rem Set local scope for the variables with windows NT shell -if "%OS%"=="Windows_NT" setlocal - -set DIRNAME=%~dp0 -if "%DIRNAME%" == "" set DIRNAME=. -set APP_BASE_NAME=%~n0 -set APP_HOME=%DIRNAME% - -@rem Resolve any "." and ".." in APP_HOME to make it shorter. -for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi - -@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. -set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" - -@rem Find java.exe -if defined JAVA_HOME goto findJavaFromJavaHome - -set JAVA_EXE=java.exe -%JAVA_EXE% -version >NUL 2>&1 -if "%ERRORLEVEL%" == "0" goto execute - -echo. -echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. -echo. -echo Please set the JAVA_HOME variable in your environment to match the -echo location of your Java installation. - -goto fail - -:findJavaFromJavaHome -set JAVA_HOME=%JAVA_HOME:"=% -set JAVA_EXE=%JAVA_HOME%/bin/java.exe - -if exist "%JAVA_EXE%" goto execute - -echo. -echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% -echo. -echo Please set the JAVA_HOME variable in your environment to match the -echo location of your Java installation. - -goto fail - -:execute -@rem Setup the command line - -set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar - - -@rem Execute Gradle -"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %* - -:end -@rem End local scope for the variables with windows NT shell -if "%ERRORLEVEL%"=="0" goto mainEnd - -:fail -rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of -rem the _cmd.exe /c_ return code! -if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 -exit /b 1 - -:mainEnd -if "%OS%"=="Windows_NT" endlocal - -:omega +@rem +@rem Copyright 2015 the original author or authors. +@rem +@rem Licensed under the Apache License, Version 2.0 (the "License"); +@rem you may not use this file except in compliance with the License. +@rem You may obtain a copy of the License at +@rem +@rem https://www.apache.org/licenses/LICENSE-2.0 +@rem +@rem Unless required by applicable law or agreed to in writing, software +@rem distributed under the License is distributed on an "AS IS" BASIS, +@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +@rem See the License for the specific language governing permissions and +@rem limitations under the License. +@rem + +@if "%DEBUG%" == "" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +set DIRNAME=%~dp0 +if "%DIRNAME%" == "" set DIRNAME=. +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Resolve any "." and ".." in APP_HOME to make it shorter. +for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if "%ERRORLEVEL%" == "0" goto execute + +echo. +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto execute + +echo. +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:execute +@rem Setup the command line + +set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar + + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %* + +:end +@rem End local scope for the variables with windows NT shell +if "%ERRORLEVEL%"=="0" goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 +exit /b 1 + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega From 3ba6272d25afce47b4da8245fc649573cbd24017 Mon Sep 17 00:00:00 2001 From: Anton Nashatyrev Date: Mon, 9 Oct 2023 10:36:36 +0300 Subject: [PATCH 04/13] Refactor YamuxHandler (#326) * Refactor YamuxHandler: * Create YamuxStreamHandler inner class * incapsulate all stream related operations there * Add Muxer specific exceptions * Extract writeFrame() method * Group class properties * Adjust exception message. Fix the test --- .../kotlin/io/libp2p/mux/MuxerException.kt | 15 + .../io/libp2p/mux/yamux/YamuxHandler.kt | 345 ++++++++++-------- .../io/libp2p/mux/MuxHandlerAbstractTest.kt | 10 +- .../io/libp2p/mux/yamux/YamuxHandlerTest.kt | 8 +- 4 files changed, 215 insertions(+), 163 deletions(-) create mode 100644 libp2p/src/main/kotlin/io/libp2p/mux/MuxerException.kt diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/MuxerException.kt b/libp2p/src/main/kotlin/io/libp2p/mux/MuxerException.kt new file mode 100644 index 000000000..1ba4eaa1b --- /dev/null +++ b/libp2p/src/main/kotlin/io/libp2p/mux/MuxerException.kt @@ -0,0 +1,15 @@ +package io.libp2p.mux + +import io.libp2p.core.Libp2pException +import io.libp2p.etc.util.netty.mux.MuxId + +open class MuxerException(message: String, ex: Exception?) : Libp2pException(message, ex) + +open class ReadMuxerException(message: String, ex: Exception?) : MuxerException(message, ex) +open class WriteMuxerException(message: String, ex: Exception?) : MuxerException(message, ex) + +class UnknownStreamIdMuxerException(muxId: MuxId) : ReadMuxerException("Stream with id $muxId not found", null) + +class InvalidFrameMuxerException(message: String) : ReadMuxerException(message, null) + +class WriteBufferOverflowMuxerException(message: String) : WriteMuxerException(message, null) diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt index 6bf1cfa24..bc3daa787 100644 --- a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt +++ b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt @@ -8,7 +8,10 @@ import io.libp2p.core.mux.StreamMuxer import io.libp2p.etc.types.sliceMaxSize import io.libp2p.etc.util.netty.mux.MuxChannel import io.libp2p.etc.util.netty.mux.MuxId +import io.libp2p.mux.InvalidFrameMuxerException import io.libp2p.mux.MuxHandler +import io.libp2p.mux.UnknownStreamIdMuxerException +import io.libp2p.mux.WriteBufferOverflowMuxerException import io.netty.buffer.ByteBuf import io.netty.channel.ChannelHandlerContext import java.util.concurrent.CompletableFuture @@ -24,23 +27,115 @@ open class YamuxHandler( ready: CompletableFuture?, inboundStreamHandler: StreamHandler<*>, private val connectionInitiator: Boolean, - private val maxBufferedConnectionWrites: Int + private val maxBufferedConnectionWrites: Int, + private val initialWindowSize: Int = INITIAL_WINDOW_SIZE ) : MuxHandler(ready, inboundStreamHandler) { - private val idGenerator = YamuxStreamIdGenerator(connectionInitiator) - private val windowSizes = ConcurrentHashMap() - private val sendBuffers = ConcurrentHashMap() - private data class WindowSize(val send: AtomicInteger, val receive: AtomicInteger) + private inner class YamuxStreamHandler( + val id: MuxId + ) { + val sendWindowSize = AtomicInteger(initialWindowSize) + val receiveWindowSize = AtomicInteger(initialWindowSize) + val sendBuffer = SendBuffer(id) - /** - * Would contain GoAway error code when received, or would be completed with [ConnectionClosedException] - * when the connection closed without GoAway message - */ - val goAwayPromise = CompletableFuture() + fun dispose() { + sendBuffer.dispose() + } + + fun handleDataRead(msg: YamuxFrame) { + handleFlags(msg) + + val size = msg.length.toInt() + if (size == 0) { + return + } + + val newWindow = receiveWindowSize.addAndGet(-size) + // send a window update frame once half of the window is depleted + if (newWindow < initialWindowSize / 2) { + val delta = initialWindowSize - newWindow + receiveWindowSize.addAndGet(delta) + writeAndFlushFrame(YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, 0, delta.toLong())) + } + childRead(msg.id, msg.data!!) + } + + fun handleWindowUpdate(msg: YamuxFrame) { + handleFlags(msg) + + val delta = msg.length.toInt() + if (delta == 0) { + return + } + + sendWindowSize.addAndGet(delta) + // try to send any buffered messages after the window update + sendBuffer.flush(sendWindowSize) + } + + private fun handleFlags(msg: YamuxFrame) { + when (msg.flags) { + YamuxFlags.SYN -> { + // ACK the new stream + writeAndFlushFrame(YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, YamuxFlags.ACK, 0)) + } + + YamuxFlags.FIN -> onRemoteDisconnect(msg.id) + YamuxFlags.RST -> onRemoteClose(msg.id) + } + } + + fun sendData( + data: ByteBuf + ) { + data.sliceMaxSize(maxFrameDataLength) + .forEach { slicedData -> + if (sendWindowSize.get() > 0 && sendBuffer.isEmpty()) { + val length = slicedData.readableBytes() + sendWindowSize.addAndGet(-length) + writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, 0, length.toLong(), slicedData)) + } else { + // wait until the window is increased to send + addToSendBuffer(data) + } + } + } + + private fun addToSendBuffer(data: ByteBuf) { + sendBuffer.add(data) + val totalBufferedWrites = calculateTotalBufferedWrites() + if (totalBufferedWrites > maxBufferedConnectionWrites) { + onLocalClose() + throw WriteBufferOverflowMuxerException( + "Overflowed send buffer ($totalBufferedWrites/$maxBufferedConnectionWrites). Last stream attempting to write: $id" + ) + } + } + + fun onLocalOpen() { + writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlags.SYN, 0)) + } + + fun onRemoteOpen() { + // nothing + } + + fun onLocalDisconnect() { + // TODO: this implementation drops remaining data + sendBuffer.flush(sendWindowSize) + sendBuffer.dispose() + writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlags.FIN, 0)) + } + + fun onLocalClose() { + // close stream immediately so not transferring buffered data + sendBuffer.dispose() + writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlags.RST, 0)) + } + } private inner class SendBuffer(val id: MuxId) { private val bufferedData = ArrayDeque() - private val ctx = getChannelHandlerContext() fun add(data: ByteBuf) { bufferedData.add(data) @@ -51,29 +146,47 @@ open class YamuxHandler( val data = bufferedData.removeFirst() val length = data.readableBytes() windowSize.addAndGet(-length) - val frame = YamuxFrame(id, YamuxType.DATA, 0, length.toLong(), data) - ctx.writeAndFlush(frame) + writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, 0, length.toLong(), data)) } } - fun bufferedBytes(): Int { - return bufferedData.sumOf { it.readableBytes() } - } - fun isEmpty(): Boolean { return bufferedData.isEmpty() } - fun close() { + fun bufferedBytes(): Int { + return bufferedData.sumOf { it.readableBytes() } + } + + fun dispose() { bufferedData.forEach { releaseMessage(it) } bufferedData.clear() } } + private val idGenerator = YamuxStreamIdGenerator(connectionInitiator) + + private val streamHandlers: MutableMap = ConcurrentHashMap() + + /** + * Would contain GoAway error code when received, or would be completed with [ConnectionClosedException] + * when the connection closed without GoAway message + */ + val goAwayPromise = CompletableFuture() + + private fun getStreamHandlerOrThrow(id: MuxId): YamuxStreamHandler = getStreamHandlerOrReleaseAndThrow(id, null) + + private fun getStreamHandlerOrReleaseAndThrow(id: MuxId, msgToRelease: ByteBuf?): YamuxStreamHandler = + streamHandlers[id] ?: run { + if (msgToRelease != null) { + releaseMessage(msgToRelease) + } + throw UnknownStreamIdMuxerException(id) + } + override fun channelUnregistered(ctx: ChannelHandlerContext?) { - windowSizes.clear() - sendBuffers.values.forEach { it.close() } - sendBuffers.clear() + streamHandlers.values.forEach { it.dispose() } + if (!goAwayPromise.isDone) { goAwayPromise.completeExceptionally(ConnectionClosedException("Connection was closed without Go Away message")) } @@ -82,180 +195,100 @@ open class YamuxHandler( override fun channelRead(ctx: ChannelHandlerContext, msg: Any) { msg as YamuxFrame + when (msg.type) { - YamuxType.DATA -> handleDataRead(msg) - YamuxType.WINDOW_UPDATE -> handleWindowUpdate(msg) YamuxType.PING -> handlePing(msg) YamuxType.GO_AWAY -> handleGoAway(msg) - } - } - - private fun handlePing(msg: YamuxFrame) { - val ctx = getChannelHandlerContext() - when (msg.flags) { - YamuxFlags.SYN -> ctx.writeAndFlush( - YamuxFrame( - YamuxId.sessionId(msg.id.parentId), - YamuxType.PING, - YamuxFlags.ACK, - msg.length - ) - ) + else -> { + if (msg.flags == YamuxFlags.SYN) { + // remote opens a new stream + validateSynRemoteMuxId(msg.id) + onRemoteYamuxOpen(msg.id) + } - YamuxFlags.ACK -> {} + val streamHandler = getStreamHandlerOrReleaseAndThrow(msg.id, msg.data) + when (msg.type) { + YamuxType.DATA -> streamHandler.handleDataRead(msg) + YamuxType.WINDOW_UPDATE -> streamHandler.handleWindowUpdate(msg) + } + } } } - private fun handleDataRead(msg: YamuxFrame) { - handleFlags(msg) - val size = msg.length.toInt() - if (size == 0) { - return - } - val windowSize = windowSizes[msg.id]?.receive - if (windowSize == null) { - releaseMessage(msg.data!!) - throw Libp2pException("Unable to retrieve receive window size for ${msg.id}") - } - - val newWindow = windowSize.addAndGet(-size) - // send a window update frame once half of the window is depleted - if (newWindow < INITIAL_WINDOW_SIZE / 2) { - val delta = INITIAL_WINDOW_SIZE - newWindow - windowSize.addAndGet(delta) - val frame = YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, 0, delta.toLong()) - getChannelHandlerContext().writeAndFlush(frame) - } - childRead(msg.id, msg.data!!) + private fun writeAndFlushFrame(yamuxFrame: YamuxFrame) { + getChannelHandlerContext().writeAndFlush(yamuxFrame) } - private fun handleWindowUpdate(msg: YamuxFrame) { - handleFlags(msg) - val delta = msg.length.toInt() - if (delta == 0) { - return - } - val windowSize = - windowSizes[msg.id]?.send ?: throw Libp2pException("Unable to retrieve send window size for ${msg.id}") - windowSize.addAndGet(delta) - // try to send any buffered messages after the window update - sendBuffers[msg.id]?.flush(windowSize) + private fun abruptlyCloseConnection() { + getChannelHandlerContext().close() } private fun validateSynRemoteMuxId(id: MuxId) { val isRemoteConnectionInitiator = !connectionInitiator if (!YamuxStreamIdGenerator.isRemoteSynStreamIdValid(isRemoteConnectionInitiator, id.id)) { - getChannelHandlerContext().close() + abruptlyCloseConnection() throw Libp2pException("Invalid remote SYN StreamID: $id, isRemoteInitiator: $isRemoteConnectionInitiator") } } - private fun handleFlags(msg: YamuxFrame) { - val ctx = getChannelHandlerContext() - when (msg.flags) { - YamuxFlags.SYN -> { - validateSynRemoteMuxId(msg.id) - onRemoteYamuxOpen(msg.id) - // ACK the new stream - ctx.writeAndFlush(YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, YamuxFlags.ACK, 0)) - } - - YamuxFlags.FIN -> onRemoteDisconnect(msg.id) - YamuxFlags.RST -> onRemoteClose(msg.id) - } - } - - private fun handleGoAway(msg: YamuxFrame) { - goAwayPromise.complete(msg.length) - } - override fun onChildWrite(child: MuxChannel, data: ByteBuf) { - val windowSize = windowSizes[child.id]?.send - if (windowSize == null) { - releaseMessage(data) - throw Libp2pException("Unable to retrieve send window size for ${child.id}") - } - - sendData(child, windowSize, data) + getStreamHandlerOrReleaseAndThrow(child.id, data).sendData(data) } - private fun calculateTotalBufferedWrites(): Int { - return sendBuffers.values.sumOf { it.bufferedBytes() } + override fun onLocalOpen(child: MuxChannel) { + createYamuxStreamHandler(child.id).onLocalOpen() } - private fun sendData( - child: MuxChannel, - windowSize: AtomicInteger, - data: ByteBuf - ) { - data.sliceMaxSize(maxFrameDataLength) - .forEach { slicedData -> - if (windowSize.get() > 0 && sendBufferIsEmpty(child.id)) { - val length = slicedData.readableBytes() - windowSize.addAndGet(-length) - val frame = YamuxFrame(child.id, YamuxType.DATA, 0, length.toLong(), slicedData) - getChannelHandlerContext().writeAndFlush(frame) - } else { - // wait until the window is increased to send - addToSendBuffer(child, data) - } - } + private fun onRemoteYamuxOpen(id: MuxId) { + createYamuxStreamHandler(id).onRemoteOpen() + onRemoteOpen(id) } - private fun addToSendBuffer(child: MuxChannel, data: ByteBuf) { - val buffer = sendBuffers.getOrPut(child.id) { SendBuffer(child.id) } - buffer.add(data) - val totalBufferedWrites = calculateTotalBufferedWrites() - if (totalBufferedWrites > maxBufferedConnectionWrites) { - onLocalClose(child) - throw Libp2pException( - "Overflowed send buffer ($totalBufferedWrites/$maxBufferedConnectionWrites) for connection ${ - getChannelHandlerContext().channel().id().asLongText() - }" - ) - } + private fun createYamuxStreamHandler(id: MuxId): YamuxStreamHandler { + val streamHandler = YamuxStreamHandler(id) + streamHandlers[id] = streamHandler + return streamHandler } - private fun sendBufferIsEmpty(id: MuxId): Boolean { - val sendBuffer = sendBuffers[id] ?: return true - return sendBuffer.isEmpty() + override fun onLocalDisconnect(child: MuxChannel) { + getStreamHandlerOrThrow(child.id).onLocalDisconnect() } - override fun onLocalOpen(child: MuxChannel) { - onStreamCreate(child.id) - getChannelHandlerContext().writeAndFlush(YamuxFrame(child.id, YamuxType.DATA, YamuxFlags.SYN, 0)) + override fun onLocalClose(child: MuxChannel) { + streamHandlers.remove(child.id)?.onLocalClose() } - private fun onRemoteYamuxOpen(id: MuxId) { - onStreamCreate(id) - onRemoteOpen(id) + override fun onChildClosed(child: MuxChannel) { + streamHandlers.remove(child.id)?.dispose() } - private fun onStreamCreate(id: MuxId) { - windowSizes.putIfAbsent(id, WindowSize(AtomicInteger(INITIAL_WINDOW_SIZE), AtomicInteger(INITIAL_WINDOW_SIZE))) + private fun calculateTotalBufferedWrites(): Int { + return streamHandlers.values.sumOf { it.sendBuffer.bufferedBytes() } } - override fun onLocalDisconnect(child: MuxChannel) { - // transfer buffered data before sending FIN - val windowSize = windowSizes[child.id]?.send - val sendBuffer = sendBuffers.remove(child.id) - if (windowSize != null && sendBuffer != null) { - sendBuffer.flush(windowSize) - sendBuffer.close() + private fun handlePing(msg: YamuxFrame) { + if (msg.id.id != YamuxId.SESSION_STREAM_ID) { + throw InvalidFrameMuxerException("Invalid StreamId for Ping frame type: ${msg.id}") } - getChannelHandlerContext().writeAndFlush(YamuxFrame(child.id, YamuxType.DATA, YamuxFlags.FIN, 0)) - } + when (msg.flags) { + YamuxFlags.SYN -> writeAndFlushFrame( + YamuxFrame( + YamuxId.sessionId(msg.id.parentId), + YamuxType.PING, + YamuxFlags.ACK, + msg.length + ) + ) - override fun onLocalClose(child: MuxChannel) { - // close stream immediately so not transferring buffered data - windowSizes.remove(child.id) - sendBuffers.remove(child.id)?.close() - getChannelHandlerContext().writeAndFlush(YamuxFrame(child.id, YamuxType.DATA, YamuxFlags.RST, 0)) + YamuxFlags.ACK -> {} + } } - override fun onChildClosed(child: MuxChannel) { - windowSizes.remove(child.id) - sendBuffers.remove(child.id)?.close() + private fun handleGoAway(msg: YamuxFrame) { + if (msg.id.id != YamuxId.SESSION_STREAM_ID) { + throw InvalidFrameMuxerException("Invalid StreamId for GoAway frame type: ${msg.id}") + } + goAwayPromise.complete(msg.length) } override fun generateNextId() = diff --git a/libp2p/src/test/kotlin/io/libp2p/mux/MuxHandlerAbstractTest.kt b/libp2p/src/test/kotlin/io/libp2p/mux/MuxHandlerAbstractTest.kt index a1f7403a1..bb0f21313 100644 --- a/libp2p/src/test/kotlin/io/libp2p/mux/MuxHandlerAbstractTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/mux/MuxHandlerAbstractTest.kt @@ -310,9 +310,15 @@ abstract class MuxHandlerAbstractTest { } @Test - fun canResetNonExistentStream() { - resetStream(99) + @SuppressWarnings("SwallowedException") + fun `resetting non existing stream doesnt close connection`() { + try { + resetStream(99) + } catch (e: UnknownStreamIdMuxerException) { + // Muxer is free to either throw an exception or just ignore + } assertHandlerCount(0) + assertThat(ech.isOpen).isTrue() } @Test diff --git a/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt b/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt index 098df2f47..a2b13d637 100644 --- a/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt @@ -205,11 +205,11 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { @Test fun `overflowing buffer sends RST flag and throws an exception`() { val handler = openStreamLocal() - val streamId = readFrameOrThrow().streamId + val muxId = readFrameOrThrow().streamId.toMuxId() ech.writeInbound( YamuxFrame( - streamId.toMuxId(), + muxId, YamuxType.WINDOW_UPDATE, YamuxFlags.ACK, -INITIAL_WINDOW_SIZE.toLong() @@ -229,7 +229,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { assertThat(writeResult.isSuccess).isFalse() assertThat(writeResult.cause()) .isInstanceOf(Libp2pException::class.java) - .hasMessage("Overflowed send buffer (612/512) for connection test") + .hasMessage("Overflowed send buffer (612/512). Last stream attempting to write: $muxId") val frame = readYamuxFrameOrThrow() assertThat(frame.flags).isEqualTo(YamuxFlags.RST) @@ -306,8 +306,6 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { assertThat(pingFrame.flags).isEqualTo(YamuxFlags.ACK) assertThat(pingFrame.type).isEqualTo(YamuxType.PING) assertThat(pingFrame.length).isEqualTo(3) - - closeStream(id) } @Test From 26efe02cad1d75cc74f9965aa4921def099eb7f3 Mon Sep 17 00:00:00 2001 From: diegomrsantos Date: Mon, 9 Oct 2023 15:49:50 +0200 Subject: [PATCH 05/13] Change maxPrunePeers and maxPeersPerPruneMessage usage (#336) * Use maxPrunePeers to limit the amount of peers in PX * Use maxPeersPerPruneMessage instead of maxPrunePeers to limit the amount of peers processed in PX * Test for maxPeersSentInPruneMsg * Add testMaxPeersAcceptedInPruneMsg test --------- Co-authored-by: Anton Nashatyrev Co-authored-by: Anton Nashatyrev --- .../io/libp2p/pubsub/gossip/GossipParams.kt | 10 +-- .../io/libp2p/pubsub/gossip/GossipRouter.kt | 5 +- .../gossip/builders/GossipParamsBuilder.kt | 18 ++--- .../gossip/GossipRouterListLimitsTest.kt | 14 ++-- .../libp2p/pubsub/gossip/GossipV1_1Tests.kt | 74 ++++++++++++++++++- 5 files changed, 97 insertions(+), 24 deletions(-) diff --git a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipParams.kt b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipParams.kt index 6823ecc4e..e63e780dc 100644 --- a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipParams.kt +++ b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipParams.kt @@ -193,17 +193,17 @@ data class GossipParams( val maxGraftMessages: Int? = null, /** - * [maxPrunePeers] controls the number of peers to include in prune Peer eXchange. + * [maxPeersSentInPruneMsg] controls the number of peers to include in prune Peer eXchange. * When we prune a peer that's eligible for PX (has a good score, etc), we will try to - * send them signed peer records for up to [maxPrunePeers] other peers that we + * send them signed peer records for up to [maxPeersSentInPruneMsg] other peers that we * know of. */ - val maxPrunePeers: Int = 16, + val maxPeersSentInPruneMsg: Int = 16, /** - * [maxPeersPerPruneMessage] is the maximum number of peers allowed in an incoming prune message + * [maxPeersAcceptedInPruneMsg] is the maximum number of peers allowed in an incoming prune message */ - val maxPeersPerPruneMessage: Int? = null, + val maxPeersAcceptedInPruneMsg: Int = 16, /** * [pruneBackoff] controls the backoff time for pruned peers. This is how long diff --git a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipRouter.kt b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipRouter.kt index c09f0a67e..b1de2bd05 100644 --- a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipRouter.kt +++ b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipRouter.kt @@ -260,7 +260,7 @@ open class GossipRouter( params.maxIWantMessageIds?.let { iWantMessageIdCount <= it } ?: true && params.maxGraftMessages?.let { (msg.control?.graftCount ?: 0) <= it } ?: true && params.maxPruneMessages?.let { (msg.control?.pruneCount ?: 0) <= it } ?: true && - params.maxPeersPerPruneMessage?.let { msg.control?.pruneList?.none { p -> p.peersCount > it } } ?: true + params.maxPeersAcceptedInPruneMsg.let { msg.control?.pruneList?.none { p -> p.peersCount > it } } ?: true } private fun processControlMessage(controlMsg: Any, receivedFrom: PeerHandler) { @@ -349,7 +349,7 @@ open class GossipRouter( } private fun processPrunePeers(peersList: List) { - peersList.shuffled(random).take(params.maxPrunePeers) + peersList.shuffled(random).take(params.maxPeersAcceptedInPruneMsg) .map { PeerId(it.peerID.toByteArray()) to it.signedPeerRecord.toByteArray() } .filter { (id, _) -> !isConnected(id) } .forEach { (id, record) -> params.connectCallback(id, record) } @@ -572,6 +572,7 @@ open class GossipRouter( val peerQueue = pendingRpcParts.getQueue(peer) if (peer.getPeerProtocol() == PubsubProtocol.Gossip_V_1_1 && this.protocol == PubsubProtocol.Gossip_V_1_1) { val backoffPeers = (getTopicPeers(topic) - peer) + .take(params.maxPeersSentInPruneMsg) .filter { score.score(it.peerId) >= 0 } .map { it.peerId } peerQueue.addPrune(topic, params.pruneBackoff.seconds, backoffPeers) diff --git a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/builders/GossipParamsBuilder.kt b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/builders/GossipParamsBuilder.kt index ed272ced7..7e9c9ebfe 100644 --- a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/builders/GossipParamsBuilder.kt +++ b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/builders/GossipParamsBuilder.kt @@ -34,9 +34,9 @@ class GossipParamsBuilder { private var seenTTL: Duration? = null - private var maxPrunePeers: Int? = null + private var maxPeersSentInPruneMsg: Int? = null - private var maxPeersPerPruneMessage: Int? = null + private var maxPeersAcceptedInPruneMsg: Int? = null private var pruneBackoff: Duration? = null @@ -81,8 +81,8 @@ class GossipParamsBuilder { this.gossipHistoryLength = source.gossipHistoryLength this.heartbeatInterval = source.heartbeatInterval this.seenTTL = source.seenTTL - this.maxPrunePeers = source.maxPrunePeers - this.maxPeersPerPruneMessage = source.maxPeersPerPruneMessage + this.maxPeersSentInPruneMsg = source.maxPeersSentInPruneMsg + this.maxPeersAcceptedInPruneMsg = source.maxPeersAcceptedInPruneMsg this.pruneBackoff = source.pruneBackoff this.floodPublish = source.floodPublish this.gossipFactor = source.gossipFactor @@ -126,9 +126,9 @@ class GossipParamsBuilder { fun seenTTL(value: Duration): GossipParamsBuilder = apply { seenTTL = value } - fun maxPrunePeers(value: Int): GossipParamsBuilder = apply { maxPrunePeers = value } + fun maxPeersSentInPruneMsg(value: Int): GossipParamsBuilder = apply { maxPeersSentInPruneMsg = value } - fun maxPeersPerPruneMessage(value: Int): GossipParamsBuilder = apply { maxPeersPerPruneMessage = value } + fun maxPeersAcceptedInPruneMsg(value: Int): GossipParamsBuilder = apply { maxPeersAcceptedInPruneMsg = value } fun pruneBackoff(value: Duration): GossipParamsBuilder = apply { pruneBackoff = value } @@ -201,8 +201,8 @@ class GossipParamsBuilder { maxIWantMessageIds = maxIWantMessageIds, iWantFollowupTime = iWantFollowupTime!!, maxGraftMessages = maxGraftMessages, - maxPrunePeers = maxPrunePeers!!, - maxPeersPerPruneMessage = maxPeersPerPruneMessage, + maxPeersSentInPruneMsg = maxPeersSentInPruneMsg!!, + maxPeersAcceptedInPruneMsg = maxPeersAcceptedInPruneMsg!!, pruneBackoff = pruneBackoff!!, maxPruneMessages = maxPruneMessages, gossipRetransmission = gossipRetransmission!!, @@ -232,7 +232,7 @@ class GossipParamsBuilder { check(gossipHistoryLength != null, { "gossipHistoryLength must not be null" }) check(heartbeatInterval != null, { "heartbeatInterval must not be null" }) check(seenTTL != null, { "seenTTL must not be null" }) - check(maxPrunePeers != null, { "maxPrunePeers must not be null" }) + check(maxPeersSentInPruneMsg != null, { "maxPeersSentInPruneMsg must not be null" }) check(pruneBackoff != null, { "pruneBackoff must not be null" }) check(floodPublish != null, { "floodPublish must not be null" }) check(gossipFactor != null, { "gossipFactor must not be null" }) diff --git a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipRouterListLimitsTest.kt b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipRouterListLimitsTest.kt index d29ec1e53..6942cc979 100644 --- a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipRouterListLimitsTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipRouterListLimitsTest.kt @@ -15,7 +15,7 @@ class GossipRouterListLimitsTest { private val maxIWantMessageIds = 14 private val maxGraftMessages = 15 private val maxPruneMessages = 16 - private val maxPeersPerPruneMessage = 17 + private val maxPeersAcceptedInPruneMsg = 17 private val gossipParamsWithLimits = GossipParamsBuilder() .maxPublishedMessages(maxPublishedMessages) @@ -25,7 +25,7 @@ class GossipRouterListLimitsTest { .maxIWantMessageIds(maxIWantMessageIds) .maxGraftMessages(maxGraftMessages) .maxPruneMessages(maxPruneMessages) - .maxPeersPerPruneMessage(maxPeersPerPruneMessage) + .maxPeersAcceptedInPruneMsg(maxPeersAcceptedInPruneMsg) .build() private val gossipParamsNoLimits = GossipParamsBuilder() @@ -44,7 +44,7 @@ class GossipRouterListLimitsTest { @Test fun validateProtobufLists_validMessageWithLargeLists_noLimits() { - val msg = fullMsgBuilder(20).build() + val msg = fullMsgBuilder(16).build() Assertions.assertThat(routerWithNoLimits.validateMessageListLimits(msg)).isTrue() } @@ -148,9 +148,9 @@ class GossipRouterListLimitsTest { } @Test - fun validateProtobufLists_tooManyPrunePeers() { + fun validateProtobufLists_tooManyPeersToAcceptInPruneMsg() { val builder = fullMsgBuilder() - builder.addPrunes(1, maxPeersPerPruneMessage + 1) + builder.addPrunes(1, maxPeersAcceptedInPruneMsg + 1) val msg = builder.build() Assertions.assertThat(routerWithLimits.validateMessageListLimits(msg)).isFalse() @@ -238,9 +238,9 @@ class GossipRouterListLimitsTest { } @Test - fun validateProtobufLists_maxPrunePeers() { + fun validateProtobufLists_maxPeersAcceptedInPruneMsg() { val builder = fullMsgBuilder() - builder.addPrunes(1, maxPeersPerPruneMessage - 1) + builder.addPrunes(1, maxPeersAcceptedInPruneMsg - 1) val msg = builder.build() Assertions.assertThat(routerWithLimits.validateMessageListLimits(msg)).isTrue() diff --git a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipV1_1Tests.kt b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipV1_1Tests.kt index 6518d07fd..3750803e4 100644 --- a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipV1_1Tests.kt +++ b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipV1_1Tests.kt @@ -3,6 +3,7 @@ package io.libp2p.pubsub.gossip import com.google.common.util.concurrent.AtomicDouble +import com.google.protobuf.ByteString import io.libp2p.core.PeerId import io.libp2p.core.pubsub.MessageApi import io.libp2p.core.pubsub.RESULT_IGNORE @@ -102,7 +103,7 @@ class GossipV1_1Tests { class TwoRoutersTest( val coreParams: GossipParams = GossipParams(), val scoreParams: GossipScoreParams = GossipScoreParams(), - mockRouterFactory: DeterministicFuzzRouterFactory = createMockFuzzRouterFactory() + val mockRouterFactory: DeterministicFuzzRouterFactory = createMockFuzzRouterFactory() ) { val fuzz = DeterministicFuzz() val gossipRouterBuilderFactory = { GossipRouterBuilder(params = coreParams, scoreParams = scoreParams) } @@ -1140,4 +1141,75 @@ class GossipV1_1Tests { assertEquals(5, iWandIds1.size) assertEquals(5, iWandIds1.distinct().size) } + + @Test + fun testMaxPeersSentInPruneMsg() { + val test = TwoRoutersTest() + + val topic = "topic1" + test.mockRouter.subscribe(topic) + test.gossipRouter.subscribe(topic) + + for (i in 0..20) { + val router = test.fuzz.createTestRouter(test.mockRouterFactory) + (router.router as MockRouter).subscribe(topic) + test.router1.connectSemiDuplex(router, null, LogLevel.ERROR) + } + + // 2 heartbeats - the topic should be GRAFTed + test.fuzz.timeController.addTime(2.seconds) + test.mockRouter.waitForMessage { it.hasControl() && it.control.graftCount > 0 } + + test.gossipRouter.unsubscribe(topic) + test.fuzz.timeController.addTime(2.seconds) + assertEquals( + 1, + test.mockRouter.inboundMessages.count { + it.hasControl() && it.control.pruneCount == 1 && + it.control.getPrune(0).peersCount == test.gossipRouter.params.maxPeersSentInPruneMsg + } + ) + } + + @Test + fun testMaxPeersAcceptedInPruneMsg() { + val test = TwoRoutersTest() + val topic = "topic1" + + test.mockRouter.subscribe(topic) + test.gossipRouter.subscribe(topic) + + // 2 heartbeats - the topic should be GRAFTed + test.fuzz.timeController.addTime(2.seconds) + + fun createPruneMessage(peersCount: Int): Rpc.RPC { + val peerInfos = List(peersCount) { + Rpc.PeerInfo.newBuilder() + .setPeerID(PeerId.random().bytes.toProtobuf()) + .setSignedPeerRecord(ByteString.EMPTY) + .build() + } + return Rpc.RPC.newBuilder().setControl( + Rpc.ControlMessage.newBuilder().addPrune( + Rpc.ControlPrune.newBuilder() + .setTopicID(topic) + .addAllPeers(peerInfos) + ) + ).build() + } + + test.mockRouter.sendToSingle( + createPruneMessage(test.gossipRouter.params.maxPeersAcceptedInPruneMsg + 1) + ) + + // prune message should be dropped because too many peers + assertEquals(1, test.gossipRouter.mesh[topic]!!.size) + + test.mockRouter.sendToSingle( + createPruneMessage(test.gossipRouter.params.maxPeersAcceptedInPruneMsg) + ) + + // prune message should now be processed + assertEquals(0, test.gossipRouter.mesh[topic]!!.size) + } } From 7dc2fa287bb466b3ac3dda18e45707a28df956f5 Mon Sep 17 00:00:00 2001 From: Anton Nashatyrev Date: Tue, 10 Oct 2023 15:06:28 +0300 Subject: [PATCH 06/13] Refactor YamuxHandler.SendBuffer (#328) * Introduce ByteBufQueue * Add ByteBufQueue tests * Writing exec path is always through fill/drain buffer * Adopt/fix existing tests * Add new test checking correct handling of negative sendWindowSize --- .../io/libp2p/etc/util/netty/ByteBufQueue.kt | 44 +++++ .../io/libp2p/mux/yamux/YamuxHandler.kt | 100 ++++------- .../libp2p/etc/util/netty/ByteBufQueueTest.kt | 167 ++++++++++++++++++ .../io/libp2p/mux/yamux/YamuxHandlerTest.kt | 104 ++++++++--- 4 files changed, 328 insertions(+), 87 deletions(-) create mode 100644 libp2p/src/main/kotlin/io/libp2p/etc/util/netty/ByteBufQueue.kt create mode 100644 libp2p/src/test/kotlin/io/libp2p/etc/util/netty/ByteBufQueueTest.kt diff --git a/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/ByteBufQueue.kt b/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/ByteBufQueue.kt new file mode 100644 index 000000000..1ddef99a2 --- /dev/null +++ b/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/ByteBufQueue.kt @@ -0,0 +1,44 @@ +package io.libp2p.etc.util.netty + +import io.netty.buffer.ByteBuf +import io.netty.buffer.Unpooled + +class ByteBufQueue { + private val data: MutableList = mutableListOf() + + fun push(buf: ByteBuf) { + data += buf + } + + fun take(maxLength: Int): ByteBuf { + val wholeBuffers = mutableListOf() + var size = 0 + while (data.isNotEmpty()) { + val bufLen = data.first().readableBytes() + if (size + bufLen > maxLength) break + size += bufLen + wholeBuffers += data.removeFirst() + if (size == maxLength) break + } + + val partialBufferSlice = + when { + data.isEmpty() -> null + size == maxLength -> null + else -> data.first() + } + ?.let { buf -> + val remainingBytes = maxLength - size + buf.readRetainedSlice(remainingBytes) + } + + val allBuffers = wholeBuffers + listOfNotNull(partialBufferSlice) + return Unpooled.wrappedBuffer(*allBuffers.toTypedArray()) + } + + fun dispose() { + data.forEach { it.release() } + } + + fun readableBytes(): Int = data.sumOf { it.readableBytes() } +} diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt index bc3daa787..ece12d4e7 100644 --- a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt +++ b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt @@ -6,6 +6,7 @@ import io.libp2p.core.StreamHandler import io.libp2p.core.multistream.MultistreamProtocol import io.libp2p.core.mux.StreamMuxer import io.libp2p.etc.types.sliceMaxSize +import io.libp2p.etc.util.netty.ByteBufQueue import io.libp2p.etc.util.netty.mux.MuxChannel import io.libp2p.etc.util.netty.mux.MuxId import io.libp2p.mux.InvalidFrameMuxerException @@ -17,6 +18,7 @@ import io.netty.channel.ChannelHandlerContext import java.util.concurrent.CompletableFuture import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicInteger +import kotlin.math.max const val INITIAL_WINDOW_SIZE = 256 * 1024 const val DEFAULT_MAX_BUFFERED_CONNECTION_WRITES = 10 * 1024 * 1024 // 10 MiB @@ -36,15 +38,21 @@ open class YamuxHandler( ) { val sendWindowSize = AtomicInteger(initialWindowSize) val receiveWindowSize = AtomicInteger(initialWindowSize) - val sendBuffer = SendBuffer(id) + val sendBuffer = ByteBufQueue() fun dispose() { sendBuffer.dispose() } - fun handleDataRead(msg: YamuxFrame) { + fun handleFrameRead(msg: YamuxFrame) { handleFlags(msg) + when (msg.type) { + YamuxType.DATA -> handleDataRead(msg) + YamuxType.WINDOW_UPDATE -> handleWindowUpdate(msg) + } + } + private fun handleDataRead(msg: YamuxFrame) { val size = msg.length.toInt() if (size == 0) { return @@ -60,17 +68,11 @@ open class YamuxHandler( childRead(msg.id, msg.data!!) } - fun handleWindowUpdate(msg: YamuxFrame) { - handleFlags(msg) - + private fun handleWindowUpdate(msg: YamuxFrame) { val delta = msg.length.toInt() - if (delta == 0) { - return - } - sendWindowSize.addAndGet(delta) // try to send any buffered messages after the window update - sendBuffer.flush(sendWindowSize) + drainBuffer() } private fun handleFlags(msg: YamuxFrame) { @@ -85,26 +87,10 @@ open class YamuxHandler( } } - fun sendData( - data: ByteBuf - ) { - data.sliceMaxSize(maxFrameDataLength) - .forEach { slicedData -> - if (sendWindowSize.get() > 0 && sendBuffer.isEmpty()) { - val length = slicedData.readableBytes() - sendWindowSize.addAndGet(-length) - writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, 0, length.toLong(), slicedData)) - } else { - // wait until the window is increased to send - addToSendBuffer(data) - } - } - } - - private fun addToSendBuffer(data: ByteBuf) { - sendBuffer.add(data) + private fun fillBuffer(data: ByteBuf) { + sendBuffer.push(data) val totalBufferedWrites = calculateTotalBufferedWrites() - if (totalBufferedWrites > maxBufferedConnectionWrites) { + if (totalBufferedWrites > maxBufferedConnectionWrites + sendWindowSize.get()) { onLocalClose() throw WriteBufferOverflowMuxerException( "Overflowed send buffer ($totalBufferedWrites/$maxBufferedConnectionWrites). Last stream attempting to write: $id" @@ -112,6 +98,22 @@ open class YamuxHandler( } } + private fun drainBuffer() { + val maxSendLength = max(0, sendWindowSize.get()) + val data = sendBuffer.take(maxSendLength) + sendWindowSize.addAndGet(-data.readableBytes()) + data.sliceMaxSize(maxFrameDataLength) + .forEach { slicedData -> + val length = slicedData.readableBytes() + writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, 0, length.toLong(), slicedData)) + } + } + + fun sendData(data: ByteBuf) { + fillBuffer(data) + drainBuffer() + } + fun onLocalOpen() { writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlags.SYN, 0)) } @@ -122,7 +124,7 @@ open class YamuxHandler( fun onLocalDisconnect() { // TODO: this implementation drops remaining data - sendBuffer.flush(sendWindowSize) + drainBuffer() sendBuffer.dispose() writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlags.FIN, 0)) } @@ -134,36 +136,6 @@ open class YamuxHandler( } } - private inner class SendBuffer(val id: MuxId) { - private val bufferedData = ArrayDeque() - - fun add(data: ByteBuf) { - bufferedData.add(data) - } - - fun flush(windowSize: AtomicInteger) { - while (!isEmpty() && windowSize.get() > 0) { - val data = bufferedData.removeFirst() - val length = data.readableBytes() - windowSize.addAndGet(-length) - writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, 0, length.toLong(), data)) - } - } - - fun isEmpty(): Boolean { - return bufferedData.isEmpty() - } - - fun bufferedBytes(): Int { - return bufferedData.sumOf { it.readableBytes() } - } - - fun dispose() { - bufferedData.forEach { releaseMessage(it) } - bufferedData.clear() - } - } - private val idGenerator = YamuxStreamIdGenerator(connectionInitiator) private val streamHandlers: MutableMap = ConcurrentHashMap() @@ -206,11 +178,7 @@ open class YamuxHandler( onRemoteYamuxOpen(msg.id) } - val streamHandler = getStreamHandlerOrReleaseAndThrow(msg.id, msg.data) - when (msg.type) { - YamuxType.DATA -> streamHandler.handleDataRead(msg) - YamuxType.WINDOW_UPDATE -> streamHandler.handleWindowUpdate(msg) - } + getStreamHandlerOrReleaseAndThrow(msg.id, msg.data).handleFrameRead(msg) } } } @@ -263,7 +231,7 @@ open class YamuxHandler( } private fun calculateTotalBufferedWrites(): Int { - return streamHandlers.values.sumOf { it.sendBuffer.bufferedBytes() } + return streamHandlers.values.sumOf { it.sendBuffer.readableBytes() } } private fun handlePing(msg: YamuxFrame) { diff --git a/libp2p/src/test/kotlin/io/libp2p/etc/util/netty/ByteBufQueueTest.kt b/libp2p/src/test/kotlin/io/libp2p/etc/util/netty/ByteBufQueueTest.kt new file mode 100644 index 000000000..848d783de --- /dev/null +++ b/libp2p/src/test/kotlin/io/libp2p/etc/util/netty/ByteBufQueueTest.kt @@ -0,0 +1,167 @@ +package io.libp2p.etc.util.netty + +import io.libp2p.tools.readAllBytesAndRelease +import io.netty.buffer.ByteBuf +import io.netty.buffer.Unpooled +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.Test + +class ByteBufQueueTest { + + val queue = ByteBufQueue() + + val allocatedBufs = mutableListOf() + + @AfterEach + fun cleanUpAndCheck() { + allocatedBufs.forEach { + assertThat(it.refCnt()).isEqualTo(1) + } + } + + fun allocateBuf(): ByteBuf { + val buf = Unpooled.buffer() + buf.retain() // ref counter to 2 to check that exactly 1 ref remains at the end + allocatedBufs += buf + return buf + } + + fun allocateData(data: String): ByteBuf = + allocateBuf().writeBytes(data.toByteArray()) + + fun ByteBuf.readString() = String(this.readAllBytesAndRelease()) + + @Test + fun emptyTest() { + assertThat(queue.take(100).readString()).isEqualTo("") + } + + @Test + fun zeroTest() { + queue.push(allocateData("abc")) + assertThat(queue.take(0).readString()).isEqualTo("") + assertThat(queue.take(100).readString()).isEqualTo("abc") + } + + @Test + fun emptyZeroTest() { + assertThat(queue.take(0).readString()).isEqualTo("") + } + + @Test + fun emptyBuffersTest1() { + queue.push(allocateData("")) + assertThat(queue.take(10).readString()).isEqualTo("") + } + + @Test + fun emptyBuffersTest2() { + queue.push(allocateData("")) + assertThat(queue.take(0).readString()).isEqualTo("") + } + + @Test + fun emptyBuffersTest3() { + queue.push(allocateData("")) + queue.push(allocateData("a")) + queue.push(allocateData("")) + assertThat(queue.take(10).readString()).isEqualTo("a") + } + + @Test + fun emptyBuffersTest4() { + queue.push(allocateData("a")) + queue.push(allocateData("")) + assertThat(queue.take(10).readString()).isEqualTo("a") + } + + @Test + fun emptyBuffersTest5() { + queue.push(allocateData("a")) + queue.push(allocateData("")) + assertThat(queue.take(1).readString()).isEqualTo("a") + assertThat(queue.take(1).readString()).isEqualTo("") + } + + @Test + fun emptyBuffersTest6() { + queue.push(allocateData("a")) + queue.push(allocateData("")) + queue.push(allocateData("")) + queue.push(allocateData("b")) + assertThat(queue.take(10).readString()).isEqualTo("ab") + } + + @Test + fun pushTake1() { + queue.push(allocateData("abc")) + queue.push(allocateData("def")) + + assertThat(queue.take(4).readString()).isEqualTo("abcd") + assertThat(queue.take(1).readString()).isEqualTo("e") + assertThat(queue.take(100).readString()).isEqualTo("f") + assertThat(queue.take(100).readString()).isEqualTo("") + } + + @Test + fun pushTake2() { + queue.push(allocateData("abc")) + queue.push(allocateData("def")) + + assertThat(queue.take(2).readString()).isEqualTo("ab") + assertThat(queue.take(2).readString()).isEqualTo("cd") + assertThat(queue.take(2).readString()).isEqualTo("ef") + assertThat(queue.take(2).readString()).isEqualTo("") + } + + @Test + fun pushTake3() { + queue.push(allocateData("abc")) + queue.push(allocateData("def")) + + assertThat(queue.take(1).readString()).isEqualTo("a") + assertThat(queue.take(1).readString()).isEqualTo("b") + assertThat(queue.take(1).readString()).isEqualTo("c") + assertThat(queue.take(1).readString()).isEqualTo("d") + assertThat(queue.take(1).readString()).isEqualTo("e") + assertThat(queue.take(1).readString()).isEqualTo("f") + assertThat(queue.take(1).readString()).isEqualTo("") + } + + @Test + fun pushTakePush1() { + queue.push(allocateData("abc")) + assertThat(queue.take(2).readString()).isEqualTo("ab") + queue.push(allocateData("def")) + assertThat(queue.take(2).readString()).isEqualTo("cd") + assertThat(queue.take(100).readString()).isEqualTo("ef") + } + + @Test + fun pushTakePush2() { + queue.push(allocateData("abc")) + assertThat(queue.take(3).readString()).isEqualTo("abc") + queue.push(allocateData("def")) + assertThat(queue.take(2).readString()).isEqualTo("de") + assertThat(queue.take(100).readString()).isEqualTo("f") + } + + @Test + fun pushTakePush3() { + queue.push(allocateData("abc")) + queue.push(allocateData("def")) + assertThat(queue.take(1).readString()).isEqualTo("a") + queue.push(allocateData("ghi")) + assertThat(queue.take(100).readString()).isEqualTo("bcdefghi") + } + + @Test + fun pushTakePush4() { + queue.push(allocateData("abc")) + assertThat(queue.take(1).readString()).isEqualTo("a") + queue.push(allocateData("def")) + queue.push(allocateData("ghi")) + assertThat(queue.take(100).readString()).isEqualTo("bcdefghi") + } +} diff --git a/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt b/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt index a2b13d637..5f239f081 100644 --- a/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt @@ -19,6 +19,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { override val maxFrameDataLength = 256 private val maxBufferedConnectionWrites = 512 + private val initialWindowSize = 300 override val localMuxIdGenerator = YamuxStreamIdGenerator(isLocalConnectionInitiator).toIterator() override val remoteMuxIdGenerator = YamuxStreamIdGenerator(!isLocalConnectionInitiator).toIterator() @@ -32,7 +33,8 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { null, streamHandler, true, - maxBufferedConnectionWrites + maxBufferedConnectionWrites, + initialWindowSize ) { // MuxHandler consumes the exception. Override this behaviour for testing @Deprecated("Deprecated in Java") @@ -112,7 +114,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { val streamId = readFrameOrThrow().streamId // > 1/2 window size - val length = (INITIAL_WINDOW_SIZE / 2) + 42 + val length = (initialWindowSize / 2) + 42 ech.writeInbound( YamuxFrame( streamId.toMuxId(), @@ -141,7 +143,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { streamId.toMuxId(), YamuxType.WINDOW_UPDATE, YamuxFlags.ACK, - -INITIAL_WINDOW_SIZE.toLong() + -initialWindowSize.toLong() ) ) @@ -164,7 +166,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { streamId.toMuxId(), YamuxType.WINDOW_UPDATE, YamuxFlags.ACK, - -INITIAL_WINDOW_SIZE.toLong() + -initialWindowSize.toLong() ) ) @@ -199,7 +201,18 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { ) ) frame = readFrameOrThrow() - assertThat(frame.data).isEqualTo("1984") + assertThat(frame.data).isEqualTo("19") + + ech.writeInbound( + YamuxFrame( + streamId.toMuxId(), + YamuxType.WINDOW_UPDATE, + YamuxFlags.ACK, + 10000 + ) + ) + frame = readFrameOrThrow() + assertThat(frame.data).isEqualTo("84") } @Test @@ -212,7 +225,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { muxId, YamuxType.WINDOW_UPDATE, YamuxFlags.ACK, - -INITIAL_WINDOW_SIZE.toLong() + -initialWindowSize.toLong() ) ) @@ -243,13 +256,13 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { val createMessage: (String) -> ByteBuf = { it.toByteArray().toByteBuf(allocateBuf()) } - val sendWindowUpdate: (Long) -> Unit = { + val sendWindowUpdate: (Int) -> Unit = { ech.writeInbound( YamuxFrame( streamId.toMuxId(), YamuxType.WINDOW_UPDATE, YamuxFlags.ACK, - it + it.toLong() ) ) } @@ -257,7 +270,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { // approximately every 5 messages window size will be depleted val messagesToSend = 500 val customWindowSize = 14 - sendWindowUpdate(-INITIAL_WINDOW_SIZE.toLong() + customWindowSize) + sendWindowUpdate(-initialWindowSize + customWindowSize) val range = 1..messagesToSend @@ -269,23 +282,23 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { for (i in range) { if (i in windowUpdatesIndices) { - sendWindowUpdate(customWindowSize.toLong()) + sendWindowUpdate(customWindowSize) } handler.ctx.writeAndFlush(createMessage(i.toString())) } - // verify the order of messages - for (i in range) { - val frame = readYamuxFrame() - assertThat(frame).overridingErrorMessage( - "Expected to send %s messages but it sent only %s", - messagesToSend, - messagesToSend - i - ).isNotNull() - assertThat(frame!!.data).isNotNull() - val data = String(frame.data!!.readAllBytesAndRelease()) - assertThat(data).isEqualTo(i.toString()) + val receivedData = generateSequence { + readYamuxFrame() } + .map { + assertThat(it.data).isNotNull() + String(it.data!!.readAllBytesAndRelease()) + } + .joinToString(separator = "") + + val expectedData = range.joinToString(separator = "") + + assertThat(receivedData).isEqualTo(expectedData) } @Test @@ -345,6 +358,55 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { assertThat(ech.isOpen).isFalse() } + @Test + fun `negative sendWindowSize should be correctly handled`() { + val handler = openStreamLocal() + val muxId = readFrameOrThrow().streamId.toMuxId() + + val msg = "42".repeat(initialWindowSize + 1).fromHex().toByteBuf(allocateBuf()) + // writing a message which is larger than sendWindowSize + handler.ctx.writeAndFlush(msg) + + // sendWindowSize is 0 now + + // remote party wants to reduce the window by 10 + ech.writeInbound( + YamuxFrame( + muxId, + YamuxType.WINDOW_UPDATE, + YamuxFlags.ACK, + -10 + ) + ) + + // sendWindowSize is -10 now + + val msgPart1 = readYamuxFrameOrThrow() + assertThat(msgPart1.length).isEqualTo(256L) + assertThat(msgPart1.data!!.readableBytes()).isEqualTo(256) + msgPart1.data!!.release() + + val msgPart2 = readYamuxFrameOrThrow() + assertThat(msgPart2.length.toInt()).isEqualTo(initialWindowSize - 256) + assertThat(msgPart2.data!!.readableBytes()).isEqualTo(initialWindowSize - 256) + msgPart2.data!!.release() + + // ACKing message receive + ech.writeInbound( + YamuxFrame( + muxId, + YamuxType.WINDOW_UPDATE, + YamuxFlags.ACK, + initialWindowSize.toLong() + ) + ) + + val msgPart3 = readYamuxFrameOrThrow() + assertThat(msgPart3.length).isEqualTo(1L) + assertThat(msgPart3.data!!.readableBytes()).isEqualTo(1) + msgPart3.data!!.release() + } + companion object { private fun YamuxStreamIdGenerator.toIterator() = iterator { while (true) { From 3ac83c4c3b3dd255ed4226e355eb3b2fbb3dd3a3 Mon Sep 17 00:00:00 2001 From: Anton Nashatyrev Date: Tue, 10 Oct 2023 16:20:32 +0300 Subject: [PATCH 07/13] Add large blob test (#337) * Add HostTestJava.largeBlob() test * Add some flexibility to the HostBuilder * Add DebugHandlerBuilder.addCompactLogger() which adds a logger which shrinks too long messages Co-authored-by: Dr Ian Preston [ianopolous@protonmail.com](mailto:ianopolous@protonmail.com) --- .../java/io/libp2p/core/dsl/HostBuilder.java | 8 + .../kotlin/io/libp2p/core/dsl/Builders.kt | 5 + .../java/io/libp2p/core/HostTestJava.java | 71 ++++++++- .../test/kotlin/io/libp2p/protocol/Blob.kt | 144 ++++++++++++++++++ 4 files changed, 226 insertions(+), 2 deletions(-) create mode 100644 libp2p/src/test/kotlin/io/libp2p/protocol/Blob.kt diff --git a/libp2p/src/main/java/io/libp2p/core/dsl/HostBuilder.java b/libp2p/src/main/java/io/libp2p/core/dsl/HostBuilder.java index 2915b5b29..6eba6a226 100644 --- a/libp2p/src/main/java/io/libp2p/core/dsl/HostBuilder.java +++ b/libp2p/src/main/java/io/libp2p/core/dsl/HostBuilder.java @@ -62,6 +62,12 @@ public final HostBuilder listen(String... addresses) { return this; } + public final HostBuilder builderModifier(Consumer builderModifier) { + this.builderModifier = builderModifier; + return this; + } + + @SuppressWarnings("unchecked") public Host build() { return BuilderJKt.hostJ( defaultMode_.asBuilderDefault(), @@ -74,6 +80,7 @@ public Host build() { muxers_.forEach(m -> b.getMuxers().add(m.get())); b.getProtocols().addAll(protocols_); listenAddresses_.forEach(a -> b.getNetwork().listen(a)); + builderModifier.accept(b); }); } // build @@ -84,4 +91,5 @@ public Host build() { private List> muxers_ = new ArrayList<>(); private List> protocols_ = new ArrayList<>(); private List listenAddresses_ = new ArrayList<>(); + private Consumer builderModifier = b -> {}; } diff --git a/libp2p/src/main/kotlin/io/libp2p/core/dsl/Builders.kt b/libp2p/src/main/kotlin/io/libp2p/core/dsl/Builders.kt index 971d35b07..ce1416dfd 100644 --- a/libp2p/src/main/kotlin/io/libp2p/core/dsl/Builders.kt +++ b/libp2p/src/main/kotlin/io/libp2p/core/dsl/Builders.kt @@ -23,6 +23,7 @@ import io.libp2p.core.security.SecureChannel import io.libp2p.core.transport.Transport import io.libp2p.etc.types.lazyVar import io.libp2p.etc.types.toProtobuf +import io.libp2p.etc.util.netty.LoggingHandlerShort import io.libp2p.host.HostImpl import io.libp2p.host.MemoryAddressBook import io.libp2p.network.NetworkImpl @@ -273,6 +274,10 @@ class DebugHandlerBuilder(var name: String) { fun addLogger(level: LogLevel, loggerName: String = name) { addNettyHandler(LoggingHandler(loggerName, level)) } + + fun addCompactLogger(level: LogLevel, loggerName: String = name) { + addNettyHandler(LoggingHandlerShort(loggerName, level)) + } } open class Enumeration(val values: MutableList = mutableListOf()) : MutableList by values { diff --git a/libp2p/src/test/java/io/libp2p/core/HostTestJava.java b/libp2p/src/test/java/io/libp2p/core/HostTestJava.java index 1f943159c..bd4f509e0 100644 --- a/libp2p/src/test/java/io/libp2p/core/HostTestJava.java +++ b/libp2p/src/test/java/io/libp2p/core/HostTestJava.java @@ -7,10 +7,11 @@ import io.libp2p.core.dsl.HostBuilder; import io.libp2p.core.multiformats.Multiaddr; import io.libp2p.core.mux.StreamMuxerProtocol; -import io.libp2p.protocol.Ping; -import io.libp2p.protocol.PingController; +import io.libp2p.protocol.*; +import io.libp2p.security.noise.*; import io.libp2p.security.tls.*; import io.libp2p.transport.tcp.TcpTransport; +import io.netty.handler.logging.LogLevel; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; @@ -153,6 +154,72 @@ void largePing() throws Exception { System.out.println("Server stopped"); } + @Test + void largeBlob() throws Exception { + int blobSize = 1024 * 1024; + String localListenAddress = "/ip4/127.0.0.1/tcp/40002"; + + Host clientHost = + new HostBuilder() + .transport(TcpTransport::new) + .secureChannel(NoiseXXSecureChannel::new) + .muxer(StreamMuxerProtocol::getYamux) + .builderModifier( + b -> b.getDebug().getMuxFramesHandler().addCompactLogger(LogLevel.ERROR, "client")) + .build(); + + Host serverHost = + new HostBuilder() + .transport(TcpTransport::new) + .secureChannel(NoiseXXSecureChannel::new) + .muxer(StreamMuxerProtocol::getYamux) + .protocol(new Blob(blobSize)) + .listen(localListenAddress) + .builderModifier( + b -> b.getDebug().getMuxFramesHandler().addCompactLogger(LogLevel.ERROR, "server")) + .build(); + + CompletableFuture clientStarted = clientHost.start(); + CompletableFuture serverStarted = serverHost.start(); + clientStarted.get(5, TimeUnit.SECONDS); + System.out.println("Client started"); + serverStarted.get(5, TimeUnit.SECONDS); + System.out.println("Server started"); + + Assertions.assertEquals(0, clientHost.listenAddresses().size()); + Assertions.assertEquals(1, serverHost.listenAddresses().size()); + Assertions.assertEquals( + localListenAddress + "/p2p/" + serverHost.getPeerId(), + serverHost.listenAddresses().get(0).toString()); + + StreamPromise blob = + clientHost + .getNetwork() + .connect(serverHost.getPeerId(), new Multiaddr(localListenAddress)) + .thenApply(it -> it.muxerSession().createStream(new Blob(blobSize))) + .join(); + + Stream blobStream = blob.getStream().get(5, TimeUnit.SECONDS); + System.out.println("Blob stream created"); + BlobController blobCtr = blob.getController().get(5, TimeUnit.SECONDS); + System.out.println("Blob controller created"); + + for (int i = 0; i < 10; i++) { + long latency = blobCtr.blob().join(); + System.out.println("Blob round trip is " + latency); + } + blobStream.close().get(5, TimeUnit.SECONDS); + System.out.println("Blob stream closed"); + + Assertions.assertThrows( + ExecutionException.class, () -> blobCtr.blob().get(5, TimeUnit.SECONDS)); + + clientHost.stop().get(5, TimeUnit.SECONDS); + System.out.println("Client stopped"); + serverHost.stop().get(5, TimeUnit.SECONDS); + System.out.println("Server stopped"); + } + @Test void addPingAfterHostStart() throws Exception { String localListenAddress = "/ip4/127.0.0.1/tcp/40002"; diff --git a/libp2p/src/test/kotlin/io/libp2p/protocol/Blob.kt b/libp2p/src/test/kotlin/io/libp2p/protocol/Blob.kt new file mode 100644 index 000000000..a763ee60f --- /dev/null +++ b/libp2p/src/test/kotlin/io/libp2p/protocol/Blob.kt @@ -0,0 +1,144 @@ +package io.libp2p.protocol + +import io.libp2p.core.BadPeerException +import io.libp2p.core.ConnectionClosedException +import io.libp2p.core.Libp2pException +import io.libp2p.core.Stream +import io.libp2p.core.multistream.StrictProtocolBinding +import io.libp2p.etc.types.completedExceptionally +import io.libp2p.etc.types.lazyVar +import io.libp2p.etc.types.toByteArray +import io.libp2p.etc.types.toHex +import io.netty.buffer.ByteBuf +import io.netty.channel.ChannelHandlerContext +import io.netty.handler.codec.ByteToMessageCodec +import java.time.Duration +import java.util.Collections +import java.util.Random +import java.util.concurrent.CompletableFuture +import java.util.concurrent.Executors +import java.util.concurrent.TimeUnit + +interface BlobController { + fun blob(): CompletableFuture +} + +class Blob(blobSize: Int) : BlobBinding(BlobProtocol(blobSize)) + +open class BlobBinding(blob: BlobProtocol) : + StrictProtocolBinding("/ipfs/blob-echo/1.0.0", blob) + +class BlobTimeoutException : Libp2pException() + +open class BlobProtocol(var blobSize: Int) : ProtocolHandler(Long.MAX_VALUE, Long.MAX_VALUE) { + var timeoutScheduler by lazyVar { Executors.newSingleThreadScheduledExecutor() } + var curTime: () -> Long = { System.currentTimeMillis() } + var random = Random() + var blobTimeout = Duration.ofSeconds(10) + + override fun onStartInitiator(stream: Stream): CompletableFuture { + val handler = BlobInitiator() + stream.pushHandler(BlobCodec()) + stream.pushHandler(handler) + stream.pushHandler(BlobCodec()) + return handler.activeFuture + } + + override fun onStartResponder(stream: Stream): CompletableFuture { + val handler = BlobResponder() + stream.pushHandler(BlobCodec()) + stream.pushHandler(BlobResponder()) + stream.pushHandler(BlobCodec()) + return CompletableFuture.completedFuture(handler) + } + + open class BlobCodec : ByteToMessageCodec() { + override fun encode(ctx: ChannelHandlerContext?, msg: ByteArray, out: ByteBuf) { + println("Codec::encode") + out.writeInt(msg.size) + out.writeBytes(msg) + } + + override fun decode(ctx: ChannelHandlerContext?, msg: ByteBuf, out: MutableList) { + println("Codec::decode " + msg.readableBytes()) + val readerIndex = msg.readerIndex() + if (msg.readableBytes() < 4) { + return + } + val len = msg.readInt() + if (msg.readableBytes() < len) { + // not enough data to read the full array + // will wait for more ... + msg.readerIndex(readerIndex) + return + } + val data = msg.readSlice(len) + out.add(data.toByteArray()) + } + } + + open inner class BlobResponder : ProtocolMessageHandler, BlobController { + override fun onMessage(stream: Stream, msg: ByteArray) { + println("Responder::onMessage") + stream.writeAndFlush(msg) + } + + override fun blob(): CompletableFuture { + throw Libp2pException("This is blob responder only") + } + } + + open inner class BlobInitiator : ProtocolMessageHandler, BlobController { + val activeFuture = CompletableFuture() + val requests = Collections.synchronizedMap(mutableMapOf>>()) + lateinit var stream: Stream + var closed = false + + override fun onActivated(stream: Stream) { + this.stream = stream + activeFuture.complete(this) + } + + override fun onMessage(stream: Stream, msg: ByteArray) { + println("Initiator::onMessage") + val dataS = msg.toHex() + val (sentT, future) = requests.remove(dataS) + ?: throw BadPeerException("Unknown or expired blob data in response: $dataS") + future.complete(curTime() - sentT) + } + + override fun onClosed(stream: Stream) { + synchronized(requests) { + closed = true + requests.values.forEach { it.second.completeExceptionally(ConnectionClosedException()) } + requests.clear() + timeoutScheduler.shutdownNow() + } + activeFuture.completeExceptionally(ConnectionClosedException()) + } + + override fun blob(): CompletableFuture { + val ret = CompletableFuture() + val arr = ByteArray(blobSize) + random.nextBytes(arr) + val dataS = arr.toHex() + + synchronized(requests) { + if (closed) return completedExceptionally(ConnectionClosedException()) + requests[dataS] = curTime() to ret + + timeoutScheduler.schedule( + { + requests.remove(dataS)?.second?.completeExceptionally(BlobTimeoutException()) + }, + blobTimeout.toMillis(), + TimeUnit.MILLISECONDS + ) + } + + println("Sender writing " + blobSize) + stream.writeAndFlush(arr) + return ret + } + } +} From ee02cf94d5bcc467e1f116b4b012d916a19009a2 Mon Sep 17 00:00:00 2001 From: Anton Nashatyrev Date: Wed, 11 Oct 2023 15:37:48 +0300 Subject: [PATCH 08/13] Fix the case when a stream is closed while still having buffered data for write (#330) * Fix the case when a stream is closed while still having buffered data for write * Add unit test for close case when outbound data buffered --- .../kotlin/io/libp2p/etc/types/Delegates.kt | 16 +++++++ .../kotlin/io/libp2p/mux/MuxerException.kt | 1 + .../io/libp2p/mux/yamux/YamuxHandler.kt | 23 +++++++--- .../io/libp2p/mux/yamux/YamuxHandlerTest.kt | 43 +++++++++++++++++++ 4 files changed, 76 insertions(+), 7 deletions(-) diff --git a/libp2p/src/main/kotlin/io/libp2p/etc/types/Delegates.kt b/libp2p/src/main/kotlin/io/libp2p/etc/types/Delegates.kt index ea44d904a..a67a3c864 100644 --- a/libp2p/src/main/kotlin/io/libp2p/etc/types/Delegates.kt +++ b/libp2p/src/main/kotlin/io/libp2p/etc/types/Delegates.kt @@ -1,5 +1,6 @@ package io.libp2p.etc.types +import kotlin.properties.Delegates import kotlin.properties.ReadWriteProperty import kotlin.reflect.KProperty @@ -92,3 +93,18 @@ data class CappedValueDelegate>( } } } + +fun Delegates.writeOnce(initialValue: T): ReadWriteProperty = object : ReadWriteProperty { + private var value: T = initialValue + private var wasSet = false + + public override fun getValue(thisRef: Any?, property: KProperty<*>): T { + return value + } + + public override fun setValue(thisRef: Any?, property: KProperty<*>, value: T) { + if (wasSet) throw IllegalStateException("Property ${property.name} cannot be set more than once.") + this.value = value + wasSet = true + } +} diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/MuxerException.kt b/libp2p/src/main/kotlin/io/libp2p/mux/MuxerException.kt index 1ba4eaa1b..b156aaf32 100644 --- a/libp2p/src/main/kotlin/io/libp2p/mux/MuxerException.kt +++ b/libp2p/src/main/kotlin/io/libp2p/mux/MuxerException.kt @@ -13,3 +13,4 @@ class UnknownStreamIdMuxerException(muxId: MuxId) : ReadMuxerException("Stream w class InvalidFrameMuxerException(message: String) : ReadMuxerException(message, null) class WriteBufferOverflowMuxerException(message: String) : WriteMuxerException(message, null) +class ClosedForWritingMuxerException(muxId: MuxId) : WriteMuxerException("Couldn't write, stream was closed for writing: $muxId", null) diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt index ece12d4e7..659c40ab6 100644 --- a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt +++ b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt @@ -6,9 +6,11 @@ import io.libp2p.core.StreamHandler import io.libp2p.core.multistream.MultistreamProtocol import io.libp2p.core.mux.StreamMuxer import io.libp2p.etc.types.sliceMaxSize +import io.libp2p.etc.types.writeOnce import io.libp2p.etc.util.netty.ByteBufQueue import io.libp2p.etc.util.netty.mux.MuxChannel import io.libp2p.etc.util.netty.mux.MuxId +import io.libp2p.mux.ClosedForWritingMuxerException import io.libp2p.mux.InvalidFrameMuxerException import io.libp2p.mux.MuxHandler import io.libp2p.mux.UnknownStreamIdMuxerException @@ -19,6 +21,7 @@ import java.util.concurrent.CompletableFuture import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicInteger import kotlin.math.max +import kotlin.properties.Delegates const val INITIAL_WINDOW_SIZE = 256 * 1024 const val DEFAULT_MAX_BUFFERED_CONNECTION_WRITES = 10 * 1024 * 1024 // 10 MiB @@ -39,6 +42,7 @@ open class YamuxHandler( val sendWindowSize = AtomicInteger(initialWindowSize) val receiveWindowSize = AtomicInteger(initialWindowSize) val sendBuffer = ByteBufQueue() + var closedForWriting by Delegates.writeOnce(false) fun dispose() { sendBuffer.dispose() @@ -72,7 +76,7 @@ open class YamuxHandler( val delta = msg.length.toInt() sendWindowSize.addAndGet(delta) // try to send any buffered messages after the window update - drainBuffer() + drainBufferAndMaybeClose() } private fun handleFlags(msg: YamuxFrame) { @@ -98,7 +102,7 @@ open class YamuxHandler( } } - private fun drainBuffer() { + private fun drainBufferAndMaybeClose() { val maxSendLength = max(0, sendWindowSize.get()) val data = sendBuffer.take(maxSendLength) sendWindowSize.addAndGet(-data.readableBytes()) @@ -107,11 +111,18 @@ open class YamuxHandler( val length = slicedData.readableBytes() writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, 0, length.toLong(), slicedData)) } + + if (closedForWriting && sendBuffer.readableBytes() == 0) { + writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlags.FIN, 0)) + } } fun sendData(data: ByteBuf) { + if (closedForWriting) { + throw ClosedForWritingMuxerException(id) + } fillBuffer(data) - drainBuffer() + drainBufferAndMaybeClose() } fun onLocalOpen() { @@ -123,10 +134,8 @@ open class YamuxHandler( } fun onLocalDisconnect() { - // TODO: this implementation drops remaining data - drainBuffer() - sendBuffer.dispose() - writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlags.FIN, 0)) + closedForWriting = true + drainBufferAndMaybeClose() } fun onLocalClose() { diff --git a/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt b/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt index 5f239f081..8b7218dca 100644 --- a/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt @@ -407,6 +407,49 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { msgPart3.data!!.release() } + @Test + fun `local close for writing should flush buffered data and send close frame on writeWindow update`() { + val handler = openStreamLocal() + val muxId = readFrameOrThrow().streamId.toMuxId() + + val msg = "42".repeat(initialWindowSize + 1).fromHex().toByteBuf(allocateBuf()) + // writing a message which is larger than sendWindowSize + handler.ctx.writeAndFlush(msg) + + val msgPart1 = readYamuxFrameOrThrow() + assertThat(msgPart1.length).isEqualTo(256L) + assertThat(msgPart1.data!!.readableBytes()).isEqualTo(256) + msgPart1.data!!.release() + + val msgPart2 = readYamuxFrameOrThrow() + assertThat(msgPart2.length.toInt()).isEqualTo(initialWindowSize - 256) + assertThat(msgPart2.data!!.readableBytes()).isEqualTo(initialWindowSize - 256) + msgPart2.data!!.release() + + // locally close for writing while some outbound data is still buffered + handler.ctx.disconnect() + + // ACKing message receive + ech.writeInbound( + YamuxFrame( + muxId, + YamuxType.WINDOW_UPDATE, + YamuxFlags.ACK, + initialWindowSize.toLong() + ) + ) + + val msgPart3 = readYamuxFrameOrThrow() + assertThat(msgPart3.length).isEqualTo(1L) + assertThat(msgPart3.data!!.readableBytes()).isEqualTo(1) + msgPart3.data!!.release() + + val closeFrame = readYamuxFrameOrThrow() + assertThat(closeFrame.flags).isEqualTo(YamuxFlags.FIN) + assertThat(closeFrame.length).isEqualTo(0L) + assertThat(closeFrame.data).isNull() + } + companion object { private fun YamuxStreamIdGenerator.toIterator() = iterator { while (true) { From e8438361eb07518e77fcd18ec73048d7a06d6d6f Mon Sep 17 00:00:00 2001 From: Anton Nashatyrev Date: Wed, 11 Oct 2023 16:05:10 +0300 Subject: [PATCH 09/13] Refactor Yamux flags (#338) * Convert YamuxType to enum * Refactor YamuxFlags: convert them to Set of enum values. --- .../kotlin/io/libp2p/mux/yamux/YamuxFlag.kt | 34 +++++++++++ .../kotlin/io/libp2p/mux/yamux/YamuxFlags.kt | 11 ---- .../kotlin/io/libp2p/mux/yamux/YamuxFrame.kt | 12 ++-- .../io/libp2p/mux/yamux/YamuxFrameCodec.kt | 17 +++--- .../io/libp2p/mux/yamux/YamuxHandler.kt | 31 +++++----- .../kotlin/io/libp2p/mux/yamux/YamuxType.kt | 19 +++++-- .../io/libp2p/mux/yamux/YamuxHandlerTest.kt | 56 +++++++++---------- 7 files changed, 106 insertions(+), 74 deletions(-) create mode 100644 libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFlag.kt delete mode 100644 libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFlags.kt diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFlag.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFlag.kt new file mode 100644 index 000000000..34f9a10d2 --- /dev/null +++ b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFlag.kt @@ -0,0 +1,34 @@ +package io.libp2p.mux.yamux + +import io.libp2p.mux.InvalidFrameMuxerException + +/** + * Contains all the permissible values for flags in the yamux protocol. + */ +enum class YamuxFlag(val intFlag: Int) { + SYN(1), + ACK(2), + FIN(4), + RST(8); + + val asSet: Set = setOf(this) + + companion object { + val NONE = emptySet() + + private val validFlagCombinations = mapOf( + 0 to NONE, + SYN.intFlag to SYN.asSet, + ACK.intFlag to ACK.asSet, + FIN.intFlag to FIN.asSet, + RST.intFlag to RST.asSet, + ) + + fun fromInt(flags: Int) = + validFlagCombinations[flags] ?: throw InvalidFrameMuxerException("Invalid Yamux flags value: $flags") + + fun Set.toInt() = this + .fold(0) { acc, flag -> acc or flag.intFlag } + .also { require(it in validFlagCombinations) { "Invalid Yamux flags combination: $this" } } + } +} diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFlags.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFlags.kt deleted file mode 100644 index 85499d0dd..000000000 --- a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFlags.kt +++ /dev/null @@ -1,11 +0,0 @@ -package io.libp2p.mux.yamux - -/** - * Contains all the permissible values for flags in the yamux protocol. - */ -object YamuxFlags { - const val SYN = 1 - const val ACK = 2 - const val FIN = 4 - const val RST = 8 -} diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrame.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrame.kt index 32bd32e6a..c35dcea88 100644 --- a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrame.kt +++ b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrame.kt @@ -1,6 +1,5 @@ package io.libp2p.mux.yamux -import io.libp2p.etc.types.toByteArray import io.libp2p.etc.util.netty.mux.MuxId import io.netty.buffer.ByteBuf import io.netty.buffer.DefaultByteBufHolder @@ -9,17 +8,16 @@ import io.netty.buffer.Unpooled /** * Contains the fields that comprise a yamux frame. * @param id the ID of the stream. - * @param flags the flags value for this frame. + * @param flags the flags for this frame. * @param length the length field for this frame. * @param data the data segment. */ -class YamuxFrame(val id: MuxId, val type: Int, val flags: Int, val length: Long, val data: ByteBuf? = null) : +class YamuxFrame(val id: MuxId, val type: YamuxType, val flags: Set, val length: Long, val data: ByteBuf? = null) : DefaultByteBufHolder(data ?: Unpooled.EMPTY_BUFFER) { override fun toString(): String { - if (data == null) { - return "YamuxFrame(id=$id, type=$type, flags=$flags, length=$length)" - } - return "YamuxFrame(id=$id, type=$type, flags=$flags, length=$length, data=${String(data.toByteArray())})" + val dataString = if (data == null) "" else ", len=${data.readableBytes()}, $data" + val flagsString = flags.joinToString("+") + return "YamuxFrame(id=$id, type=$type, flags=$flagsString, length=$length$dataString)" } } diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrameCodec.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrameCodec.kt index d85696508..f2db941ec 100644 --- a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrameCodec.kt +++ b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrameCodec.kt @@ -1,6 +1,7 @@ package io.libp2p.mux.yamux import io.libp2p.core.ProtocolViolationException +import io.libp2p.mux.yamux.YamuxFlag.Companion.toInt import io.netty.buffer.ByteBuf import io.netty.buffer.Unpooled import io.netty.channel.ChannelHandlerContext @@ -24,8 +25,8 @@ class YamuxFrameCodec( */ override fun encode(ctx: ChannelHandlerContext, msg: YamuxFrame, out: ByteBuf) { out.writeByte(0) // version - out.writeByte(msg.type) - out.writeShort(msg.flags) + out.writeByte(msg.type.intValue) + out.writeShort(msg.flags.toInt()) out.writeInt(msg.id.id.toInt()) out.writeInt(msg.data?.readableBytes() ?: msg.length.toInt()) out.writeBytes(msg.data ?: Unpooled.EMPTY_BUFFER) @@ -46,15 +47,17 @@ class YamuxFrameCodec( val readerIndex = msg.readerIndex() msg.readByte(); // version always 0 val type = msg.readUnsignedByte() + val yamuxType = YamuxType.fromInt(type.toInt()) val flags = msg.readUnsignedShort() val streamId = msg.readUnsignedInt() val length = msg.readUnsignedInt() val yamuxId = YamuxId(ctx.channel().id(), streamId) - if (type.toInt() != YamuxType.DATA) { + val yamuxFlags = YamuxFlag.fromInt(flags) + if (yamuxType != YamuxType.DATA) { val yamuxFrame = YamuxFrame( yamuxId, - type.toInt(), - flags, + yamuxType, + yamuxFlags, length ) out.add(yamuxFrame) @@ -74,8 +77,8 @@ class YamuxFrameCodec( data.retain() // MessageToMessageCodec releases original buffer, but it needs to be relayed val yamuxFrame = YamuxFrame( yamuxId, - type.toInt(), - flags, + yamuxType, + yamuxFlags, length, data ) diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt index 659c40ab6..65339c57f 100644 --- a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt +++ b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt @@ -53,6 +53,7 @@ open class YamuxHandler( when (msg.type) { YamuxType.DATA -> handleDataRead(msg) YamuxType.WINDOW_UPDATE -> handleWindowUpdate(msg) + else -> { /* ignore */ } } } @@ -67,7 +68,7 @@ open class YamuxHandler( if (newWindow < initialWindowSize / 2) { val delta = initialWindowSize - newWindow receiveWindowSize.addAndGet(delta) - writeAndFlushFrame(YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, 0, delta.toLong())) + writeAndFlushFrame(YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, YamuxFlag.NONE, delta.toLong())) } childRead(msg.id, msg.data!!) } @@ -80,14 +81,14 @@ open class YamuxHandler( } private fun handleFlags(msg: YamuxFrame) { - when (msg.flags) { - YamuxFlags.SYN -> { + when { + YamuxFlag.SYN in msg.flags -> { // ACK the new stream - writeAndFlushFrame(YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, YamuxFlags.ACK, 0)) + writeAndFlushFrame(YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, YamuxFlag.ACK.asSet, 0)) } - YamuxFlags.FIN -> onRemoteDisconnect(msg.id) - YamuxFlags.RST -> onRemoteClose(msg.id) + YamuxFlag.FIN in msg.flags -> onRemoteDisconnect(msg.id) + YamuxFlag.RST in msg.flags -> onRemoteClose(msg.id) } } @@ -109,11 +110,11 @@ open class YamuxHandler( data.sliceMaxSize(maxFrameDataLength) .forEach { slicedData -> val length = slicedData.readableBytes() - writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, 0, length.toLong(), slicedData)) + writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlag.NONE, length.toLong(), slicedData)) } if (closedForWriting && sendBuffer.readableBytes() == 0) { - writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlags.FIN, 0)) + writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlag.FIN.asSet, 0)) } } @@ -126,7 +127,7 @@ open class YamuxHandler( } fun onLocalOpen() { - writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlags.SYN, 0)) + writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlag.SYN.asSet, 0)) } fun onRemoteOpen() { @@ -141,7 +142,7 @@ open class YamuxHandler( fun onLocalClose() { // close stream immediately so not transferring buffered data sendBuffer.dispose() - writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlags.RST, 0)) + writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlag.RST.asSet, 0)) } } @@ -181,7 +182,7 @@ open class YamuxHandler( YamuxType.PING -> handlePing(msg) YamuxType.GO_AWAY -> handleGoAway(msg) else -> { - if (msg.flags == YamuxFlags.SYN) { + if (YamuxFlag.SYN in msg.flags) { // remote opens a new stream validateSynRemoteMuxId(msg.id) onRemoteYamuxOpen(msg.id) @@ -247,17 +248,15 @@ open class YamuxHandler( if (msg.id.id != YamuxId.SESSION_STREAM_ID) { throw InvalidFrameMuxerException("Invalid StreamId for Ping frame type: ${msg.id}") } - when (msg.flags) { - YamuxFlags.SYN -> writeAndFlushFrame( + if (YamuxFlag.SYN in msg.flags) { + writeAndFlushFrame( YamuxFrame( YamuxId.sessionId(msg.id.parentId), YamuxType.PING, - YamuxFlags.ACK, + YamuxFlag.ACK.asSet, msg.length ) ) - - YamuxFlags.ACK -> {} } } diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxType.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxType.kt index 0746c8cf8..db779e7f9 100644 --- a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxType.kt +++ b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxType.kt @@ -1,11 +1,20 @@ package io.libp2p.mux.yamux +import io.libp2p.mux.InvalidFrameMuxerException + /** * Contains all the permissible values for types in the yamux protocol. */ -object YamuxType { - const val DATA = 0 - const val WINDOW_UPDATE = 1 - const val PING = 2 - const val GO_AWAY = 3 +enum class YamuxType(val intValue: Int) { + DATA(0), + WINDOW_UPDATE(1), + PING(2), + GO_AWAY(3); + + companion object { + private val intToTypeCache = values().associateBy { it.intValue } + + fun fromInt(intValue: Int): YamuxType = + intToTypeCache[intValue] ?: throw InvalidFrameMuxerException("Invalid Yamux type value: $intValue") + } } diff --git a/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt b/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt index 8b7218dca..b85e95733 100644 --- a/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt @@ -46,20 +46,20 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { override fun writeFrame(frame: AbstractTestMuxFrame) { val muxId = frame.streamId.toMuxId() val yamuxFrame = when (frame.flag) { - Open -> YamuxFrame(muxId, YamuxType.DATA, YamuxFlags.SYN, 0) + Open -> YamuxFrame(muxId, YamuxType.DATA, YamuxFlag.SYN.asSet, 0) Data -> { val data = frame.data.fromHex() YamuxFrame( muxId, YamuxType.DATA, - 0, + YamuxFlag.NONE, data.size.toLong(), data.toByteBuf(allocateBuf()) ) } - Close -> YamuxFrame(muxId, YamuxType.DATA, YamuxFlags.FIN, 0) - Reset -> YamuxFrame(muxId, YamuxType.DATA, YamuxFlags.RST, 0) + Close -> YamuxFrame(muxId, YamuxType.DATA, YamuxFlag.FIN.asSet, 0) + Reset -> YamuxFrame(muxId, YamuxType.DATA, YamuxFlag.RST.asSet, 0) } ech.writeInbound(yamuxFrame) } @@ -67,8 +67,8 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { override fun readFrame(): AbstractTestMuxFrame? { val yamuxFrame = readYamuxFrame() if (yamuxFrame != null) { - when (yamuxFrame.flags) { - YamuxFlags.SYN -> readFrameQueue += AbstractTestMuxFrame(yamuxFrame.id.id, Open) + when { + YamuxFlag.SYN in yamuxFrame.flags -> readFrameQueue += AbstractTestMuxFrame(yamuxFrame.id.id, Open) } val data = yamuxFrame.data?.readAllBytesAndRelease()?.toHex() ?: "" @@ -77,9 +77,9 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { readFrameQueue += AbstractTestMuxFrame(yamuxFrame.id.id, Data, data) } - when (yamuxFrame.flags) { - YamuxFlags.FIN -> readFrameQueue += AbstractTestMuxFrame(yamuxFrame.id.id, Close) - YamuxFlags.RST -> readFrameQueue += AbstractTestMuxFrame(yamuxFrame.id.id, Reset) + when { + YamuxFlag.FIN in yamuxFrame.flags -> readFrameQueue += AbstractTestMuxFrame(yamuxFrame.id.id, Close) + YamuxFlag.RST in yamuxFrame.flags -> readFrameQueue += AbstractTestMuxFrame(yamuxFrame.id.id, Reset) } } @@ -102,7 +102,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { val ackFrame = readYamuxFrameOrThrow() // receives ack stream - assertThat(ackFrame.flags).isEqualTo(YamuxFlags.ACK) + assertThat(ackFrame.flags).containsExactly(YamuxFlag.ACK) assertThat(ackFrame.type).isEqualTo(YamuxType.WINDOW_UPDATE) closeStream(12) @@ -119,7 +119,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { YamuxFrame( streamId.toMuxId(), YamuxType.DATA, - 0, + YamuxFlag.NONE, length.toLong(), "42".repeat(length).fromHex().toByteBuf(allocateBuf()) ) @@ -128,7 +128,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { val windowUpdateFrame = readYamuxFrameOrThrow() // window frame is sent based on the new window - assertThat(windowUpdateFrame.flags).isZero() + assertThat(windowUpdateFrame.flags).isEmpty() assertThat(windowUpdateFrame.type).isEqualTo(YamuxType.WINDOW_UPDATE) assertThat(windowUpdateFrame.length).isEqualTo(length.toLong()) } @@ -142,7 +142,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { YamuxFrame( streamId.toMuxId(), YamuxType.WINDOW_UPDATE, - YamuxFlags.ACK, + YamuxFlag.ACK.asSet, -initialWindowSize.toLong() ) ) @@ -151,7 +151,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { assertThat(readFrame()).isNull() - ech.writeInbound(YamuxFrame(streamId.toMuxId(), YamuxType.WINDOW_UPDATE, YamuxFlags.ACK, 5000)) + ech.writeInbound(YamuxFrame(streamId.toMuxId(), YamuxType.WINDOW_UPDATE, YamuxFlag.ACK.asSet, 5000)) val frame = readFrameOrThrow() assertThat(frame.data).isEqualTo("1984") } @@ -165,7 +165,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { YamuxFrame( streamId.toMuxId(), YamuxType.WINDOW_UPDATE, - YamuxFlags.ACK, + YamuxFlag.ACK.asSet, -initialWindowSize.toLong() ) ) @@ -181,7 +181,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { YamuxFrame( streamId.toMuxId(), YamuxType.WINDOW_UPDATE, - YamuxFlags.ACK, + YamuxFlag.ACK.asSet, 2 ) ) @@ -196,7 +196,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { YamuxFrame( streamId.toMuxId(), YamuxType.WINDOW_UPDATE, - YamuxFlags.ACK, + YamuxFlag.ACK.asSet, 1 ) ) @@ -207,7 +207,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { YamuxFrame( streamId.toMuxId(), YamuxType.WINDOW_UPDATE, - YamuxFlags.ACK, + YamuxFlag.ACK.asSet, 10000 ) ) @@ -224,7 +224,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { YamuxFrame( muxId, YamuxType.WINDOW_UPDATE, - YamuxFlags.ACK, + YamuxFlag.ACK.asSet, -initialWindowSize.toLong() ) ) @@ -245,7 +245,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { .hasMessage("Overflowed send buffer (612/512). Last stream attempting to write: $muxId") val frame = readYamuxFrameOrThrow() - assertThat(frame.flags).isEqualTo(YamuxFlags.RST) + assertThat(frame.flags).containsExactly(YamuxFlag.RST) } @Test @@ -261,7 +261,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { YamuxFrame( streamId.toMuxId(), YamuxType.WINDOW_UPDATE, - YamuxFlags.ACK, + YamuxFlag.ACK.asSet, it.toLong() ) ) @@ -308,7 +308,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { YamuxFrame( id.toMuxId(), YamuxType.PING, - YamuxFlags.SYN, + YamuxFlag.SYN.asSet, // opaque value, echoed back 3 ) @@ -316,7 +316,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { val pingFrame = readYamuxFrameOrThrow() - assertThat(pingFrame.flags).isEqualTo(YamuxFlags.ACK) + assertThat(pingFrame.flags).containsExactly(YamuxFlag.ACK) assertThat(pingFrame.type).isEqualTo(YamuxType.PING) assertThat(pingFrame.length).isEqualTo(3) } @@ -328,7 +328,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { YamuxFrame( id.toMuxId(), YamuxType.GO_AWAY, - 0, + YamuxFlag.NONE, // normal termination 0x2 ) @@ -374,7 +374,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { YamuxFrame( muxId, YamuxType.WINDOW_UPDATE, - YamuxFlags.ACK, + YamuxFlag.ACK.asSet, -10 ) ) @@ -396,7 +396,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { YamuxFrame( muxId, YamuxType.WINDOW_UPDATE, - YamuxFlags.ACK, + YamuxFlag.ACK.asSet, initialWindowSize.toLong() ) ) @@ -434,7 +434,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { YamuxFrame( muxId, YamuxType.WINDOW_UPDATE, - YamuxFlags.ACK, + YamuxFlag.ACK.asSet, initialWindowSize.toLong() ) ) @@ -445,7 +445,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { msgPart3.data!!.release() val closeFrame = readYamuxFrameOrThrow() - assertThat(closeFrame.flags).isEqualTo(YamuxFlags.FIN) + assertThat(closeFrame.flags).containsExactly(YamuxFlag.FIN) assertThat(closeFrame.length).isEqualTo(0L) assertThat(closeFrame.data).isNull() } From e363c49b5eae73e526cc73649195e0644e78d32d Mon Sep 17 00:00:00 2001 From: Stefan Bratanov Date: Thu, 26 Oct 2023 13:56:03 +0100 Subject: [PATCH 10/13] [Yamux] Allow max ACK backlog of 256 streams (#340) --- .../io/libp2p/core/mux/StreamMuxerProtocol.kt | 10 ++- .../kotlin/io/libp2p/etc/types/AsyncExt.kt | 8 +++ .../etc/util/netty/mux/AbstractMuxHandler.kt | 7 +- .../main/kotlin/io/libp2p/mux/MuxHandler.kt | 5 +- .../kotlin/io/libp2p/mux/MuxerException.kt | 5 +- .../io/libp2p/mux/yamux/YamuxHandler.kt | 65 ++++++++++++++----- .../io/libp2p/mux/yamux/YamuxStreamMuxer.kt | 6 +- .../io/libp2p/mux/yamux/YamuxHandlerTest.kt | 39 +++++++++++ 8 files changed, 119 insertions(+), 26 deletions(-) diff --git a/libp2p/src/main/kotlin/io/libp2p/core/mux/StreamMuxerProtocol.kt b/libp2p/src/main/kotlin/io/libp2p/core/mux/StreamMuxerProtocol.kt index 3f7f460a0..878e74d05 100644 --- a/libp2p/src/main/kotlin/io/libp2p/core/mux/StreamMuxerProtocol.kt +++ b/libp2p/src/main/kotlin/io/libp2p/core/mux/StreamMuxerProtocol.kt @@ -3,6 +3,7 @@ package io.libp2p.core.mux import io.libp2p.core.multistream.MultistreamProtocol import io.libp2p.core.multistream.ProtocolBinding import io.libp2p.mux.mplex.MplexStreamMuxer +import io.libp2p.mux.yamux.DEFAULT_ACK_BACKLOG_LIMIT import io.libp2p.mux.yamux.DEFAULT_MAX_BUFFERED_CONNECTION_WRITES import io.libp2p.mux.yamux.YamuxStreamMuxer @@ -23,17 +24,22 @@ fun interface StreamMuxerProtocol { /** * @param maxBufferedConnectionWrites the maximum amount of bytes in the write buffer per connection + * @param ackBacklogLimit the maximum amount of opened streams per connection which have not been acknowledged */ @JvmStatic @JvmOverloads - fun getYamux(maxBufferedConnectionWrites: Int = DEFAULT_MAX_BUFFERED_CONNECTION_WRITES): StreamMuxerProtocol { + fun getYamux( + maxBufferedConnectionWrites: Int = DEFAULT_MAX_BUFFERED_CONNECTION_WRITES, + ackBacklogLimit: Int = DEFAULT_ACK_BACKLOG_LIMIT + ): StreamMuxerProtocol { return StreamMuxerProtocol { multistreamProtocol, protocols -> YamuxStreamMuxer( multistreamProtocol.createMultistream( protocols ).toStreamHandler(), multistreamProtocol, - maxBufferedConnectionWrites + maxBufferedConnectionWrites, + ackBacklogLimit ) } } diff --git a/libp2p/src/main/kotlin/io/libp2p/etc/types/AsyncExt.kt b/libp2p/src/main/kotlin/io/libp2p/etc/types/AsyncExt.kt index 7dcbf57c9..176650fd3 100644 --- a/libp2p/src/main/kotlin/io/libp2p/etc/types/AsyncExt.kt +++ b/libp2p/src/main/kotlin/io/libp2p/etc/types/AsyncExt.kt @@ -19,6 +19,14 @@ fun CompletableFuture.bind(result: CompletableFuture) { fun CompletableFuture.forward(forwardTo: CompletableFuture) = forwardTo.bind(this) +fun CompletableFuture.forwardException(forwardTo: CompletableFuture): CompletableFuture { + return whenComplete { _, t -> + if (t != null) { + forwardTo.completeExceptionally(t) + } + } +} + /** * The same as [CompletableFuture.get] but unwraps [ExecutionException] */ diff --git a/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/AbstractMuxHandler.kt b/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/AbstractMuxHandler.kt index e072ef8b5..a5c49b175 100644 --- a/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/AbstractMuxHandler.kt +++ b/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/AbstractMuxHandler.kt @@ -9,7 +9,6 @@ import io.netty.channel.ChannelHandlerContext import io.netty.channel.ChannelInboundHandlerAdapter import org.slf4j.LoggerFactory import java.util.concurrent.CompletableFuture -import java.util.function.Function typealias MuxChannelInitializer = (MuxChannel) -> Unit @@ -61,10 +60,12 @@ abstract class AbstractMuxHandler() : releaseMessage(msg) throw ConnectionClosedException("Channel with id $id not opened") } + child.remoteDisconnected -> { releaseMessage(msg) throw ConnectionClosedException("Channel with id $id was closed for sending by remote") } + else -> { pendingReadComplete += id child.pipeline().fireChannelRead(msg) @@ -136,7 +137,7 @@ abstract class AbstractMuxHandler() : ): MuxChannel { val child = MuxChannel(this, id, initializer, initiator) streamMap[id] = child - ctx!!.channel().eventLoop().register(child) + ctx!!.channel().eventLoop().register(child).sync() return child } @@ -148,7 +149,7 @@ abstract class AbstractMuxHandler() : try { checkClosed() // if already closed then event loop is already down and async task may never execute return activeFuture.thenApplyAsync( - Function { + { checkClosed() // close may happen after above check and before this point val child = createChild( generateNextId(), diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/MuxHandler.kt b/libp2p/src/main/kotlin/io/libp2p/mux/MuxHandler.kt index 08a6bd12b..51bcc275c 100644 --- a/libp2p/src/main/kotlin/io/libp2p/mux/MuxHandler.kt +++ b/libp2p/src/main/kotlin/io/libp2p/mux/MuxHandler.kt @@ -9,6 +9,7 @@ import io.libp2p.core.mux.StreamMuxer import io.libp2p.etc.CONNECTION import io.libp2p.etc.STREAM import io.libp2p.etc.types.forward +import io.libp2p.etc.types.forwardException import io.libp2p.etc.util.netty.mux.AbstractMuxHandler import io.libp2p.etc.util.netty.mux.MuxChannel import io.libp2p.etc.util.netty.mux.MuxChannelInitializer @@ -49,7 +50,9 @@ abstract class MuxHandler( val controller = CompletableFuture() val stream = newStream { streamHandler.handleStream(createStream(it)).forward(controller) - }.thenApply { it.attr(STREAM).get() } + } + .thenApply { it.attr(STREAM).get() } + .forwardException(controller) return StreamPromise(stream, controller) } diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/MuxerException.kt b/libp2p/src/main/kotlin/io/libp2p/mux/MuxerException.kt index b156aaf32..b424d7caa 100644 --- a/libp2p/src/main/kotlin/io/libp2p/mux/MuxerException.kt +++ b/libp2p/src/main/kotlin/io/libp2p/mux/MuxerException.kt @@ -5,6 +5,8 @@ import io.libp2p.etc.util.netty.mux.MuxId open class MuxerException(message: String, ex: Exception?) : Libp2pException(message, ex) +class AckBacklogLimitExceededMuxerException(message: String) : MuxerException(message, null) + open class ReadMuxerException(message: String, ex: Exception?) : MuxerException(message, ex) open class WriteMuxerException(message: String, ex: Exception?) : MuxerException(message, ex) @@ -13,4 +15,5 @@ class UnknownStreamIdMuxerException(muxId: MuxId) : ReadMuxerException("Stream w class InvalidFrameMuxerException(message: String) : ReadMuxerException(message, null) class WriteBufferOverflowMuxerException(message: String) : WriteMuxerException(message, null) -class ClosedForWritingMuxerException(muxId: MuxId) : WriteMuxerException("Couldn't write, stream was closed for writing: $muxId", null) +class ClosedForWritingMuxerException(muxId: MuxId) : + WriteMuxerException("Couldn't write, stream was closed for writing: $muxId", null) diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt index 65339c57f..bdde3478b 100644 --- a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt +++ b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt @@ -10,21 +10,20 @@ import io.libp2p.etc.types.writeOnce import io.libp2p.etc.util.netty.ByteBufQueue import io.libp2p.etc.util.netty.mux.MuxChannel import io.libp2p.etc.util.netty.mux.MuxId -import io.libp2p.mux.ClosedForWritingMuxerException -import io.libp2p.mux.InvalidFrameMuxerException -import io.libp2p.mux.MuxHandler -import io.libp2p.mux.UnknownStreamIdMuxerException -import io.libp2p.mux.WriteBufferOverflowMuxerException +import io.libp2p.mux.* import io.netty.buffer.ByteBuf import io.netty.channel.ChannelHandlerContext import java.util.concurrent.CompletableFuture import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.atomic.AtomicInteger import kotlin.math.max import kotlin.properties.Delegates -const val INITIAL_WINDOW_SIZE = 256 * 1024 const val DEFAULT_MAX_BUFFERED_CONNECTION_WRITES = 10 * 1024 * 1024 // 10 MiB +const val DEFAULT_ACK_BACKLOG_LIMIT = 256 + +const val INITIAL_WINDOW_SIZE = 256 * 1024 open class YamuxHandler( override val multistreamProtocol: MultistreamProtocol, @@ -33,12 +32,15 @@ open class YamuxHandler( inboundStreamHandler: StreamHandler<*>, private val connectionInitiator: Boolean, private val maxBufferedConnectionWrites: Int, + private val ackBacklogLimit: Int, private val initialWindowSize: Int = INITIAL_WINDOW_SIZE ) : MuxHandler(ready, inboundStreamHandler) { private inner class YamuxStreamHandler( - val id: MuxId + val id: MuxId, + val outbound: Boolean ) { + val acknowledged = AtomicBoolean(false) val sendWindowSize = AtomicInteger(initialWindowSize) val receiveWindowSize = AtomicInteger(initialWindowSize) val sendBuffer = ByteBufQueue() @@ -53,7 +55,9 @@ open class YamuxHandler( when (msg.type) { YamuxType.DATA -> handleDataRead(msg) YamuxType.WINDOW_UPDATE -> handleWindowUpdate(msg) - else -> { /* ignore */ } + else -> { + /* ignore */ + } } } @@ -62,7 +66,7 @@ open class YamuxHandler( if (size == 0) { return } - + acknowledgeInboundStreamIfNeeded() val newWindow = receiveWindowSize.addAndGet(-size) // send a window update frame once half of the window is depleted if (newWindow < initialWindowSize / 2) { @@ -87,11 +91,27 @@ open class YamuxHandler( writeAndFlushFrame(YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, YamuxFlag.ACK.asSet, 0)) } + YamuxFlag.ACK in msg.flags -> { + acknowledgeOutboundStreamIfNeeded() + } + YamuxFlag.FIN in msg.flags -> onRemoteDisconnect(msg.id) YamuxFlag.RST in msg.flags -> onRemoteClose(msg.id) } } + private fun acknowledgeInboundStreamIfNeeded() { + if (!outbound) { + acknowledged.set(true) + } + } + + private fun acknowledgeOutboundStreamIfNeeded() { + if (outbound) { + acknowledged.set(true) + } + } + private fun fillBuffer(data: ByteBuf) { sendBuffer.push(data) val totalBufferedWrites = calculateTotalBufferedWrites() @@ -122,6 +142,7 @@ open class YamuxHandler( if (closedForWriting) { throw ClosedForWritingMuxerException(id) } + acknowledgeInboundStreamIfNeeded() fillBuffer(data) drainBufferAndMaybeClose() } @@ -214,16 +235,26 @@ open class YamuxHandler( } override fun onLocalOpen(child: MuxChannel) { - createYamuxStreamHandler(child.id).onLocalOpen() + verifyAckBacklogLimitNotReached(child.id, true) + createYamuxStreamHandler(child.id, true).onLocalOpen() } private fun onRemoteYamuxOpen(id: MuxId) { - createYamuxStreamHandler(id).onRemoteOpen() + verifyAckBacklogLimitNotReached(id, false) + createYamuxStreamHandler(id, false).onRemoteOpen() onRemoteOpen(id) } - private fun createYamuxStreamHandler(id: MuxId): YamuxStreamHandler { - val streamHandler = YamuxStreamHandler(id) + private fun verifyAckBacklogLimitNotReached(id: MuxId, outbound: Boolean) { + val totalUnacknowledgedStreams = + streamHandlers.values.count { it.outbound == outbound && !it.acknowledged.get() } + if (totalUnacknowledgedStreams >= ackBacklogLimit) { + throw AckBacklogLimitExceededMuxerException("The ACK backlog limit of $ackBacklogLimit streams has been reached. Will not open new stream: $id") + } + } + + private fun createYamuxStreamHandler(id: MuxId, outbound: Boolean): YamuxStreamHandler { + val streamHandler = YamuxStreamHandler(id, outbound) streamHandlers[id] = streamHandler return streamHandler } @@ -240,10 +271,6 @@ open class YamuxHandler( streamHandlers.remove(child.id)?.dispose() } - private fun calculateTotalBufferedWrites(): Int { - return streamHandlers.values.sumOf { it.sendBuffer.readableBytes() } - } - private fun handlePing(msg: YamuxFrame) { if (msg.id.id != YamuxId.SESSION_STREAM_ID) { throw InvalidFrameMuxerException("Invalid StreamId for Ping frame type: ${msg.id}") @@ -267,6 +294,10 @@ open class YamuxHandler( goAwayPromise.complete(msg.length) } + private fun calculateTotalBufferedWrites(): Int { + return streamHandlers.values.sumOf { it.sendBuffer.readableBytes() } + } + override fun generateNextId() = YamuxId(getChannelHandlerContext().channel().id(), idGenerator.next()) } diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxStreamMuxer.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxStreamMuxer.kt index bffec7941..b64d81389 100644 --- a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxStreamMuxer.kt +++ b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxStreamMuxer.kt @@ -13,7 +13,8 @@ import java.util.concurrent.CompletableFuture class YamuxStreamMuxer( val inboundStreamHandler: StreamHandler<*>, private val multistreamProtocol: MultistreamProtocol, - private val maxBufferedConnectionWrites: Int + private val maxBufferedConnectionWrites: Int, + private val ackBacklogLimit: Int ) : StreamMuxer, StreamMuxerDebug { override val protocolDescriptor = ProtocolDescriptor("/yamux/1.0.0") @@ -32,7 +33,8 @@ class YamuxStreamMuxer( muxSessionReady, inboundStreamHandler, ch.isInitiator, - maxBufferedConnectionWrites + maxBufferedConnectionWrites, + ackBacklogLimit ) ) diff --git a/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt b/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt index b85e95733..2bac4bdf3 100644 --- a/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt @@ -5,6 +5,7 @@ import io.libp2p.core.StreamHandler import io.libp2p.core.multistream.MultistreamProtocolV1 import io.libp2p.etc.types.fromHex import io.libp2p.etc.types.toHex +import io.libp2p.mux.AckBacklogLimitExceededMuxerException import io.libp2p.mux.MuxHandler import io.libp2p.mux.MuxHandlerAbstractTest import io.libp2p.mux.MuxHandlerAbstractTest.AbstractTestMuxFrame.Flag.* @@ -14,11 +15,15 @@ import io.netty.channel.ChannelHandlerContext import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.Assertions import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ValueSource class YamuxHandlerTest : MuxHandlerAbstractTest() { override val maxFrameDataLength = 256 private val maxBufferedConnectionWrites = 512 + private val ackBacklogLimit = 42 private val initialWindowSize = 300 override val localMuxIdGenerator = YamuxStreamIdGenerator(isLocalConnectionInitiator).toIterator() override val remoteMuxIdGenerator = YamuxStreamIdGenerator(!isLocalConnectionInitiator).toIterator() @@ -34,6 +39,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { streamHandler, true, maxBufferedConnectionWrites, + ackBacklogLimit, initialWindowSize ) { // MuxHandler consumes the exception. Override this behaviour for testing @@ -450,6 +456,39 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { assertThat(closeFrame.data).isNull() } + @ParameterizedTest + @ValueSource(booleans = [true, false]) + fun `does not create new stream if ACK backlog limit is reached`(outbound: Boolean) { + val openStream: () -> Unit = { + if (outbound) { + openStreamLocal() + } else { + openStreamRemote() + } + } + for (i in 1..ackBacklogLimit) { + openStream() + } + // opening new stream should fail + val exception = assertThrows { openStream() } + + if (outbound) { + assertThat(exception).hasCauseInstanceOf(AckBacklogLimitExceededMuxerException::class.java) + // expected number of SYN frames have been sent + var synFlagFrames = 0 + do { + val frame = readYamuxFrame() + frame?.let { + assertThat(it.flags).isEqualTo(YamuxFlag.SYN.asSet) + synFlagFrames += 1 + } + } while (frame != null) + assertThat(synFlagFrames).isEqualTo(ackBacklogLimit) + } else { + assertThat(exception).isInstanceOf(AckBacklogLimitExceededMuxerException::class.java) + } + } + companion object { private fun YamuxStreamIdGenerator.toIterator() = iterator { while (true) { From 10514b13673ee69db32a4a3282f8aa5f66beba04 Mon Sep 17 00:00:00 2001 From: Dr Ian Preston Date: Fri, 24 Nov 2023 06:36:39 +0000 Subject: [PATCH 11/13] Don't try and dial DNSADDR addresses (#343) --- .../src/main/kotlin/io/libp2p/transport/tcp/TcpTransport.kt | 4 +++- .../test/kotlin/io/libp2p/transport/tcp/TcpTransportTest.kt | 3 +-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/libp2p/src/main/kotlin/io/libp2p/transport/tcp/TcpTransport.kt b/libp2p/src/main/kotlin/io/libp2p/transport/tcp/TcpTransport.kt index 71d34c2df..92dcb6e42 100644 --- a/libp2p/src/main/kotlin/io/libp2p/transport/tcp/TcpTransport.kt +++ b/libp2p/src/main/kotlin/io/libp2p/transport/tcp/TcpTransport.kt @@ -2,6 +2,7 @@ package io.libp2p.transport.tcp import io.libp2p.core.InternalErrorException import io.libp2p.core.multiformats.Multiaddr +import io.libp2p.core.multiformats.Protocol.DNSADDR import io.libp2p.core.multiformats.Protocol.IP4 import io.libp2p.core.multiformats.Protocol.IP6 import io.libp2p.core.multiformats.Protocol.TCP @@ -27,7 +28,8 @@ open class TcpTransport( override fun handles(addr: Multiaddr) = handlesHost(addr) && addr.has(TCP) && - !addr.has(WS) + !addr.has(WS) && + !addr.has(DNSADDR) override fun serverTransportBuilder( connectionBuilder: ConnectionBuilder, diff --git a/libp2p/src/test/kotlin/io/libp2p/transport/tcp/TcpTransportTest.kt b/libp2p/src/test/kotlin/io/libp2p/transport/tcp/TcpTransportTest.kt index 167092911..834f2b20f 100644 --- a/libp2p/src/test/kotlin/io/libp2p/transport/tcp/TcpTransportTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/transport/tcp/TcpTransportTest.kt @@ -35,8 +35,7 @@ class TcpTransportTest : TransportTests() { "/ip4/0.0.0.0/tcp/1234", "/ip6/fe80::6f77:b303:aa6e:a16/tcp/42", "/dns4/localhost/tcp/9999", - "/dns6/localhost/tcp/9999", - "/dnsaddr/ipfs.io/tcp/97" + "/dns6/localhost/tcp/9999" ).map { Multiaddr(it) } @JvmStatic From ac6127b66caf1773ee68235deaaad5736d60eb75 Mon Sep 17 00:00:00 2001 From: Dr Ian Preston Date: Mon, 15 Jan 2024 13:37:44 +0000 Subject: [PATCH 12/13] Implement circuit relay v2 (#345) * Initial import of circuit relay * Don't let tcp dial a circuit address * Use ScheduledExecutorService in RelayTransport maintenance * Add self-contained local relay test * Add unit test for relay bandwidth limit --- README.md | 2 +- .../protocol/circuit/CircuitHopProtocol.java | 417 ++++++++++++++++++ .../protocol/circuit/CircuitStopProtocol.java | 157 +++++++ .../libp2p/protocol/circuit/HostConsumer.java | 8 + .../protocol/circuit/RelayTransport.java | 326 ++++++++++++++ .../io/libp2p/core/multiformats/Multiaddr.kt | 6 +- .../io/libp2p/transport/tcp/TcpTransport.kt | 4 +- libp2p/src/main/proto/circuit.proto | 60 +++ libp2p/src/main/proto/envelope.proto | 46 ++ libp2p/src/main/proto/voucher.proto | 9 + .../java/io/libp2p/core/RelayTestJava.java | 225 ++++++++++ 11 files changed, 1257 insertions(+), 3 deletions(-) create mode 100644 libp2p/src/main/java/io/libp2p/protocol/circuit/CircuitHopProtocol.java create mode 100644 libp2p/src/main/java/io/libp2p/protocol/circuit/CircuitStopProtocol.java create mode 100644 libp2p/src/main/java/io/libp2p/protocol/circuit/HostConsumer.java create mode 100644 libp2p/src/main/java/io/libp2p/protocol/circuit/RelayTransport.java create mode 100644 libp2p/src/main/proto/circuit.proto create mode 100644 libp2p/src/main/proto/envelope.proto create mode 100644 libp2p/src/main/proto/voucher.proto create mode 100644 libp2p/src/test/java/io/libp2p/core/RelayTestJava.java diff --git a/README.md b/README.md index dc0a32ccb..3229f3d41 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ List of components in the Libp2p spec and their JVM implementation status | **Protocol Select** | [multistream](https://github.com/multiformats/multistream-select) | :green_apple: | | **Stream Multiplexing** | [yamux](https://github.com/libp2p/specs/blob/master/yamux/README.md) | :lemon: | | | [mplex](https://github.com/libp2p/specs/blob/master/mplex/README.md) | :green_apple: | -| **NAT Traversal** | [circuit-relay-v2](https://github.com/libp2p/specs/blob/master/relay/circuit-v2.md) | | +| **NAT Traversal** | [circuit-relay-v2](https://github.com/libp2p/specs/blob/master/relay/circuit-v2.md) | :lemon: | | | [autonat](https://github.com/libp2p/specs/tree/master/autonat) | | | | [hole-punching](https://github.com/libp2p/specs/blob/master/connections/hole-punching.md) | | | **Discovery** | [bootstrap](https://github.com/libp2p/specs/blob/master/kad-dht/README.md#bootstrap-process) | | diff --git a/libp2p/src/main/java/io/libp2p/protocol/circuit/CircuitHopProtocol.java b/libp2p/src/main/java/io/libp2p/protocol/circuit/CircuitHopProtocol.java new file mode 100644 index 000000000..be2be179d --- /dev/null +++ b/libp2p/src/main/java/io/libp2p/protocol/circuit/CircuitHopProtocol.java @@ -0,0 +1,417 @@ +package io.libp2p.protocol.circuit; + +import com.google.protobuf.*; +import io.libp2p.core.*; +import io.libp2p.core.Stream; +import io.libp2p.core.crypto.*; +import io.libp2p.core.multiformats.*; +import io.libp2p.core.multistream.*; +import io.libp2p.etc.util.netty.*; +import io.libp2p.protocol.*; +import io.libp2p.protocol.circuit.crypto.pb.*; +import io.libp2p.protocol.circuit.pb.*; +import io.netty.buffer.*; +import io.netty.channel.*; +import io.netty.handler.codec.protobuf.*; +import java.io.*; +import java.nio.charset.*; +import java.time.*; +import java.time.Duration; +import java.time.temporal.*; +import java.util.*; +import java.util.concurrent.*; +import java.util.function.*; +import java.util.stream.*; +import org.jetbrains.annotations.*; + +public class CircuitHopProtocol extends ProtobufProtocolHandler { + + private static final String HOP_HANDLER_NAME = "HOP_HANDLER"; + private static final String STREAM_CLEARER_NAME = "STREAM_CLEARER"; + + public static class Binding extends StrictProtocolBinding implements HostConsumer { + private final CircuitHopProtocol hop; + + private Binding(CircuitHopProtocol hop) { + super("/libp2p/circuit/relay/0.2.0/hop", hop); + this.hop = hop; + } + + public Binding(RelayManager manager, CircuitStopProtocol.Binding stop) { + this(new CircuitHopProtocol(manager, stop)); + } + + @Override + public void setHost(Host us) { + hop.setHost(us); + } + } + + private static void putUvarint(OutputStream out, long x) throws IOException { + while (x >= 0x80) { + out.write((byte) (x | 0x80)); + x >>= 7; + } + out.write((byte) x); + } + + public static byte[] createVoucher( + PrivKey priv, PeerId relay, PeerId requestor, LocalDateTime expiry) { + ByteArrayOutputStream bout = new ByteArrayOutputStream(); + try { + putUvarint(bout, 0x0302); + } catch (IOException e) { + } + byte[] typeMulticodec = bout.toByteArray(); + byte[] payload = + VoucherOuterClass.Voucher.newBuilder() + .setRelay(ByteString.copyFrom(relay.getBytes())) + .setPeer(ByteString.copyFrom(requestor.getBytes())) + .setExpiration(expiry.toEpochSecond(ZoneOffset.UTC) * 1_000_000_000) + .build() + .toByteArray(); + byte[] signDomain = "libp2p-relay-rsvp".getBytes(StandardCharsets.UTF_8); + ByteArrayOutputStream toSign = new ByteArrayOutputStream(); + try { + putUvarint(toSign, signDomain.length); + toSign.write(signDomain); + putUvarint(toSign, typeMulticodec.length); + toSign.write(typeMulticodec); + putUvarint(toSign, payload.length); + toSign.write(payload); + } catch (IOException e) { + } + byte[] signature = priv.sign(toSign.toByteArray()); + return EnvelopeOuterClass.Envelope.newBuilder() + .setPayloadType(ByteString.copyFrom(typeMulticodec)) + .setPayload(ByteString.copyFrom(payload)) + .setPublicKey( + EnvelopeOuterClass.PublicKey.newBuilder() + .setTypeValue(priv.publicKey().getKeyType().getNumber()) + .setData(ByteString.copyFrom(priv.publicKey().raw()))) + .setSignature(ByteString.copyFrom(signature)) + .build() + .toByteArray(); + } + + public static class Reservation { + public final LocalDateTime expiry; + public final int durationSeconds; + public final long maxBytes; + public final byte[] voucher; + public final Multiaddr[] addrs; + + public Reservation( + LocalDateTime expiry, + int durationSeconds, + long maxBytes, + byte[] voucher, + Multiaddr[] addrs) { + this.expiry = expiry; + this.durationSeconds = durationSeconds; + this.maxBytes = maxBytes; + this.voucher = voucher; + this.addrs = addrs; + } + } + + public interface RelayManager { + boolean hasReservation(PeerId source); + + Optional createReservation(PeerId requestor, Multiaddr addr); + + Optional allowConnection(PeerId target, PeerId initiator); + + static RelayManager limitTo(PrivKey priv, PeerId relayPeerId, int concurrent) { + return new RelayManager() { + Map reservations = new HashMap<>(); + + @Override + public synchronized boolean hasReservation(PeerId source) { + return reservations.containsKey(source); + } + + @Override + public synchronized Optional createReservation( + PeerId requestor, Multiaddr addr) { + if (reservations.size() >= concurrent) return Optional.empty(); + LocalDateTime now = LocalDateTime.now(); + LocalDateTime expiry = now.plusHours(1); + byte[] voucher = createVoucher(priv, relayPeerId, requestor, now); + Reservation resv = new Reservation(expiry, 120, 4096, voucher, new Multiaddr[] {addr}); + reservations.put(requestor, resv); + return Optional.of(resv); + } + + @Override + public synchronized Optional allowConnection(PeerId target, PeerId initiator) { + return Optional.ofNullable(reservations.get(target)); + } + }; + } + } + + public interface HopController { + CompletableFuture rpc(Circuit.HopMessage req); + + default CompletableFuture reserve() { + return rpc(Circuit.HopMessage.newBuilder().setType(Circuit.HopMessage.Type.RESERVE).build()) + .thenApply( + msg -> { + if (msg.getStatus() == Circuit.Status.OK) { + long expiry = msg.getReservation().getExpire(); + return new Reservation( + LocalDateTime.ofEpochSecond(expiry, 0, ZoneOffset.UTC), + msg.getLimit().getDuration(), + msg.getLimit().getData(), + msg.getReservation().getVoucher().toByteArray(), + null); + } + throw new IllegalStateException(msg.getStatus().name()); + }); + } + + CompletableFuture connect(PeerId target); + } + + public static class HopRemover extends ChannelInitializer { + + @Override + protected void initChannel(@NotNull Channel ch) throws Exception { + ch.pipeline().remove(HOP_HANDLER_NAME); + // also remove associated protobuf handlers + ch.pipeline().remove(ProtobufDecoder.class); + ch.pipeline().remove(ProtobufEncoder.class); + ch.pipeline().remove(ProtobufVarint32FrameDecoder.class); + ch.pipeline().remove(ProtobufVarint32LengthFieldPrepender.class); + ch.pipeline().remove(STREAM_CLEARER_NAME); + } + } + + public static class Sender implements ProtocolMessageHandler, HopController { + private final Stream stream; + private final LinkedBlockingDeque> queue = + new LinkedBlockingDeque<>(); + + public Sender(Stream stream) { + this.stream = stream; + } + + @Override + public void onMessage(@NotNull Stream stream, Circuit.HopMessage msg) { + queue.poll().complete(msg); + } + + public CompletableFuture rpc(Circuit.HopMessage req) { + CompletableFuture res = new CompletableFuture<>(); + queue.add(res); + stream.writeAndFlush(req); + return res; + } + + @Override + public CompletableFuture connect(PeerId target) { + return rpc(Circuit.HopMessage.newBuilder() + .setType(Circuit.HopMessage.Type.CONNECT) + .setPeer(Circuit.Peer.newBuilder().setId(ByteString.copyFrom(target.getBytes()))) + .build()) + .thenApply( + msg -> { + if (msg.getType() == Circuit.HopMessage.Type.STATUS + && msg.getStatus() == Circuit.Status.OK) { + // remove handler for HOP to return bare stream + stream.pushHandler(STREAM_CLEARER_NAME, new HopRemover()); + return stream; + } + throw new IllegalStateException("Circuit dial returned " + msg.getStatus().name()); + }); + } + } + + public static class Receiver + implements ProtocolMessageHandler, HopController { + private final Host us; + private final RelayManager manager; + private final Supplier> publicAddresses; + private final CircuitStopProtocol.Binding stop; + private final AddressBook addressBook; + + public Receiver( + Host us, + RelayManager manager, + Supplier> publicAddresses, + CircuitStopProtocol.Binding stop, + AddressBook addressBook) { + this.us = us; + this.manager = manager; + this.publicAddresses = publicAddresses; + this.stop = stop; + this.addressBook = addressBook; + } + + @Override + public void onMessage(@NotNull Stream stream, Circuit.HopMessage msg) { + switch (msg.getType()) { + case RESERVE: + { + PeerId requestor = stream.remotePeerId(); + Optional reservation = + manager.createReservation(requestor, stream.getConnection().remoteAddress()); + if (reservation.isEmpty() + || new Multiaddr(stream.getConnection().remoteAddress().toString()) + .has(Protocol.P2PCIRCUIT)) { + stream.writeAndFlush( + Circuit.HopMessage.newBuilder() + .setType(Circuit.HopMessage.Type.STATUS) + .setStatus(Circuit.Status.RESERVATION_REFUSED)); + return; + } + Reservation resv = reservation.get(); + stream.writeAndFlush( + Circuit.HopMessage.newBuilder() + .setType(Circuit.HopMessage.Type.STATUS) + .setStatus(Circuit.Status.OK) + .setReservation( + Circuit.Reservation.newBuilder() + .setExpire(resv.expiry.toEpochSecond(ZoneOffset.UTC)) + .addAllAddrs( + publicAddresses.get().stream() + .map(a -> ByteString.copyFrom(a.serialize())) + .collect(Collectors.toList())) + .setVoucher(ByteString.copyFrom(resv.voucher))) + .setLimit( + Circuit.Limit.newBuilder() + .setDuration(resv.durationSeconds) + .setData(resv.maxBytes))); + } + case CONNECT: + { + PeerId target = new PeerId(msg.getPeer().getId().toByteArray()); + if (manager.hasReservation(target)) { + PeerId initiator = stream.remotePeerId(); + Optional res = manager.allowConnection(target, initiator); + if (res.isPresent()) { + Reservation resv = res.get(); + try { + CircuitStopProtocol.StopController stop = + this.stop + .dial(us, target, resv.addrs) + .getController() + .orTimeout(15, TimeUnit.SECONDS) + .join(); + Circuit.StopMessage reply = + stop.connect(initiator, resv.durationSeconds, resv.maxBytes).join(); + if (reply.getStatus().equals(Circuit.Status.OK)) { + stream.writeAndFlush( + Circuit.HopMessage.newBuilder() + .setType(Circuit.HopMessage.Type.STATUS) + .setStatus(Circuit.Status.OK)); + Stream toTarget = stop.getStream(); + Stream fromRequestor = stream; + // remove hop and stop handlers from streams before proxying + fromRequestor.pushHandler(STREAM_CLEARER_NAME, new HopRemover()); + toTarget.pushHandler( + CircuitStopProtocol.STOP_REMOVER_NAME, + new CircuitStopProtocol.StopRemover()); + + // connect these streams with time + bytes enforcement + fromRequestor.pushHandler(new InboundTrafficLimitHandler(resv.maxBytes)); + fromRequestor.pushHandler( + new TotalTimeoutHandler( + Duration.of(resv.durationSeconds, ChronoUnit.SECONDS))); + toTarget.pushHandler(new InboundTrafficLimitHandler(resv.maxBytes)); + toTarget.pushHandler( + new TotalTimeoutHandler( + Duration.of(resv.durationSeconds, ChronoUnit.SECONDS))); + fromRequestor.pushHandler(new ProxyHandler(toTarget)); + toTarget.pushHandler(new ProxyHandler(fromRequestor)); + } else { + stream.writeAndFlush( + Circuit.HopMessage.newBuilder() + .setType(Circuit.HopMessage.Type.STATUS) + .setStatus(reply.getStatus())); + } + } catch (Exception e) { + stream.writeAndFlush( + Circuit.HopMessage.newBuilder() + .setType(Circuit.HopMessage.Type.STATUS) + .setStatus(Circuit.Status.CONNECTION_FAILED)); + } + } else { + stream.writeAndFlush( + Circuit.HopMessage.newBuilder() + .setType(Circuit.HopMessage.Type.STATUS) + .setStatus(Circuit.Status.RESOURCE_LIMIT_EXCEEDED)); + } + } else { + stream.writeAndFlush( + Circuit.HopMessage.newBuilder() + .setType(Circuit.HopMessage.Type.STATUS) + .setStatus(Circuit.Status.NO_RESERVATION)); + } + } + } + } + + @Override + public CompletableFuture connect(PeerId target) { + return CompletableFuture.failedFuture( + new IllegalStateException("Cannot send from a receiver!")); + } + + public CompletableFuture rpc(Circuit.HopMessage msg) { + return CompletableFuture.failedFuture( + new IllegalStateException("Cannot send from a receiver!")); + } + } + + private static class ProxyHandler extends ChannelInboundHandlerAdapter { + + private final Stream target; + + public ProxyHandler(Stream target) { + this.target = target; + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (msg instanceof ByteBuf) { + target.writeAndFlush(msg); + } + } + } + + private static final int TRAFFIC_LIMIT = 2 * 1024; + private final RelayManager manager; + private final CircuitStopProtocol.Binding stop; + private Host us; + + public CircuitHopProtocol(RelayManager manager, CircuitStopProtocol.Binding stop) { + super(Circuit.HopMessage.getDefaultInstance(), TRAFFIC_LIMIT, TRAFFIC_LIMIT); + this.manager = manager; + this.stop = stop; + } + + public void setHost(Host us) { + this.us = us; + } + + @NotNull + @Override + protected CompletableFuture onStartInitiator(@NotNull Stream stream) { + Sender replyPropagator = new Sender(stream); + stream.pushHandler( + HOP_HANDLER_NAME, new ProtocolMessageHandlerAdapter<>(stream, replyPropagator)); + return CompletableFuture.completedFuture(replyPropagator); + } + + @NotNull + @Override + protected CompletableFuture onStartResponder(@NotNull Stream stream) { + if (us == null) throw new IllegalStateException("null Host for us!"); + Supplier> ourpublicAddresses = () -> us.listenAddresses(); + Receiver dialer = new Receiver(us, manager, ourpublicAddresses, stop, us.getAddressBook()); + stream.pushHandler(HOP_HANDLER_NAME, new ProtocolMessageHandlerAdapter<>(stream, dialer)); + return CompletableFuture.completedFuture(dialer); + } +} diff --git a/libp2p/src/main/java/io/libp2p/protocol/circuit/CircuitStopProtocol.java b/libp2p/src/main/java/io/libp2p/protocol/circuit/CircuitStopProtocol.java new file mode 100644 index 000000000..b10ee62d4 --- /dev/null +++ b/libp2p/src/main/java/io/libp2p/protocol/circuit/CircuitStopProtocol.java @@ -0,0 +1,157 @@ +package io.libp2p.protocol.circuit; + +import com.google.protobuf.*; +import io.libp2p.core.*; +import io.libp2p.core.multistream.*; +import io.libp2p.protocol.*; +import io.libp2p.protocol.circuit.pb.*; +import io.netty.channel.*; +import io.netty.handler.codec.protobuf.*; +import java.util.concurrent.*; +import org.jetbrains.annotations.*; + +public class CircuitStopProtocol + extends ProtobufProtocolHandler { + + private static final String STOP_HANDLER_NAME = "STOP_HANDLER"; + public static final String STOP_REMOVER_NAME = "STOP_REMOVER"; + + public static class Binding extends StrictProtocolBinding { + private final CircuitStopProtocol stop; + + public Binding(CircuitStopProtocol stop) { + super("/libp2p/circuit/relay/0.2.0/stop", stop); + this.stop = stop; + } + + public void setTransport(RelayTransport transport) { + stop.setTransport(transport); + } + } + + public interface StopController { + CompletableFuture rpc(Circuit.StopMessage req); + + Stream getStream(); + + default CompletableFuture connect( + PeerId source, int durationSeconds, long maxBytes) { + return rpc( + Circuit.StopMessage.newBuilder() + .setType(Circuit.StopMessage.Type.CONNECT) + .setPeer(Circuit.Peer.newBuilder().setId(ByteString.copyFrom(source.getBytes()))) + .setLimit(Circuit.Limit.newBuilder().setData(maxBytes).setDuration(durationSeconds)) + .build()); + } + } + + public static class Sender + implements ProtocolMessageHandler, StopController { + private final Stream stream; + private final LinkedBlockingDeque> queue = + new LinkedBlockingDeque<>(); + + public Sender(Stream stream) { + this.stream = stream; + } + + @Override + public void onMessage(@NotNull Stream stream, Circuit.StopMessage msg) { + queue.poll().complete(msg); + } + + public CompletableFuture rpc(Circuit.StopMessage req) { + CompletableFuture res = new CompletableFuture<>(); + queue.add(res); + stream.writeAndFlush(req); + return res; + } + + public Stream getStream() { + return stream; + } + } + + public static class StopRemover extends ChannelInitializer { + + @Override + protected void initChannel(@NotNull Channel ch) throws Exception { + ch.pipeline().remove(ProtobufDecoder.class); + ch.pipeline().remove(ProtobufEncoder.class); + ch.pipeline().remove(ProtobufVarint32FrameDecoder.class); + ch.pipeline().remove(ProtobufVarint32LengthFieldPrepender.class); + ch.pipeline().remove(STOP_HANDLER_NAME); + ch.pipeline().remove(STOP_REMOVER_NAME); + } + } + + public static class Receiver + implements ProtocolMessageHandler, StopController { + private final Stream stream; + private final RelayTransport transport; + + public Receiver(Stream stream, RelayTransport transport) { + this.stream = stream; + this.transport = transport; + } + + @Override + public void onMessage(@NotNull Stream stream, Circuit.StopMessage msg) { + if (msg.getType() == Circuit.StopMessage.Type.CONNECT) { + PeerId remote = new PeerId(msg.getPeer().getId().toByteArray()); + int durationSeconds = msg.getLimit().getDuration(); + long limitBytes = msg.getLimit().getData(); + stream.writeAndFlush( + Circuit.StopMessage.newBuilder() + .setType(Circuit.StopMessage.Type.STATUS) + .setStatus(Circuit.Status.OK) + .build()); + // remove STOP handler from stream before upgrading + stream.pushHandler(STOP_REMOVER_NAME, new StopRemover()); + + // now upgrade connection with security and muxer protocol + ConnectionHandler connHandler = null; // TODO + RelayTransport.upgradeStream( + stream, false, transport.upgrader, transport, remote, connHandler); + } + } + + public Stream getStream() { + return stream; + } + + public CompletableFuture rpc(Circuit.StopMessage msg) { + return CompletableFuture.failedFuture( + new IllegalStateException("Cannot send form a receiver!")); + } + } + + private static final int TRAFFIC_LIMIT = 2 * 1024; + + private RelayTransport transport; + + public CircuitStopProtocol() { + super(Circuit.StopMessage.getDefaultInstance(), TRAFFIC_LIMIT, TRAFFIC_LIMIT); + } + + public void setTransport(RelayTransport transport) { + this.transport = transport; + } + + @NotNull + @Override + protected CompletableFuture onStartInitiator(@NotNull Stream stream) { + Sender replyPropagator = new Sender(stream); + stream.pushHandler( + STOP_HANDLER_NAME, new ProtocolMessageHandlerAdapter<>(stream, replyPropagator)); + return CompletableFuture.completedFuture(replyPropagator); + } + + @NotNull + @Override + protected CompletableFuture onStartResponder(@NotNull Stream stream) { + Receiver acceptor = new Receiver(stream, transport); + stream.pushHandler(STOP_HANDLER_NAME, new ProtocolMessageHandlerAdapter<>(stream, acceptor)); + return CompletableFuture.completedFuture(acceptor); + } +} diff --git a/libp2p/src/main/java/io/libp2p/protocol/circuit/HostConsumer.java b/libp2p/src/main/java/io/libp2p/protocol/circuit/HostConsumer.java new file mode 100644 index 000000000..c3848e699 --- /dev/null +++ b/libp2p/src/main/java/io/libp2p/protocol/circuit/HostConsumer.java @@ -0,0 +1,8 @@ +package io.libp2p.protocol.circuit; + +import io.libp2p.core.*; + +public interface HostConsumer { + + void setHost(Host us); +} diff --git a/libp2p/src/main/java/io/libp2p/protocol/circuit/RelayTransport.java b/libp2p/src/main/java/io/libp2p/protocol/circuit/RelayTransport.java new file mode 100644 index 000000000..a6d44e1f8 --- /dev/null +++ b/libp2p/src/main/java/io/libp2p/protocol/circuit/RelayTransport.java @@ -0,0 +1,326 @@ +package io.libp2p.protocol.circuit; + +import io.libp2p.core.*; +import io.libp2p.core.Stream; +import io.libp2p.core.multiformats.*; +import io.libp2p.core.mux.*; +import io.libp2p.core.security.*; +import io.libp2p.core.transport.*; +import io.libp2p.etc.*; +import io.libp2p.transport.*; +import io.netty.channel.*; +import java.time.*; +import java.util.*; +import java.util.concurrent.*; +import java.util.concurrent.atomic.*; +import java.util.function.Function; +import java.util.stream.*; +import kotlin.*; +import org.jetbrains.annotations.*; + +public class RelayTransport implements Transport, HostConsumer { + private Host us; + private final Map listeners = new ConcurrentHashMap<>(); + private final Map dials = new ConcurrentHashMap<>(); + private final Function> candidateRelays; + private final CircuitHopProtocol.Binding hop; + private final CircuitStopProtocol.Binding stop; + public final ConnectionUpgrader upgrader; + private final AtomicInteger relayCount; + private final ScheduledExecutorService runner; + + public RelayTransport( + CircuitHopProtocol.Binding hop, + CircuitStopProtocol.Binding stop, + ConnectionUpgrader upgrader, + Function> candidateRelays, + ScheduledExecutorService runner) { + this.hop = hop; + this.stop = stop; + this.upgrader = upgrader; + this.candidateRelays = candidateRelays; + this.relayCount = new AtomicInteger(0); + this.runner = runner; + } + + @Override + public void setHost(Host us) { + this.us = us; + hop.setHost(us); + } + + public static class CandidateRelay { + public final PeerId id; + public final List addrs; + + public CandidateRelay(PeerId id, List addrs) { + this.id = id; + this.addrs = addrs; + } + } + + private static class RelayState { + List addrs; + CircuitHopProtocol.HopController controller; + Connection conn; + LocalDateTime renewAfter; + } + + public void setRelayCount(int count) { + relayCount.set(count); + } + + @Override + public int getActiveConnections() { + return dials.size(); + } + + @Override + public int getActiveListeners() { + return listeners.size(); + } + + @NotNull + @Override + public CompletableFuture close() { + return CompletableFuture.allOf( + dials.values().stream().map(Stream::close).toArray(CompletableFuture[]::new)) + .thenApply( + x -> { + dials.clear(); + return null; + }); + } + + static class ConnectionOverStream implements Connection { + private final boolean isInitiator; + private final Transport transport; + private final Stream stream; + private SecureChannel.Session security; + private StreamMuxer.Session muxer; + + public ConnectionOverStream(boolean isInitiator, Transport transport, Stream stream) { + this.isInitiator = isInitiator; + this.transport = transport; + this.stream = stream; + } + + @NotNull + @Override + public Multiaddr localAddress() { + return stream.getConnection().localAddress().withComponent(Protocol.P2PCIRCUIT); + } + + @NotNull + @Override + public Multiaddr remoteAddress() { + return stream.getConnection().remoteAddress().withComponent(Protocol.P2PCIRCUIT); + } + + public void setSecureSession(SecureChannel.Session sec) { + this.security = sec; + } + + @NotNull + @Override + public SecureChannel.Session secureSession() { + return security; + } + + public void setMuxerSession(StreamMuxer.Session mux) { + this.muxer = mux; + } + + @NotNull + @Override + public StreamMuxer.Session muxerSession() { + return muxer; + } + + @NotNull + @Override + public Transport transport() { + return transport; + } + + @Override + public boolean isInitiator() { + return isInitiator; + } + + @Override + public void addHandlerBefore( + @NotNull String s, @NotNull String s1, @NotNull ChannelHandler channelHandler) { + stream.addHandlerBefore(s, s1, channelHandler); + } + + @NotNull + @Override + public CompletableFuture close() { + return stream.close(); + } + + @NotNull + @Override + public CompletableFuture closeFuture() { + return stream.closeFuture(); + } + + @Override + public void pushHandler(@NotNull ChannelHandler channelHandler) { + stream.pushHandler(channelHandler); + } + + @Override + public void pushHandler(@NotNull String s, @NotNull ChannelHandler channelHandler) { + stream.pushHandler(s, channelHandler); + } + } + + @NotNull + @Override + public CompletableFuture dial( + @NotNull Multiaddr multiaddr, + @NotNull ConnectionHandler connHandler, + @Nullable ChannelVisitor channelVisitor) { + // first connect to relay over hop + List comps = multiaddr.getComponents(); + int split = comps.indexOf(new MultiaddrComponent(Protocol.P2PCIRCUIT, null)); + Multiaddr relay = new Multiaddr(comps.subList(0, split)); + Multiaddr target = new Multiaddr(comps.subList(split, comps.size())); + CircuitHopProtocol.HopController ctr = hop.dial(us, relay).getController().join(); + // request proxy to target + Stream stream = ctr.connect(target.getPeerId()).join(); + // upgrade with sec and muxer + return upgradeStream(stream, true, upgrader, this, target.getPeerId(), connHandler); + } + + public static CompletableFuture upgradeStream( + Stream stream, + boolean isInitiator, + ConnectionUpgrader upgrader, + Transport transport, + PeerId remote, + ConnectionHandler connHandler) { + ConnectionOverStream conn = new ConnectionOverStream(isInitiator, transport, stream); + CompletableFuture res = new CompletableFuture<>(); + stream.pushHandler( + new ChannelInitializer<>() { + @Override + protected void initChannel(Channel channel) throws Exception { + channel.attr(AttributesKt.getREMOTE_PEER_ID()).set(remote); + channel.attr(AttributesKt.getCONNECTION()).set(conn); + upgrader + .establishSecureChannel(conn) + .thenCompose( + sess -> { + conn.setSecureSession(sess); + if (sess.getEarlyMuxer() != null) { + return ConnectionUpgrader.Companion.establishMuxer( + sess.getEarlyMuxer(), conn); + } else { + return upgrader.establishMuxer(conn); + } + }) + .thenAccept( + sess -> { + conn.setMuxerSession(sess); + connHandler.handleConnection(conn); + res.complete(conn); + }) + .exceptionally( + t -> { + res.completeExceptionally(t); + return null; + }); + channel.pipeline().fireChannelActive(); + } + }); + return res; + } + + @Override + public boolean handles(@NotNull Multiaddr multiaddr) { + return multiaddr.hasAny(Protocol.P2PCIRCUIT); + } + + @Override + public void initialize() { + stop.setTransport(this); + // find relays and connect and reserve + runner.scheduleAtFixedRate(this::ensureEnoughCurrentRelays, 0, 2 * 60, TimeUnit.SECONDS); + } + + public void ensureEnoughCurrentRelays() { + int active = 0; + // renew existing relays before finding new ones + Set> currentRelays = listeners.entrySet(); + for (Map.Entry current : currentRelays) { + RelayState relay = current.getValue(); + LocalDateTime now = LocalDateTime.now(); + if (now.isBefore(relay.renewAfter)) { + active++; + } else { + try { + CircuitHopProtocol.Reservation reservation = relay.controller.reserve().join(); + relay.renewAfter = reservation.expiry.minusMinutes(1); + active++; + } catch (Exception e) { + listeners.remove(current.getKey()); + } + } + } + if (active >= relayCount.get()) return; + + List candidates = candidateRelays.apply(us); + for (CandidateRelay candidate : candidates) { + // connect to relay and get reservation + CircuitHopProtocol.HopController ctr = + hop.dial(us, candidate.id, candidate.addrs.toArray(new Multiaddr[0])) + .getController() + .join(); + CircuitHopProtocol.Reservation resv = ctr.reserve().join(); + active++; + listeners.put(candidate.id, new RelayState()); + if (active >= relayCount.get()) return; + } + } + + @NotNull + @Override + public CompletableFuture listen( + @NotNull Multiaddr relayAddr, + @NotNull ConnectionHandler connectionHandler, + @Nullable ChannelVisitor channelVisitor) { + List components = relayAddr.getComponents(); + Multiaddr withoutCircuit = new Multiaddr(components.subList(0, components.size() - 1)); + CircuitHopProtocol.HopController ctr = hop.dial(us, withoutCircuit).getController().join(); + return ctr.reserve().thenApply(res -> null); + } + + @NotNull + @Override + public List listenAddresses() { + return listeners.entrySet().stream() + .flatMap( + r -> + r.getValue().addrs.stream() + .map( + a -> + a.withP2P(r.getKey()) + .concatenated( + new Multiaddr( + List.of( + new MultiaddrComponent(Protocol.P2PCIRCUIT, null))) + .withP2P(us.getPeerId())))) + .collect(Collectors.toList()); + } + + @NotNull + @Override + public CompletableFuture unlisten(@NotNull Multiaddr multiaddr) { + RelayState relayState = listeners.get(multiaddr); + if (relayState == null) return CompletableFuture.completedFuture(null); + return relayState.conn.close(); + } +} diff --git a/libp2p/src/main/kotlin/io/libp2p/core/multiformats/Multiaddr.kt b/libp2p/src/main/kotlin/io/libp2p/core/multiformats/Multiaddr.kt index d918c1bfb..8bc56d621 100644 --- a/libp2p/src/main/kotlin/io/libp2p/core/multiformats/Multiaddr.kt +++ b/libp2p/src/main/kotlin/io/libp2p/core/multiformats/Multiaddr.kt @@ -61,7 +61,11 @@ data class Multiaddr(val components: List) { * @throws IllegalArgumentException if existing component value doesn't match [value] */ private fun withComponentImpl(protocol: Protocol, value: ByteArray?): Multiaddr { - val existingComponent = getFirstComponent(protocol) + val existingComponent = if (has(Protocol.P2PCIRCUIT)) { + split { it == Protocol.P2PCIRCUIT }.get(1).getFirstComponent(protocol) + } else { + getFirstComponent(protocol) + } val newComponent = MultiaddrComponent(protocol, value) return if (existingComponent != null) { if (!existingComponent.value.contentEquals(value)) { diff --git a/libp2p/src/main/kotlin/io/libp2p/transport/tcp/TcpTransport.kt b/libp2p/src/main/kotlin/io/libp2p/transport/tcp/TcpTransport.kt index 92dcb6e42..a081ff67d 100644 --- a/libp2p/src/main/kotlin/io/libp2p/transport/tcp/TcpTransport.kt +++ b/libp2p/src/main/kotlin/io/libp2p/transport/tcp/TcpTransport.kt @@ -5,6 +5,7 @@ import io.libp2p.core.multiformats.Multiaddr import io.libp2p.core.multiformats.Protocol.DNSADDR import io.libp2p.core.multiformats.Protocol.IP4 import io.libp2p.core.multiformats.Protocol.IP6 +import io.libp2p.core.multiformats.Protocol.P2PCIRCUIT import io.libp2p.core.multiformats.Protocol.TCP import io.libp2p.core.multiformats.Protocol.WS import io.libp2p.transport.ConnectionUpgrader @@ -29,7 +30,8 @@ open class TcpTransport( handlesHost(addr) && addr.has(TCP) && !addr.has(WS) && - !addr.has(DNSADDR) + !addr.has(DNSADDR) && + !addr.has(P2PCIRCUIT) override fun serverTransportBuilder( connectionBuilder: ConnectionBuilder, diff --git a/libp2p/src/main/proto/circuit.proto b/libp2p/src/main/proto/circuit.proto new file mode 100644 index 000000000..efc8d425a --- /dev/null +++ b/libp2p/src/main/proto/circuit.proto @@ -0,0 +1,60 @@ +syntax = "proto2"; + +package io.libp2p.protocol.circuit.pb; + +message HopMessage { + enum Type { + RESERVE = 0; + CONNECT = 1; + STATUS = 2; + } + + required Type type = 1; + + optional Peer peer = 2; + optional Reservation reservation = 3; + optional Limit limit = 4; + + optional Status status = 5; +} + +message StopMessage { + enum Type { + CONNECT = 0; + STATUS = 1; + } + + required Type type = 1; + + optional Peer peer = 2; + optional Limit limit = 3; + + optional Status status = 4; +} + +message Peer { + required bytes id = 1; + repeated bytes addrs = 2; +} + +message Reservation { + required uint64 expire = 1; // Unix expiration time (UTC) + repeated bytes addrs = 2; // relay addrs for reserving peer + optional bytes voucher = 3; // reservation voucher +} + +message Limit { + optional uint32 duration = 1; // seconds + optional uint64 data = 2; // bytes +} + +enum Status { + OK = 100; + RESERVATION_REFUSED = 200; + RESOURCE_LIMIT_EXCEEDED = 201; + PERMISSION_DENIED = 202; + CONNECTION_FAILED = 203; + NO_RESERVATION = 204; + MALFORMED_MESSAGE = 400; + UNEXPECTED_MESSAGE = 401; +} diff --git a/libp2p/src/main/proto/envelope.proto b/libp2p/src/main/proto/envelope.proto new file mode 100644 index 000000000..d303fd553 --- /dev/null +++ b/libp2p/src/main/proto/envelope.proto @@ -0,0 +1,46 @@ +syntax = "proto3"; + +package io.libp2p.protocol.circuit.crypto.pb; + +enum KeyType { + RSA = 0; + Ed25519 = 1; + Secp256k1 = 2; + ECDSA = 3; + Curve25519 = 4; +} + +message PublicKey { + KeyType Type = 1; + bytes Data = 2; +} + +message PrivateKey { + KeyType Type = 1; + bytes Data = 2; +} + +// Envelope encloses a signed payload produced by a peer, along with the public +// key of the keypair it was signed with so that it can be statelessly validated +// by the receiver. +// +// The payload is prefixed with a byte string that determines the type, so it +// can be deserialized deterministically. Often, this byte string is a +// multicodec. +message Envelope { + // public_key is the public key of the keypair the enclosed payload was + // signed with. + PublicKey public_key = 1; + + // payload_type encodes the type of payload, so that it can be deserialized + // deterministically. + bytes payload_type = 2; + + // payload is the actual payload carried inside this envelope. + bytes payload = 3; + + // signature is the signature produced by the private key corresponding to + // the enclosed public key, over the payload, prefixing a domain string for + // additional security. + bytes signature = 5; +} diff --git a/libp2p/src/main/proto/voucher.proto b/libp2p/src/main/proto/voucher.proto new file mode 100644 index 000000000..5b2dea19e --- /dev/null +++ b/libp2p/src/main/proto/voucher.proto @@ -0,0 +1,9 @@ +syntax = "proto2"; + +package io.libp2p.protocol.circuit.pb; + +message Voucher { + required bytes relay = 1; + required bytes peer = 2; + required uint64 expiration = 3; +} \ No newline at end of file diff --git a/libp2p/src/test/java/io/libp2p/core/RelayTestJava.java b/libp2p/src/test/java/io/libp2p/core/RelayTestJava.java new file mode 100644 index 000000000..03ce8c28a --- /dev/null +++ b/libp2p/src/test/java/io/libp2p/core/RelayTestJava.java @@ -0,0 +1,225 @@ +package io.libp2p.core; + +import io.libp2p.core.crypto.*; +import io.libp2p.core.dsl.*; +import io.libp2p.core.multiformats.*; +import io.libp2p.core.mux.*; +import io.libp2p.protocol.*; +import io.libp2p.protocol.circuit.*; +import io.libp2p.security.noise.*; +import io.libp2p.transport.tcp.*; +import java.util.*; +import java.util.concurrent.*; +import org.junit.jupiter.api.*; + +public class RelayTestJava { + + private static void enableRelay(BuilderJ b, List relays) { + PrivKey priv = b.getIdentity().random().getFactory().invoke(); + b.getIdentity().setFactory(() -> priv); + PeerId us = PeerId.fromPubKey(priv.publicKey()); + CircuitHopProtocol.RelayManager relayManager = + CircuitHopProtocol.RelayManager.limitTo(priv, us, 5); + CircuitStopProtocol.Binding stop = new CircuitStopProtocol.Binding(new CircuitStopProtocol()); + CircuitHopProtocol.Binding hop = new CircuitHopProtocol.Binding(relayManager, stop); + b.getProtocols().add(hop); + b.getProtocols().add(stop); + b.getTransports() + .add( + u -> new RelayTransport(hop, stop, u, h -> relays, new ScheduledThreadPoolExecutor(1))); + } + + @Test + void pingOverLocalRelay() throws Exception { + String localListenAddress = "/ip4/127.0.0.1/tcp/40002"; + + Host relayHost = + new HostBuilder() + .builderModifier(b -> enableRelay(b, Collections.emptyList())) + .transport(TcpTransport::new) + .secureChannel(NoiseXXSecureChannel::new) + .muxer(StreamMuxerProtocol::getYamux) + .listen("/ip4/127.0.0.1/tcp/0") + .protocol(new Ping()) + .build(); + relayHost.getNetwork().getTransports().stream() + .filter(t -> t instanceof RelayTransport) + .map(t -> (RelayTransport) t) + .findFirst() + .get() + .setHost(relayHost); + CompletableFuture relayStarted = relayHost.start(); + relayStarted.get(5, TimeUnit.SECONDS); + + List relayAddrs = relayHost.listenAddresses(); + Multiaddr relayAddr = relayAddrs.get(0); + RelayTransport.CandidateRelay relay = + new RelayTransport.CandidateRelay(relayHost.getPeerId(), relayAddrs); + List relays = List.of(relay); + + Host clientHost = + new HostBuilder() + .builderModifier(b -> enableRelay(b, relays)) + .transport(TcpTransport::new) + .secureChannel(NoiseXXSecureChannel::new) + .muxer(StreamMuxerProtocol::getYamux) + .protocol(new Ping()) + .build(); + clientHost.getNetwork().getTransports().stream() + .filter(t -> t instanceof RelayTransport) + .map(t -> (RelayTransport) t) + .findFirst() + .get() + .setHost(clientHost); + + Host serverHost = + new HostBuilder() + .builderModifier(b -> enableRelay(b, relays)) + .transport(TcpTransport::new) + .secureChannel(NoiseXXSecureChannel::new) + .muxer(StreamMuxerProtocol::getYamux) + .protocol(new Ping()) + .listen(localListenAddress) + .listen(relayAddr + "/p2p-circuit") + .build(); + serverHost.getNetwork().getTransports().stream() + .filter(t -> t instanceof RelayTransport) + .map(t -> (RelayTransport) t) + .findFirst() + .get() + .setHost(serverHost); + + CompletableFuture clientStarted = clientHost.start(); + CompletableFuture serverStarted = serverHost.start(); + clientStarted.get(5, TimeUnit.SECONDS); + System.out.println("Client started"); + serverStarted.get(5, TimeUnit.SECONDS); + System.out.println("Server started"); + + Multiaddr toDial = + relayAddr.concatenated( + new Multiaddr("/p2p-circuit/p2p/" + serverHost.getPeerId().toBase58())); + System.out.println("Dialling " + toDial + " from " + clientHost.getPeerId()); + StreamPromise ping = + clientHost + .getNetwork() + .connect(serverHost.getPeerId(), toDial) + .thenApply(it -> it.muxerSession().createStream(new Ping())) + .get(5, TimeUnit.SECONDS); + + Stream pingStream = ping.getStream().get(5, TimeUnit.SECONDS); + System.out.println("Ping stream created"); + PingController pingCtr = ping.getController().get(5, TimeUnit.SECONDS); + System.out.println("Ping controller created"); + + for (int i = 0; i < 10; i++) { + long latency = pingCtr.ping().get(1, TimeUnit.SECONDS); + System.out.println("Ping is " + latency); + } + pingStream.close().get(5, TimeUnit.SECONDS); + System.out.println("Ping stream closed"); + + Assertions.assertThrows( + ExecutionException.class, () -> pingCtr.ping().get(5, TimeUnit.SECONDS)); + + clientHost.stop().get(5, TimeUnit.SECONDS); + System.out.println("Client stopped"); + serverHost.stop().get(5, TimeUnit.SECONDS); + System.out.println("Server stopped"); + } + + @Test + void relayStreamsAreLimited() throws Exception { + String localListenAddress = "/ip4/127.0.0.1/tcp/40002"; + + Host relayHost = + new HostBuilder() + .builderModifier(b -> enableRelay(b, Collections.emptyList())) + .transport(TcpTransport::new) + .secureChannel(NoiseXXSecureChannel::new) + .muxer(StreamMuxerProtocol::getYamux) + .listen("/ip4/127.0.0.1/tcp/0") + .build(); + relayHost.getNetwork().getTransports().stream() + .filter(t -> t instanceof RelayTransport) + .map(t -> (RelayTransport) t) + .findFirst() + .get() + .setHost(relayHost); + CompletableFuture relayStarted = relayHost.start(); + relayStarted.get(5, TimeUnit.SECONDS); + + List relayAddrs = relayHost.listenAddresses(); + Multiaddr relayAddr = relayAddrs.get(0); + RelayTransport.CandidateRelay relay = + new RelayTransport.CandidateRelay(relayHost.getPeerId(), relayAddrs); + List relays = List.of(relay); + + // Relay streams are limited to 4096 bytes in either direction + // This is the smallest value that triggers the limit + // not sure why there is so much overhead from 3 * multistream + noise + yamux! + int blobSize = 1469; + Host clientHost = + new HostBuilder() + .builderModifier(b -> enableRelay(b, relays)) + .transport(TcpTransport::new) + .secureChannel(NoiseXXSecureChannel::new) + .muxer(StreamMuxerProtocol::getYamux) + .protocol(new Blob(blobSize)) + .build(); + clientHost.getNetwork().getTransports().stream() + .filter(t -> t instanceof RelayTransport) + .map(t -> (RelayTransport) t) + .findFirst() + .get() + .setHost(clientHost); + + Host serverHost = + new HostBuilder() + .builderModifier(b -> enableRelay(b, relays)) + .transport(TcpTransport::new) + .secureChannel(NoiseXXSecureChannel::new) + .muxer(StreamMuxerProtocol::getYamux) + .protocol(new Blob(blobSize)) + .listen(localListenAddress) + .listen(relayAddr + "/p2p-circuit") + .build(); + serverHost.getNetwork().getTransports().stream() + .filter(t -> t instanceof RelayTransport) + .map(t -> (RelayTransport) t) + .findFirst() + .get() + .setHost(serverHost); + + CompletableFuture clientStarted = clientHost.start(); + CompletableFuture serverStarted = serverHost.start(); + clientStarted.get(5, TimeUnit.SECONDS); + System.out.println("Client started"); + serverStarted.get(5, TimeUnit.SECONDS); + System.out.println("Server started"); + + Multiaddr toDial = + relayAddr.concatenated( + new Multiaddr("/p2p-circuit/p2p/" + serverHost.getPeerId().toBase58())); + System.out.println("Dialling " + toDial + " from " + clientHost.getPeerId()); + StreamPromise blob = + clientHost + .getNetwork() + .connect(serverHost.getPeerId(), toDial) + .thenApply(it -> it.muxerSession().createStream(new Blob(blobSize))) + .get(5, TimeUnit.SECONDS); + + Stream blobStream = blob.getStream().get(5, TimeUnit.SECONDS); + System.out.println("Blob stream created"); + BlobController blobCtr = blob.getController().get(5, TimeUnit.SECONDS); + System.out.println("Blob controller created"); + + Assertions.assertThrows( + ExecutionException.class, () -> blobCtr.blob().get(5, TimeUnit.SECONDS)); + + clientHost.stop().get(5, TimeUnit.SECONDS); + System.out.println("Client stopped"); + serverHost.stop().get(5, TimeUnit.SECONDS); + System.out.println("Server stopped"); + } +} From 2e441b79793811fcc8b86e093cbcf8530d862864 Mon Sep 17 00:00:00 2001 From: Stefan Bratanov Date: Wed, 17 Jan 2024 14:22:43 +0200 Subject: [PATCH 13/13] 1.1.0 release --- build.gradle.kts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build.gradle.kts b/build.gradle.kts index 17448f184..5e5757281 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -37,7 +37,7 @@ configure( } ) { group = "io.libp2p" - version = "1.0.1-RELEASE" + version = "1.1.0-RELEASE" apply(plugin = "kotlin") apply(plugin = "idea")