diff --git a/src/device/socket/connectionmanager.rs b/src/device/socket/connectionmanager.rs index 518b5547..0c9e4e81 100644 --- a/src/device/socket/connectionmanager.rs +++ b/src/device/socket/connectionmanager.rs @@ -1,6 +1,6 @@ use super::{ protocol::VsockAddr, vsock::ConnectionInfo, DisconnectReason, SocketError, VirtIOSocket, - VsockEvent, VsockEventType, + VsockEvent, VsockEventType, DEFAULT_RX_BUFFER_SIZE, }; use crate::{transport::Transport, Hal, Result}; use alloc::{boxed::Box, vec::Vec}; @@ -10,12 +10,15 @@ use core::hint::spin_loop; use log::debug; use zerocopy::FromZeroes; -const PER_CONNECTION_BUFFER_CAPACITY: usize = 1024; +const DEFAULT_PER_CONNECTION_BUFFER_CAPACITY: u32 = 1024; /// A higher level interface for VirtIO socket (vsock) devices. /// /// This keeps track of multiple vsock connections. /// +/// `RX_BUFFER_SIZE` is the size in bytes of each buffer used in the RX virtqueue. This must be +/// bigger than `size_of::()`. +/// /// # Example /// /// ``` @@ -40,8 +43,13 @@ const PER_CONNECTION_BUFFER_CAPACITY: usize = 1024; /// # Ok(()) /// # } /// ``` -pub struct VsockConnectionManager { - driver: VirtIOSocket, +pub struct VsockConnectionManager< + H: Hal, + T: Transport, + const RX_BUFFER_SIZE: usize = DEFAULT_RX_BUFFER_SIZE, +> { + driver: VirtIOSocket, + per_connection_buffer_capacity: u32, connections: Vec, listening_ports: Vec, } @@ -56,24 +64,36 @@ struct Connection { } impl Connection { - fn new(peer: VsockAddr, local_port: u32) -> Self { + fn new(peer: VsockAddr, local_port: u32, buffer_capacity: u32) -> Self { let mut info = ConnectionInfo::new(peer, local_port); - info.buf_alloc = PER_CONNECTION_BUFFER_CAPACITY.try_into().unwrap(); + info.buf_alloc = buffer_capacity; Self { info, - buffer: RingBuffer::new(PER_CONNECTION_BUFFER_CAPACITY), + buffer: RingBuffer::new(buffer_capacity.try_into().unwrap()), peer_requested_shutdown: false, } } } -impl VsockConnectionManager { +impl + VsockConnectionManager +{ /// Construct a new connection manager wrapping the given low-level VirtIO socket driver. - pub fn new(driver: VirtIOSocket) -> Self { + pub fn new(driver: VirtIOSocket) -> Self { + Self::new_with_capacity(driver, DEFAULT_PER_CONNECTION_BUFFER_CAPACITY) + } + + /// Construct a new connection manager wrapping the given low-level VirtIO socket driver, with + /// the given per-connection buffer capacity. + pub fn new_with_capacity( + driver: VirtIOSocket, + per_connection_buffer_capacity: u32, + ) -> Self { Self { driver, connections: Vec::new(), listening_ports: Vec::new(), + per_connection_buffer_capacity, } } @@ -106,7 +126,8 @@ impl VsockConnectionManager { return Err(SocketError::ConnectionExists.into()); } - let new_connection = Connection::new(destination, src_port); + let new_connection = + Connection::new(destination, src_port, self.per_connection_buffer_capacity); self.driver.connect(&new_connection.info)?; debug!("Connection requested: {:?}", new_connection.info); @@ -125,6 +146,7 @@ impl VsockConnectionManager { pub fn poll(&mut self) -> Result> { let guest_cid = self.driver.guest_cid(); let connections = &mut self.connections; + let per_connection_buffer_capacity = self.per_connection_buffer_capacity; let result = self.driver.poll(|event, body| { let connection = get_connection_for_event(connections, &event, guest_cid); @@ -140,7 +162,11 @@ impl VsockConnectionManager { } // Add the new connection to our list, at least for now. It will be removed again // below if we weren't listening on the port. - connections.push(Connection::new(event.source, event.destination.port)); + connections.push(Connection::new( + event.source, + event.destination.port, + per_connection_buffer_capacity, + )); connections.last_mut().unwrap() } else { return Ok(None); diff --git a/src/device/socket/mod.rs b/src/device/socket/mod.rs index 8d2de2b1..3b59d655 100644 --- a/src/device/socket/mod.rs +++ b/src/device/socket/mod.rs @@ -20,3 +20,7 @@ pub use error::SocketError; pub use protocol::{VsockAddr, VMADDR_CID_HOST}; #[cfg(feature = "alloc")] pub use vsock::{DisconnectReason, VirtIOSocket, VsockEvent, VsockEventType}; + +/// The size in bytes of each buffer used in the RX virtqueue. This must be bigger than +/// `size_of::()`. +const DEFAULT_RX_BUFFER_SIZE: usize = 512; diff --git a/src/device/socket/vsock.rs b/src/device/socket/vsock.rs index e3ba1107..372734ff 100644 --- a/src/device/socket/vsock.rs +++ b/src/device/socket/vsock.rs @@ -5,6 +5,7 @@ use super::error::SocketError; use super::protocol::{ Feature, StreamShutdown, VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp, VsockAddr, }; +use super::DEFAULT_RX_BUFFER_SIZE; use crate::hal::Hal; use crate::queue::VirtQueue; use crate::transport::Transport; @@ -23,9 +24,6 @@ const EVENT_QUEUE_IDX: u16 = 2; pub(crate) const QUEUE_SIZE: usize = 8; const SUPPORTED_FEATURES: Feature = Feature::RING_EVENT_IDX; -/// The size in bytes of each buffer used in the RX virtqueue. This must be bigger than size_of::(). -const RX_BUFFER_SIZE: usize = 512; - #[derive(Clone, Debug, Default, PartialEq, Eq)] pub struct ConnectionInfo { pub dst: VsockAddr, @@ -212,7 +210,11 @@ pub enum VsockEventType { /// /// You probably want to use [`VsockConnectionManager`](super::VsockConnectionManager) rather than /// using this directly. -pub struct VirtIOSocket { +/// +/// `RX_BUFFER_SIZE` is the size in bytes of each buffer used in the RX virtqueue. This must be +/// bigger than `size_of::()`. +pub struct VirtIOSocket +{ transport: T, /// Virtqueue to receive packets. rx: VirtQueue, @@ -237,7 +239,9 @@ unsafe impl Sync for VirtIOSocket where { } -impl Drop for VirtIOSocket { +impl Drop + for VirtIOSocket +{ fn drop(&mut self) { // Clear any pointers pointing to DMA regions, so the device doesn't try to access them // after they have been freed. @@ -253,9 +257,11 @@ impl Drop for VirtIOSocket { } } -impl VirtIOSocket { +impl VirtIOSocket { /// Create a new VirtIO Vsock driver. pub fn new(mut transport: T) -> Result { + assert!(RX_BUFFER_SIZE > size_of::()); + let negotiated_features = transport.begin_init(SUPPORTED_FEATURES); let config = transport.config_space::()?;