Skip to content

Commit

Permalink
uint256 hint utils + test (keep-starknet-strange#440)
Browse files Browse the repository at this point in the history
* uint256 hint utils + test

* fix review
  • Loading branch information
StringNick authored Mar 4, 2024
1 parent 4b322e5 commit c5ebc7b
Show file tree
Hide file tree
Showing 10 changed files with 1,251 additions and 27 deletions.
98 changes: 98 additions & 0 deletions src/hint_processor/builtin_hint_codes.zig
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,101 @@ pub const SPLIT_OUTPUT_MID_LOW_HIGH =
\\tmp, ids.output1_low = divmod(ids.output1, 256 ** 7)
\\ids.output1_high, ids.output1_mid = divmod(tmp, 2 ** 128)
;

pub const BIGINT_TO_UINT256 = "ids.low = (ids.x.d0 + ids.x.d1 * ids.BASE) & ((1 << 128) - 1)";
pub const UINT256_ADD =
\\sum_low = ids.a.low + ids.b.low
\\ids.carry_low = 1 if sum_low >= ids.SHIFT else 0
\\sum_high = ids.a.high + ids.b.high + ids.carry_low
\\ids.carry_high = 1 if sum_high >= ids.SHIFT else 0
;

pub const UINT256_ADD_LOW =
\\sum_low = ids.a.low + ids.b.low
\\ids.carry_low = 1 if sum_low >= ids.SHIFT else 0
;

pub const UINT128_ADD =
\\res = ids.a + ids.b
\\ids.carry = 1 if res >= ids.SHIFT else 0
;

pub const UINT256_SUB =
\\def split(num: int, num_bits_shift: int = 128, length: int = 2):
\\ a = []
\\ for _ in range(length):
\\ a.append( num & ((1 << num_bits_shift) - 1) )
\\ num = num >> num_bits_shift
\\ return tuple(a)
\\
\\def pack(z, num_bits_shift: int = 128) -> int:
\\ limbs = (z.low, z.high)
\\ return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))
\\
\\a = pack(ids.a)
\\b = pack(ids.b)
\\res = (a - b)%2**256
\\res_split = split(res)
\\ids.res.low = res_split[0]
\\ids.res.high = res_split[1]
;

pub const UINT256_SQRT =
\\from starkware.python.math_utils import isqrt
\\n = (ids.n.high << 128) + ids.n.low
\\root = isqrt(n)
\\assert 0 <= root < 2 ** 128
\\ids.root.low = root
\\ids.root.high = 0
;

pub const UINT256_SQRT_FELT =
\\from starkware.python.math_utils import isqrt
\\n = (ids.n.high << 128) + ids.n.low
\\root = isqrt(n)
\\assert 0 <= root < 2 ** 128
\\ids.root = root
;

pub const UINT256_SIGNED_NN = "memory[ap] = 1 if 0 <= (ids.a.high % PRIME) < 2 ** 127 else 0";

pub const UINT256_UNSIGNED_DIV_REM =
\\a = (ids.a.high << 128) + ids.a.low
\\div = (ids.div.high << 128) + ids.div.low
\\quotient, remainder = divmod(a, div)
\\
\\ids.quotient.low = quotient & ((1 << 128) - 1)
\\ids.quotient.high = quotient >> 128
\\ids.remainder.low = remainder & ((1 << 128) - 1)
\\ids.remainder.high = remainder >> 128
;

pub const UINT256_EXPANDED_UNSIGNED_DIV_REM =
\\a = (ids.a.high << 128) + ids.a.low
\\div = (ids.div.b23 << 128) + ids.div.b01
\\quotient, remainder = divmod(a, div)
\\
\\ids.quotient.low = quotient & ((1 << 128) - 1)
\\ids.quotient.high = quotient >> 128
\\ids.remainder.low = remainder & ((1 << 128) - 1)
\\ids.remainder.high = remainder >> 128
;

pub const UINT256_MUL_DIV_MOD =
\\a = (ids.a.high << 128) + ids.a.low
\\b = (ids.b.high << 128) + ids.b.low
\\div = (ids.div.high << 128) + ids.div.low
\\quotient, remainder = divmod(a * b, div)
\\
\\ids.quotient_low.low = quotient & ((1 << 128) - 1)
\\ids.quotient_low.high = (quotient >> 128) & ((1 << 128) - 1)
\\ids.quotient_high.low = (quotient >> 256) & ((1 << 128) - 1)
\\ids.quotient_high.high = quotient >> 384
\\ids.remainder.low = remainder & ((1 << 128) - 1)
\\ids.remainder.high = remainder >> 128
;

