diff --git a/src/stateless_transportstate.rs b/src/stateless_transportstate.rs index 7eeb0a8..904ab4a 100644 --- a/src/stateless_transportstate.rs +++ b/src/stateless_transportstate.rs @@ -40,7 +40,8 @@ impl StatelessTransportState { /// doesn't necessitate a remote static key, *or* if the remote /// static key is not yet known (as can be the case in the `XX` /// pattern, for example). - #[must_use] pub fn get_remote_static(&self) -> Option<&[u8]> { + #[must_use] + pub fn get_remote_static(&self) -> Option<&[u8]> { self.rs.get().map(|rs| &rs[..self.dh_len]) } @@ -74,6 +75,7 @@ impl StatelessTransportState { /// Returns the number of bytes written to `payload`. /// /// # Errors + /// Will result in `Error::Input` if the message is more than 65535 bytes. /// /// Will result in `Error::Decrypt` if the contents couldn't be decrypted and/or the /// authentication tag didn't verify. @@ -85,11 +87,14 @@ impl StatelessTransportState { payload: &[u8], message: &mut [u8], ) -> Result { - if self.initiator && self.pattern.is_oneway() { - return Err(StateProblem::OneWay.into()); + if payload.len() > MAXMSGLEN { + Err(Error::Input) + } else if self.initiator && self.pattern.is_oneway() { + Err(StateProblem::OneWay.into()) + } else { + let cipher = if self.initiator { &self.cipherstates.1 } else { &self.cipherstates.0 }; + cipher.decrypt(nonce, payload, message) } - let cipher = if self.initiator { &self.cipherstates.1 } else { &self.cipherstates.0 }; - cipher.decrypt(nonce, payload, message) } /// Generate a new key for the egress symmetric cipher according to Section 4.2 @@ -141,7 +146,8 @@ impl StatelessTransportState { } /// Check if this session was started with the "initiator" role. - #[must_use] pub fn is_initiator(&self) -> bool { + #[must_use] + pub fn is_initiator(&self) -> bool { self.initiator } } diff --git a/src/transportstate.rs b/src/transportstate.rs index b5ab403..b927b37 100644 --- a/src/transportstate.rs +++ b/src/transportstate.rs @@ -40,7 +40,8 @@ impl TransportState { /// doesn't necessitate a remote static key, *or* if the remote /// static key is not yet known (as can be the case in the `XX` /// pattern, for example). - #[must_use] pub fn get_remote_static(&self) -> Option<&[u8]> { + #[must_use] + pub fn get_remote_static(&self) -> Option<&[u8]> { self.rs.get().map(|rs| &rs[..self.dh_len]) } @@ -70,19 +71,22 @@ impl TransportState { /// Returns the number of bytes written to `payload`. /// /// # Errors + /// Will result in `Error::Input` if the message is more than 65535 bytes. /// /// Will result in `Error::Decrypt` if the contents couldn't be decrypted and/or the /// authentication tag didn't verify. /// /// Will result in `StateProblem::Exhausted` if the max nonce overflows. pub fn read_message(&mut self, message: &[u8], payload: &mut [u8]) -> Result { - if self.initiator && self.pattern.is_oneway() { - return Err(StateProblem::OneWay.into()); + if message.len() > MAXMSGLEN { + Err(Error::Input) + } else if self.initiator && self.pattern.is_oneway() { + Err(StateProblem::OneWay.into()) + } else { + let cipher = + if self.initiator { &mut self.cipherstates.1 } else { &mut self.cipherstates.0 }; + cipher.decrypt(message, payload) } - let cipher = - if self.initiator { &mut self.cipherstates.1 } else { &mut self.cipherstates.0 }; - - cipher.decrypt(message, payload) } /// Generate a new key for the egress symmetric cipher according to Section 4.2 @@ -147,7 +151,8 @@ impl TransportState { /// # Errors /// /// Will result in `Error::State` if not in transport mode. - #[must_use] pub fn receiving_nonce(&self) -> u64 { + #[must_use] + pub fn receiving_nonce(&self) -> u64 { if self.initiator { self.cipherstates.1.nonce() } else { @@ -160,7 +165,8 @@ impl TransportState { /// # Errors /// /// Will result in `Error::State` if not in transport mode. - #[must_use] pub fn sending_nonce(&self) -> u64 { + #[must_use] + pub fn sending_nonce(&self) -> u64 { if self.initiator { self.cipherstates.0.nonce() } else { @@ -169,7 +175,8 @@ impl TransportState { } /// Check if this session was started with the "initiator" role. - #[must_use] pub fn is_initiator(&self) -> bool { + #[must_use] + pub fn is_initiator(&self) -> bool { self.initiator } } diff --git a/tests/general.rs b/tests/general.rs index 57e428d..d97662b 100644 --- a/tests/general.rs +++ b/tests/general.rs @@ -5,7 +5,7 @@ use hex::FromHex; use snow::{ resolvers::{CryptoResolver, DefaultResolver}, - Builder, + Builder, Error, }; use rand_core::{impls, CryptoRng, RngCore}; @@ -514,6 +514,36 @@ fn test_handshake_message_undersized_output_buffer() -> TestResult { Ok(()) } +#[test] +fn test_handshake_message_receive_oversized_message() -> TestResult { + let params: NoiseParams = "Noise_NN_25519_ChaChaPoly_SHA256".parse()?; + let mut h_i = Builder::new(params.clone()).build_initiator()?; + let mut h_r = Builder::new(params).build_responder()?; + + let mut buffer_msg = [0u8; 100_000]; + let mut buffer_out = [0u8; 100_000]; + let len = h_i.write_message(b"abc", &mut buffer_msg)?; + assert_eq!(Error::Input, h_r.read_message(&buffer_msg, &mut buffer_out).unwrap_err()); + h_r.read_message(&buffer_msg[..len], &mut buffer_out)?; + + let len = h_r.write_message(b"defg", &mut buffer_msg)?; + h_i.read_message(&buffer_msg[..len], &mut buffer_out)?; + + let h_i = h_i.into_stateless_transport_mode()?; + let mut h_r = h_r.into_transport_mode()?; + + let len = h_i.write_message(0, b"hack the planet", &mut buffer_msg)?; + assert_eq!(Error::Input, h_r.read_message(&buffer_msg, &mut buffer_out).unwrap_err()); + h_r.read_message(&buffer_msg[..len], &mut buffer_out)?; + + let len = h_r.write_message(b"hack the planet", &mut buffer_msg)?; + assert_eq!(Error::Input, h_i.read_message(0, &buffer_msg, &mut buffer_out).unwrap_err()); + let len = h_i.read_message(0, &buffer_msg[..len], &mut buffer_out)?; + assert_eq!(&buffer_out[..len], b"hack the planet"); + + Ok(()) +} + #[test] fn test_transport_message_exceeds_max_len() -> TestResult { let params: NoiseParams = "Noise_N_25519_ChaChaPoly_SHA256".parse()?;