Skip to content

Commit

Permalink
Feat/uint384 hint (#503)
Browse files Browse the repository at this point in the history
uint_utils
  • Loading branch information
StringNick authored Apr 15, 2024
1 parent cf4a3fd commit 28df66d
Show file tree
Hide file tree
Showing 12 changed files with 1,237 additions and 36 deletions.
186 changes: 186 additions & 0 deletions src/hint_processor/builtin_hint_codes.zig
Original file line number Diff line number Diff line change
Expand Up @@ -664,3 +664,189 @@ pub const NONDET_BIGINT3_V2 =
\\from starkware.cairo.common.cairo_secp.secp_utils import split
\\segments.write_arg(ids.res.address_, split(value))
;


// The following hints support the lib https://github.com/NethermindEth/research-basic-Cairo-operations-big-integers/blob/main/lib
pub const UINT384_UNSIGNED_DIV_REM =
\\def split(num: int, num_bits_shift: int, length: int):
\\ a = []
\\ for _ in range(length):
\\ a.append( num & ((1 << num_bits_shift) - 1) )
\\ num = num >> num_bits_shift
\\ return tuple(a)
\\
\\def pack(z, num_bits_shift: int) -> int:
\\ limbs = (z.d0, z.d1, z.d2)
\\ return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))
\\
\\a = pack(ids.a, num_bits_shift = 128)
\\div = pack(ids.div, num_bits_shift = 128)
\\quotient, remainder = divmod(a, div)
\\
\\quotient_split = split(quotient, num_bits_shift=128, length=3)
\\assert len(quotient_split) == 3
\\
\\ids.quotient.d0 = quotient_split[0]
\\ids.quotient.d1 = quotient_split[1]
\\ids.quotient.d2 = quotient_split[2]
\\
\\remainder_split = split(remainder, num_bits_shift=128, length=3)
\\ids.remainder.d0 = remainder_split[0]
\\ids.remainder.d1 = remainder_split[1]
\\ids.remainder.d2 = remainder_split[2]
;

pub const UINT384_SPLIT_128 =
\\ids.low = ids.a & ((1<<128) - 1)
\\ids.high = ids.a >> 128
;

pub const ADD_NO_UINT384_CHECK =
\\sum_d0 = ids.a.d0 + ids.b.d0
\\ids.carry_d0 = 1 if sum_d0 >= ids.SHIFT else 0
\\sum_d1 = ids.a.d1 + ids.b.d1 + ids.carry_d0
\\ids.carry_d1 = 1 if sum_d1 >= ids.SHIFT else 0
\\sum_d2 = ids.a.d2 + ids.b.d2 + ids.carry_d1
\\ids.carry_d2 = 1 if sum_d2 >= ids.SHIFT else 0
;

pub const UINT384_SQRT =
\\from starkware.python.math_utils import isqrt
\\
\\def split(num: int, num_bits_shift: int, length: int):
\\ a = []
\\ for _ in range(length):
\\ a.append( num & ((1 << num_bits_shift) - 1) )
\\ num = num >> num_bits_shift
\\ return tuple(a)
\\
\\def pack(z, num_bits_shift: int) -> int:
\\ limbs = (z.d0, z.d1, z.d2)
\\ return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))
\\
\\a = pack(ids.a, num_bits_shift=128)
\\root = isqrt(a)
\\assert 0 <= root < 2 ** 192
\\root_split = split(root, num_bits_shift=128, length=3)
\\ids.root.d0 = root_split[0]
\\ids.root.d1 = root_split[1]
\\ids.root.d2 = root_split[2]
;

