From 04eaa5c3d01a8f3a599a3a1abf7205eed80df4a2 Mon Sep 17 00:00:00 2001 From: mat Date: Wed, 25 Dec 2024 06:16:10 +0000 Subject: [PATCH] remove dependency on bytes crate for azalea-protocol and fix memory leak --- Cargo.lock | 1 - Cargo.toml | 1 - azalea-client/src/raw_connection.rs | 12 ++--- azalea-protocol/Cargo.toml | 5 +- azalea-protocol/src/connect.rs | 17 +++--- azalea-protocol/src/lib.rs | 12 +++-- azalea-protocol/src/read.rs | 81 +++++++++++++++++------------ azalea-protocol/src/write.rs | 4 +- 8 files changed, 74 insertions(+), 59 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d42f4594f..6cb02f415 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -484,7 +484,6 @@ dependencies = [ "azalea-registry", "azalea-world", "bevy_ecs", - "bytes", "flate2", "futures", "futures-lite", diff --git a/Cargo.toml b/Cargo.toml index 33f6fcb18..24f832a57 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,7 +40,6 @@ bevy_log = "0.15.0" bevy_tasks = "0.15.0" bevy_time = "0.15.0" byteorder = "1.5.0" -bytes = "1.9.0" cfb8 = "0.8.1" chrono = { version = "0.4.39", default-features = false } criterion = "0.5.1" diff --git a/azalea-client/src/raw_connection.rs b/azalea-client/src/raw_connection.rs index 3eacf5289..2091c14e3 100644 --- a/azalea-client/src/raw_connection.rs +++ b/azalea-client/src/raw_connection.rs @@ -33,12 +33,12 @@ pub struct RawConnection { #[derive(Clone)] struct RawConnectionReader { - pub incoming_packet_queue: Arc>>>, + pub incoming_packet_queue: Arc>>>, pub run_schedule_sender: mpsc::UnboundedSender<()>, } #[derive(Clone)] struct RawConnectionWriter { - pub outgoing_packets_sender: mpsc::UnboundedSender>, + pub outgoing_packets_sender: mpsc::UnboundedSender>, } #[derive(Error, Debug)] @@ -54,7 +54,7 @@ pub enum WritePacketError { SendError { #[from] #[backtrace] - source: SendError>, + source: SendError>, }, } @@ -93,7 +93,7 @@ impl RawConnection { } } - pub fn write_raw_packet(&self, raw_packet: Vec) -> Result<(), WritePacketError> { + pub fn write_raw_packet(&self, raw_packet: Box<[u8]>) -> Result<(), WritePacketError> { self.writer.outgoing_packets_sender.send(raw_packet)?; Ok(()) } @@ -120,7 +120,7 @@ impl RawConnection { !self.read_packets_task.is_finished() } - pub fn incoming_packet_queue(&self) -> Arc>>> { + pub fn incoming_packet_queue(&self) -> Arc>>> { self.reader.incoming_packet_queue.clone() } @@ -161,7 +161,7 @@ impl RawConnectionWriter { pub async fn write_task( self, mut write_conn: RawWriteConnection, - mut outgoing_packets_receiver: mpsc::UnboundedReceiver>, + mut outgoing_packets_receiver: mpsc::UnboundedReceiver>, ) { while let Some(raw_packet) = outgoing_packets_receiver.recv().await { if let Err(err) = write_conn.write(&raw_packet).await { diff --git a/azalea-protocol/Cargo.toml b/azalea-protocol/Cargo.toml index 3d5d2ec19..202bd08ac 100644 --- a/azalea-protocol/Cargo.toml +++ b/azalea-protocol/Cargo.toml @@ -33,12 +33,11 @@ azalea-protocol-macros = { path = "./azalea-protocol-macros", version = "0.11.0" azalea-registry = { path = "../azalea-registry", version = "0.11.0" } azalea-world = { path = "../azalea-world", version = "0.11.0" } bevy_ecs = { workspace = true } -#byteorder = { workspace = true } -bytes = { workspace = true } +# byteorder = { workspace = true } flate2 = { workspace = true } futures = { workspace = true } futures-lite = { workspace = true } -#futures-util = { workspace = true } +# futures-util = { workspace = true } serde = { workspace = true, features = ["serde_derive"] } serde_json = { workspace = true } simdnbt = { workspace = true } diff --git a/azalea-protocol/src/connect.rs b/azalea-protocol/src/connect.rs index f33ce2a53..ef2023782 100755 --- a/azalea-protocol/src/connect.rs +++ b/azalea-protocol/src/connect.rs @@ -8,7 +8,6 @@ use std::net::SocketAddr; use azalea_auth::game_profile::GameProfile; use azalea_auth::sessionserver::{ClientSessionServerError, ServerSessionServerError}; use azalea_crypto::{Aes128CfbDec, Aes128CfbEnc}; -use bytes::BytesMut; use thiserror::Error; use tokio::io::{AsyncWriteExt, BufStream}; use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf, ReuniteError}; @@ -28,7 +27,7 @@ use crate::write::{serialize_packet, write_raw_packet}; pub struct RawReadConnection { pub read_stream: OwnedReadHalf, - pub buffer: BytesMut, + pub buffer: Cursor>, pub compression_threshold: Option, pub dec_cipher: Option, } @@ -135,7 +134,7 @@ pub struct Connection { } impl RawReadConnection { - pub async fn read(&mut self) -> Result, Box> { + pub async fn read(&mut self) -> Result, Box> { read_raw_packet::<_>( &mut self.read_stream, &mut self.buffer, @@ -145,7 +144,7 @@ impl RawReadConnection { .await } - pub fn try_read(&mut self) -> Result>, Box> { + pub fn try_read(&mut self) -> Result>, Box> { try_read_raw_packet::<_>( &mut self.read_stream, &mut self.buffer, @@ -190,7 +189,7 @@ where /// Read a packet from the stream. pub async fn read(&mut self) -> Result> { let raw_packet = self.raw.read().await?; - deserialize_packet(&mut Cursor::new(raw_packet.as_slice())) + deserialize_packet(&mut Cursor::new(&raw_packet)) } /// Try to read a packet from the stream, or return Ok(None) if there's no @@ -199,9 +198,7 @@ where let Some(raw_packet) = self.raw.try_read()? else { return Ok(None); }; - Ok(Some(deserialize_packet(&mut Cursor::new( - raw_packet.as_slice(), - ))?)) + Ok(Some(deserialize_packet(&mut Cursor::new(&raw_packet))?)) } } impl WriteConnection @@ -304,7 +301,7 @@ impl Connection { reader: ReadConnection { raw: RawReadConnection { read_stream, - buffer: BytesMut::new(), + buffer: Cursor::new(Vec::new()), compression_threshold: None, dec_cipher: None, }, @@ -562,7 +559,7 @@ where reader: ReadConnection { raw: RawReadConnection { read_stream, - buffer: BytesMut::new(), + buffer: Cursor::new(Vec::new()), compression_threshold: None, dec_cipher: None, }, diff --git a/azalea-protocol/src/lib.rs b/azalea-protocol/src/lib.rs index 5e663c8f2..12243de61 100644 --- a/azalea-protocol/src/lib.rs +++ b/azalea-protocol/src/lib.rs @@ -9,7 +9,7 @@ //! //! See [`crate::connect::Connection`] for an example. -// these two are necessary for thiserror backtraces +// this is necessary for thiserror backtraces #![feature(error_generic_member_access)] use std::{fmt::Display, net::SocketAddr, str::FromStr}; @@ -111,7 +111,6 @@ impl serde::Serialize for ServerAddress { mod tests { use std::io::Cursor; - use bytes::BytesMut; use uuid::Uuid; use crate::{ @@ -135,11 +134,16 @@ mod tests { .await .unwrap(); + assert_eq!( + stream, + [22, 0, 4, 116, 101, 115, 116, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + ); + let mut stream = Cursor::new(stream); let _ = read_packet::( &mut stream, - &mut BytesMut::new(), + &mut Cursor::new(Vec::new()), None, &mut None, ) @@ -163,7 +167,7 @@ mod tests { .unwrap(); let mut stream = Cursor::new(stream); - let mut buffer = BytesMut::new(); + let mut buffer = Cursor::new(Vec::new()); let _ = read_packet::(&mut stream, &mut buffer, None, &mut None) .await diff --git a/azalea-protocol/src/read.rs b/azalea-protocol/src/read.rs index 8569ca734..6f9b754ab 100755 --- a/azalea-protocol/src/read.rs +++ b/azalea-protocol/src/read.rs @@ -9,13 +9,12 @@ use std::{ use azalea_buf::AzaleaReadVar; use azalea_buf::BufReadError; use azalea_crypto::Aes128CfbDec; -use bytes::Buf; -use bytes::BytesMut; use flate2::read::ZlibDecoder; use futures::StreamExt; use futures_lite::future; use thiserror::Error; use tokio::io::AsyncRead; +use tokio_util::bytes::Buf; use tokio_util::codec::{BytesCodec, FramedRead}; use tracing::trace; @@ -79,12 +78,12 @@ pub enum FrameSplitterError { ConnectionClosed, } -/// Read a length, then read that amount of bytes from `BytesMut`. If there's -/// not enough data, return None -fn parse_frame(buffer: &mut BytesMut) -> Result { +/// Read a length, then read that amount of bytes from the `Cursor>`. If +/// there's not enough data, return None +fn parse_frame(buffer: &mut Cursor>) -> Result, FrameSplitterError> { // copy the buffer first and read from the copy, then once we make sure // the packet is all good we read it fully - let mut buffer_copy = Cursor::new(&buffer[..]); + let mut buffer_copy = Cursor::new(&buffer.get_ref()[buffer.position() as usize..]); // Packet Length let length = match u32::azalea_read_var(&mut buffer_copy) { Ok(length) => length as usize, @@ -106,18 +105,28 @@ fn parse_frame(buffer: &mut BytesMut) -> Result { // the length of the varint that says the length of the whole packet let varint_length = buffer.remaining() - buffer_copy.remaining(); + drop(buffer_copy); buffer.advance(varint_length); - let data = buffer.split_to(length); + let data = + buffer.get_ref()[buffer.position() as usize..buffer.position() as usize + length].to_vec(); + buffer.advance(length); + + if buffer.position() == buffer.get_ref().len() as u64 { + // reset the inner vec once we've reached the end of the buffer so we don't keep + // leaking memory + *buffer.get_mut() = Vec::new(); + buffer.set_position(0); + } - Ok(data) + Ok(data.into_boxed_slice()) } -fn frame_splitter(buffer: &mut BytesMut) -> Result>, FrameSplitterError> { +fn frame_splitter(buffer: &mut Cursor>) -> Result>, FrameSplitterError> { // https://tokio.rs/tokio/tutorial/framing let read_frame = parse_frame(buffer); match read_frame { - Ok(frame) => return Ok(Some(frame.to_vec())), + Ok(frame) => return Ok(Some(frame)), Err(err) => match err { FrameSplitterError::BadLength { .. } | FrameSplitterError::Io { .. } => { // we probably just haven't read enough yet @@ -141,7 +150,7 @@ pub fn deserialize_packet( // this is always true in multiplayer, false in singleplayer static VALIDATE_DECOMPRESSED: bool = true; -pub static MAXIMUM_UNCOMPRESSED_LENGTH: u32 = 2097152; +pub static MAXIMUM_UNCOMPRESSED_LENGTH: u32 = 2_097_152; #[derive(Error, Debug)] pub enum DecompressionError { @@ -169,13 +178,15 @@ pub enum DecompressionError { pub fn compression_decoder( stream: &mut Cursor<&[u8]>, compression_threshold: u32, -) -> Result, DecompressionError> { +) -> Result, DecompressionError> { // Data Length let n = u32::azalea_read_var(stream)?; if n == 0 { // no data size, no compression - let mut buf = vec![]; - std::io::Read::read_to_end(stream, &mut buf)?; + let buf = stream.get_ref()[stream.position() as usize..] + .to_vec() + .into_boxed_slice(); + stream.set_position(stream.get_ref().len() as u64); return Ok(buf); } @@ -194,11 +205,14 @@ pub fn compression_decoder( } } - let mut decoded_buf = vec![]; + // VALIDATE_DECOMPRESSED should always be true, so the max they can make us + // allocate here is 2mb + let mut decoded_buf = Vec::with_capacity(n as usize); + let mut decoder = ZlibDecoder::new(stream); decoder.read_to_end(&mut decoded_buf)?; - Ok(decoded_buf) + Ok(decoded_buf.into_boxed_slice()) } /// Read a single packet from a stream. @@ -211,7 +225,7 @@ pub fn compression_decoder( /// For the non-waiting version, see [`try_read_packet`]. pub async fn read_packet( stream: &mut R, - buffer: &mut BytesMut, + buffer: &mut Cursor>, compression_threshold: Option, cipher: &mut Option, ) -> Result> @@ -219,7 +233,7 @@ where R: AsyncRead + std::marker::Unpin + std::marker::Send + std::marker::Sync, { let raw_packet = read_raw_packet(stream, buffer, compression_threshold, cipher).await?; - let packet = deserialize_packet(&mut Cursor::new(raw_packet.as_slice()))?; + let packet = deserialize_packet(&mut Cursor::new(&raw_packet))?; Ok(packet) } @@ -227,7 +241,7 @@ where /// received a full packet yet. pub fn try_read_packet( stream: &mut R, - buffer: &mut BytesMut, + buffer: &mut Cursor>, compression_threshold: Option, cipher: &mut Option, ) -> Result, Box> @@ -238,18 +252,18 @@ where else { return Ok(None); }; - let packet = deserialize_packet(&mut Cursor::new(raw_packet.as_slice()))?; + let packet = deserialize_packet(&mut Cursor::new(&raw_packet))?; Ok(Some(packet)) } pub async fn read_raw_packet( stream: &mut R, - buffer: &mut BytesMut, + buffer: &mut Cursor>, compression_threshold: Option, // this has to be a &mut Option instead of an Option<&mut T> because // otherwise the borrow checker complains about the cipher being moved cipher: &mut Option, -) -> Result, Box> +) -> Result, Box> where R: AsyncRead + std::marker::Unpin + std::marker::Send + std::marker::Sync, { @@ -260,15 +274,15 @@ where }; let bytes = read_and_decrypt_frame(stream, cipher).await?; - buffer.extend_from_slice(&bytes); + buffer.get_mut().extend_from_slice(&bytes); } } pub fn try_read_raw_packet( stream: &mut R, - buffer: &mut BytesMut, + buffer: &mut Cursor>, compression_threshold: Option, cipher: &mut Option, -) -> Result>, Box> +) -> Result>, Box> where R: AsyncRead + std::marker::Unpin + std::marker::Send + std::marker::Sync, { @@ -282,14 +296,14 @@ where return Ok(None); }; // we got some data, so add it to the buffer and try again - buffer.extend_from_slice(&bytes); + buffer.get_mut().extend_from_slice(&bytes); } } async fn read_and_decrypt_frame( stream: &mut R, cipher: &mut Option, -) -> Result> +) -> Result, Box> where R: AsyncRead + Unpin + Send + Sync, { @@ -298,7 +312,9 @@ where let Some(message) = framed.next().await else { return Err(Box::new(ReadPacketError::ConnectionClosed)); }; - let mut bytes = message.map_err(ReadPacketError::from)?; + let bytes = message.map_err(ReadPacketError::from)?; + + let mut bytes = bytes.to_vec().into_boxed_slice(); // decrypt if necessary if let Some(cipher) = cipher { @@ -310,7 +326,7 @@ where fn try_read_and_decrypt_frame( stream: &mut R, cipher: &mut Option, -) -> Result, Box> +) -> Result>, Box> where R: AsyncRead + Unpin + Send + Sync, { @@ -323,7 +339,8 @@ where let Some(message) = message else { return Err(Box::new(ReadPacketError::ConnectionClosed)); }; - let mut bytes = message.map_err(ReadPacketError::from)?; + let bytes = message.map_err(ReadPacketError::from)?; + let mut bytes = bytes.to_vec().into_boxed_slice(); // decrypt if necessary if let Some(cipher) = cipher { @@ -334,9 +351,9 @@ where } pub fn read_raw_packet_from_buffer( - buffer: &mut BytesMut, + buffer: &mut Cursor>, compression_threshold: Option, -) -> Result>, Box> +) -> Result>, Box> where R: AsyncRead + std::marker::Unpin + std::marker::Send + std::marker::Sync, { diff --git a/azalea-protocol/src/write.rs b/azalea-protocol/src/write.rs index 512d08ad7..f1ffd82e8 100755 --- a/azalea-protocol/src/write.rs +++ b/azalea-protocol/src/write.rs @@ -31,7 +31,7 @@ where pub fn serialize_packet( packet: &P, -) -> Result, PacketEncodeError> { +) -> Result, PacketEncodeError> { let mut buf = Vec::new(); packet.id().azalea_write_var(&mut buf)?; packet.write(&mut buf)?; @@ -42,7 +42,7 @@ pub fn serialize_packet( packet_string: format!("{packet:?}"), }); } - Ok(buf) + Ok(buf.into_boxed_slice()) } pub async fn write_raw_packet(