pub const SPLIT_64 =
\\ids.low = ids.a & ((1<<64) - 1)
\\ids.high = ids.a >> 64
;
26 changes: 24 additions & 2 deletions src/hint_processor/hint_processor_def.zig
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ const Relocatable = @import("../vm/memory/relocatable.zig").Relocatable;
const hint_codes = @import("builtin_hint_codes.zig");
const math_hints = @import("math_hints.zig");
const memcpy_hint_utils = @import("memcpy_hint_utils.zig");
const uint256_utils = @import("uint256_utils.zig");

const poseidon_utils = @import("poseidon_utils.zig");
const keccak_utils = @import("keccak_utils.zig");
const felt_bit_length = @import("felt_bit_length.zig");


const deserialize_utils = @import("../parser/deserialize_utils.zig");

const expect = std.testing.expect;
Expand Down Expand Up @@ -230,7 +230,29 @@ pub const CairoVMHintProcessor = struct {
try keccak_utils.splitOutputMidLowHigh(allocator, vm, hint_data.ids_data, hint_data.ap_tracking);
} else if (std.mem.eql(u8, hint_codes.GET_FELT_BIT_LENGTH, hint_data.code)) {
try felt_bit_length.getFeltBitLength(allocator, vm, hint_data.ids_data, hint_data.ap_tracking);
}
} else if (std.mem.eql(u8, hint_codes.UINT256_ADD, hint_data.code)) {
try uint256_utils.uint256Add(allocator, vm, hint_data.ids_data, hint_data.ap_tracking, false);
} else if (std.mem.eql(u8, hint_codes.UINT256_ADD_LOW, hint_data.code)) {
try uint256_utils.uint256Add(allocator, vm, hint_data.ids_data, hint_data.ap_tracking, true);
} else if (std.mem.eql(u8, hint_codes.UINT128_ADD, hint_data.code)) {
try uint256_utils.uint128Add(allocator, vm, hint_data.ids_data, hint_data.ap_tracking);
} else if (std.mem.eql(u8, hint_codes.UINT256_SUB, hint_data.code)) {
try uint256_utils.uint256Sub(allocator, vm, hint_data.ids_data, hint_data.ap_tracking);
} else if (std.mem.eql(u8, hint_codes.SPLIT_64, hint_data.code)) {
try uint256_utils.split64(allocator, vm, hint_data.ids_data, hint_data.ap_tracking);
} else if (std.mem.eql(u8, hint_codes.UINT256_SQRT, hint_data.code)) {
try uint256_utils.uint256Sqrt(allocator, vm, hint_data.ids_data, hint_data.ap_tracking, false);
} else if (std.mem.eql(u8, hint_codes.UINT256_SQRT_FELT, hint_data.code)) {
try uint256_utils.uint256Sqrt(allocator, vm, hint_data.ids_data, hint_data.ap_tracking, true);
} else if (std.mem.eql(u8, hint_codes.UINT256_SIGNED_NN, hint_data.code)) {
try uint256_utils.uint256SignedNn(allocator, vm, hint_data.ids_data, hint_data.ap_tracking);
} else if (std.mem.eql(u8, hint_codes.UINT256_UNSIGNED_DIV_REM, hint_data.code)) {
try uint256_utils.uint256UnsignedDivRem(allocator, vm, hint_data.ids_data, hint_data.ap_tracking);
} else if (std.mem.eql(u8, hint_codes.UINT256_EXPANDED_UNSIGNED_DIV_REM, hint_data.code)) {
try uint256_utils.uint256ExpandedUnsignedDivRem(allocator, vm, hint_data.ids_data, hint_data.ap_tracking);
} else if (std.mem.eql(u8, hint_codes.UINT256_MUL_DIV_MOD, hint_data.code)) {
try uint256_utils.uint256MulDivMod(allocator, vm, hint_data.ids_data, hint_data.ap_tracking);
}
}

