diff --git a/src/certs/snp/ecdsa/mod.rs b/src/certs/snp/ecdsa/mod.rs index b8d6fa48..c6da6fbe 100644 --- a/src/certs/snp/ecdsa/mod.rs +++ b/src/certs/snp/ecdsa/mod.rs @@ -153,3 +153,157 @@ impl TryFrom<&Signature> for Vec { Ok(ecdsa::EcdsaSig::try_from(value)?.to_der()?) } } +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_signature_default() { + let sig: Signature = Signature::default(); + assert_eq!(sig.r(), &[0u8; 72]); + assert_eq!(sig.s(), &[0u8; 72]); + } + + #[test] + fn test_signature_getters() { + let sig: Signature = Signature { + r: [1u8; 72], + s: [2u8; 72], + _reserved: [0u8; 512 - (SIG_PIECE_SIZE * 2)], + }; + assert_eq!(sig.r(), &[1u8; 72]); + assert_eq!(sig.s(), &[2u8; 72]); + } + + #[test] + fn test_signature_eq() { + let sig1: Signature = Signature::default(); + let sig2: Signature = Signature::default(); + let sig3: Signature = Signature { + r: [1u8; 72], + s: [0u8; 72], + _reserved: [0u8; 512 - (SIG_PIECE_SIZE * 2)], + }; + + assert_eq!(sig1, sig2); + assert_ne!(sig1, sig3); + } + + #[test] + fn test_signature_ord() { + let sig1: Signature = Signature::default(); + let sig2: Signature = Signature { + r: [1u8; 72], + s: [0u8; 72], + _reserved: [0u8; 512 - (SIG_PIECE_SIZE * 2)], + }; + + assert!(sig1 < sig2); + } + + #[test] + fn test_signature_debug() { + let sig: Signature = Signature::default(); + let debug_str: String = format!("{:?}", sig); + assert!(debug_str.starts_with("Signature { r: ")); + assert!(debug_str.contains(", s: ")); + } + + #[test] + fn test_signature_display() { + let sig: Signature = Signature::default(); + let display_str: String = format!("{}", sig); + assert!(display_str.contains("Signature:")); + assert!(display_str.contains("R:")); + assert!(display_str.contains("S:")); + } + + #[cfg(feature = "openssl")] + mod openssl_tests { + use super::*; + use openssl::bn::BigNum; + use std::convert::TryInto; + + #[test] + fn test_from_ecdsa_sig() { + let r = BigNum::from_dec_str("123").unwrap(); + let s = BigNum::from_dec_str("456").unwrap(); + let ecdsa_sig = ecdsa::EcdsaSig::from_private_components(r, s).unwrap(); + let sig: Signature = ecdsa_sig.into(); + assert_ne!(sig.r(), &[0u8; 72]); + assert_ne!(sig.s(), &[0u8; 72]); + } + + #[test] + fn test_try_from_bytes() { + let r = BigNum::from_dec_str("123").unwrap(); + let s = BigNum::from_dec_str("456").unwrap(); + let ecdsa_sig = ecdsa::EcdsaSig::from_private_components(r, s).unwrap(); + let der = ecdsa_sig.to_der().unwrap(); + let sig = Signature::try_from(der.as_slice()).unwrap(); + assert_ne!(sig.r(), &[0u8; 72]); + assert_ne!(sig.s(), &[0u8; 72]); + } + + #[test] + fn test_try_into_ecdsa_sig() { + let sig = Signature::default(); + let ecdsa_sig: ecdsa::EcdsaSig = (&sig).try_into().unwrap(); + assert_eq!(ecdsa_sig.r().to_vec(), vec![]); + assert_eq!(ecdsa_sig.s().to_vec(), vec![]); + } + + #[test] + fn test_try_into_vec() { + let sig = Signature::default(); + let der: Vec = (&sig).try_into().unwrap(); + assert!(!der.is_empty()); + } + } + + #[cfg(feature = "crypto_nossl")] + mod crypto_nossl_tests { + use super::*; + use std::convert::TryInto; + + #[test] + #[should_panic] + fn test_try_into_p384_signature_failure() { + let signature: Signature = Signature::default(); + + let _p384_sig: p384::ecdsa::Signature = (&signature).try_into().unwrap(); + } + + #[test] + fn test_try_into_p384_signature() { + // Test with non-zero values + let sig = Signature { + r: [1u8; 72], + s: [2u8; 72], + _reserved: [0u8; 512 - (SIG_PIECE_SIZE * 2)], + }; + let p384_sig: p384::ecdsa::Signature = (&sig).try_into().unwrap(); + assert_eq!(p384_sig.r().to_bytes().as_slice(), &[1u8; 48]); + assert_eq!(p384_sig.s().to_bytes().as_slice(), &[2u8; 48]); + } + } + + #[test] + fn test_signature_serde() { + let sig: Signature = Signature::default(); + let serialized: Vec = bincode::serialize(&sig).unwrap(); + let deserialized: Signature = bincode::deserialize(&serialized).unwrap(); + assert_eq!(sig, deserialized); + } + + #[test] + fn test_signature_max_values() { + let sig: Signature = Signature { + r: [0xFF; 72], + s: [0xFF; 72], + _reserved: [0u8; 512 - (SIG_PIECE_SIZE * 2)], + }; + assert_eq!(sig.r(), &[0xFF; 72]); + assert_eq!(sig.s(), &[0xFF; 72]); + } +} diff --git a/src/error.rs b/src/error.rs index 57d41ac4..374b9501 100644 --- a/src/error.rs +++ b/src/error.rs @@ -141,7 +141,7 @@ impl From for (u32, u32) { impl From for (VmmError, SevError) { fn from(value: RawFwError) -> Self { - ((value.0 >> 0x20).into(), value.0.into()) + (((value.0 >> 0x20) as u32).into(), (value.0 as u32).into()) } } @@ -655,7 +655,7 @@ impl std::fmt::Display for HashstickError { } } -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)] /// Errors which may be encountered through misuse of the User API. pub enum CertError { /// Malformed GUID. @@ -1117,3 +1117,375 @@ impl From for SessionError { Self::OpenSSLStack(value) } } + +#[cfg(test)] +mod tests { + use bincode::ErrorKind; + + use super::*; + use std::{ + convert::{TryFrom, TryInto}, + error::Error, + }; + + #[test] + fn test_vmm_error_complete() { + // Test all variants + let variants = vec![ + (1u32, VmmError::InvalidCertificatePageLength), + (2u32, VmmError::RateLimitRetryRequest), + (999u32, VmmError::Unknown), + ]; + + for (code, expected) in variants { + // Test u32 conversion + assert_eq!(VmmError::from(code), expected); + // Test u64 conversion + assert_eq!(VmmError::from((code as u64) << 32), expected); + // Test display + assert!(!expected.to_string().is_empty()); + // Test error trait + assert!(std::error::Error::source(&expected).is_none()); + } + } + + #[test] + fn test_sev_error_complete() { + // Test all valid codes + for code in 0x01..=0x27u32 { + if code == 0x1E { + continue; + } // Skip gap + let err = SevError::from(code); + assert!(!matches!(err, SevError::UnknownError)); + + // Test u64 conversion + let err64 = SevError::from(code as u64); + assert_eq!(err, err64); + + // Test display + assert!(!err.to_string().is_empty()); + + // Test c_int conversion + let c_val: c_int = err.into(); + assert_eq!(c_val as u32, code); + } + + // Test invalid codes + assert_eq!(SevError::from(0u32), SevError::UnknownError); + assert_eq!(SevError::from(0x28u32), SevError::UnknownError); + assert!(!SevError::from(0u32).to_string().is_empty()); + assert!(!SevError::from(0x28u32).to_string().is_empty()); + let err: SevError = SevError::UnknownError; + let c_val: c_int = err.into(); + assert_eq!(c_val as u32, u32::MAX); + } + + #[test] + fn test_raw_fw_error_complete() { + let raw = RawFwError(0x100000000u64); + + // Test display and debug + assert!(raw.to_string().contains("RawFwError: 4294967296")); + assert!(format!("{:?}", raw).contains("RawFwError")); + + // Test From + assert_eq!(RawFwError::from(0x100000000u64), raw); + + // Test tuple conversions + let (upper, lower): (u32, u32) = raw.into(); + assert_eq!(upper, 1); + assert_eq!(lower, 0); + + let raw2 = RawFwError(0x100000000u64); + let (vmm, _sev): (VmmError, SevError) = raw2.into(); + assert_eq!(vmm, VmmError::InvalidCertificatePageLength); + } + + #[test] + fn test_firmware_error_complete() { + let io_err = std::io::Error::new(std::io::ErrorKind::Other, "test"); + let variants = vec![ + FirmwareError::IoError(io_err), + FirmwareError::KnownSevError(SevError::InvalidPlatformState), + FirmwareError::UnknownSevError(999), + ]; + + for err in variants { + // Test display + assert!(!err.to_string().is_empty()); + + // Test c_int conversion + let c_val: c_int = err.into(); + assert!(c_val == -1 || c_val > 0); + } + + // Test conversions + let from_u32: FirmwareError = 0x0u32.into(); + assert!(matches!(from_u32, FirmwareError::IoError(_))); + let from_u32: FirmwareError = 0x1u32.into(); + assert!(matches!(from_u32, FirmwareError::KnownSevError(_))); + let from_u32: FirmwareError = 0x28u32.into(); + assert!(matches!(from_u32, FirmwareError::UnknownSevError(_))); + let from_u64: FirmwareError = 0x1u64.into(); + assert!(matches!(from_u64, FirmwareError::KnownSevError(_))); + } + + #[test] + fn test_firmware_error_conversions() { + // Test From + let sev_err = SevError::InvalidPlatformState; + let fw_err = FirmwareError::from(sev_err); + assert!(matches!( + fw_err, + FirmwareError::KnownSevError(SevError::InvalidPlatformState) + )); + + let unknown_sev = SevError::UnknownError; + let fw_err = FirmwareError::from(unknown_sev); + assert!(matches!(fw_err, FirmwareError::UnknownSevError(_))); + + // Test From + let io_err = std::io::Error::new(std::io::ErrorKind::Other, "test"); + let fw_err = FirmwareError::from(io_err); + assert!(matches!(fw_err, FirmwareError::IoError(_))); + } + + #[test] + fn test_user_api_error_complete() { + let variants = vec![ + FirmwareError::UnknownSevError(0).into(), + std::io::Error::new(std::io::ErrorKind::Other, "test").into(), + CertError::UnknownError.into(), + VmmError::Unknown.into(), + uuid::Uuid::try_from("").unwrap_err().into(), + HashstickError::UnknownError.into(), + UserApiError::VmplError, // No From impl + UserApiError::Unknown, // No From impl + ]; + + for err in variants { + // Test display + assert!(!err.to_string().is_empty()); + // Test error source + match &err { + UserApiError::VmplError | UserApiError::Unknown => assert!(err.source().is_none()), + _ => assert!(err.source().is_some()), + } + // Test io::Error conversion + let _: std::io::Error = err.into(); + } + + let sev_error: SevError = SevError::InvalidPlatformState; + let uapi_error: UserApiError = sev_error.into(); + assert!(matches!(uapi_error, UserApiError::FirmwareError(_))); + } + + #[test] + fn test_hashstick_error_complete() { + let variants = vec![ + HashstickError::InvalidLength, + HashstickError::EmptyHashstickBuffer, + HashstickError::UnknownError, + ]; + + for err in variants { + assert!(!err.to_string().is_empty()); + assert!(std::error::Error::source(&err).is_none()); + } + } + + #[test] + fn test_cert_error_complete() { + let variants = vec![ + CertError::InvalidGUID, + CertError::PageMisalignment, + CertError::BufferOverflow, + CertError::EmptyCertBuffer, + CertError::UnknownError, + ]; + + for err in variants { + assert!(!err.to_string().is_empty()); + assert!(std::error::Error::source(&err).is_none()); + } + } + + #[test] + fn test_gctx_error_complete() { + let variants = vec![ + GCTXError::InvalidPageSize(100, 200), + GCTXError::InvalidBlockSize, + GCTXError::MissingData, + GCTXError::MissingBlockSize, + GCTXError::UnknownError, + ]; + + for err in variants { + assert!(!err.to_string().is_empty()); + assert!(std::error::Error::source(&err).is_none()); + } + } + + #[test] + fn test_ovmf_error_complete() { + let variants = vec![ + OVMFError::InvalidSectionType, + OVMFError::SEVMetadataVerification("test".into()), + OVMFError::EntryMissingInTable("test".into()), + OVMFError::GetTableItemError, + OVMFError::InvalidSize("test".into(), 1, 2), + OVMFError::MismatchingGUID, + OVMFError::UnknownError, + ]; + + for err in variants { + assert!(!err.to_string().is_empty()); + assert!(std::error::Error::source(&err).is_none()); + } + } + + #[test] + fn test_sev_hash_error_complete() { + let variants = vec![ + SevHashError::InvalidSize(1, 2), + SevHashError::InvalidOffset(1, 2), + SevHashError::UnknownError, + ]; + + for err in variants { + assert!(!err.to_string().is_empty()); + assert!(std::error::Error::source(&err).is_none()); + } + } + + #[test] + fn test_large_array_error_complete() { + let slice_err: Result<[u8; 2], TryFromSliceError> = vec![1u8].as_slice().try_into(); + let variants = vec![ + slice_err.unwrap_err().into(), + LargeArrayError::VectorError("test".into()), + ]; + + for err in variants { + assert!(!err.to_string().is_empty()); + assert!(std::error::Error::source(&err).is_none()); + } + } + + #[test] + fn test_id_block_error_complete() { + let slice_err: Result<[u8; 2], TryFromSliceError> = vec![1u8].as_slice().try_into(); + let bincode_err: ErrorKind = bincode::ErrorKind::Custom("test".into()); + + let variants = vec![ + LargeArrayError::VectorError("test".into()).into(), + std::io::Error::new(std::io::ErrorKind::Other, "test").into(), + bincode_err.into(), + slice_err.unwrap_err().into(), + IdBlockError::SevCurveError(), + IdBlockError::SevEcsdsaSigError("test".into()), + ]; + + for err in variants { + assert!(!err.to_string().is_empty()); + assert!(std::error::Error::source(&err).is_none()); + } + + // Test conversions + let arr_err = LargeArrayError::VectorError("test".into()); + assert!(matches!( + IdBlockError::from(arr_err), + IdBlockError::LargeArrayError(_) + )); + } + + #[test] + fn test_measurement_error_complete() { + let slice_err: Result<[u8; 2], TryFromSliceError> = vec![1u8].as_slice().try_into(); + let bincode_err: ErrorKind = bincode::ErrorKind::Custom("test".into()); + + let uuid_err = uuid::Uuid::try_from("").unwrap_err(); + + let variants = vec![ + slice_err.unwrap_err().into(), + uuid_err.into(), + bincode_err.into(), + std::io::Error::new(std::io::ErrorKind::Other, "test").into(), + hex::FromHexError::OddLength.into(), + GCTXError::UnknownError.into(), + OVMFError::UnknownError.into(), + SevHashError::UnknownError.into(), + IdBlockError::SevCurveError().into(), + LargeArrayError::VectorError("test".into()).into(), + MeasurementError::InvalidVcpuTypeError("test".into()), + MeasurementError::InvalidVcpuSignatureError("test".into()), + MeasurementError::InvalidVmmError("test".into()), + MeasurementError::InvalidSevModeError("test".into()), + MeasurementError::InvalidOvmfKernelError, + MeasurementError::MissingSection("test".into()), + ]; + + for err in variants { + assert!(!err.to_string().is_empty()); + assert!( + err.source().is_some() + || matches!( + err, + MeasurementError::FromSliceError(_) + | MeasurementError::UUIDError(_) + | MeasurementError::BincodeError(_) + | MeasurementError::FileError(_) + | MeasurementError::FromHexError(_) + | MeasurementError::GCTXError(_) + | MeasurementError::OVMFError(_) + | MeasurementError::SevHashError(_) + | MeasurementError::IdBlockError(_) + | MeasurementError::LargeArrayError(_) + | MeasurementError::InvalidVcpuTypeError(_) + | MeasurementError::InvalidVcpuSignatureError(_) + | MeasurementError::InvalidVmmError(_) + | MeasurementError::InvalidSevModeError(_) + | MeasurementError::InvalidOvmfKernelError + | MeasurementError::MissingSection(_) + ) + ); + } + } + + #[cfg(feature = "openssl")] + #[test] + fn test_openssl_features_complete() { + // Test CertFormatError + let cert_err = CertFormatError::UnknownFormat; + assert!(!cert_err.to_string().is_empty()); + assert!(std::error::Error::source(&cert_err).is_none()); + + // Test SessionError + let io_err = std::io::Error::new(std::io::ErrorKind::Other, "test"); + let variants = vec![ + SessionError::RandError(ErrorCode::HardwareFailure), + SessionError::IOError(io_err), + SessionError::OpenSSLStack(ErrorStack::get()), + ]; + + for err in variants { + let debug_str = format!("{:?}", err); + match err { + SessionError::RandError(_) => assert!(debug_str.contains("RandError")), + SessionError::IOError(_) => assert!(debug_str.contains("IOError")), + SessionError::OpenSSLStack(_) => assert!(debug_str.contains("OpenSSLStack")), + } + } + + // Test conversions + let from_io = SessionError::from(std::io::Error::new(std::io::ErrorKind::Other, "test")); + assert!(matches!(from_io, SessionError::IOError(_))); + + let from_code = SessionError::from(ErrorCode::HardwareFailure); + assert!(matches!(from_code, SessionError::RandError(_))); + + let from_stack = SessionError::from(ErrorStack::get()); + assert!(matches!(from_stack, SessionError::OpenSSLStack(_))); + } +} diff --git a/src/firmware/guest/types/snp.rs b/src/firmware/guest/types/snp.rs index 3656e782..d8cc5a3e 100644 --- a/src/firmware/guest/types/snp.rs +++ b/src/firmware/guest/types/snp.rs @@ -561,7 +561,7 @@ bitfield! { /// Indicates that ciphertext hiding is enabled pub ciphertext_hiding_enabled, _: 4, 4; /// reserved - reserved, _: 5, 63; + reserved, _: 63, 5; } impl Display for PlatformInfo { @@ -613,7 +613,7 @@ bitfield! { /// (7) NONE pub signing_key, _: 4,2; /// reserved - reserved, _: 5, 31; + reserved, _: 31, 5; } impl Display for KeyInfo { @@ -642,6 +642,7 @@ Key Information: #[cfg(test)] mod tests { + use super::*; #[test] @@ -919,6 +920,24 @@ Signature: assert_eq!(expected, AttestationReport::default().to_string()) } + #[test] + fn test_attestation_report_clone() { + let expected: AttestationReport = AttestationReport::default(); + + let copy: AttestationReport = expected; + + assert_eq!(expected, copy); + } + + #[test] + fn test_attestation_report_copy() { + let expected: AttestationReport = AttestationReport::default(); + + let copy: AttestationReport = expected; + + assert_eq!(expected, copy); + } + #[test] fn test_guest_policy_zeroed() { let gp: GuestPolicy = GuestPolicy(0); @@ -1113,4 +1132,161 @@ Key Information: assert_eq!(expected, actual.to_string()); } + + #[test] + fn test_platform_info_serialization() { + let original = PlatformInfo(0b11111); + + // Test bincode + let binary = bincode::serialize(&original).unwrap(); + let from_binary: PlatformInfo = bincode::deserialize(&binary).unwrap(); + assert_eq!(original, from_binary); + } + + #[test] + fn test_key_info_serialization() { + let original = KeyInfo(0b11111); + + // Test bincode + let binary = bincode::serialize(&original).unwrap(); + let from_binary: KeyInfo = bincode::deserialize(&binary).unwrap(); + assert_eq!(original, from_binary); + assert!(from_binary.author_key_en()); + assert_eq!(from_binary.mask_chip_key(), 1); + assert_eq!(from_binary.signing_key(), 0b111); + } + + #[test] + fn test_guest_policy_serialization() { + let mut original = GuestPolicy::default(); + original.set_abi_major(2); + original.set_abi_minor(1); + original.set_smt_allowed(1); + original.set_debug_allowed(1); + + // Test bincode + let binary = bincode::serialize(&original).unwrap(); + let from_binary: GuestPolicy = bincode::deserialize(&binary).unwrap(); + assert_eq!(original, from_binary); + } + + #[test] + fn test_attestation_report_serialization() { + let original: AttestationReport = AttestationReport { + version: 2, + guest_svn: 1, + policy: GuestPolicy(3), + family_id: [1; 16], + image_id: [2; 16], + ..Default::default() + }; + + // Test bincode + let binary = bincode::serialize(&original).unwrap(); + let from_binary: AttestationReport = bincode::deserialize(&binary).unwrap(); + assert_eq!(original, from_binary); + } + + #[test] + fn test_boundary_value_serialization() { + // Test max values + let platform_info = PlatformInfo(u64::MAX); + let key_info = KeyInfo(u32::MAX); + let guest_policy = GuestPolicy(u64::MAX); + + // Verify serialization/deserialization preserves max values + assert_eq!( + platform_info, + bincode::deserialize(&bincode::serialize(&platform_info).unwrap()).unwrap() + ); + assert_eq!( + key_info, + bincode::deserialize(&bincode::serialize(&key_info).unwrap()).unwrap() + ); + assert_eq!( + guest_policy, + bincode::deserialize(&bincode::serialize(&guest_policy).unwrap()).unwrap() + ); + } + + #[test] + fn test_guest_field_select_operations() { + let mut field = GuestFieldSelect::default(); + + field.set_guest_policy(1); + assert_eq!(field.get_guest_policy(), 1); + + field.set_image_id(1); + assert_eq!(field.get_image_id(), 1); + + field.set_family_id(1); + assert_eq!(field.get_family_id(), 1); + + field.set_measurement(1); + assert_eq!(field.get_measurement(), 1); + } + + #[test] + fn test_derived_key_fields() { + let key = DerivedKey::new(true, GuestFieldSelect(0xFF), 2, 3, 0x1234); + assert_eq!(key.get_root_key_select(), 1); + assert_eq!(key.vmpl, 2); + assert_eq!(key.guest_svn, 3); + assert_eq!(key.tcb_version, 0x1234); + } + + #[test] + fn test_key_info_all_combinations() { + let mut info = KeyInfo(0); + + // Test VCEK + assert_eq!(info.signing_key(), 0); + assert!(!info.author_key_en()); + + // Test VLEK + info = KeyInfo(0b100); + assert_eq!(info.signing_key(), 1); + + // Test None + info = KeyInfo(0b11100); + assert_eq!(info.signing_key(), 7); + } + + #[test] + fn test_attestation_report_fields() { + let report: AttestationReport = AttestationReport { + version: 2, + guest_svn: 1, + vmpl: 3, + ..Default::default() + }; + assert_eq!(report.version, 2); + assert_eq!(report.guest_svn, 1); + assert_eq!(report.vmpl, 3); + assert_eq!(report.measurement, [0; 48]); + } + + #[test] + fn test_platform_info_reserved() { + let info = PlatformInfo(0xFF); + assert_eq!(info.reserved(), 0x7); + } + + #[test] + fn test_guest_policy_combined_fields() { + let mut policy = GuestPolicy::default(); + + policy.set_abi_major(2); + policy.set_abi_minor(1); + policy.set_smt_allowed(1); + policy.set_debug_allowed(1); + + assert_eq!(policy.abi_major(), 2); + assert_eq!(policy.abi_minor(), 1); + assert_eq!(policy.smt_allowed(), 1); + assert_eq!(policy.debug_allowed(), 1); + + let policy_u64: u64 = policy.into(); + assert_eq!(policy_u64 & (1 << 17), 1 << 17); // Reserved bit 17 must be 1 + } } diff --git a/src/firmware/host/types/mod.rs b/src/firmware/host/types/mod.rs index 9e3eec77..84e99702 100644 --- a/src/firmware/host/types/mod.rs +++ b/src/firmware/host/types/mod.rs @@ -41,3 +41,40 @@ impl std::fmt::Display for State { write!(f, "{state}") } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_state_display() { + assert_eq!(State::Uninitialized.to_string(), "uninitialized"); + assert_eq!(State::Initialized.to_string(), "initialized"); + assert_eq!(State::Working.to_string(), "working"); + } + + #[test] + fn test_state_representation() { + assert_eq!(State::Uninitialized as u8, 0); + assert_eq!(State::Initialized as u8, 1); + assert_eq!(State::Working as u8, 2); + } + + #[test] + fn test_state_debug() { + assert_eq!(format!("{:?}", State::Uninitialized), "Uninitialized"); + assert_eq!(format!("{:?}", State::Initialized), "Initialized"); + assert_eq!(format!("{:?}", State::Working), "Working"); + } + + #[test] + fn test_state_equality() { + assert_eq!(State::Uninitialized, State::Uninitialized); + assert_eq!(State::Initialized, State::Initialized); + assert_eq!(State::Working, State::Working); + + assert_ne!(State::Uninitialized, State::Initialized); + assert_ne!(State::Initialized, State::Working); + assert_ne!(State::Working, State::Uninitialized); + } +} diff --git a/src/firmware/host/types/snp.rs b/src/firmware/host/types/snp.rs index f6a29f9b..6e9131bd 100644 --- a/src/firmware/host/types/snp.rs +++ b/src/firmware/host/types/snp.rs @@ -420,7 +420,7 @@ impl Display for MaskId { #[cfg(test)] mod tests { - use super::{CertType, SnpPlatformStatusFlags}; + use super::*; use uuid::Uuid; #[test] @@ -531,4 +531,771 @@ mod tests { assert_eq!(cert_type.to_string(), expected.to_string()); } + + #[test] + fn test_cert_table_entry_creation() { + let data = vec![1, 2, 3, 4]; + let entry = CertTableEntry::new(CertType::ARK, data.clone()); + + assert_eq!(entry.cert_type, CertType::ARK); + assert_eq!(entry.data(), &data); + assert_eq!(entry.guid_string(), "c0b406a4-a803-4952-9743-3fb6014cd0ae"); + } + + #[test] + fn test_cert_table_entry_from_guid() { + let guid = Uuid::parse_str("c0b406a4-a803-4952-9743-3fb6014cd0ae").unwrap(); + let data = vec![1, 2, 3, 4]; + let entry = CertTableEntry::from_guid(&guid, data.clone()).unwrap(); + + assert_eq!(entry.cert_type, CertType::ARK); + assert_eq!(entry.data(), &data); + } + + #[test] + fn test_cert_table_entry_invalid_guid() { + let guid = Uuid::parse_str("11111111-1111-1111-1111-111111111111").unwrap(); + let data = vec![1, 2, 3, 4]; + let entry = CertTableEntry::from_guid(&guid, data.clone()).unwrap(); + + assert!(matches!(entry.cert_type, CertType::OTHER(_))); + } + + #[test] + fn test_cert_table_entry_empty() { + let entry = CertTableEntry::new(CertType::Empty, vec![]); + + assert_eq!(entry.cert_type, CertType::Empty); + assert!(entry.data().is_empty()); + assert_eq!(entry.guid_string(), "00000000-0000-0000-0000-000000000000"); + } + + #[test] + fn test_cert_table_entry_ordering() { + let entry1 = CertTableEntry::new(CertType::ARK, vec![1]); + let entry2 = CertTableEntry::new(CertType::ASK, vec![2]); + let entry3 = CertTableEntry::new(CertType::Empty, vec![3]); + + assert!(entry1 < entry2); + assert!(entry2 < entry3); + assert!(entry1 < entry3); + } + + #[test] + fn test_cert_table_entry_data_access() { + let large_data = vec![0u8; 1024]; + let entry = CertTableEntry::new(CertType::VCEK, large_data.clone()); + + assert_eq!(entry.data(), &large_data); + } + + #[cfg(target_os = "linux")] + #[test] + fn test_cert_table_conversion() { + let entries = vec![ + CertTableEntry::new(CertType::ARK, vec![1, 2, 3]), + CertTableEntry::new(CertType::ASK, vec![4, 5, 6]), + ]; + + let bytes = CertTableEntry::cert_table_to_vec_bytes(&entries).unwrap(); + let converted = CertTableEntry::vec_bytes_to_cert_table(&mut bytes.clone()).unwrap(); + + assert_eq!(entries.len(), converted.len()); + assert_eq!(entries[0].cert_type, converted[0].cert_type); + assert_eq!(entries[1].cert_type, converted[1].cert_type); + } + + #[test] + fn test_cert_type_conversion() { + let ark_guid = Uuid::parse_str("c0b406a4-a803-4952-9743-3fb6014cd0ae").unwrap(); + let cert_type = CertType::try_from(&ark_guid).unwrap(); + assert_eq!(cert_type, CertType::ARK); + + let uuid = Uuid::try_from(CertType::ARK).unwrap(); + assert_eq!(uuid, ark_guid); + } + + // Test TcbVersion struct and methods + #[test] + fn test_tcb_version() { + let tcb = TcbVersion::new(1, 2, 3, 4); + assert_eq!(tcb.bootloader, 1); + assert_eq!(tcb.tee, 2); + assert_eq!(tcb.snp, 3); + assert_eq!(tcb.microcode, 4); + + // Test Display implementation + let display_output = format!("{}", tcb); + assert!(display_output.contains("Microcode: 4")); + assert!(display_output.contains("SNP: 3")); + } + + // Test Config struct and conversions + #[test] + #[cfg(feature = "snp")] + fn test_config() { + let tcb = TcbVersion::new(1, 2, 3, 4); + let mask = MaskId(0x3); + let config = Config::new(tcb, mask); + + assert_eq!(config.reported_tcb, tcb); + let config_mask = config.mask_id; + assert_eq!(config_mask, mask); + + // Test conversion to FFI type + let snp_config: SnpSetConfig = config.try_into().unwrap(); + assert_eq!(snp_config.reported_tcb, tcb); + let snp_config_mask = snp_config.mask_id; + + assert_eq!(snp_config_mask, mask); + } + + // Test PlatformInit flags + #[test] + fn test_platform_init() { + let mut init = PlatformInit::empty(); + assert!(!init.contains(PlatformInit::IS_RMP_INIT)); + + init.insert(PlatformInit::IS_RMP_INIT); + assert!(init.contains(PlatformInit::IS_RMP_INIT)); + + init.insert(PlatformInit::IS_TIO_EN); + assert!(init.contains(PlatformInit::IS_TIO_EN)); + } + + // Test MaskId bitfield operations + #[test] + fn test_mask_id() { + let mut mask = MaskId(0); + assert_eq!(mask.mask_chip_id(), 0); + + mask.0 = 0x3; + assert_eq!(mask.mask_chip_id(), 1); + assert_eq!(mask.mask_chip_key(), 1); + + // Test Display implementation + let display_output = format!("{}", mask); + assert!(display_output.contains("MaskID (3)")); + } + + // Test Build struct + #[test] + fn test_build() { + let build = Build { + version: Version { major: 1, minor: 2 }, + build: 42, + }; + + assert_eq!(build.version.major, 1); + assert_eq!(build.version.minor, 2); + assert_eq!(build.build, 42); + } + + // Test SnpPlatformStatus + #[test] + fn test_platform_status() { + let status = SnpPlatformStatus::default(); + assert_eq!(status.state, 0); + assert_eq!(status.guest_count, 0); + + let init_status = SnpPlatformStatus { + is_rmp_init: PlatformInit::IS_RMP_INIT, + ..Default::default() + }; + assert!(init_status.is_rmp_init.contains(PlatformInit::IS_RMP_INIT)); + } + + #[test] + fn test_tcb_version_creation_and_display() { + let tcb = TcbVersion::new(1, 2, 3, 4); + assert_eq!(tcb.bootloader, 1); + assert_eq!(tcb.tee, 2); + assert_eq!(tcb.snp, 3); + assert_eq!(tcb.microcode, 4); + + let display = format!("{}", tcb); + assert!(display.contains("Microcode: 4")); + assert!(display.contains("SNP: 3")); + assert!(display.contains("TEE: 2")); + assert!(display.contains("Boot Loader: 1")); + } + + // Build Tests + #[test] + fn test_build_ordering_and_comparison() { + let build1 = Build { + version: Version { major: 1, minor: 0 }, + build: 1, + }; + let build2 = Build { + version: Version { major: 1, minor: 1 }, + build: 1, + }; + assert!(build1 < build2); + + let default_build = Build::default(); + assert_eq!(default_build.version.major, 0); + assert_eq!(default_build.build, 0); + } + + // PlatformInit Tests + #[test] + fn test_platform_init_flags() { + let mut flags = PlatformInit::empty(); + assert!(!flags.contains(PlatformInit::IS_RMP_INIT)); + + flags.insert(PlatformInit::IS_RMP_INIT | PlatformInit::IS_TIO_EN); + assert!(flags.contains(PlatformInit::IS_RMP_INIT)); + assert!(flags.contains(PlatformInit::IS_TIO_EN)); + assert!(!flags.contains(PlatformInit::ALIAS_CHECK_COMPLETE)); + } + + // MaskId Tests + #[test] + fn test_mask_id_operations() { + let mut mask = MaskId(0); + assert_eq!(mask.mask_chip_id(), 0); + assert_eq!(mask.mask_chip_key(), 0); + + mask.0 = 0x3; + assert_eq!(mask.mask_chip_id(), 1); + assert_eq!(mask.mask_chip_key(), 1); + + let display = format!("{}", mask); + assert!(display.contains("MaskID (3)")); + assert!(display.contains("Mask Chip ID: 1")); + } + + // Config Tests + #[test] + #[cfg(feature = "snp")] + fn test_config_conversions() { + let tcb = TcbVersion::new(1, 2, 3, 4); + let mask = MaskId(0x3); + let config = Config::new(tcb, mask); + + let ffi_config: SnpSetConfig = config.try_into().unwrap(); + assert_eq!(ffi_config.reported_tcb, tcb); + let ffi_config_mask = ffi_config.mask_id; + assert_eq!(ffi_config_mask, mask); + + let converted_config: Config = ffi_config.try_into().unwrap(); + assert_eq!(converted_config.reported_tcb, tcb); + let converted_config_mask = converted_config.mask_id; + assert_eq!(converted_config_mask, mask); + } + + // SnpPlatformStatus Tests + #[test] + fn test_platform_status_initialization() { + let mut status = SnpPlatformStatus::default(); + assert_eq!(status.state, 0); + assert_eq!(status.guest_count, 0); + + status.is_rmp_init = PlatformInit::IS_RMP_INIT; + assert!(status.is_rmp_init.contains(PlatformInit::IS_RMP_INIT)); + + status.platform_tcb_version = TcbVersion::new(1, 2, 3, 4); + assert_eq!(status.platform_tcb_version.snp, 3); + } + + #[test] + fn test_platform_status_flags_operations() { + let mut flags = SnpPlatformStatusFlags::empty(); + assert!(!flags.contains(SnpPlatformStatusFlags::OWNED)); + + flags.insert(SnpPlatformStatusFlags::OWNED); + assert!(flags.contains(SnpPlatformStatusFlags::OWNED)); + assert!(!flags.contains(SnpPlatformStatusFlags::ENCRYPTED_STATE)); + + flags.insert(SnpPlatformStatusFlags::ENCRYPTED_STATE); + assert!(flags.contains(SnpPlatformStatusFlags::ENCRYPTED_STATE)); + + flags.remove(SnpPlatformStatusFlags::OWNED); + assert!(!flags.contains(SnpPlatformStatusFlags::OWNED)); + } + + #[test] + fn test_tcb_status() { + let status = TcbStatus { + platform_version: TcbVersion::new(1, 2, 3, 4), + reported_version: TcbVersion::new(5, 6, 7, 8), + }; + + assert_eq!(status.platform_version.bootloader, 1); + assert_eq!(status.reported_version.bootloader, 5); + + let default_status = TcbStatus::default(); + assert_eq!(default_status.platform_version, TcbVersion::default()); + } + + #[test] + #[cfg(feature = "snp")] + fn test_config_error_cases() { + let tcb = TcbVersion::new(255, 255, 255, 255); + let mask = MaskId(u32::MAX); + let config = Config::new(tcb, mask); + + let ffi_result: Result = config.try_into(); + assert!(ffi_result.is_ok()); + + let default_config = Config::default(); + assert_eq!(default_config.reported_tcb, TcbVersion::default()); + let default_config_mask_id = default_config.mask_id; + assert_eq!(default_config_mask_id, MaskId::default()); + } + + #[test] + fn test_version_comparisons() { + let v1 = TcbVersion::new(1, 2, 3, 4); + let v2 = TcbVersion::new(1, 2, 3, 5); + let v3 = TcbVersion::new(1, 2, 3, 4); + + assert!(v1 < v2); + assert_eq!(v1, v3); + assert!(v2 > v1); + + assert!(v1.partial_cmp(&v2).unwrap().is_lt()); + } + + #[test] + fn test_build_version_comparisons() { + let b1 = Build { + version: Version { major: 1, minor: 0 }, + build: 100, + }; + let b2 = Build { + version: Version { major: 1, minor: 1 }, + build: 50, + }; + + assert!(b1 < b2); + assert_ne!(b1, b2); + } + + #[test] + fn test_platform_status_boundary() { + let status = SnpPlatformStatus { + guest_count: u32::MAX, + build_id: u32::MAX, + mask_chip_id: u32::MAX, + ..Default::default() + }; + + assert_eq!(status.guest_count, u32::MAX); + assert_eq!(status.build_id, u32::MAX); + } + + #[test] + fn test_mask_id_boundary() { + let mut mask = MaskId(u32::MAX); + assert_eq!(mask.mask_chip_id(), 1); + assert_eq!(mask.mask_chip_key(), 1); + + mask = MaskId(0); + assert_eq!(mask.mask_chip_id(), 0); + assert_eq!(mask.mask_chip_key(), 0); + } + + #[test] + fn test_platform_init_combinations() { + let mut init = PlatformInit::empty(); + init.insert(PlatformInit::IS_RMP_INIT | PlatformInit::IS_TIO_EN); + assert!(init.contains(PlatformInit::IS_RMP_INIT | PlatformInit::IS_TIO_EN)); + + init.remove(PlatformInit::IS_RMP_INIT); + assert!(!init.contains(PlatformInit::IS_RMP_INIT)); + assert!(init.contains(PlatformInit::IS_TIO_EN)); + } + + #[test] + fn test_tcb_version_reserved() { + let tcb = TcbVersion::new(1, 2, 3, 4); + assert_eq!(tcb._reserved, [0u8; 4]); + } + + #[test] + fn test_config_reserved() { + let config = Config::default(); + assert_eq!(config.reserved, [0u8; 52]); + } + + #[test] + fn test_platform_status_all_fields() { + let status: SnpPlatformStatus = SnpPlatformStatus { + version: Version { major: 1, minor: 2 }, + build_id: 0xDEADBEEF, + mask_chip_id: 0x1, + state: 0xFF, + ..Default::default() + }; + assert_eq!(status.version.major, 1); + assert_eq!(status.version.minor, 2); + assert_eq!(status.build_id, 0xDEADBEEF); + assert_eq!(status.mask_chip_id, 0x1); + assert_eq!(status.state, 0xFF); + } + + #[test] + fn test_cert_type_deserialization() { + let cert_types = vec![ + CertType::Empty, + CertType::ARK, + CertType::ASK, + CertType::VCEK, + CertType::VLEK, + CertType::CRL, + CertType::OTHER(Uuid::parse_str("11111111-1111-1111-1111-111111111111").unwrap()), + ]; + + for cert_type in cert_types { + let serialized = bincode::serialize(&cert_type).unwrap(); + let deserialized: CertType = bincode::deserialize(&serialized).unwrap(); + assert_eq!(cert_type, deserialized); + } + } + + #[test] + fn test_cert_type_try_from_uuid() { + // Test all valid UUIDs + let test_cases = vec![ + ("00000000-0000-0000-0000-000000000000", CertType::Empty), + ("c0b406a4-a803-4952-9743-3fb6014cd0ae", CertType::ARK), + ("4ab7b379-bbac-4fe4-a02f-05aef327c782", CertType::ASK), + ("63da758d-e664-4564-adc5-f4b93be8accd", CertType::VCEK), + ("a8074bc2-a25a-483e-aae6-39c045a0b8a1", CertType::VLEK), + ("92f81bc3-5811-4d3d-97ff-d19f88dc67ea", CertType::CRL), + ( + "11111111-1111-1111-1111-111111111111", + CertType::OTHER(Uuid::parse_str("11111111-1111-1111-1111-111111111111").unwrap()), + ), + ]; + + for (uuid_str, expected_type) in test_cases { + let uuid = Uuid::parse_str(uuid_str).unwrap(); + assert_eq!(CertType::try_from(&uuid).unwrap(), expected_type); + } + } + + #[test] + fn test_cert_type_cmp_complete() { + let mut cert_types = vec![ + CertType::ARK, + CertType::VCEK, + CertType::VLEK, + CertType::ASK, + CertType::CRL, + CertType::Empty, + CertType::OTHER(Uuid::parse_str("11111111-1111-1111-1111-111111111111").unwrap()), + ]; + + let expected = vec![ + CertType::ARK, + CertType::VCEK, + CertType::VLEK, + CertType::ASK, + CertType::CRL, + CertType::OTHER(Uuid::parse_str("11111111-1111-1111-1111-111111111111").unwrap()), + CertType::Empty, + ]; + + cert_types.sort(); + assert_eq!(cert_types, expected); + } + + #[test] + fn test_cert_table_entry_deserialization() { + let entry = CertTableEntry::new(CertType::ARK, vec![1, 2, 3, 4]); + + let serialized = bincode::serialize(&entry).unwrap(); + let deserialized: CertTableEntry = bincode::deserialize(&serialized).unwrap(); + + assert_eq!(entry.cert_type, deserialized.cert_type); + assert_eq!(entry.data, deserialized.data); + } + + #[test] + fn test_cert_table_entry_cmp_complete() { + let entries = vec![ + CertTableEntry::new(CertType::ARK, vec![1]), + CertTableEntry::new(CertType::VCEK, vec![2]), + CertTableEntry::new(CertType::Empty, vec![4]), + CertTableEntry::new(CertType::ASK, vec![3]), + ]; + + let mut sorted = entries.clone(); + sorted.sort(); + + assert_eq!(sorted[0].cert_type, CertType::ARK); + assert_eq!(sorted[1].cert_type, CertType::VCEK); + assert_eq!(sorted[2].cert_type, CertType::ASK); + assert_eq!(sorted[3].cert_type, CertType::Empty); + } + + #[test] + fn test_build_deserialization() { + let build = Build { + version: Version { major: 1, minor: 2 }, + build: 42, + }; + + let serialized = bincode::serialize(&build).unwrap(); + let deserialized: Build = bincode::deserialize(&serialized).unwrap(); + + assert_eq!(build, deserialized); + } + + #[test] + fn test_tcb_version_deserialization() { + let tcb = TcbVersion::new(1, 2, 3, 4); + + let serialized = bincode::serialize(&tcb).unwrap(); + let deserialized: TcbVersion = bincode::deserialize(&serialized).unwrap(); + + assert_eq!(tcb, deserialized); + } + + #[test] + fn test_mask_id_deserialization() { + let test_cases = vec![ + MaskId(0), // No bits set + MaskId(0x1), // chip_id only + MaskId(0x2), // chip_key only + MaskId(0x3), // Both bits + MaskId(u32::MAX), // All bits + ]; + + for mask in test_cases { + let serialized = bincode::serialize(&mask).unwrap(); + let deserialized: MaskId = bincode::deserialize(&serialized).unwrap(); + + assert_eq!(mask.0, deserialized.0); + assert_eq!(mask.mask_chip_id(), deserialized.mask_chip_id()); + assert_eq!(mask.mask_chip_key(), deserialized.mask_chip_key()); + } + } + #[test] + fn test_cert_table_entry_complete_ordering() { + let entries = vec![ + CertTableEntry::new(CertType::ARK, vec![1, 2, 3]), + CertTableEntry::new(CertType::ARK, vec![9, 9, 9]), // Same type, different data + CertTableEntry::new(CertType::VCEK, vec![1]), + CertTableEntry::new(CertType::ASK, vec![2]), + CertTableEntry::new(CertType::CRL, vec![3]), + CertTableEntry::new(CertType::Empty, vec![]), + CertTableEntry::new( + CertType::OTHER(Uuid::parse_str("11111111-1111-1111-1111-111111111111").unwrap()), + vec![4], + ), + CertTableEntry::new( + CertType::OTHER(Uuid::parse_str("22222222-2222-2222-2222-222222222222").unwrap()), + vec![5], + ), + ]; + + // Test equality + assert_eq!(entries[0], entries[0]); + + // Test ordering + assert!(entries[0] < entries[2]); // ARK < VCEK + assert!(entries[2] < entries[3]); // VCEK < ASK + assert!(entries[3] < entries[4]); // ASK < CRL + assert!(entries[4] < entries[6]); // CRL < OTHER + assert!(entries[6] < entries[7]); // OTHER orders by UUID + assert!(entries[6] < entries[5]); // OTHER < Empty + + // Test transitivity + assert!(entries[0] < entries[3]); // ARK < ASK + assert!(entries[0] < entries[5]); // ARK < Empty + + // Verify reverse comparisons + assert!(entries[5] > entries[0]); // Empty > ARK + assert!(entries[4] > entries[3]); // CRL > ASK + } + + #[test] + fn test_cert_table_entry_sort_and_compare() { + let mut entries = vec![ + CertTableEntry::new(CertType::Empty, vec![]), + CertTableEntry::new(CertType::CRL, vec![1]), + CertTableEntry::new( + CertType::OTHER(Uuid::parse_str("33333333-3333-3333-3333-333333333333").unwrap()), + vec![2], + ), + CertTableEntry::new( + CertType::OTHER(Uuid::parse_str("11111111-1111-1111-1111-111111111111").unwrap()), + vec![3], + ), + CertTableEntry::new(CertType::ARK, vec![4]), + CertTableEntry::new(CertType::ASK, vec![5]), + CertTableEntry::new(CertType::VCEK, vec![6]), + CertTableEntry::new(CertType::VLEK, vec![7]), + ]; + + let expected = vec![ + CertTableEntry::new(CertType::ARK, vec![4]), + CertTableEntry::new(CertType::VCEK, vec![6]), + CertTableEntry::new(CertType::VLEK, vec![7]), + CertTableEntry::new(CertType::ASK, vec![5]), + CertTableEntry::new(CertType::CRL, vec![1]), + CertTableEntry::new( + CertType::OTHER(Uuid::parse_str("11111111-1111-1111-1111-111111111111").unwrap()), + vec![3], + ), + CertTableEntry::new( + CertType::OTHER(Uuid::parse_str("33333333-3333-3333-3333-333333333333").unwrap()), + vec![2], + ), + CertTableEntry::new(CertType::Empty, vec![]), + ]; + + entries.sort(); + assert_eq!(entries, expected); + + // Verify stability with duplicate types + let mut duplicates = [ + CertTableEntry::new(CertType::ARK, vec![1]), + CertTableEntry::new(CertType::ARK, vec![2]), + ]; + duplicates.sort(); + assert_eq!(duplicates[0].data(), &[1]); + assert_eq!(duplicates[1].data(), &[2]); + } + + #[test] + fn test_cert_table_entry_direct_cmp() { + let entry1 = CertTableEntry::new(CertType::ARK, vec![1]); + let entry2 = CertTableEntry::new(CertType::VCEK, vec![2]); + + // Direct call to cmp() method to ensure coverage + let ordering = entry1.cmp(&entry2); + assert!(matches!(ordering, std::cmp::Ordering::Less)); + + // Reverse comparison + let ordering = entry2.cmp(&entry1); + assert!(matches!(ordering, std::cmp::Ordering::Greater)); + + // Equal comparison + let ordering = entry1.cmp(&entry1); + assert!(matches!(ordering, std::cmp::Ordering::Equal)); + } + + #[test] + fn test_cert_table_entry_direct_cmp_vlek() { + let entry1 = CertTableEntry::new(CertType::ARK, vec![1]); + let entry2 = CertTableEntry::new(CertType::VLEK, vec![2]); + + // Direct call to cmp() method to ensure coverage + let ordering = entry1.cmp(&entry2); + assert!(matches!(ordering, std::cmp::Ordering::Less)); + + // Reverse comparison + let ordering = entry2.cmp(&entry1); + assert!(matches!(ordering, std::cmp::Ordering::Greater)); + + // Equal comparison + let ordering = entry1.cmp(&entry1); + assert!(matches!(ordering, std::cmp::Ordering::Equal)); + } + #[test] + fn test_cert_table_entry_deserialize() { + use bincode::{deserialize, serialize}; + + // Create a test entry + let original = CertTableEntry::new(CertType::ARK, vec![0x41, 0x42, 0x43]); + + // Serialize and then deserialize + let serialized = serialize(&original).expect("Failed to serialize"); + let deserialized: CertTableEntry = deserialize(&serialized).expect("Failed to deserialize"); + + // Verify deserialized data matches original + assert_eq!(deserialized.cert_type, original.cert_type); + assert_eq!(deserialized.data(), original.data()); + } + + #[test] + fn test_cert_type_to_uuid_conversion() { + use uuid::Uuid; + + // Test successful conversions + assert_eq!( + Uuid::try_from(CertType::ARK).unwrap(), + Uuid::parse_str("c0b406a4-a803-4952-9743-3fb6014cd0ae").unwrap() + ); + assert_eq!( + Uuid::try_from(CertType::ASK).unwrap(), + Uuid::parse_str("4ab7b379-bbac-4fe4-a02f-05aef327c782").unwrap() + ); + assert_eq!( + Uuid::try_from(CertType::VCEK).unwrap(), + Uuid::parse_str("63da758d-e664-4564-adc5-f4b93be8accd").unwrap() + ); + assert_eq!( + Uuid::try_from(CertType::VLEK).unwrap(), + Uuid::parse_str("a8074bc2-a25a-483e-aae6-39c045a0b8a1").unwrap() + ); + assert_eq!( + Uuid::try_from(CertType::Empty).unwrap(), + Uuid::parse_str("00000000-0000-0000-0000-000000000000").unwrap() + ); + assert_eq!( + Uuid::try_from(CertType::CRL).unwrap(), + Uuid::parse_str("92f81bc3-5811-4d3d-97ff-d19f88dc67ea").unwrap() + ); + assert_eq!( + Uuid::try_from(CertType::OTHER(uuid::Uuid::max())).unwrap(), + Uuid::parse_str("ffffffff-ffff-ffff-ffff-ffffffffffff").unwrap() + ); + } + + #[test] + fn test_build_deserialize() { + use bincode::{deserialize, serialize}; + + let original = Build { + version: 1.into(), + build: 2, + }; + + let serialized = serialize(&original).expect("Failed to serialize"); + let deserialized: Build = deserialize(&serialized).expect("Failed to deserialize"); + + assert_eq!(deserialized.version, original.version); + assert_eq!(deserialized.build, original.build); + } + + #[test] + fn test_chain_visitor_methods() { + use bincode::{deserialize, serialize}; + // Test sequence visiting + let chain_data = vec![ + CertTableEntry::new(CertType::ARK, vec![1]), + CertTableEntry::new(CertType::ASK, vec![2]), + ]; + let serialized = serialize(&chain_data).expect("Failed to serialize"); + let deserialized: Vec = + deserialize(&serialized).expect("Failed to deserialize"); + + assert_eq!(deserialized.len(), chain_data.len()); + assert_eq!(deserialized[0].cert_type, chain_data[0].cert_type); + } + + #[test] + fn test_field_visitor_methods() { + use bincode::{deserialize, serialize}; + + // Test various field types + let bytes = vec![1u8, 2u8, 3u8]; + let serialized = serialize(&bytes).expect("Failed to serialize"); + let deserialized: Vec = deserialize(&serialized).expect("Failed to deserialize"); + + assert_eq!(deserialized, bytes); + + // Test string field + let text = "test"; + let serialized = serialize(&text).expect("Failed to serialize"); + let deserialized: String = deserialize(&serialized).expect("Failed to deserialize"); + + assert_eq!(deserialized, text); + } } diff --git a/src/firmware/linux/guest/ioctl.rs b/src/firmware/linux/guest/ioctl.rs index b81cde0c..7480b567 100644 --- a/src/firmware/linux/guest/ioctl.rs +++ b/src/firmware/linux/guest/ioctl.rs @@ -61,3 +61,35 @@ impl<'a, 'b, Req, Rsp> GuestRequest<'a, 'b, Req, Rsp> { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_guest_request_new() { + let mut req = ReportReq::default(); + let mut rsp = ReportRsp::default(); + + // Test with explicit version + let guest_req = GuestRequest::new(Some(2), &mut req, &mut rsp); + assert_eq!(guest_req.message_version, 2); + assert_ne!(guest_req.request_data, 0); + assert_ne!(guest_req.response_data, 0); + assert_eq!(guest_req.fw_err, 0); + + // Test with default version + let guest_req = GuestRequest::new(None, &mut req, &mut rsp); + assert_eq!(guest_req.message_version, 1); + assert_ne!(guest_req.request_data, 0); + assert_ne!(guest_req.response_data, 0); + assert_eq!(guest_req.fw_err, 0); + } + + #[test] + fn test_guest_ioctl_values() { + assert_eq!(GuestIoctl::GetReport as u8, 0x0); + assert_eq!(GuestIoctl::GetDerivedKey as u8, 0x1); + assert_eq!(GuestIoctl::GetExtReport as u8, 0x2); + } +} diff --git a/src/firmware/linux/guest/types.rs b/src/firmware/linux/guest/types.rs index ddf3ff87..4c31cc11 100644 --- a/src/firmware/linux/guest/types.rs +++ b/src/firmware/linux/guest/types.rs @@ -9,6 +9,7 @@ use static_assertions::const_assert; const MAX_VMPL: u32 = 3; #[repr(C)] +#[derive(Debug, Default)] pub struct DerivedKeyReq { /// Selects the root key to derive the key from. /// 0: Indicates VCEK. @@ -256,4 +257,91 @@ mod test { assert_eq!(expected, actual); } } + + use super::*; + + #[test] + fn test_derived_key_req_conversion() { + // Create a mock DerivedKey + let derived_key = DerivedKey::new(false, GuestFieldSelect(0x1234), 2, 1, 100); + + // Test From + let req: DerivedKeyReq = derived_key.into(); + assert_eq!(req.root_key_select, 0); + assert_eq!(req.reserved_0, 0); + assert_eq!(req.guest_field_select, 0x1234); + assert_eq!(req.vmpl, 2); + assert_eq!(req.guest_svn, 1); + assert_eq!(req.tcb_version, 100); + + // Test From<&mut DerivedKey> + let mut derived_key = derived_key; + let req: DerivedKeyReq = (&mut derived_key).into(); + assert_eq!(req.root_key_select, 0); + assert_eq!(req.reserved_0, 0); + assert_eq!(req.guest_field_select, 0x1234); + assert_eq!(req.vmpl, 2); + assert_eq!(req.guest_svn, 1); + assert_eq!(req.tcb_version, 100); + } + + #[test] + fn test_ext_report_req() { + let report_req = ReportReq::default(); + let ext_report = ExtReportReq::new(&report_req); + + assert_eq!(ext_report.data, report_req); + assert_eq!(ext_report.certs_address, u64::MAX); + assert_eq!(ext_report.certs_len, 0); + + // Test Default + let default_ext = ExtReportReq::default(); + assert_eq!(default_ext.certs_address, 0); + assert_eq!(default_ext.certs_len, 0); + } + + #[test] + fn test_report_req() { + // Test default values + let default_req = ReportReq::default(); + assert_eq!(default_req.report_data, [0; 64]); + assert_eq!(default_req.vmpl, 1); + assert_eq!(default_req._reserved, [0; 28]); + + // Test successful creation with Some values + let report_data = [42u8; 64]; + let req = ReportReq::new(Some(report_data), Some(2)).unwrap(); + assert_eq!(req.report_data, report_data); + assert_eq!(req.vmpl, 2); + + // Test successful creation with None values + let req = ReportReq::new(None, None).unwrap(); + assert_eq!(req.report_data, [0; 64]); + assert_eq!(req.vmpl, 1); + + // Test VMPL validation + assert!(ReportReq::new(None, Some(4)).is_err()); + assert!(ReportReq::new(None, Some(MAX_VMPL)).is_ok()); + } + + #[test] + fn test_report_rsp() { + let rsp = ReportRsp::default(); + + assert_eq!(rsp.status, 0); + assert_eq!(rsp.report_size, 0); + assert_eq!(rsp.reserved_0, [0; 24]); + + // Verify size is exactly 4000 bytes + assert_eq!(std::mem::size_of::(), 4000); + } + + #[test] + fn test_derived_key_rsp() { + let rsp = DerivedKeyRsp::default(); + + assert_eq!(rsp.status, 0); + assert_eq!(rsp.reserved_0, [0; 28]); + assert_eq!(rsp.key, [0; 32]); + } } diff --git a/src/firmware/linux/host/ioctl.rs b/src/firmware/linux/host/ioctl.rs index 64843c0b..8a666ed8 100644 --- a/src/firmware/linux/host/ioctl.rs +++ b/src/firmware/linux/host/ioctl.rs @@ -167,3 +167,53 @@ impl<'a, T: Id> Command<'a, T> { FirmwareError::from(self.error) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_command_get_id() { + let mut id = [0u8; 64]; + let mut data = GetId::new(&mut id); + let cmd = Command::::from_mut(&mut data); + let code = cmd.code; + let error = cmd.error; + assert_eq!(code, GetId::ID); + assert_eq!(error, 0); + } + + #[test] + fn test_command_platform_id() { + let mut data = PlatformStatus::default(); + let cmd = Command::::from_mut(&mut data); + let code = cmd.code; + let error = cmd.error; + assert_eq!(code, PlatformStatus::ID); + assert_eq!(error, 0); + } + + #[test] + fn test_command_platform_id_non_mut() { + let data = PlatformStatus::default(); + let cmd = Command::::from(&data); + let code = cmd.code; + let error = cmd.error; + assert_eq!(code, PlatformStatus::ID); + assert_eq!(error, 0); + } + + #[test] + fn test_command_error_encapsulation() { + // Test with success (0) + let cmd = Command:: { + code: PlatformStatus::ID, + error: 0, + data: 0, + _phantom: PhantomData, + }; + + let error = cmd.encapsulate(); + assert!(matches!(error, FirmwareError::IoError(_))); + } +} diff --git a/src/firmware/linux/host/types/mod.rs b/src/firmware/linux/host/types/mod.rs index 1d9de740..6a59e5f6 100644 --- a/src/firmware/linux/host/types/mod.rs +++ b/src/firmware/linux/host/types/mod.rs @@ -51,3 +51,38 @@ impl<'a> GetId<'a> { #[cfg(feature = "sev")] #[cfg(target_os = "linux")] pub struct PlatformReset; + +#[cfg(target_os = "linux")] +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_get_id_new() { + let mut id = [0u8; 64]; + let get_id = GetId::new(&mut id); + + assert_eq!( + unsafe { std::ptr::addr_of!(get_id.id_len).read_unaligned() }, + 64 + ); + assert_eq!(get_id.id_addr as *const u8, id.as_ptr()); + } + + #[test] + fn test_get_id_slice() { + let mut id = [42u8; 64]; + let get_id = GetId::new(&mut id); + + assert_eq!(get_id.as_slice(), &[42u8; 64]); + } + + #[test] + fn test_get_id_phantom() { + let mut id = [0u8; 64]; + let get_id = GetId::new(&mut id); + + // Verify PhantomData is working as expected + assert_eq!(std::mem::size_of_val(&get_id._phantom), 0); + } +} diff --git a/src/firmware/linux/host/types/snp.rs b/src/firmware/linux/host/types/snp.rs index 9f81c41b..bcc7ad0f 100644 --- a/src/firmware/linux/host/types/snp.rs +++ b/src/firmware/linux/host/types/snp.rs @@ -316,6 +316,19 @@ impl<'a> std::convert::From<&WrappedVlekHashstick<'a>> for SnpVlekLoad { #[cfg(test)] mod test { + use crate::firmware::host::FFI::types::SnpSetConfig; + + #[test] + fn test_snp_set_config_default() { + let expected: SnpSetConfig = SnpSetConfig { + reported_tcb: Default::default(), + mask_id: Default::default(), + reserved: [0; 52], + }; + let actual: SnpSetConfig = Default::default(); + assert_eq!(expected, actual); + } + mod raw_data { use crate::firmware::linux::host::types::RawData; diff --git a/src/measurement/vmsa.rs b/src/measurement/vmsa.rs index 1aee43ff..5a5f05a7 100644 --- a/src/measurement/vmsa.rs +++ b/src/measurement/vmsa.rs @@ -171,7 +171,7 @@ bitfield! { ///SmtProtection pub smt_protection, _: 15,15; /// Reserved, SBZ - reserved_3, sbz: 16, 63; + reserved_3, sbz: 63, 16; } impl Default for GuestFeatures {