diff --git a/cryptoki/src/context/mod.rs b/cryptoki/src/context/mod.rs index 160b9cb..cfa965e 100644 --- a/cryptoki/src/context/mod.rs +++ b/cryptoki/src/context/mod.rs @@ -3,13 +3,18 @@ //! Pkcs11 context and initialization types /// Directly get the PKCS #11 operation from the context structure and check for null pointers. +/// Note that this macro depends on the get_pkcs11_func! macro. macro_rules! get_pkcs11 { ($pkcs11:expr, $func_name:ident) => { - ($pkcs11 - .impl_ - .function_list - .$func_name - .ok_or(crate::error::Error::NullFunctionPointer)?) + (get_pkcs11_func!($pkcs11, $func_name).ok_or(crate::error::Error::NullFunctionPointer)?) + }; +} + +/// Same as get_pkcs11! but does not attempt to apply '?' syntactic sugar. +/// Suitable only if the caller can't return a Result. +macro_rules! get_pkcs11_func { + ($pkcs11:expr, $func_name:ident) => { + ($pkcs11.impl_.function_list.$func_name) }; } diff --git a/cryptoki/src/session/mod.rs b/cryptoki/src/session/mod.rs index 77d9f79..7e3bbb8 100644 --- a/cryptoki/src/session/mod.rs +++ b/cryptoki/src/session/mod.rs @@ -19,6 +19,7 @@ mod session_management; mod signing_macing; mod slot_token_management; +pub use object_management::ObjectHandleIterator; pub use session_info::{SessionInfo, SessionState}; /// Type that identifies a session diff --git a/cryptoki/src/session/object_management.rs b/cryptoki/src/session/object_management.rs index 350fce6..9c56eca 100644 --- a/cryptoki/src/session/object_management.rs +++ b/cryptoki/src/session/object_management.rs @@ -3,24 +3,277 @@ //! Object management functions use crate::context::Function; -use crate::error::{Result, Rv, RvError}; +use crate::error::{Error, Result, Rv, RvError}; use crate::object::{Attribute, AttributeInfo, AttributeType, ObjectHandle}; use crate::session::Session; use cryptoki_sys::*; use std::collections::HashMap; use std::convert::TryInto; +use std::num::NonZeroUsize; // Search 10 elements at a time -const MAX_OBJECT_COUNT: usize = 10; +// Safety: the value provided (10) must be non-zero +const MAX_OBJECT_COUNT: NonZeroUsize = unsafe { NonZeroUsize::new_unchecked(10) }; + +/// Iterator over object handles, in an active session. +/// +/// Used to iterate over the object handles returned by underlying calls to `C_FindObjects`. +/// The iterator is created by calling the `iter_objects` and `iter_objects_with_cache_size` methods on a `Session` object. +/// +/// # Note +/// +/// The iterator `new()` method will call `C_FindObjectsInit`. It means that until the iterator is dropped, +/// creating another iterator will result in an error (typically `RvError::OperationActive` ). +/// +/// # Example +/// +/// ```no_run +/// use cryptoki::context::CInitializeArgs; +/// use cryptoki::context::Pkcs11; +/// use cryptoki::error::Error; +/// use cryptoki::object::Attribute; +/// use cryptoki::object::AttributeType; +/// use cryptoki::session::UserType; +/// use cryptoki::types::AuthPin; +/// use std::env; +/// +/// # fn main() -> testresult::TestResult { +/// # let pkcs11 = Pkcs11::new( +/// # env::var("PKCS11_SOFTHSM2_MODULE") +/// # .unwrap_or_else(|_| "/usr/local/lib/libsofthsm2.so".to_string()), +/// # )?; +/// # +/// # pkcs11.initialize(CInitializeArgs::OsThreads)?; +/// # let slot = pkcs11.get_slots_with_token()?.remove(0); +/// # +/// # let session = pkcs11.open_ro_session(slot).unwrap(); +/// # session.login(UserType::User, Some(&AuthPin::new("fedcba".into())))?; +/// +/// let token_object = vec![Attribute::Token(true)]; +/// let wanted_attr = vec![AttributeType::Label]; +/// +/// for (idx, obj) in session.iter_objects(&token_object)?.enumerate() { +/// let obj = obj?; // handle potential error condition +/// +/// let attributes = session.get_attributes(obj, &wanted_attr)?; +/// +/// match attributes.get(0) { +/// Some(Attribute::Label(l)) => { +/// println!( +/// "token object #{}: handle {}, label {}", +/// idx, +/// obj, +/// String::from_utf8(l.to_vec()) +/// .unwrap_or_else(|_| "*** not valid utf8 ***".to_string()) +/// ); +/// } +/// _ => { +/// println!("token object #{}: handle {}, label not found", idx, obj); +/// } +/// } +/// } +/// # Ok(()) +/// # } +/// +/// ``` +#[derive(Debug)] +pub struct ObjectHandleIterator<'a> { + session: &'a Session, + object_count: usize, + index: usize, + cache: Vec, +} + +impl<'a> ObjectHandleIterator<'a> { + /// Create a new iterator over object handles. + /// + /// # Arguments + /// + /// * `session` - The session to iterate over + /// * `template` - The template to match objects against + /// * `cache_size` - The number of objects to cache (type is [`NonZeroUsize`]) + /// + /// # Returns + /// + /// This function will return a [`Result`] that can be used to iterate over the objects + /// matching the template. The cache size corresponds to the size of the array provided to `C_FindObjects()`. + /// + /// # Errors + /// + /// This function will return an error if the call to `C_FindObjectsInit` fails. + /// + /// # Note + /// + /// The iterator `new()` method will call `C_FindObjectsInit`. It means that until the iterator is dropped, + /// creating another iterator will result in an error (typically `RvError::OperationActive` ). + /// + fn new( + session: &'a Session, + mut template: Vec, + cache_size: NonZeroUsize, + ) -> Result { + unsafe { + Rv::from(get_pkcs11!(session.client(), C_FindObjectsInit)( + session.handle(), + template.as_mut_ptr(), + template.len().try_into()?, + )) + .into_result(Function::FindObjectsInit)?; + } + + let cache: Vec = vec![0; cache_size.get()]; + Ok(ObjectHandleIterator { + session, + object_count: cache_size.get(), + index: cache_size.get(), + cache, + }) + } +} + +// In this implementation, we use object_count to keep track of the number of objects +// returned by the last C_FindObjects call; the index is used to keep track of +// the next object in the cache to be returned. The size of cache is never changed. +// In order to enter the loop for the first time, we set object_count to cache_size +// and index to cache_size. That allows to jump directly to the C_FindObjects call +// and start filling the cache. + +impl<'a> Iterator for ObjectHandleIterator<'a> { + type Item = Result; + + fn next(&mut self) -> Option { + // since the iterator is initialized with object_count and index both equal and > 0, + // we are guaranteed to enter the loop at least once + while self.object_count > 0 { + // if index unsafe { + f( + self.session.handle(), + self.cache.as_mut_ptr(), + self.cache.len() as CK_ULONG, + &mut self.object_count as *mut usize as CK_ULONG_PTR, + ) + }, + None => { + // C_FindObjects() is not implemented,, bark and return an error + log::error!("C_FindObjects() is not implemented on this library"); + return Some(Err(Error::NullFunctionPointer) as Result); + } + }; + + if let Rv::Error(error) = Rv::from(p11rv) { + return Some( + Err(Error::Pkcs11(error, Function::FindObjects)) as Result + ); + } + } + None + } +} + +impl Drop for ObjectHandleIterator<'_> { + fn drop(&mut self) { + if let Some(f) = get_pkcs11_func!(self.session.client(), C_FindObjectsFinal) { + // swallow the return value, as we can't do anything about it, + // but log the error + if let Rv::Error(error) = Rv::from(unsafe { f(self.session.handle()) }) { + log::error!("C_FindObjectsFinal() failed with error: {:?}", error); + } + } else { + // bark but pass if C_FindObjectsFinal() is not implemented + log::error!("C_FindObjectsFinal() is not implemented on this library"); + } + } +} impl Session { + /// Iterate over session objects matching a template. + /// + /// # Arguments + /// + /// * `template` - The template to match objects against + /// + /// # Returns + /// + /// This function will return a [`Result`] that can be used to iterate over the objects + /// matching the template. Note that the cache size is managed internally and set to a default value (10) + /// + /// # See also + /// + /// * [`ObjectHandleIterator`] for more information on how to use the iterator + /// * [`Session::iter_objects_with_cache_size`] for a way to specify the cache size + #[inline(always)] + pub fn iter_objects(&self, template: &[Attribute]) -> Result { + self.iter_objects_with_cache_size(template, MAX_OBJECT_COUNT) + } + + /// Iterate over session objects matching a template, with cache size + /// + /// # Arguments + /// + /// * `template` - The template to match objects against + /// * `cache_size` - The number of objects to cache (type is [`NonZeroUsize`]) + /// + /// # Returns + /// + /// This function will return a [`Result`] that can be used to iterate over the objects + /// matching the template. The cache size corresponds to the size of the array provided to `C_FindObjects()`. + /// + /// # See also + /// + /// * [`ObjectHandleIterator`] for more information on how to use the iterator + /// * [`Session::iter_objects`] for a simpler way to iterate over objects + pub fn iter_objects_with_cache_size( + &self, + template: &[Attribute], + cache_size: NonZeroUsize, + ) -> Result { + let template: Vec = template.iter().map(Into::into).collect(); + ObjectHandleIterator::new(self, template, cache_size) + } + /// Search for session objects matching a template /// /// # Arguments - /// * `template` - A [Attribute] of search parameters that will be used - /// to find objects. /// - /// # Examples + /// * `template` - A reference to [Attribute] of search parameters that will be used + /// to find objects. + /// + /// # Returns + /// + /// Upon success, a vector of [`ObjectHandle`] wrapped in a Result. + /// Upon failure, the first error encountered. + /// + /// # Note + /// + /// It is a convenience method that will call [`Session::iter_objects`] and collect the results. + /// + /// # See also + /// + /// * [`Session::iter_objects`] for a way to specify the cache size + + /// # Example /// /// ```rust /// # fn main() -> testresult::TestResult { @@ -50,54 +303,10 @@ impl Session { /// } /// # Ok(()) } /// ``` + /// + #[inline(always)] pub fn find_objects(&self, template: &[Attribute]) -> Result> { - let mut template: Vec = template.iter().map(|attr| attr.into()).collect(); - - unsafe { - Rv::from(get_pkcs11!(self.client(), C_FindObjectsInit)( - self.handle(), - template.as_mut_ptr(), - template.len().try_into()?, - )) - .into_result(Function::FindObjectsInit)?; - } - - let mut object_handles = [0; MAX_OBJECT_COUNT]; - let mut object_count = MAX_OBJECT_COUNT as CK_ULONG; // set to MAX_OBJECT_COUNT to enter loop - let mut objects = Vec::new(); - - // as long as the number of objects returned equals the maximum number - // of objects that can be returned, we keep calling C_FindObjects - while object_count == MAX_OBJECT_COUNT as CK_ULONG { - unsafe { - Rv::from(get_pkcs11!(self.client(), C_FindObjects)( - self.handle(), - object_handles.as_mut_ptr() as CK_OBJECT_HANDLE_PTR, - MAX_OBJECT_COUNT.try_into()?, - &mut object_count, - )) - .into_result(Function::FindObjects)?; - } - - // exit loop, no more objects to be returned, no need to extend the objects vector - if object_count == 0 { - break; - } - - // extend the objects vector with the new objects - objects.extend_from_slice(&object_handles[..object_count.try_into()?]); - } - - unsafe { - Rv::from(get_pkcs11!(self.client(), C_FindObjectsFinal)( - self.handle(), - )) - .into_result(Function::FindObjectsFinal)?; - } - - let objects = objects.into_iter().map(ObjectHandle::new).collect(); - - Ok(objects) + self.iter_objects(template)?.collect() } /// Create a new object diff --git a/cryptoki/tests/basic.rs b/cryptoki/tests/basic.rs index 02a83a6..8368a54 100644 --- a/cryptoki/tests/basic.rs +++ b/cryptoki/tests/basic.rs @@ -9,11 +9,14 @@ use cryptoki::error::{Error, RvError}; use cryptoki::mechanism::aead::GcmParams; use cryptoki::mechanism::rsa::{PkcsMgfType, PkcsOaepParams, PkcsOaepSource}; use cryptoki::mechanism::{Mechanism, MechanismType}; -use cryptoki::object::{Attribute, AttributeInfo, AttributeType, KeyType, ObjectClass}; +use cryptoki::object::{ + Attribute, AttributeInfo, AttributeType, KeyType, ObjectClass, ObjectHandle, +}; use cryptoki::session::{SessionState, UserType}; use cryptoki::types::AuthPin; use serial_test::serial; use std::collections::HashMap; +use std::num::NonZeroUsize; use std::thread; use cryptoki::mechanism::ekdf::AesCbcDeriveParams; @@ -311,15 +314,13 @@ fn get_token_info() -> TestResult { #[test] #[serial] -fn session_find_objects() { +fn session_find_objects() -> testresult::TestResult { let (pkcs11, slot) = init_pins(); // open a session - let session = pkcs11.open_rw_session(slot).unwrap(); + let session = pkcs11.open_rw_session(slot)?; // log in the session - session - .login(UserType::User, Some(&AuthPin::new(USER_PIN.into()))) - .unwrap(); + session.login(UserType::User, Some(&AuthPin::new(USER_PIN.into())))?; // we generate 11 keys with the same CKA_ID // we will check 3 different use cases, this will cover all cases for Session.find_objects @@ -349,19 +350,111 @@ fn session_find_objects() { Attribute::KeyType(KeyType::DES3), ]; - let mut found_keys = session.find_objects(&key_search_template).unwrap(); + let mut found_keys = session.find_objects(&key_search_template)?; assert_eq!(found_keys.len(), 11); // destroy one key - session.destroy_object(found_keys.pop().unwrap()).unwrap(); + session.destroy_object(found_keys.pop().unwrap())?; - let mut found_keys = session.find_objects(&key_search_template).unwrap(); + let mut found_keys = session.find_objects(&key_search_template)?; assert_eq!(found_keys.len(), 10); // destroy another key - session.destroy_object(found_keys.pop().unwrap()).unwrap(); - let found_keys = session.find_objects(&key_search_template).unwrap(); + session.destroy_object(found_keys.pop().unwrap())?; + let found_keys = session.find_objects(&key_search_template)?; assert_eq!(found_keys.len(), 9); + Ok(()) +} + +#[test] +#[serial] +fn session_objecthandle_iterator() -> testresult::TestResult { + let (pkcs11, slot) = init_pins(); + // open a session + let session = pkcs11.open_rw_session(slot)?; + + // log in the session + session.login(UserType::User, Some(&AuthPin::new(USER_PIN.into())))?; + + // we generate 11 keys with the same CKA_ID + + for i in 1..=11 { + let key_template = vec![ + Attribute::Token(true), + Attribute::Encrypt(true), + Attribute::Label(format!("key_{}", i).as_bytes().to_vec()), + Attribute::Id("12345678".as_bytes().to_vec()), // reusing the same CKA_ID + ]; + + // generate a secret key + session.generate_key(&Mechanism::Des3KeyGen, &key_template)?; + } + + // retrieve these keys using this template + let key_search_template = vec![ + Attribute::Token(true), + Attribute::Id("12345678".as_bytes().to_vec()), + Attribute::Class(ObjectClass::SECRET_KEY), + Attribute::KeyType(KeyType::DES3), + ]; + + // test iter_objects_with_cache_size() + // count keys with cache size of 20 + let found_keys = session + .iter_objects_with_cache_size(&key_search_template, NonZeroUsize::new(20).unwrap())?; + let found_keys = found_keys.map_while(|key| key.ok()).count(); + assert_eq!(found_keys, 11); + + // count keys with cache size of 1 + let found_keys = session + .iter_objects_with_cache_size(&key_search_template, NonZeroUsize::new(1).unwrap())?; + let found_keys = found_keys.map_while(|key| key.ok()).count(); + assert_eq!(found_keys, 11); + + // count keys with cache size of 10 + let found_keys = session + .iter_objects_with_cache_size(&key_search_template, NonZeroUsize::new(10).unwrap())?; + let found_keys = found_keys.map_while(|key| key.ok()).count(); + assert_eq!(found_keys, 11); + + // fetch keys into a vector + let found_keys: Vec = session + .iter_objects_with_cache_size(&key_search_template, NonZeroUsize::new(10).unwrap())? + .map_while(|key| key.ok()) + .collect(); + assert_eq!(found_keys.len(), 11); + + let key0 = found_keys[0]; + let key1 = found_keys[1]; + + session.destroy_object(key0).unwrap(); + let found_keys = session + .iter_objects_with_cache_size(&key_search_template, NonZeroUsize::new(10).unwrap())?; + let found_keys = found_keys.map_while(|key| key.ok()).count(); + assert_eq!(found_keys, 10); + + // destroy another key + session.destroy_object(key1).unwrap(); + let found_keys = session + .iter_objects_with_cache_size(&key_search_template, NonZeroUsize::new(10).unwrap())?; + let found_keys = found_keys.map_while(|key| key.ok()).count(); + assert_eq!(found_keys, 9); + + // test iter_objects() + let found_keys = session.iter_objects(&key_search_template)?; + let found_keys = found_keys.map_while(|key| key.ok()).count(); + assert_eq!(found_keys, 9); + + // test interleaved iterators - the second iterator should fail + let iter = session.iter_objects(&key_search_template); + let iter2 = session.iter_objects(&key_search_template); + + assert!(iter.is_ok()); + assert!(matches!( + iter2, + Err(Error::Pkcs11(RvError::OperationActive, _)) + )); + Ok(()) } #[test]