From 76d48b767bfe57bf0975694ab4471a968ad09536 Mon Sep 17 00:00:00 2001 From: Lior Goldberg Date: Sun, 15 Aug 2021 21:37:37 +0300 Subject: [PATCH] Cairo v0.3.1. --- README.md | 4 +- src/demo/amm_demo/demo.py | 6 +- src/starkware/cairo/common/CMakeLists.txt | 2 + src/starkware/cairo/common/keccak.cairo | 33 +++ src/starkware/cairo/common/pow.cairo | 51 ++++ src/starkware/cairo/common/uint256.cairo | 248 +++++++++++++++++- src/starkware/cairo/lang/VERSION | 2 +- .../cairo/lang/ide/vscode-cairo/package.json | 2 +- src/starkware/cairo/lang/vm/cairo_runner.py | 15 +- .../cairo/lang/vm/cairo_runner_test.py | 55 ++++ .../cairo/lang/vm/memory_segments.py | 23 +- .../cairo/lang/vm/memory_segments_test.py | 17 ++ src/starkware/cairo/lang/vm/vm.py | 9 +- src/starkware/cairo/lang/vm/vm_test.py | 4 + src/starkware/cairo/sharp/config.json | 2 +- src/starkware/cairo/sharp/sharp_client.py | 2 +- src/starkware/starknet/cli/starknet_cli.py | 18 +- .../starknet/compiler/calldata_parser.py | 2 +- .../starknet/compiler/calldata_parser_test.py | 2 +- .../starknet/security/secure_hints.py | 3 + .../starknet/security/starknet_common.cairo | 5 + .../starknet/security/whitelists/latest.json | 235 +++++++++++++++-- 22 files changed, 692 insertions(+), 48 deletions(-) create mode 100644 src/starkware/cairo/common/keccak.cairo create mode 100644 src/starkware/cairo/common/pow.cairo diff --git a/README.md b/README.md index 59a03cc0..6cc354f6 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ We recommend starting from [Setting up the environment](https://cairo-lang.org/d # Installation instructions You should be able to download the python package zip file directly from -[github](https://github.com/starkware-libs/cairo-lang/releases/tag/v0.3.0) +[github](https://github.com/starkware-libs/cairo-lang/releases/tag/v0.3.1) and install it using ``pip``. See [Setting up the environment](https://cairo-lang.org/docs/quickstart.html). @@ -54,7 +54,7 @@ Once the docker image is built, you can fetch the python package zip file using: ```bash > container_id=$(docker create cairo) -> docker cp ${container_id}:/app/cairo-lang-0.3.0.zip . +> docker cp ${container_id}:/app/cairo-lang-0.3.1.zip . > docker rm -v ${container_id} ``` diff --git a/src/demo/amm_demo/demo.py b/src/demo/amm_demo/demo.py index d273bdcb..ea41ba82 100644 --- a/src/demo/amm_demo/demo.py +++ b/src/demo/amm_demo/demo.py @@ -81,7 +81,7 @@ def deploy_contract(batch_prover: BatchProver, w3: Web3, operator: eth.Account) input( f'AMM demo smart contract successfully deployed to address {contract_address}. ' 'You can track the contract state through this link ' - f'https://ropsten.etherscan.io/address/{contract_address} .' + f'https://goerli.etherscan.io/address/{contract_address} .' 'Press enter to continue.') return w3.eth.contract(abi=abi, address=contract_address) @@ -102,7 +102,7 @@ def main(): # Connect to an Ethereum node. node_rpc_url = input( - 'Please provide an RPC URL to communicate with an Ethereum node on Ropsten: ') + 'Please provide an RPC URL to communicate with an Ethereum node on Goerli: ') w3 = Web3(HTTPProvider(node_rpc_url)) if not w3.isConnected(): print('Error: could not connect to the Ethereum node.') @@ -123,7 +123,7 @@ def main(): # Ask for funds to be transferred to the operator account id its balance is too low. if w3.eth.getBalance(operator.address) < MIN_OPERATOR_BALANCE: input( - f'Please send funds (at least {MIN_OPERATOR_BALANCE * 10**-18} Ropsten ETH) ' + f'Please send funds (at least {MIN_OPERATOR_BALANCE * 10**-18} Goerli ETH) ' f'to {operator.address} and press enter.') while w3.eth.getBalance(operator.address) < MIN_OPERATOR_BALANCE: print('Funds not received yet...') diff --git a/src/starkware/cairo/common/CMakeLists.txt b/src/starkware/cairo/common/CMakeLists.txt index 3fcf6ac0..b4e35d0d 100644 --- a/src/starkware/cairo/common/CMakeLists.txt +++ b/src/starkware/cairo/common/CMakeLists.txt @@ -14,12 +14,14 @@ python_lib(cairo_common_lib hash_state.cairo hash.cairo invoke.cairo + keccak.cairo math_cmp.cairo math_utils.py math.cairo memcpy.cairo merkle_multi_update.cairo merkle_update.cairo + pow.cairo registers.cairo segments.cairo serialize.cairo diff --git a/src/starkware/cairo/common/keccak.cairo b/src/starkware/cairo/common/keccak.cairo new file mode 100644 index 00000000..da1e2f92 --- /dev/null +++ b/src/starkware/cairo/common/keccak.cairo @@ -0,0 +1,33 @@ +# Computes the keccak hash. +# This function is unsafe (not sound): there is no validity enforcement that the result is indeed +# keccak, but an honest prover will compute the keccak. +# Args: +# data - an array of words representing the input data. Each word in the array is 16 bytes of the +# input data, except the last word, which may be less. +# length - the number of bytes in the input. +func unsafe_keccak(data : felt*, length : felt) -> (low, high): + alloc_locals + local low + local high + %{ + from eth_hash.auto import keccak + data, length = ids.data, ids.length + + if '__keccak_max_size' in globals(): + assert length <= __keccak_max_size, \ + f'unsafe_keccak() can only be used with length<={__keccak_max_size}. ' \ + f'Got: length={length}.' + + keccak_input = bytearray() + for word_i, byte_i in enumerate(range(0, length, 16)): + word = memory[data + word_i] + n_bytes = min(16, length - byte_i) + assert 0 <= word < 2 ** (8 * n_bytes) + keccak_input += word.to_bytes(n_bytes, 'big') + + hashed = keccak(keccak_input) + ids.high = int.from_bytes(hashed[:16], 'big') + ids.low = int.from_bytes(hashed[16:32], 'big') + %} + return (low=low, high=high) +end diff --git a/src/starkware/cairo/common/pow.cairo b/src/starkware/cairo/common/pow.cairo new file mode 100644 index 00000000..dc1d8359 --- /dev/null +++ b/src/starkware/cairo/common/pow.cairo @@ -0,0 +1,51 @@ +from starkware.cairo.common.math import assert_le +from starkware.cairo.common.registers import get_ap, get_fp_and_pc + +# Returns base ** exp, for 0 <= exp < 2**251. +func pow{range_check_ptr}(base, exp) -> (res): + struct LoopLocals: + member bit : felt + member temp0 : felt + + member res : felt + member base : felt + member exp : felt + end + + if exp == 0: + return (1) + end + + let initial_locs : LoopLocals* = cast(fp - 2, LoopLocals*) + initial_locs.res = 1; ap++ + initial_locs.base = base; ap++ + initial_locs.exp = exp; ap++ + + loop: + let prev_locs : LoopLocals* = cast(ap - LoopLocals.SIZE, LoopLocals*) + let locs : LoopLocals* = cast(ap, LoopLocals*) + locs.base = prev_locs.base * prev_locs.base; ap++ + %{ ids.locs.bit = (ids.prev_locs.exp % PRIME) & 1 %} + jmp odd if locs.bit != 0; ap++ + + even: + locs.exp = prev_locs.exp / 2; ap++ + locs.res = prev_locs.res; ap++ + # exp cannot be 0 here. + static_assert ap + 1 == locs + LoopLocals.SIZE + jmp loop; ap++ + + odd: + locs.temp0 = prev_locs.exp - 1 + locs.exp = locs.temp0 / 2; ap++ + locs.res = prev_locs.res * prev_locs.base; ap++ + static_assert ap + 1 == locs + LoopLocals.SIZE + jmp loop if locs.exp != 0; ap++ + + # Cap the number of steps. + let (__ap__) = get_ap() + let (__fp__, _) = get_fp_and_pc() + let n_steps = (__ap__ - cast(initial_locs, felt)) / LoopLocals.SIZE - 1 + assert_le(n_steps, 251) + return (res=locs.res) +end diff --git a/src/starkware/cairo/common/uint256.cairo b/src/starkware/cairo/common/uint256.cairo index c342dc93..bddb51b0 100644 --- a/src/starkware/cairo/common/uint256.cairo +++ b/src/starkware/cairo/common/uint256.cairo @@ -1,5 +1,7 @@ -from starkware.cairo.common.math import assert_nn_le, assert_not_zero +from starkware.cairo.common.math import assert_le, assert_nn_le, assert_not_zero from starkware.cairo.common.math_cmp import is_le +from starkware.cairo.common.pow import pow +from starkware.cairo.common.registers import get_ap, get_fp_and_pc # Represents an integer in the range [0, 2^256). struct Uint256: @@ -10,6 +12,7 @@ struct Uint256: end const SHIFT = 2 ** 128 +const ALL_ONES = 2 ** 128 - 1 const HALF_SHIFT = 2 ** 64 # Verifies that the given integer is valid. @@ -81,3 +84,246 @@ func uint256_mul{range_check_ptr}(a : Uint256, b : Uint256) -> (low : Uint256, h low=Uint256(low=res0 + HALF_SHIFT * res1, high=res2 + HALF_SHIFT * res3), high=Uint256(low=res4 + HALF_SHIFT * res5, high=res6 + HALF_SHIFT * carry)) end + +# Returns 1 if the first unsigned integer is less than the second unsigned integer. +func uint256_lt{range_check_ptr}(a : Uint256, b : Uint256) -> (res): + if a.high == b.high: + return is_le(a.low + 1, b.low) + end + return is_le(a.high + 1, b.high) +end + +# Returns 1 if the first signed integer is less than the second signed integer. +func uint256_signed_lt{range_check_ptr}(a : Uint256, b : Uint256) -> (res): + let (a, _) = uint256_add(a, cast((low=0, high=2 ** 127), Uint256)) + let (b, _) = uint256_add(b, cast((low=0, high=2 ** 127), Uint256)) + return uint256_lt(a, b) +end + +# Unsigned integer division between two integers. Returns the quotient and the remainder. +# Conforms to EVM specifications: division by 0 yields 0. +func uint256_unsigned_div_rem{range_check_ptr}(a : Uint256, div : Uint256) -> ( + quotient : Uint256, remainder : Uint256): + alloc_locals + local quotient : Uint256 + local remainder : Uint256 + + # If div == 0, return (0, 0). + if div.low + div.high == 0: + return (quotient=Uint256(0, 0), remainder=Uint256(0, 0)) + end + + %{ + a = (ids.a.high << 128) + ids.a.low + div = (ids.div.high << 128) + ids.div.low + quotient, remainder = divmod(a, div) + + ids.quotient.low = quotient & ((1 << 128) - 1) + ids.quotient.high = quotient >> 128 + ids.remainder.low = remainder & ((1 << 128) - 1) + ids.remainder.high = remainder >> 128 + %} + let (res_mul, carry) = uint256_mul(quotient, div) + assert carry = Uint256(0, 0) + + let (check_val, add_carry) = uint256_add(res_mul, remainder) + assert check_val = a + assert add_carry = 0 + + let (is_valid) = uint256_lt(remainder, div) + assert is_valid = 1 + return (quotient=quotient, remainder=remainder) +end + +# Returns the bitwise NOT of an integer. +func uint256_not{range_check_ptr}(a : Uint256) -> (res : Uint256): + return (Uint256(low=ALL_ONES - a.low, high=ALL_ONES - a.high)) +end + +# Returns the negation of an integer. +# Note that the negation of -2**255 is -2**255. +func uint256_neg{range_check_ptr}(a : Uint256) -> (res : Uint256): + let (not_num) = uint256_not(a) + let (res, _) = uint256_add(not_num, Uint256(low=1, high=0)) + return (res) +end + +# Conditionally negates an integer. +func uint256_cond_neg{range_check_ptr}(a : Uint256, should_neg) -> (res : Uint256): + if should_neg != 0: + return uint256_neg(a) + else: + return (res=a) + end +end + +# Signed integer division between two integers. Returns the quotient and the remainder. +# Conforms to EVM specifications. +# See ethereum yellow paper (https://ethereum.github.io/yellowpaper/paper.pdf, page 29). +# Note that the remainder may be negative if one of the inputs is negative and that +# (-2**255) / (-1) = -2**255 because 2*255 is out of range. +func uint256_signed_div_rem{range_check_ptr}(a : Uint256, div : Uint256) -> ( + quot : Uint256, rem : Uint256): + alloc_locals + + # When div=-1, simply return -a. + if div.low == SHIFT - 1: + if div.high == SHIFT - 1: + let (quot) = uint256_neg(a) + return (quot, cast((0, 0), Uint256)) + end + end + + # Take the absolute value of a. + let (local a_sign) = is_le(2 ** 127, a.high) + local range_check_ptr = range_check_ptr + let (local a) = uint256_cond_neg(a, should_neg=a_sign) + + # Take the absolute value of div. + let (local div_sign) = is_le(2 ** 127, div.high) + local range_check_ptr = range_check_ptr + let (div) = uint256_cond_neg(div, should_neg=div_sign) + + # Unsigned division. + let (local quot, local rem) = uint256_unsigned_div_rem(a, div) + local range_check_ptr = range_check_ptr + + # Fix the remainder according to the sign of a. + let (rem) = uint256_cond_neg(rem, should_neg=a_sign) + + # Fix the quotient according to the signs of a and div. + if a_sign == div_sign: + return (quot=quot, rem=rem) + end + let (local quot_neg) = uint256_neg(quot) + + return (quot=quot_neg, rem=rem) +end + +# Subtracts two integers. Returns the result as a 256-bit integer. +func uint256_sub{range_check_ptr}(a : Uint256, b : Uint256) -> (res : Uint256): + let (b_neg) = uint256_neg(b) + let (res, _) = uint256_add(a, b_neg) + return (res) +end + +# Bitwise. + +# Computes the bitwise XOR of 2 n-bit words. +# This is an inefficient implementation, and will be replaced with a builtin in the future. +func felt_xor{range_check_ptr}(a, b, n) -> (res : felt): + alloc_locals + local a_lsb + local b_lsb + + if n == 0: + assert a = 0 + assert b = 0 + return (0) + end + + %{ + ids.a_lsb = ids.a & 1 + ids.b_lsb = ids.b & 1 + %} + assert a_lsb * a_lsb = a_lsb + assert b_lsb * b_lsb = b_lsb + + local res_bit = a_lsb + b_lsb - 2 * a_lsb * b_lsb + + let (res) = felt_xor((a - a_lsb) / 2, (b - b_lsb) / 2, n - 1) + return (res=res * 2 + res_bit) +end + +# Return true if both integers are equal. +func uint256_eq{range_check_ptr}(a : Uint256, b : Uint256) -> (res): + if a.high != b.high: + return (0) + end + if a.low != b.low: + return (0) + end + return (1) +end + +# Computes the bitwise AND of 2 n-bit words. +# This is an inefficient implementation, and will be replaced with a builtin in the future. +func felt_and{range_check_ptr}(a, b, n) -> (res : felt): + alloc_locals + local a_lsb + local b_lsb + + if n == 0: + assert a = 0 + assert b = 0 + return (res=0) + end + + %{ + ids.a_lsb = ids.a & 1 + ids.b_lsb = ids.b & 1 + %} + assert a_lsb * a_lsb = a_lsb + assert b_lsb * b_lsb = b_lsb + + local res_bit = a_lsb * b_lsb + + let (res) = felt_and((a - a_lsb) / 2, (b - b_lsb) / 2, n - 1) + return (res=res * 2 + res_bit) +end + +# Computes the bitwise XOR of 2 uint256 integers. +func uint256_xor{range_check_ptr}(a : Uint256, b : Uint256) -> (res : Uint256): + alloc_locals + let (local low) = felt_xor(a.low, b.low, 128) + let (high) = felt_xor(a.high, b.high, 128) + return (Uint256(low, high)) +end + +# Computes the bitwise AND of 2 uint256 integers. +func uint256_and{range_check_ptr}(a : Uint256, b : Uint256) -> (res : Uint256): + alloc_locals + let (local low) = felt_and(a.low, b.low, 128) + let (high) = felt_and(a.high, b.high, 128) + return (Uint256(low, high)) +end + +# Computes the bitwise OR of 2 uint256 integers. +func uint256_or{range_check_ptr}(a : Uint256, b : Uint256) -> (res : Uint256): + let (a) = uint256_not(a) + let (b) = uint256_not(b) + let (res) = uint256_and(a, b) + return uint256_not(res) +end + +# Computes 2**exp % 2**256 as a uint256 integer. +func uint256_pow2{range_check_ptr}(exp : Uint256) -> (res : Uint256): + # If exp >= 256, the result will be zero modulo 2**256. + let (res) = uint256_lt(exp, Uint256(256, 0)) + if res == 0: + return (Uint256(0, 0)) + end + + let (res) = is_le(exp.low, 127) + if res != 0: + let (x) = pow(2, exp.low) + return (Uint256(x, 0)) + else: + let (x) = pow(2, exp.low - 128) + return (Uint256(0, x)) + end +end + +# Computes the logical left shift of a uint256 integer. +func uint256_shl{range_check_ptr}(a : Uint256, b : Uint256) -> (res : Uint256): + let (c) = uint256_pow2(b) + let (res, _) = uint256_mul(a, c) + return (res) +end + +# Computes the logical right shift of a uint256 integer. +func uint256_shr{range_check_ptr}(a : Uint256, b : Uint256) -> (res : Uint256): + let (c) = uint256_pow2(b) + let (res, _) = uint256_unsigned_div_rem(a, c) + return (res) +end diff --git a/src/starkware/cairo/lang/VERSION b/src/starkware/cairo/lang/VERSION index 0d91a54c..9e11b32f 100644 --- a/src/starkware/cairo/lang/VERSION +++ b/src/starkware/cairo/lang/VERSION @@ -1 +1 @@ -0.3.0 +0.3.1 diff --git a/src/starkware/cairo/lang/ide/vscode-cairo/package.json b/src/starkware/cairo/lang/ide/vscode-cairo/package.json index 4244a1cc..a2bebabf 100644 --- a/src/starkware/cairo/lang/ide/vscode-cairo/package.json +++ b/src/starkware/cairo/lang/ide/vscode-cairo/package.json @@ -2,7 +2,7 @@ "name": "cairo", "displayName": "Cairo", "description": "Support Cairo syntax", - "version": "0.3.0", + "version": "0.3.1", "engines": { "vscode": "^1.30.0" }, diff --git a/src/starkware/cairo/lang/vm/cairo_runner.py b/src/starkware/cairo/lang/vm/cairo_runner.py index 69f9d705..6c811fd2 100644 --- a/src/starkware/cairo/lang/vm/cairo_runner.py +++ b/src/starkware/cairo/lang/vm/cairo_runner.py @@ -1,5 +1,5 @@ import functools -from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union from starkware.cairo.lang.builtins.bitwise.bitwise_builtin_runner import BitwiseBuiltinRunner from starkware.cairo.lang.builtins.hash.hash_builtin_runner import HashBuiltinRunner @@ -106,6 +106,9 @@ def __init__( # Flags used to ensure a safe use. self._run_ended: bool = False self._segments_finalized: bool = False + # A set of memory addresses accessed by the VM, after relocation of temporary segments into + # real ones. + self.accessed_addresses: Optional[Set[RelocatableValue]] = None @classmethod def from_file( @@ -268,6 +271,8 @@ def run_until_next_power_of_2(self): def end_run(self, disable_trace_padding: bool = True, disable_finalize_all: bool = False): assert not self._run_ended, 'end_run called twice.' + self.accessed_addresses = { + self.vm_memory.relocate_value(addr) for addr in self.vm.accessed_addresses} self.vm_memory.relocate_memory() self.vm.end_run() @@ -380,6 +385,10 @@ def check_range_check_usage(self): f'There are only {unused_rc_units} cells to fill the range checks holes, but ' f'potentially {rc_usage_upper_bound} are required.') + def get_memory_holes(self): + assert self.accessed_addresses is not None + return self.segments.get_memory_holes(accessed_addresses=self.accessed_addresses) + def check_memory_usage(self): """ Checks that there are enough trace cells to fill the entire memory range. @@ -395,7 +404,7 @@ def check_memory_usage(self): instruction_memory_units = 4 * self.vm.current_step unused_memory_units = total_memory_units - \ (public_memory_units + instruction_memory_units + builtins_memory_units) - memory_address_holes = self.segments.get_memory_holes() + memory_address_holes = self.get_memory_holes() if unused_memory_units < memory_address_holes: raise InsufficientAllocatedCells( f'There are only {unused_memory_units} cells to fill the memory address holes, but ' @@ -577,7 +586,7 @@ def get_builtin_segments_info(self): def get_execution_resources(self) -> ExecutionResources: n_steps = len(self.vm.trace) if self.original_steps is None else self.original_steps - n_memory_holes = self.segments.get_memory_holes() + n_memory_holes = self.get_memory_holes() builtin_instance_counter = { builtin_name: builtin_runner.get_used_instances(self) for builtin_name, builtin_runner in self.builtin_runners.items() diff --git a/src/starkware/cairo/lang/vm/cairo_runner_test.py b/src/starkware/cairo/lang/vm/cairo_runner_test.py index 64a68d9a..fba478c1 100644 --- a/src/starkware/cairo/lang/vm/cairo_runner_test.py +++ b/src/starkware/cairo/lang/vm/cairo_runner_test.py @@ -144,3 +144,58 @@ def test_memory_hole_insufficient(): match=re.escape( 'There are only 8 cells to fill the memory address holes, but 999 are required.')): runner.check_memory_usage() + + +def test_hint_memory_holes(): + code_base_format = """\ +func main(): + [ap] = 0 + %{{ + memory[fp + 1] = segments.add_temp_segment() + %}} + [[fp + 1]] = [ap] + ap += 7 + {} + ap += 1 + [ap] = 0 + %{{ + memory.add_relocation_rule(memory[fp + 1], fp + 3) + %}} + ret +end +""" + code_no_hint, code_untouched_hint, code_touched_hint = [ + code_base_format.format(extra_code) + for extra_code in ['', '%{ memory[ap] = 7 %}', '%{ memory[ap] = 7 %}\n [ap]=[ap]']] + + runner_no_hint, runner_untouched_hint, runner_touched_hint = [ + get_runner_from_code(code, layout='plain', prime=PRIME) + for code in (code_no_hint, code_untouched_hint, code_touched_hint)] + + def filter_program_segment(addr_lst): + return {addr for addr in addr_lst if addr.segment_index != 0} + + initial_ap = runner_no_hint.initial_ap + accessed_addresses = { + # Return fp and pc. + initial_ap - 2, + initial_ap - 1, + # Values set in the function. + initial_ap, + initial_ap + 1, + initial_ap + 3, + initial_ap + 8, + } + assert filter_program_segment(runner_no_hint.vm_memory.keys()) == accessed_addresses + assert filter_program_segment(runner_no_hint.accessed_addresses) == accessed_addresses + assert filter_program_segment(runner_untouched_hint.vm_memory.keys()) == \ + accessed_addresses | {initial_ap + 7} + assert filter_program_segment(runner_untouched_hint.accessed_addresses) == accessed_addresses + assert filter_program_segment(runner_touched_hint.vm_memory.keys()) == \ + accessed_addresses | {initial_ap + 7} + assert filter_program_segment(runner_touched_hint.accessed_addresses) == \ + accessed_addresses | {initial_ap + 7} + + assert runner_no_hint.get_memory_holes() == \ + runner_untouched_hint.get_memory_holes() == \ + runner_touched_hint.get_memory_holes() + 1 == 5 diff --git a/src/starkware/cairo/lang/vm/memory_segments.py b/src/starkware/cairo/lang/vm/memory_segments.py index 8b4000d7..3c573818 100644 --- a/src/starkware/cairo/lang/vm/memory_segments.py +++ b/src/starkware/cairo/lang/vm/memory_segments.py @@ -152,20 +152,27 @@ def write_arg(self, ptr, arg, apply_modulo_to_args=True): data = [self.gen_arg(arg=x, apply_modulo_to_args=apply_modulo_to_args) for x in arg] return self.load_data(ptr, data) - def get_memory_holes(self) -> int: + def get_memory_holes(self, accessed_addresses: Set[MaybeRelocatable]) -> int: """ Returns the total number of memory holes in all segments. """ - used_offsets_sets: Dict[int, Set] = defaultdict(set) - for addr in self.memory.keys(): + # A map from segment index to the set of accessed offsets. + accessed_offsets_sets: Dict[int, Set] = defaultdict(set) + for addr in accessed_addresses: assert isinstance(addr, RelocatableValue), \ f'Expected memory address to be relocatable value. Found: {addr}.' - assert addr.offset >= 0, \ - f'Address offsets must be non-negative. Found: {addr.offset}.' - used_offsets_sets[addr.segment_index].add(addr.offset) + index, offset = addr.segment_index, addr.offset + assert offset >= 0, f'Address offsets must be non-negative. Found: {offset}.' + assert offset <= self.get_segment_size(segment_index=index), \ + f'Accessed address {addr} has higher offset than the maximal offset ' \ + f'{self.get_segment_size(segment_index=index)} encountered in the memory segment.' + accessed_offsets_sets[index].add(offset) + + assert self._segment_used_sizes is not None, \ + 'compute_effective_sizes must be called before get_memory_holes.' return sum( - max(used_offsets) + 1 - len(used_offsets) - for used_offsets in used_offsets_sets.values()) + self.get_segment_size(segment_index=index) - len(accessed_offsets_sets[index]) + for index in self._segment_sizes.keys() | self._segment_used_sizes.keys()) def get_segment_used_size(self, segment_index: int) -> int: assert self._segment_used_sizes is not None, \ diff --git a/src/starkware/cairo/lang/vm/memory_segments_test.py b/src/starkware/cairo/lang/vm/memory_segments_test.py index 41386ddd..b15ad3d4 100644 --- a/src/starkware/cairo/lang/vm/memory_segments_test.py +++ b/src/starkware/cairo/lang/vm/memory_segments_test.py @@ -78,3 +78,20 @@ def test_gen_args(): assert memory[ptr] == 1 memory.get_range(memory[ptr + 1], len(test_array)) == test_array memory.get_range(memory[ptr + 2], 2) == [4, PRIME - 1] + + +def test_get_memory_holes(): + segments = MemorySegmentManager(memory=MemoryDict({}), prime=PRIME) + seg0 = segments.add(size=10) + seg1 = segments.add() + + accessed_addresses = {seg0, seg1, seg0 + 1, seg1 + 5} + # Since segment 1 has no specified size, we must set a memory entry directly. + segments.memory[seg1 + 5] = 0 + + segments.memory.relocate_memory() + segments.memory.freeze() + segments.compute_effective_sizes() + seg0_holes = 10 - 2 + seg1_holes = 6 - 2 + assert segments.get_memory_holes(accessed_addresses) == seg0_holes + seg1_holes diff --git a/src/starkware/cairo/lang/vm/vm.py b/src/starkware/cairo/lang/vm/vm.py index b90c9e9a..c64a4387 100644 --- a/src/starkware/cairo/lang/vm/vm.py +++ b/src/starkware/cairo/lang/vm/vm.py @@ -4,7 +4,7 @@ import sys import traceback from functools import lru_cache -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Set, Tuple from starkware.cairo.lang.compiler.debug_info import DebugInfo, InstructionLocation from starkware.cairo.lang.compiler.encode import decode_instruction, is_call_instruction @@ -251,6 +251,10 @@ def __init__( self.program = program self.program_base = program_base if program_base is not None else self.run_context.pc self.validated_memory = ValidatedMemoryDict(memory=self.run_context.memory) + # A set to track the memory addresses accessed by actual Cairo instructions (as opposed to + # hints), necessary for accurate counting of memory holes. + self.accessed_addresses: Set[MaybeRelocatable] = { + self.program_base + i for i in range(len(self.program.data))} # If program is a StrippedProgram, there are no hints or debug information to load. if isinstance(program, Program): @@ -662,6 +666,9 @@ def run_instruction(self, instruction, instruction_encoding): fp=self.run_context.fp, )) + self.accessed_addresses.update(operands_mem_addresses) + self.accessed_addresses.add(self.run_context.pc) + try: # Update registers. self.update_registers(instruction, operands) diff --git a/src/starkware/cairo/lang/vm/vm_test.py b/src/starkware/cairo/lang/vm/vm_test.py index 9b85a402..4f1adc04 100644 --- a/src/starkware/cairo/lang/vm/vm_test.py +++ b/src/starkware/cairo/lang/vm/vm_test.py @@ -68,6 +68,8 @@ def test_simple(): vm = run_single(code, 9, pc=10, ap=102, extra_mem={101: 1}) assert [vm.run_context.memory[101 + i] for i in range(7)] == [1, 3, 9, 10, 16, 48, 10] + assert vm.accessed_addresses == \ + set(vm.run_context.memory.keys()) == {*range(10, 28), 99, *range(101, 108)} def test_jnz(): @@ -219,6 +221,8 @@ def test_hints(): vm.step() assert [vm.run_context.memory[202 + i] for i in range(3)] == [2000, 1000, 1234] + # Check that address fp + 2, whose value was only set in a hint, is not counted as accessed. + assert [202 + i in vm.accessed_addresses for i in range(3)] == [True, True, False] def test_hint_between_references(): diff --git a/src/starkware/cairo/sharp/config.json b/src/starkware/cairo/sharp/config.json index a239ed49..c16073d1 100644 --- a/src/starkware/cairo/sharp/config.json +++ b/src/starkware/cairo/sharp/config.json @@ -1,5 +1,5 @@ { "prover_url": "https://ropsten-v1.provingservice.io", - "verifier_address": "0x2886D2A190f00aA324Ac5BF5a5b90217121D5756", + "verifier_address": "0xAB43bA48c9edF4C2C4bB01237348D1D7B28ef168", "steps_limit": 1000000 } diff --git a/src/starkware/cairo/sharp/sharp_client.py b/src/starkware/cairo/sharp/sharp_client.py index fc4de2af..03c21147 100755 --- a/src/starkware/cairo/sharp/sharp_client.py +++ b/src/starkware/cairo/sharp/sharp_client.py @@ -221,7 +221,7 @@ def is_verified(args, command_args): description='Verify a fact is registered on the SHARP fact-registry.') parser.add_argument('fact', type=str, help='The fact to verify if registered.') parser.add_argument( - '--node_url', required=True, type=str, help='URL for a Ropsten Ethereum node RPC API.') + '--node_url', required=True, type=str, help='URL for a Goerli Ethereum node RPC API.') parser.parse_args(command_args, namespace=args) diff --git a/src/starkware/starknet/cli/starknet_cli.py b/src/starkware/starknet/cli/starknet_cli.py index 4aac71b2..27f401dd 100755 --- a/src/starkware/starknet/cli/starknet_cli.py +++ b/src/starkware/starknet/cli/starknet_cli.py @@ -110,25 +110,27 @@ async def invoke_or_call(args, command_args, call: bool): current_inputs_ptr = 0 for input_desc in abi_entry['inputs']: if input_desc['type'] == 'felt': - assert current_inputs_ptr < len(args.inputs), \ - f'Expected at least {current_inputs_ptr + 1} inputs, got {len(args.inputs)}' + assert current_inputs_ptr < len(args.inputs), ( + f'Expected at least {current_inputs_ptr + 1} inputs, ' + f'got {len(args.inputs)}.') + previous_felt_input = args.inputs[current_inputs_ptr] current_inputs_ptr += 1 elif input_desc['type'] == 'felt*': - assert previous_felt_input is not None, \ - f'The array argument {input_desc["name"]} of type felt* must be preceded ' \ - 'by a length argument of type felt.' + assert previous_felt_input is not None, ( + f'The array argument {input_desc["name"]} of type felt* must be preceded ' + 'by a length argument of type felt.') current_inputs_ptr += previous_felt_input previous_felt_input = None else: - raise Exception(f'Unsupported type {input_desc["type"]}') + raise Exception(f'Type {input_desc["type"]} is not supported.') break else: raise Exception(f'Function {args.function} not found.') selector = get_selector_from_name(args.function) - assert len(args.inputs) == current_inputs_ptr, \ - f'Wrong number of arguments. Expected {current_inputs_ptr}, got {len(args.inputs)}.' + assert len(args.inputs) == current_inputs_ptr, ( + f'Wrong number of arguments. Expected {current_inputs_ptr}, got {len(args.inputs)}.') calldata = args.inputs tx = InvokeFunction(contract_address=address, entry_point_selector=selector, calldata=calldata) diff --git a/src/starkware/starknet/compiler/calldata_parser.py b/src/starkware/starknet/compiler/calldata_parser.py index cdbb272d..ffba7ba8 100644 --- a/src/starkware/starknet/compiler/calldata_parser.py +++ b/src/starkware/starknet/compiler/calldata_parser.py @@ -48,7 +48,7 @@ def parse_code_element(code: str, parent_location: ParentLocation): isinstance(prev_member[1].cairo_type, TypeFelt) if not has_len: raise PreprocessorError( - f'Array argument "{member_name}" must be preceeded by a length argument ' + f'Array argument "{member_name}" must be preceded by a length argument ' f'named "{member_name}_len" of type felt.', location=member_location) if not has_range_check_builtin: diff --git a/src/starkware/starknet/compiler/calldata_parser_test.py b/src/starkware/starknet/compiler/calldata_parser_test.py index 3dd84763..8c24beb3 100644 --- a/src/starkware/starknet/compiler/calldata_parser_test.py +++ b/src/starkware/starknet/compiler/calldata_parser_test.py @@ -80,7 +80,7 @@ def test_process_calldata_failure(): 'arg_b': MemberDefinition(offset=1, cairo_type=TypeFelt(), location=location), }) with pytest.raises( - PreprocessorError, match='Array argument "arg_a" must be preceeded by a length ' + PreprocessorError, match='Array argument "arg_a" must be preceded by a length ' 'argument named "arg_a_len" of type felt.'): process_test_calldata(members={ 'arg_a': MemberDefinition(offset=0, cairo_type=FELT_STAR, location=location), diff --git a/src/starkware/starknet/security/secure_hints.py b/src/starkware/starknet/security/secure_hints.py index 4d905db2..57b765d0 100644 --- a/src/starkware/starknet/security/secure_hints.py +++ b/src/starkware/starknet/security/secure_hints.py @@ -1,3 +1,4 @@ +import re from dataclasses import field from typing import ClassVar, Dict, List, Set, Type @@ -125,6 +126,8 @@ def _get_hint_reference_expressions( Set[NamedExpression]: ref_exprs: Set[NamedExpression] = set() for ref_name, ref_id in hint.flow_tracking_data.reference_ids.items(): + if re.match('^__temp[0-9]+$', ref_name.path[-1]): + continue ref = reference_manager.get_ref(ref_id) ref_exprs.add(NamedExpression(name=str(ref_name), expr=ref.value.format())) return ref_exprs diff --git a/src/starkware/starknet/security/starknet_common.cairo b/src/starkware/starknet/security/starknet_common.cairo index 6f459745..531239e3 100644 --- a/src/starkware/starknet/security/starknet_common.cairo +++ b/src/starkware/starknet/security/starknet_common.cairo @@ -2,6 +2,7 @@ from starkware.cairo.common.alloc import alloc from starkware.cairo.common.default_dict import default_dict_finalize, default_dict_new from starkware.cairo.common.dict import dict_read, dict_squash, dict_update, dict_write from starkware.cairo.common.find_element import find_element, search_sorted, search_sorted_lower +from starkware.cairo.common.keccak import unsafe_keccak from starkware.cairo.common.math import ( abs_value, assert_250_bit, assert_in_range, assert_le, assert_le_felt, assert_lt, assert_lt_felt, assert_nn, assert_nn_le, assert_not_equal, assert_not_zero, sign, @@ -11,4 +12,8 @@ from starkware.cairo.common.math_cmp import ( from starkware.cairo.common.memcpy import memcpy from starkware.cairo.common.signature import verify_ecdsa_signature from starkware.cairo.common.squash_dict import squash_dict +from starkware.cairo.common.uint256 import ( + uint256_add, uint256_and, uint256_cond_neg, uint256_eq, uint256_lt, uint256_mul, uint256_neg, + uint256_not, uint256_or, uint256_shl, uint256_shr, uint256_signed_div_rem, uint256_signed_lt, + uint256_sub, uint256_unsigned_div_rem, uint256_xor) from starkware.starknet.common.storage import normalize_address, storage_read, storage_write diff --git a/src/starkware/starknet/security/whitelists/latest.json b/src/starkware/starknet/security/whitelists/latest.json index d57cadd1..2c6562af 100644 --- a/src/starkware/starknet/security/whitelists/latest.json +++ b/src/starkware/starknet/security/whitelists/latest.json @@ -28,10 +28,6 @@ }, { "allowed_expressions": [ - { - "expr": "[cast(ap + (-1), starkware.cairo.common.dict_access.DictAccess**)]", - "name": "starkware.cairo.common.dict.dict_squash.__temp31" - }, { "expr": "[cast(fp + (-3), starkware.cairo.common.dict_access.DictAccess**)]", "name": "starkware.cairo.common.dict.dict_squash.dict_accesses_end" @@ -114,6 +110,40 @@ "ids.is_small = 1 if ids.addr < ADDR_BOUND else 0" ] }, + { + "allowed_expressions": [ + { + "expr": "[cast(fp + (-6), starkware.cairo.common.uint256.Uint256*)]", + "name": "starkware.cairo.common.uint256.uint256_unsigned_div_rem.a" + }, + { + "expr": "[cast(fp + (-4), starkware.cairo.common.uint256.Uint256*)]", + "name": "starkware.cairo.common.uint256.uint256_unsigned_div_rem.div" + }, + { + "expr": "[cast(fp, starkware.cairo.common.uint256.Uint256*)]", + "name": "starkware.cairo.common.uint256.uint256_unsigned_div_rem.quotient" + }, + { + "expr": "[cast(fp + (-7), felt*)]", + "name": "starkware.cairo.common.uint256.uint256_unsigned_div_rem.range_check_ptr" + }, + { + "expr": "[cast(fp + 2, starkware.cairo.common.uint256.Uint256*)]", + "name": "starkware.cairo.common.uint256.uint256_unsigned_div_rem.remainder" + } + ], + "hint_lines": [ + "a = (ids.a.high << 128) + ids.a.low", + "div = (ids.div.high << 128) + ids.div.low", + "quotient, remainder = divmod(a, div)", + "", + "ids.quotient.low = quotient & ((1 << 128) - 1)", + "ids.quotient.high = quotient >> 128", + "ids.remainder.low = remainder & ((1 << 128) - 1)", + "ids.remainder.high = remainder >> 128" + ] + }, { "allowed_expressions": [ { @@ -611,6 +641,46 @@ "ecdsa_builtin.add_signature(ids.ecdsa_ptr.address_, (ids.signature_r, ids.signature_s))" ] }, + { + "allowed_expressions": [ + { + "expr": "[cast(fp + (-4), felt**)]", + "name": "starkware.cairo.common.keccak.unsafe_keccak.data" + }, + { + "expr": "[cast(fp + 1, felt*)]", + "name": "starkware.cairo.common.keccak.unsafe_keccak.high" + }, + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "starkware.cairo.common.keccak.unsafe_keccak.length" + }, + { + "expr": "[cast(fp, felt*)]", + "name": "starkware.cairo.common.keccak.unsafe_keccak.low" + } + ], + "hint_lines": [ + "from eth_hash.auto import keccak", + "data, length = ids.data, ids.length", + "", + "if '__keccak_max_size' in globals():", + " assert length <= __keccak_max_size, \\", + " f'unsafe_keccak() can only be used with length<={__keccak_max_size}. ' \\", + " f'Got: length={length}.'", + "", + "keccak_input = bytearray()", + "for word_i, byte_i in enumerate(range(0, length, 16)):", + " word = memory[data + word_i]", + " n_bytes = min(16, length - byte_i)", + " assert 0 <= word < 2 ** (8 * n_bytes)", + " keccak_input += word.to_bytes(n_bytes, 'big')", + "", + "hashed = keccak(keccak_input)", + "ids.high = int.from_bytes(hashed[:16], 'big')", + "ids.low = int.from_bytes(hashed[16:32], 'big')" + ] + }, { "allowed_expressions": [ { @@ -877,6 +947,62 @@ "assert (ids.a - ids.b) % PRIME != 0, f'assert_not_equal failed: {ids.a} = {ids.b}.'" ] }, + { + "allowed_expressions": [ + { + "expr": "[cast(fp + (-5), felt*)]", + "name": "starkware.cairo.common.uint256.felt_and.a" + }, + { + "expr": "[cast(fp, felt*)]", + "name": "starkware.cairo.common.uint256.felt_and.a_lsb" + }, + { + "expr": "[cast(fp + (-4), felt*)]", + "name": "starkware.cairo.common.uint256.felt_and.b" + }, + { + "expr": "[cast(fp + 1, felt*)]", + "name": "starkware.cairo.common.uint256.felt_and.b_lsb" + }, + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "starkware.cairo.common.uint256.felt_and.n" + }, + { + "expr": "[cast(fp + (-6), felt*)]", + "name": "starkware.cairo.common.uint256.felt_and.range_check_ptr" + }, + { + "expr": "[cast(fp + (-5), felt*)]", + "name": "starkware.cairo.common.uint256.felt_xor.a" + }, + { + "expr": "[cast(fp, felt*)]", + "name": "starkware.cairo.common.uint256.felt_xor.a_lsb" + }, + { + "expr": "[cast(fp + (-4), felt*)]", + "name": "starkware.cairo.common.uint256.felt_xor.b" + }, + { + "expr": "[cast(fp + 1, felt*)]", + "name": "starkware.cairo.common.uint256.felt_xor.b_lsb" + }, + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "starkware.cairo.common.uint256.felt_xor.n" + }, + { + "expr": "[cast(fp + (-6), felt*)]", + "name": "starkware.cairo.common.uint256.felt_xor.range_check_ptr" + } + ], + "hint_lines": [ + "ids.a_lsb = ids.a & 1", + "ids.b_lsb = ids.b & 1" + ] + }, { "allowed_expressions": [ { @@ -943,6 +1069,37 @@ "ids.is_250 = 1 if ids.addr < 2**250 else 0" ] }, + { + "allowed_expressions": [ + { + "expr": "[cast(fp + (-4), felt*)]", + "name": "starkware.cairo.common.pow.pow.base" + }, + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "starkware.cairo.common.pow.pow.exp" + }, + { + "expr": "cast(fp + (-2), starkware.cairo.common.pow.pow.LoopLocals*)", + "name": "starkware.cairo.common.pow.pow.initial_locs" + }, + { + "expr": "cast(ap, starkware.cairo.common.pow.pow.LoopLocals*)", + "name": "starkware.cairo.common.pow.pow.locs" + }, + { + "expr": "cast(ap + (-5), starkware.cairo.common.pow.pow.LoopLocals*)", + "name": "starkware.cairo.common.pow.pow.prev_locs" + }, + { + "expr": "[cast(fp + (-5), felt*)]", + "name": "starkware.cairo.common.pow.pow.range_check_ptr" + } + ], + "hint_lines": [ + "ids.locs.bit = (ids.prev_locs.exp % PRIME) & 1" + ] + }, { "allowed_expressions": [ { @@ -1022,6 +1179,30 @@ "ids.loop_temps.should_continue = 1 if current_access_indices else 0" ] }, + { + "allowed_expressions": [ + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "starkware.cairo.common.uint256.split_64.a" + }, + { + "expr": "[cast(fp + 1, felt*)]", + "name": "starkware.cairo.common.uint256.split_64.high" + }, + { + "expr": "[cast(fp, felt*)]", + "name": "starkware.cairo.common.uint256.split_64.low" + }, + { + "expr": "[cast(fp + (-4), felt*)]", + "name": "starkware.cairo.common.uint256.split_64.range_check_ptr" + } + ], + "hint_lines": [ + "ids.low = ids.a & ((1<<64) - 1)", + "ids.high = ids.a >> 64" + ] + }, { "allowed_expressions": [ { @@ -1168,10 +1349,6 @@ }, { "allowed_expressions": [ - { - "expr": "[cast(ap + (-1), felt*)]", - "name": "starkware.cairo.common.memcpy.memcpy.__temp46" - }, { "expr": "[cast(ap, felt*)]", "name": "starkware.cairo.common.memcpy.memcpy.continue_copying" @@ -1279,6 +1456,40 @@ "current_access_index = new_access_index" ] }, + { + "allowed_expressions": [ + { + "expr": "[cast(fp + (-6), starkware.cairo.common.uint256.Uint256*)]", + "name": "starkware.cairo.common.uint256.uint256_add.a" + }, + { + "expr": "[cast(fp + (-4), starkware.cairo.common.uint256.Uint256*)]", + "name": "starkware.cairo.common.uint256.uint256_add.b" + }, + { + "expr": "[cast(fp + 3, felt*)]", + "name": "starkware.cairo.common.uint256.uint256_add.carry_high" + }, + { + "expr": "[cast(fp + 2, felt*)]", + "name": "starkware.cairo.common.uint256.uint256_add.carry_low" + }, + { + "expr": "[cast(fp + (-7), felt*)]", + "name": "starkware.cairo.common.uint256.uint256_add.range_check_ptr" + }, + { + "expr": "[cast(fp, starkware.cairo.common.uint256.Uint256*)]", + "name": "starkware.cairo.common.uint256.uint256_add.res" + } + ], + "hint_lines": [ + "sum_low = ids.a.low + ids.b.low", + "ids.carry_low = 1 if sum_low >= ids.SHIFT else 0", + "sum_high = ids.a.high + ids.b.high + ids.carry_low", + "ids.carry_high = 1 if sum_high >= ids.SHIFT else 0" + ] + }, { "allowed_expressions": [ { @@ -1327,10 +1538,6 @@ }, { "allowed_expressions": [ - { - "expr": "[cast(ap + (-1), starkware.cairo.common.dict_access.DictAccess**)]", - "name": "starkware.cairo.common.dict.dict_squash.__temp31" - }, { "expr": "[cast(fp + (-3), starkware.cairo.common.dict_access.DictAccess**)]", "name": "starkware.cairo.common.dict.dict_squash.dict_accesses_end" @@ -1347,10 +1554,6 @@ "expr": "[cast(fp, starkware.cairo.common.dict_access.DictAccess**)]", "name": "starkware.cairo.common.dict.dict_squash.squashed_dict_start" }, - { - "expr": "[cast(ap + (-1), felt*)]", - "name": "starkware.cairo.common.memcpy.memcpy.__temp46" - }, { "expr": "[cast(ap, felt*)]", "name": "starkware.cairo.common.memcpy.memcpy.continue_copying"