pub const SUB_REDUCED_A_AND_REDUCED_B =
\\def split(num: int, num_bits_shift: int, length: int):
\\ a = []
\\ for _ in range(length):
\\ a.append( num & ((1 << num_bits_shift) - 1) )
\\ num = num >> num_bits_shift
\\ return tuple(a)
\\
\\def pack(z, num_bits_shift: int) -> int:
\\ limbs = (z.d0, z.d1, z.d2)
\\ return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))
\\
\\a = pack(ids.a, num_bits_shift = 128)
\\b = pack(ids.b, num_bits_shift = 128)
\\p = pack(ids.p, num_bits_shift = 128)
\\
\\res = (a - b) % p
\\
\\
\\res_split = split(res, num_bits_shift=128, length=3)
\\
\\ids.res.d0 = res_split[0]
\\ids.res.d1 = res_split[1]
\\ids.res.d2 = res_split[2]
;

pub const UINT384_SIGNED_NN = "memory[ap] = 1 if 0 <= (ids.a.d2 % PRIME) < 2 ** 127 else 0";

pub const UNSIGNED_DIV_REM_UINT768_BY_UINT384 =
\\def split(num: int, num_bits_shift: int, length: int):
\\ a = []
\\ for _ in range(length):
\\ a.append( num & ((1 << num_bits_shift) - 1) )
\\ num = num >> num_bits_shift
\\ return tuple(a)
\\
\\def pack(z, num_bits_shift: int) -> int:
\\ limbs = (z.d0, z.d1, z.d2)
\\ return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))
\\
\\def pack_extended(z, num_bits_shift: int) -> int:
\\ limbs = (z.d0, z.d1, z.d2, z.d3, z.d4, z.d5)
\\ return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))
\\
\\a = pack_extended(ids.a, num_bits_shift = 128)
\\div = pack(ids.div, num_bits_shift = 128)
\\
\\quotient, remainder = divmod(a, div)
\\
\\quotient_split = split(quotient, num_bits_shift=128, length=6)
\\
\\ids.quotient.d0 = quotient_split[0]
\\ids.quotient.d1 = quotient_split[1]
\\ids.quotient.d2 = quotient_split[2]
\\ids.quotient.d3 = quotient_split[3]
\\ids.quotient.d4 = quotient_split[4]
\\ids.quotient.d5 = quotient_split[5]
\\
\\remainder_split = split(remainder, num_bits_shift=128, length=3)
\\ids.remainder.d0 = remainder_split[0]
\\ids.remainder.d1 = remainder_split[1]
\\ids.remainder.d2 = remainder_split[2]
;

// equal to UNSIGNED_DIV_REM_UINT768_BY_UINT384 but with some whitespace removed
// in the `num = num >> num_bits_shift` and between `pack` and `pack_extended`
pub const UNSIGNED_DIV_REM_UINT768_BY_UINT384_STRIPPED =
\\def split(num: int, num_bits_shift: int, length: int):
\\ a = []
\\ for _ in range(length):
\\ a.append( num & ((1 << num_bits_shift) - 1) )
\\ num = num >> num_bits_shift
\\ return tuple(a)
\\
\\def pack(z, num_bits_shift: int) -> int:
\\ limbs = (z.d0, z.d1, z.d2)
\\ return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))
\\
\\def pack_extended(z, num_bits_shift: int) -> int:
\\ limbs = (z.d0, z.d1, z.d2, z.d3, z.d4, z.d5)
\\ return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))
\\
\\a = pack_extended(ids.a, num_bits_shift = 128)
\\div = pack(ids.div, num_bits_shift = 128)
\\
\\quotient, remainder = divmod(a, div)
\\
\\quotient_split = split(quotient, num_bits_shift=128, length=6)
\\
\\ids.quotient.d0 = quotient_split[0]
\\ids.quotient.d1 = quotient_split[1]
\\ids.quotient.d2 = quotient_split[2]
\\ids.quotient.d3 = quotient_split[3]
\\ids.quotient.d4 = quotient_split[4]
\\ids.quotient.d5 = quotient_split[5]
\\
\\remainder_split = split(remainder, num_bits_shift=128, length=3)
\\ids.remainder.d0 = remainder_split[0]
\\ids.remainder.d1 = remainder_split[1]
\\ids.remainder.d2 = remainder_split[2]
;

