Skip to content

Commit

Permalink
feat/usort builtin (#441)
Browse files Browse the repository at this point in the history
* usort builtin

* some review fixes

* adopt to exec scope methods

* fix review

---------

Co-authored-by: lanaivina <[email protected]>
  • Loading branch information
StringNick and lana-shanghai authored Mar 6, 2024
1 parent a4d730e commit e316ce5
Show file tree
Hide file tree
Showing 6 changed files with 485 additions and 4 deletions.
38 changes: 38 additions & 0 deletions src/hint_processor/builtin_hint_codes.zig
Original file line number Diff line number Diff line change
Expand Up @@ -292,10 +292,48 @@ pub const SPLIT_64 =
\\ids.high = ids.a >> 64
;


pub const USORT_ENTER_SCOPE =
"vm_enter_scope(dict(__usort_max_size = globals().get('__usort_max_size')))";
pub const USORT_BODY =
\\from collections import defaultdict
\\
\\input_ptr = ids.input
\\input_len = int(ids.input_len)
\\if __usort_max_size is not None:
\\ assert input_len <= __usort_max_size, (
\\ f"usort() can only be used with input_len<={__usort_max_size}. "
\\ f"Got: input_len={input_len}."
\\ )
\\
\\positions_dict = defaultdict(list)
\\for i in range(input_len):
\\ val = memory[input_ptr + i]
\\ positions_dict[val].append(i)
\\
\\output = sorted(positions_dict.keys())
\\ids.output_len = len(output)
\\ids.output = segments.gen_arg(output)
\\ids.multiplicities = segments.gen_arg([len(positions_dict[k]) for k in output])
;

pub const USORT_VERIFY =
\\last_pos = 0
\\positions = positions_dict[ids.value][::-1]
;

pub const USORT_VERIFY_MULTIPLICITY_ASSERT = "assert len(positions) == 0";
pub const USORT_VERIFY_MULTIPLICITY_BODY =
\\current_pos = positions.pop()
\\ids.next_item_index = current_pos - last_pos
\\last_pos = current_pos + 1
;

pub const MEMSET_ENTER_SCOPE = "vm_enter_scope({'n': ids.n})";
pub const MEMSET_CONTINUE_LOOP =
\\n -= 1
\\ids.continue_loop = 1 if n > 0 else 0
;

pub const MEMCPY_CONTINUE_COPYING = "n -= 1 ids.continue_copying = 1 if n > 0 else 0";

11 changes: 11 additions & 0 deletions src/hint_processor/hint_processor_def.zig
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ const math_hints = @import("math_hints.zig");
const memcpy_hint_utils = @import("memcpy_hint_utils.zig");
const memset_utils = @import("memset_utils.zig");
const uint256_utils = @import("uint256_utils.zig");
const usort = @import("usort.zig");

const poseidon_utils = @import("poseidon_utils.zig");
const keccak_utils = @import("keccak_utils.zig");
Expand Down Expand Up @@ -258,6 +259,16 @@ pub const CairoVMHintProcessor = struct {
try uint256_utils.uint256ExpandedUnsignedDivRem(allocator, vm, hint_data.ids_data, hint_data.ap_tracking);
} else if (std.mem.eql(u8, hint_codes.UINT256_MUL_DIV_MOD, hint_data.code)) {
try uint256_utils.uint256MulDivMod(allocator, vm, hint_data.ids_data, hint_data.ap_tracking);
} else if (std.mem.eql(u8, hint_codes.USORT_ENTER_SCOPE, hint_data.code)) {
try usort.usortEnterScope(allocator, exec_scopes);
} else if (std.mem.eql(u8, hint_codes.USORT_BODY, hint_data.code)) {
try usort.usortBody(allocator, vm, exec_scopes, hint_data.ids_data, hint_data.ap_tracking);
} else if (std.mem.eql(u8, hint_codes.USORT_VERIFY, hint_data.code)) {
try usort.verifyUsort(vm, exec_scopes, hint_data.ids_data, hint_data.ap_tracking);
} else if (std.mem.eql(u8, hint_codes.USORT_VERIFY_MULTIPLICITY_ASSERT, hint_data.code)) {
try usort.verifyMultiplicityAssert(exec_scopes);
} else if (std.mem.eql(u8, hint_codes.USORT_VERIFY_MULTIPLICITY_BODY, hint_data.code)) {
try usort.verifyMultiplicityBody(allocator, vm, exec_scopes, hint_data.ids_data, hint_data.ap_tracking);
} else if (std.mem.eql(u8, hint_codes.MEMSET_ENTER_SCOPE, hint_data.code)) {
try memset_utils.memsetEnterScope(allocator, vm, exec_scopes, hint_data.ids_data, hint_data.ap_tracking);
} else if (std.mem.eql(u8, hint_codes.MEMCPY_ENTER_SCOPE, hint_data.code)) {
Expand Down
Loading

0 comments on commit e316ce5

Please sign in to comment.