// Executes the hint which's data is provided by a dynamic structure previously created by compile_hint
Expand Down
8 changes: 4 additions & 4 deletions src/hint_processor/keccak_utils.zig
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ pub fn splitInput(
const inputs_ptr = try hint_utils.getPtrFromVarName("inputs", vm, ids_data, ap_tracking);
const binding = try vm.getFelt(try inputs_ptr.addUint(input_key));
const split = Felt252.pow2Const(8 * exponent);
const high_low = try helper.divRem(binding.toInteger(), split.toInteger());
const high_low = try helper.divRem(u256, binding.toInteger(), split.toInteger());
var buffer: [20]u8 = undefined;

try hint_utils.insertValueFromVarName(
Expand Down Expand Up @@ -224,7 +224,7 @@ pub fn splitOutput(
var buffer: [30]u8 = undefined;
const output = try hint_utils.getIntegerFromVarName(try std.fmt.bufPrint(buffer[0..], "output{d}", .{num}), vm, ids_data, ap_tracking);

const high_low = try helper.divRem(output.toInteger(), Felt252.pow2Const(128).toInteger());
const high_low = try helper.divRem(u256, output.toInteger(), Felt252.pow2Const(128).toInteger());
try hint_utils.insertValueFromVarName(
allocator,
try std.fmt.bufPrint(buffer[0..], "output{d}_high", .{num}),
Expand Down Expand Up @@ -450,8 +450,8 @@ pub fn splitOutputMidLowHigh(
ap_tracking: ApTracking,
) !void {
const output1 = try hint_utils.getIntegerFromVarName("output1", vm, ids_data, ap_tracking);
const tmp_output1_low = try helper.divRem(output1.toInteger(), Felt252.pow2Const(8 * 7).toInteger());
const output1_high_output1_mid = try helper.divRem(tmp_output1_low[0], Felt252.pow2Const(128).toInteger());
const tmp_output1_low = try helper.divRem(u256, output1.toInteger(), Felt252.pow2Const(8 * 7).toInteger());
const output1_high_output1_mid = try helper.divRem(u256, tmp_output1_low[0], Felt252.pow2Const(128).toInteger());

try hint_utils.insertValueFromVarName(allocator, "output1_high", MaybeRelocatable.fromInt(u256, output1_high_output1_mid[0]), vm, ids_data, ap_tracking);
try hint_utils.insertValueFromVarName(allocator, "output1_mid", MaybeRelocatable.fromInt(u256, output1_high_output1_mid[1]), vm, ids_data, ap_tracking);
Expand Down
20 changes: 2 additions & 18 deletions src/hint_processor/math_hints.zig
Original file line number Diff line number Diff line change
Expand Up @@ -159,22 +159,6 @@ pub fn assertNotEqual(
}
}

fn isqrt(n: u256) !u256 {
var x = n;
var y = (n + 1) >> @as(u32, 1);

while (y < x) {
x = y;
y = (@divFloor(n, x) + x) >> @as(u32, 1);
}

if (!(std.math.pow(u256, x, 2) <= n and n < std.math.pow(u256, x + 1, 2))) {
return error.FailedToGetSqrt;
}

return x;
}

//Implements hint: from starkware.python.math_utils import isqrt
// value = ids.value % PRIME
// assert value < 2 ** 250, f"value={value} is outside of the range [0, 2**250)."
Expand All @@ -192,7 +176,7 @@ pub fn sqrt(
return HintError.ValueOutside250BitRange;
}

const root = Felt252.fromInt(u256, isqrt(mod_value.toInteger()) catch unreachable);
const root = Felt252.fromInt(u256, field_helper.isqrt(u256, mod_value.toInteger()) catch unreachable);

try hint_utils.insertValueFromVarName(
allocator,
Expand Down Expand Up @@ -225,7 +209,7 @@ pub fn unsignedDivRem(
if (div.isZero() or div.gt(divPrimeByBound(b))) return HintError.OutOfValidRange;
} else if (div.isZero()) return HintError.OutOfValidRange;

const qr = try (field_helper.divRem(value.toInteger(), div.toInteger()) catch MathError.DividedByZero);
const qr = try (field_helper.divRem(u256, value.toInteger(), div.toInteger()) catch MathError.DividedByZero);

try hint_utils.insertValueFromVarName(allocator, "r", MaybeRelocatable.fromInt(u256, qr[1]), vm, ids_data, ap_tracking);
try hint_utils.insertValueFromVarName(allocator, "q", MaybeRelocatable.fromInt(u256, qr[0]), vm, ids_data, ap_tracking);
Expand Down
11 changes: 11 additions & 0 deletions src/hint_processor/testing_utils.zig
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,17 @@ const MaybeRelocatable = relocatable.MaybeRelocatable;
const IdsManager = @import("hint_utils.zig").IdsManager;
const HintReference = @import("../hint_processor/hint_processor_def.zig").HintReference;

pub fn setupIdsNonContinuousIdsData(allocator: std.mem.Allocator, data: []const struct { []const u8, i32 }) !std.StringHashMap(HintReference) {
var ids_data = std.StringHashMap(HintReference).init(allocator);
errdefer ids_data.deinit();

for (data) |d| {
try ids_data.put(d[0], HintReference.initSimple(d[1]));
}

return ids_data;
}

pub fn setupIdsForTestWithoutMemory(allocator: std.mem.Allocator, data: []const []const u8) !std.StringHashMap(HintReference) {
var result = std.StringHashMap(HintReference).init(allocator);
errdefer result.deinit();
Expand Down
Loading

0 comments on commit c5ebc7b

Please sign in to comment.