pub const INV_MOD_P_UINT512 =
\\def pack_512(u, num_bits_shift: int) -> int:
\\ limbs = (u.d0, u.d1, u.d2, u.d3)
\\ return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))
\\
\\x = pack_512(ids.x, num_bits_shift = 128)
\\p = ids.p.low + (ids.p.high << 128)
\\x_inverse_mod_p = pow(x,-1, p)
\\
\\x_inverse_mod_p_split = (x_inverse_mod_p & ((1 << 128) - 1), x_inverse_mod_p >> 128)
\\
\\ids.x_inverse_mod_p.low = x_inverse_mod_p_split[0]
\\ids.x_inverse_mod_p.high = x_inverse_mod_p_split[1]
;

20 changes: 13 additions & 7 deletions src/hint_processor/builtin_hint_processor/secp/bigint_utils.zig
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,22 @@ pub fn BigIntN(comptime NUM_LIMBS: usize) type {
return .{ .limbs = limbs };
}

pub fn insertFromVarName(self: *Self, allocator: std.mem.allocator, var_name: []const u8, vm: *CairoVM, ids_data: std.StringHashMap(HintReference), ap_tracking: ApTracking) !void {
pub fn insertFromVarName(
self: *const Self,
allocator: std.mem.Allocator,
var_name: []const u8,
vm: *CairoVM,
ids_data: std.StringHashMap(HintReference),
ap_tracking: ApTracking,
) !void {
const addr = try hint_utils.getRelocatableFromVarName(var_name, vm, ids_data, ap_tracking);
inline for (0..NUM_LIMBS) |i| {
try vm.insertInMemory(allocator, addr + i, self.limbs[i]);
try vm.insertInMemory(allocator, try addr.addUint(i), MaybeRelocatable.fromFelt(self.limbs[i]));
}
}

pub fn pack(self: *const Self, allocator: std.mem.Allocator) !Int {
const result = packBigInt(allocator, NUM_LIMBS, self.limbs, 128);
return result;
return packBigInt(allocator, NUM_LIMBS, self.limbs, 128);
}

pub fn pack86(self: *const Self, allocator: std.mem.Allocator) !Int {
Expand All @@ -80,9 +86,9 @@ pub fn BigIntN(comptime NUM_LIMBS: usize) type {
return result;
}

pub fn split(self: *Self, num: Int) Self {
const limbs = splitBigInt(std.mem.Allocator, num, self.limbs.len, 128);
return self.fromValues(limbs);

pub fn split(allocator: std.mem.Allocator, num: Int) !Self {
return Self.fromValues(try splitBigInt(allocator, num, NUM_LIMBS, 128));
}

// @TODO: implement from. It is dependent on split function.
Expand Down
21 changes: 21 additions & 0 deletions src/hint_processor/hint_processor_def.zig
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ const segments = @import("segments.zig");

const bigint_utils = @import("../hint_processor/builtin_hint_processor/secp/bigint_utils.zig");
const bigint = @import("bigint.zig");
const uint384 = @import("uint384.zig");
const inv_mod_p_uint512 = @import("vrf/inv_mod_p_uint512.zig");


const deserialize_utils = @import("../parser/deserialize_utils.zig");

Expand Down Expand Up @@ -384,6 +387,24 @@ pub const CairoVMHintProcessor = struct {
try bigint.bigintPackDivModHint(allocator, vm, exec_scopes, hint_data.ids_data, hint_data.ap_tracking);
} else if (std.mem.eql(u8, hint_codes.BIGINT_SAFE_DIV, hint_data.code)) {
try bigint.bigIntSafeDivHint(allocator, vm, exec_scopes, hint_data.ids_data, hint_data.ap_tracking);
} else if (std.mem.eql(u8, hint_codes.UINT384_UNSIGNED_DIV_REM, hint_data.code)) {
try uint384.uint384UnsignedDivRem(allocator, vm, hint_data.ids_data, hint_data.ap_tracking);
} else if (std.mem.eql(u8, hint_codes.UINT384_SPLIT_128, hint_data.code)) {
try uint384.uint384Split128(allocator, vm, hint_data.ids_data, hint_data.ap_tracking);
} else if (std.mem.eql(u8, hint_codes.ADD_NO_UINT384_CHECK, hint_data.code)) {
try uint384.addNoUint384Check(allocator, vm, hint_data.ids_data, hint_data.ap_tracking, constants);
} else if (std.mem.eql(u8, hint_codes.UINT384_SQRT, hint_data.code)) {
try uint384.uint384Sqrt(allocator, vm, hint_data.ids_data, hint_data.ap_tracking);
} else if (std.mem.eql(u8, hint_codes.UINT384_SIGNED_NN, hint_data.code)) {
try uint384.uint384SignedNn(allocator, vm, hint_data.ids_data, hint_data.ap_tracking);
} else if (std.mem.eql(u8, hint_codes.SUB_REDUCED_A_AND_REDUCED_B, hint_data.code)) {
try uint384.subReducedAAndReducedB(allocator, vm, hint_data.ids_data, hint_data.ap_tracking);
} else if (std.mem.eql(u8, hint_codes.UNSIGNED_DIV_REM_UINT768_BY_UINT384, hint_data.code)) {
try uint384.unsignedDivRemUint768ByUint384(allocator, vm, hint_data.ids_data, hint_data.ap_tracking);
} else if (std.mem.eql(u8, hint_codes.UNSIGNED_DIV_REM_UINT768_BY_UINT384_STRIPPED, hint_data.code)) {
try uint384.unsignedDivRemUint768ByUint384(allocator, vm, hint_data.ids_data, hint_data.ap_tracking);
} else if (std.mem.eql(u8, hint_codes.INV_MOD_P_UINT512, hint_data.code)) {
try inv_mod_p_uint512.invModPUint512(allocator, vm, hint_data.ids_data, hint_data.ap_tracking);
} else {
std.log.err("not implemented: {s}\n", .{hint_data.code});
return HintError.HintNotImplemented;
Expand Down
4 changes: 4 additions & 0 deletions src/hint_processor/testing_utils.zig
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ pub fn checkMemory(mem: *Memory, comptime rows: anytype) !void {
pub fn checkMemoryAddress(mem: *Memory, data: anytype) !void {
const expected = if (data[1].len == 2) MaybeRelocatable.fromRelocatable(Relocatable.init(data[1][0], data[1][1])) else MaybeRelocatable.fromInt(u256, data[1][0]);

errdefer {
std.log.err("failed expect: {any}, got: {any}\n", .{ expected, mem.get(Relocatable.init(data[0][0], data[0][1])) });
}

try std.testing.expectEqual(expected, mem.get(Relocatable.init(data[0][0], data[0][1])));
}

Expand Down
10 changes: 10 additions & 0 deletions src/hint_processor/uint256_utils.zig
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ const ApTracking = @import("../vm/types/programjson.zig").ApTracking;
const HintData = @import("hint_processor_def.zig").HintData;
const ExecutionScopes = @import("../vm/types/execution_scopes.zig").ExecutionScopes;

const Int = @import("std").math.big.int.Managed;
const helper = @import("../math/fields/helper.zig");
const MathError = @import("../vm/error.zig").MathError;
const HintError = @import("../vm/error.zig").HintError;
Expand Down Expand Up @@ -55,6 +56,15 @@ pub const Uint256 = struct {
pub fn split(comptime T: type, num: T) Self {
return Self.init(Felt252.fromInt(T, num & std.math.maxInt(u128)), Felt252.fromInt(T, num >> 128));
}

pub fn pack(self: Self, allocator: std.mem.Allocator) !Int {
var result = try Int.initSet(allocator, self.high.toInteger());
errdefer result.deinit();

try result.shiftLeft(&result, 128);
try result.addScalar(&result, self.low.toInteger());
return result;
}
// converting self to biguint value
// optimize by using biguint
// right now using u512, so to not use allocator with big int
Expand Down
Loading

0 comments on commit 28df66d

Please sign in to comment.