From d8473bb98d432b8bcbd11081b355dc8e32a7cb40 Mon Sep 17 00:00:00 2001 From: Mathieu <60658558+enitrat@users.noreply.github.com> Date: Wed, 15 Nov 2023 13:08:31 +0300 Subject: [PATCH] refactor: List (#209) ## Pull Request type Please check the type of change your PR introduces: - [ ] Bugfix - [x] Feature - [x] Refactor with API Changes Issue Number: N/A ## What is the new behavior? Refactored the List type to: - Return SyscallResults instead of panicking in case of storage interaction errors Added features to make it easier to interact with low-level storage: - ListTrait::new() to instanciate a new list at a given storage address - ListTrait::fetch() to fetch a list stored at a given storage address - Tested these features Added documentation on List functions Optimized some functions with inlining ## Does this introduce a breaking change? - [x] Yes - [ ] No ## Other information --- src/storage/src/list.cairo | 239 ++++++++++++++++++++------ src/storage/src/tests/list_test.cairo | 189 +++++++++++++++----- 2 files changed, 340 insertions(+), 88 deletions(-) diff --git a/src/storage/src/list.cairo b/src/storage/src/list.cairo index 4379c899..3876206f 100644 --- a/src/storage/src/list.cairo +++ b/src/storage/src/list.cairo @@ -17,119 +17,260 @@ struct List { } trait ListTrait { + /// Instantiates a new List with the given base address. + /// + /// + /// # Arguments + /// + /// * `address_domain` - The domain of the address. Only address_domain 0 is + /// currently supported, in the future it will enable access to address + /// spaces with different data availability + /// * `base` - The base address of the List. This corresponds to the + /// location in storage of the List's first element. + /// + /// # Returns + /// + /// A new List. + fn new(address_domain: u32, base: StorageBaseAddress) -> List; + + /// Fetches an existing List stored at the given base address. + /// Returns an error if the storage read fails. + /// + /// # Arguments + /// + /// * `address_domain` - The domain of the address. Only address_domain 0 is + /// currently supported, in the future it will enable access to address + /// spaces with different data availability + /// * `base` - The base address of the List. This corresponds to the + /// location in storage of the List's first element. + /// + /// # Returns + /// + /// An instance of the List fetched from storage, or an error in + /// `SyscallResult`. + fn fetch(address_domain: u32, base: StorageBaseAddress) -> SyscallResult>; + + /// Appends an existing Span to a List. Returns an error if the span + /// cannot be appended to the a list due to storage errors + /// + /// # Arguments + /// + /// * `self` - The List to add the span to. + /// * `span` - A Span to append to the List. + /// + /// # Returns + /// + /// A List constructed from the span or an error in `SyscallResult`. + fn append_span(ref self: List, span: Span) -> SyscallResult<()>; + + /// Gets the length of the List. + /// + /// # Returns + /// + /// The number of elements in the List. fn len(self: @List) -> u32; + + /// Checks if the List is empty. + /// + /// # Returns + /// + /// `true` if the List is empty, `false` otherwise. fn is_empty(self: @List) -> bool; - fn append(ref self: List, value: T) -> u32; - fn get(self: @List, index: u32) -> Option; - fn set(ref self: List, index: u32, value: T); + + /// Appends a value to the end of the List. Returns an error if the append + /// operation fails due to reasons such as storage issues. + /// + /// # Arguments + /// + /// * `value` - The value to append. + /// + /// # Returns + /// + /// The index at which the value was appended or an error in `SyscallResult`. + fn append(ref self: List, value: T) -> SyscallResult; + + /// Retrieves an element by index from the List. Returns an error if there + /// is a retrieval issue. + /// + /// # Arguments + /// + /// * `index` - The index of the element to retrieve. + /// + /// # Returns + /// + /// An `Option` which is `None` if the list is empty, or + /// `Some(value)` if an element was found, encapsulated + /// in `SyscallResult`. + fn get(self: @List, index: u32) -> SyscallResult>; + + /// Sets the value of an element at a given index. + /// + /// # Arguments + /// + /// * `index` - The index of the element to modify. + /// * `value` - The value to set at the given index. + /// + /// # Returns + /// + /// A result indicating success or encapsulating the error in `SyscallResult`. + /// + /// # Panics + /// + /// Panics if the index is out of bounds. + fn set(ref self: List, index: u32, value: T) -> SyscallResult<()>; + + /// Clears the List by setting its length to 0. + /// + /// The storage is not actually cleared, only the length is set to 0. + /// The values can still be accessible using low-level syscalls, but cannot + /// be accessed through the list interface. fn clean(ref self: List); - fn pop_front(ref self: List) -> Option; - fn array(self: @List) -> Array; - fn from_array(ref self: List, array: @Array); - fn from_span(ref self: List, span: Span); + + /// Removes and returns the first element of the List. + /// + /// The storage is not actually cleared, only the length is decreased by + /// one. + /// The value popped can still be accessible using low-level syscalls, but + /// cannot be accessed through the list interface. + /// # Returns + /// + /// An `Option` which is `None` if the index is out of bounds, or + /// `Some(value)` if an element was found at the given index, encapsulated + /// in `SyscallResult`. + fn pop_front(ref self: List) -> SyscallResult>; + + /// Converts the List into an Array. If the list cannot be converted + /// to an array due storage errors, an error is returned. + /// + /// # Returns + /// + /// An `Array` containing all the elements of the List, encapsulated + /// in `SyscallResult`. + fn array(self: @List) -> SyscallResult>; } impl ListImpl, +Drop, +Store> of ListTrait { + #[inline(always)] + fn new(address_domain: u32, base: StorageBaseAddress) -> List { + let storage_size: u8 = Store::::size(); + List { address_domain, base, len: 0, storage_size } + } + + #[inline(always)] + fn fetch(address_domain: u32, base: StorageBaseAddress) -> SyscallResult> { + ListStore::read(address_domain, base) + } + + fn append_span(ref self: List, mut span: Span) -> SyscallResult<()> { + let mut index = self.len; + self.len += span.len(); + + loop { + match span.pop_front() { + Option::Some(v) => { + let (base, offset) = calculate_base_and_offset_for_index( + self.base, index, self.storage_size + ); + match Store::write_at_offset(self.address_domain, base, offset, *v) { + Result::Ok(_) => {}, + Result::Err(e) => { break Result::Err(e); } + } + index += 1; + }, + Option::None => { break Store::write(self.address_domain, self.base, self.len); } + }; + } + } + + #[inline(always)] fn len(self: @List) -> u32 { *self.len } + #[inline(always)] fn is_empty(self: @List) -> bool { *self.len == 0 } - fn append(ref self: List, value: T) -> u32 { + fn append(ref self: List, value: T) -> SyscallResult { let (base, offset) = calculate_base_and_offset_for_index( self.base, self.len, self.storage_size ); - Store::write_at_offset(self.address_domain, base, offset, value).unwrap_syscall(); + Store::write_at_offset(self.address_domain, base, offset, value)?; let append_at = self.len; self.len += 1; - Store::write(self.address_domain, self.base, self.len); + Store::write(self.address_domain, self.base, self.len)?; - append_at + Result::Ok(append_at) } - fn get(self: @List, index: u32) -> Option { + fn get(self: @List, index: u32) -> SyscallResult> { if (index >= *self.len) { - return Option::None; + return Result::Ok(Option::None); } let (base, offset) = calculate_base_and_offset_for_index( *self.base, index, *self.storage_size ); - let t = Store::read_at_offset(*self.address_domain, base, offset).unwrap_syscall(); - Option::Some(t) + let t = Store::read_at_offset(*self.address_domain, base, offset)?; + Result::Ok(Option::Some(t)) } - fn set(ref self: List, index: u32, value: T) { + fn set(ref self: List, index: u32, value: T) -> SyscallResult<()> { assert(index < self.len, 'List index out of bounds'); let (base, offset) = calculate_base_and_offset_for_index( self.base, index, self.storage_size ); - Store::write_at_offset(self.address_domain, base, offset, value).unwrap_syscall(); + Store::write_at_offset(self.address_domain, base, offset, value) } + #[inline(always)] fn clean(ref self: List) { self.len = 0; Store::write(self.address_domain, self.base, self.len); } - fn pop_front(ref self: List) -> Option { + fn pop_front(ref self: List) -> SyscallResult> { if self.len == 0 { - return Option::None; + return Result::Ok(Option::None); } - let popped = self.get(self.len - 1); + let popped = self.get(self.len - 1)?; // not clearing the popped value to save a storage write, // only decrementing the len - makes it unaccessible through // the interfaces, next append will overwrite the values self.len -= 1; - Store::write(self.address_domain, self.base, self.len); + Store::write(self.address_domain, self.base, self.len)?; - popped + Result::Ok(popped) } - fn array(self: @List) -> Array { + fn array(self: @List) -> SyscallResult> { let mut array = array![]; let mut index = 0; - loop { + let result: SyscallResult<()> = loop { if index == *self.len { - break; + break Result::Ok(()); } - array.append(self.get(index).expect('List index out of bounds')); + let value = match self.get(index) { + Result::Ok(v) => v, + Result::Err(e) => { break Result::Err(e); } + }.expect('List index out of bounds'); + array.append(value); index += 1; }; - array - } - - fn from_array(ref self: List, array: @Array) { - self.from_span(array.span()); - } - fn from_span(ref self: List, mut span: Span) { - let mut index = 0; - self.len = span.len(); - loop { - match span.pop_front() { - Option::Some(v) => { - let (base, offset) = calculate_base_and_offset_for_index( - self.base, index, self.storage_size - ); - Store::write_at_offset(self.address_domain, base, offset, *v).unwrap_syscall(); - index += 1; - }, - Option::None => { break; } - }; - }; - Store::write(self.address_domain, self.base, self.len); + match result { + Result::Ok(_) => Result::Ok(array), + Result::Err(e) => Result::Err(e) + } } } impl AListIndexViewImpl, +Drop, +Store> of IndexView, u32, T> { fn index(self: @List, index: u32) -> T { - self.get(index).expect('List index out of bounds') + self.get(index).expect('read syscall failed').expect('List index out of bounds') } } diff --git a/src/storage/src/tests/list_test.cairo b/src/storage/src/tests/list_test.cairo index efc7595b..291a98b5 100644 --- a/src/storage/src/tests/list_test.cairo +++ b/src/storage/src/tests/list_test.cairo @@ -15,7 +15,7 @@ trait IAListHolder { fn do_clean(ref self: TContractState); fn do_pop_front(ref self: TContractState) -> (Option, Option); fn do_array(self: @TContractState) -> (Array, Array); - fn do_from_array( + fn do_append_span( ref self: TContractState, addrs_array: Array, numbers_array: Array ); } @@ -29,7 +29,7 @@ mod AListHolder { struct Storage { // to test a corelib type that has Store and // Into - addrs: List, + addresses: List, // to test a corelib compound struct numbers: List } @@ -37,73 +37,101 @@ mod AListHolder { #[external(v0)] impl Holder of super::IAListHolder { fn do_get_len(self: @ContractState) -> (u32, u32) { - (self.addrs.read().len(), self.numbers.read().len()) + (self.addresses.read().len(), self.numbers.read().len()) } fn do_is_empty(self: @ContractState) -> (bool, bool) { - (self.addrs.read().is_empty(), self.numbers.read().is_empty()) + (self.addresses.read().is_empty(), self.numbers.read().is_empty()) } fn do_append( ref self: ContractState, addrs_value: ContractAddress, numbers_value: u256 ) -> (u32, u32) { - let mut a = self.addrs.read(); + let mut a = self.addresses.read(); let mut n = self.numbers.read(); - (a.append(addrs_value), n.append(numbers_value)) + ( + a.append(addrs_value).expect('syscallresult error'), + n.append(numbers_value).expect('syscallresult error') + ) } fn do_get(self: @ContractState, index: u32) -> (Option, Option) { - (self.addrs.read().get(index), self.numbers.read().get(index)) + ( + self.addresses.read().get(index).expect('syscallresult error'), + self.numbers.read().get(index).expect('syscallresult error') + ) } fn do_get_index(self: @ContractState, index: u32) -> (ContractAddress, u256) { - (self.addrs.read()[index], self.numbers.read()[index]) + (self.addresses.read()[index], self.numbers.read()[index]) } fn do_set( ref self: ContractState, index: u32, addrs_value: ContractAddress, numbers_value: u256 ) { - let mut a = self.addrs.read(); + let mut a = self.addresses.read(); let mut n = self.numbers.read(); a.set(index, addrs_value); n.set(index, numbers_value); } fn do_clean(ref self: ContractState) { - let mut a = self.addrs.read(); + let mut a = self.addresses.read(); let mut n = self.numbers.read(); a.clean(); n.clean(); } fn do_pop_front(ref self: ContractState) -> (Option, Option) { - let mut a = self.addrs.read(); + let mut a = self.addresses.read(); let mut n = self.numbers.read(); - (a.pop_front(), n.pop_front()) + ( + a.pop_front().expect('syscallresult error'), + n.pop_front().expect('syscallresult error') + ) } fn do_array(self: @ContractState) -> (Array, Array) { - let mut a = self.addrs.read(); + let mut a = self.addresses.read(); let mut n = self.numbers.read(); - (a.array(), n.array()) + (a.array().expect('syscallresult error'), n.array().expect('syscallresult error')) } - fn do_from_array( + fn do_append_span( ref self: ContractState, addrs_array: Array, numbers_array: Array ) { - let mut a = self.addrs.read(); + let mut a = self.addresses.read(); let mut n = self.numbers.read(); - a.from_array(@addrs_array); - n.from_array(@numbers_array); + a.append_span(addrs_array.span()).expect('syscallresult error'); + n.append_span(numbers_array.span()).expect('syscallresult error'); } } } #[cfg(test)] mod tests { - use starknet::{ClassHash, ContractAddress, deploy_syscall, SyscallResultTrait}; + use AListHolder::{addressesContractMemberStateTrait, numbersContractMemberStateTrait}; + use alexandria_storage::list::{List, ListTrait}; + use debug::PrintTrait; + use starknet::{ + ClassHash, ContractAddress, deploy_syscall, SyscallResultTrait, + testing::set_contract_address, storage_address_from_base, storage_address_to_felt252, + storage_base_address_from_felt252, StorageBaseAddress + }; use super::{AListHolder, IAListHolderDispatcher, IAListHolderDispatcherTrait}; + impl StorageBaseAddressPartialEq of PartialEq { + fn eq(lhs: @StorageBaseAddress, rhs: @StorageBaseAddress) -> bool { + storage_address_to_felt252( + storage_address_from_base(*lhs) + ) == storage_address_to_felt252(storage_address_from_base(*rhs)) + } + + fn ne(lhs: @StorageBaseAddress, rhs: @StorageBaseAddress) -> bool { + !StorageBaseAddressPartialEq::eq(lhs, rhs) + } + } + fn deploy_mock() -> IAListHolderDispatcher { let class_hash: ClassHash = AListHolder::TEST_CLASS_HASH.try_into().unwrap(); let ctor_data: Array = Default::default(); @@ -122,6 +150,100 @@ mod tests { assert(contract.do_get_len() == (0, 0), 'do_get_len'); } + #[test] + #[available_gas(100000000)] + fn test_new_initializes_empty_list() { + let contract = deploy_mock(); + set_contract_address(contract.contract_address); + let mut contract_state = AListHolder::unsafe_new_contract_state(); + + let addresses_address = contract_state.addresses.address(); + let addresses_list = ListTrait::::new(0, addresses_address); + assert(addresses_list.address_domain == 0, 'Address domain should be 0'); + assert(addresses_list.len() == 0, 'Initial length should be 0'); + assert(addresses_list.base.into() == addresses_address, 'Base address mismatch'); + assert(addresses_list.storage_size == 1, 'Storage size should be 1'); + + let numbers_address = contract_state.numbers.address(); + let numbers_list = ListTrait::::new(0, numbers_address); + assert(numbers_list.address_domain == 0, 'Address domain should be 0'); + assert(numbers_list.len() == 0, 'Initial length should be 0'); + assert(numbers_list.base.into() == numbers_address, 'Base address mismatch'); + assert(numbers_list.storage_size == 2, 'Storage size should be 2'); + + // Check if both addresses and numbers lists are initialized to be empty + assert(contract.do_get_len() == (0, 0), 'Initial lengths should be 0'); + assert(contract.do_is_empty() == (true, true), 'Lists should be empty'); + } + + #[test] + #[available_gas(100000000)] + fn test_new_then_fill_list() { + let contract = deploy_mock(); + set_contract_address(contract.contract_address); + let mut contract_state = AListHolder::unsafe_new_contract_state(); + + let addresses_address = contract_state.addresses.address(); + let mut addresses_list = ListTrait::::new(0, addresses_address); + + let numbers_address = contract_state.numbers.address(); + let mut numbers_list = ListTrait::::new(0, numbers_address); + + addresses_list.append(mock_addr()); + numbers_list.append(1); + numbers_list.append(2); + + assert(addresses_list.len() == 1, 'Addresses length should be 1'); + assert(numbers_list.len() == 2, 'Numbers length should be 2'); + + assert(contract.do_get_len() == (1, 2), 'Lengths should be (1,2)'); + assert(contract.do_is_empty() == (false, false), 'Lists should not be empty'); + } + + #[test] + #[available_gas(100000000)] + fn test_fetch_empty_list() { + let contract = deploy_mock(); + set_contract_address(contract.contract_address); + let mut contract_state = AListHolder::unsafe_new_contract_state(); + let storage_address = storage_base_address_from_felt252('empty_address'); + let contract = deploy_mock(); + + let empty_list = ListTrait::::fetch(0, storage_address).expect('List fetch failed'); + + assert(empty_list.address_domain == 0, 'Address domain should be 0'); + assert(empty_list.len() == 0, 'Length should be 0'); + assert(empty_list.base.into() == storage_address, 'Base address mismatch'); + assert(empty_list.storage_size == 1, 'Storage size should be 1'); + } + + + #[test] + #[available_gas(100000000)] + fn test_fetch_existing_list() { + let contract = deploy_mock(); + set_contract_address(contract.contract_address); + let mut contract_state = AListHolder::unsafe_new_contract_state(); + let mock_addr = mock_addr(); + + assert(contract.do_append(mock_addr, 10) == (0, 0), '1st append idx'); + assert(contract.do_append(mock_addr, 20) == (1, 1), '2nd append idx'); + + let addresses_address = contract_state.addresses.address(); + let addresses_list = ListTrait::::fetch(0, addresses_address) + .expect('List fetch failed'); + assert(addresses_list.address_domain == 0, 'Address domain should be 0'); + assert(addresses_list.len() == 2, 'Length should be 2'); + assert(addresses_list.base.into() == addresses_address, 'Base address mismatch'); + assert(addresses_list.storage_size == 1, 'Storage size should be 1'); + + let numbers_address = contract_state.numbers.address(); + let numbers_list = ListTrait::::fetch(0, numbers_address).expect('List fetch failed'); + assert(numbers_list.address_domain == 0, 'Address domain should be 0'); + assert(numbers_list.len() == 2, 'Length should be 2'); + assert(numbers_list.base.into() == numbers_address, 'Base address mismatch'); + } + #[test] #[available_gas(100000000)] fn test_is_empty() { @@ -423,31 +545,16 @@ mod tests { #[test] #[available_gas(100000000)] - fn test_from_array() { - let contract = deploy_mock(); - let mock_addr = mock_addr(); - - let addrs_array = array![mock_addr, mock_addr, mock_addr]; - let numbers_array = array![200, 300, 100]; - contract.do_from_array(addrs_array, numbers_array); - assert(contract.do_get_len() == (3, 3), 'len should be 3'); - assert(contract.do_get_index(0) == (mock_addr, 200), 'idx 0'); - assert(contract.do_get_index(1) == (mock_addr, 300), 'idx 1'); - assert(contract.do_get_index(2) == (mock_addr, 100), 'idx 2'); - } - - #[test] - #[available_gas(100000000)] - fn test_from_array_empty() { + fn test_append_array_empty() { let contract = deploy_mock(); - contract.do_from_array(array![], array![]); + contract.do_append_span(array![], array![]); assert(contract.do_is_empty() == (true, true), 'should be empty'); } #[test] #[available_gas(100000000)] - fn test_from_array_remove_elements() { + fn test_append_span_existing_list() { let contract = deploy_mock(); let mock_addr = mock_addr(); @@ -457,8 +564,12 @@ mod tests { assert(contract.do_get_index(0) == (mock_addr, 10), 'idx 0'); assert(contract.do_get_index(1) == (mock_addr, 20), 'idx 1'); - contract.do_from_array(array![], array![]); + contract.do_append_span(array![mock_addr], array![30]); let (a, b) = contract.do_get_len(); - assert(contract.do_is_empty() == (true, true), 'should be empty'); + assert((a, b) == (3, 3), 'len'); + + assert(contract.do_get_index(0) == (mock_addr, 10), 'idx 0'); + assert(contract.do_get_index(1) == (mock_addr, 20), 'idx 1'); + assert(contract.do_get_index(2) == (mock_addr, 30), 'idx 2'); } }