diff --git a/build.zig b/build.zig index 3f3cb246..2b2eb173 100644 --- a/build.zig +++ b/build.zig @@ -141,6 +141,7 @@ pub fn build(b: *std.Build) void { .target = target, .optimize = optimize, .filter = test_filter, + .single_threaded = false, }); // Add dependency modules to the tests. diff --git a/src/hint_processor/builtin_hint_codes.zig b/src/hint_processor/builtin_hint_codes.zig index 9a6dc9b2..9fdf9ba0 100644 --- a/src/hint_processor/builtin_hint_codes.zig +++ b/src/hint_processor/builtin_hint_codes.zig @@ -1483,3 +1483,117 @@ pub const PACK_MODN_DIV_MODN = \\s = pack(ids.s, PRIME) % N \\value = res = div_mod(x, s, N) ; + +pub const UINT384_GET_SQUARE_ROOT = + \\from starkware.python.math_utils import is_quad_residue, sqrt + \\ + \\def split(num: int, num_bits_shift: int = 128, length: int = 3): + \\ a = [] + \\ for _ in range(length): + \\ a.append( num & ((1 << num_bits_shift) - 1) ) + \\ num = num >> num_bits_shift + \\ return tuple(a) + \\ + \\def pack(z, num_bits_shift: int = 128) -> int: + \\ limbs = (z.d0, z.d1, z.d2) + \\ return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs)) + \\ + \\ + \\generator = pack(ids.generator) + \\x = pack(ids.x) + \\p = pack(ids.p) + \\ + \\success_x = is_quad_residue(x, p) + \\root_x = sqrt(x, p) if success_x else None + \\ + \\success_gx = is_quad_residue(generator*x, p) + \\root_gx = sqrt(generator*x, p) if success_gx else None + \\ + \\# Check that one is 0 and the other is 1 + \\if x != 0: + \\ assert success_x + success_gx ==1 + \\ + \\# `None` means that no root was found, but we need to transform these into a felt no matter what + \\if root_x == None: + \\ root_x = 0 + \\if root_gx == None: + \\ root_gx = 0 + \\ids.success_x = int(success_x) + \\ids.success_gx = int(success_gx) + \\split_root_x = split(root_x) + \\split_root_gx = split(root_gx) + \\ids.sqrt_x.d0 = split_root_x[0] + \\ids.sqrt_x.d1 = split_root_x[1] + \\ids.sqrt_x.d2 = split_root_x[2] + \\ids.sqrt_gx.d0 = split_root_gx[0] + \\ids.sqrt_gx.d1 = split_root_gx[1] + \\ids.sqrt_gx.d2 = split_root_gx[2] +; + +pub const UINT256_GET_SQUARE_ROOT = + \\from starkware.python.math_utils import is_quad_residue, sqrt + \\ + \\def split(a: int): + \\ return (a & ((1 << 128) - 1), a >> 128) + \\ + \\def pack(z) -> int: + \\ return z.low + (z.high << 128) + \\ + \\generator = pack(ids.generator) + \\x = pack(ids.x) + \\p = pack(ids.p) + \\ + \\success_x = is_quad_residue(x, p) + \\root_x = sqrt(x, p) if success_x else None + \\success_gx = is_quad_residue(generator*x, p) + \\root_gx = sqrt(generator*x, p) if success_gx else None + \\ + \\# Check that one is 0 and the other is 1 + \\if x != 0: + \\ assert success_x + success_gx == 1 + \\ + \\# `None` means that no root was found, but we need to transform these into a felt no matter what + \\if root_x == None: + \\ root_x = 0 + \\if root_gx == None: + \\ root_gx = 0 + \\ids.success_x = int(success_x) + \\ids.success_gx = int(success_gx) + \\split_root_x = split(root_x) + \\# print('split root x', split_root_x) + \\split_root_gx = split(root_gx) + \\ids.sqrt_x.low = split_root_x[0] + \\ids.sqrt_x.high = split_root_x[1] + \\ids.sqrt_gx.low = split_root_gx[0] + \\ids.sqrt_gx.high = split_root_gx[1] +; + +pub const UINT384_DIV = + \\from starkware.python.math_utils import div_mod + \\ + \\def split(num: int, num_bits_shift: int, length: int): + \\ a = [] + \\ for _ in range(length): + \\ a.append( num & ((1 << num_bits_shift) - 1) ) + \\ num = num >> num_bits_shift + \\ return tuple(a) + \\ + \\def pack(z, num_bits_shift: int) -> int: + \\ limbs = (z.d0, z.d1, z.d2) + \\ return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs)) + \\ + \\a = pack(ids.a, num_bits_shift = 128) + \\b = pack(ids.b, num_bits_shift = 128) + \\p = pack(ids.p, num_bits_shift = 128) + \\# For python3.8 and above the modular inverse can be computed as follows: + \\# b_inverse_mod_p = pow(b, -1, p) + \\# Instead we use the python3.7-friendly function div_mod from starkware.python.math_utils + \\b_inverse_mod_p = div_mod(1, b, p) + \\ + \\ + \\b_inverse_mod_p_split = split(b_inverse_mod_p, num_bits_shift=128, length=3) + \\ + \\ids.b_inverse_mod_p.d0 = b_inverse_mod_p_split[0] + \\ids.b_inverse_mod_p.d1 = b_inverse_mod_p_split[1] + \\ids.b_inverse_mod_p.d2 = b_inverse_mod_p_split[2] +; diff --git a/src/hint_processor/builtin_hint_processor/secp/bigint_utils.zig b/src/hint_processor/builtin_hint_processor/secp/bigint_utils.zig index e5c0f535..c6376dac 100644 --- a/src/hint_processor/builtin_hint_processor/secp/bigint_utils.zig +++ b/src/hint_processor/builtin_hint_processor/secp/bigint_utils.zig @@ -24,6 +24,7 @@ const hint_codes = @import("../../builtin_hint_codes.zig"); pub const BigInt3 = BigIntN(3); pub const Uint384 = BigIntN(3); +pub const Uint256 = BigIntN(2); pub const Uint512 = BigIntN(4); pub const BigInt5 = BigIntN(5); pub const Uint768 = BigIntN(6); @@ -216,7 +217,6 @@ test "Get BigInt3 from base address with missing member should fail" { defer vm.segments.memory.deinitData(std.testing.allocator); - try std.testing.expectError(HintError.IdentifierHasNoMember, BigInt3.fromBaseAddr(Relocatable{ .segment_index = 0, .offset = 0 }, &vm)); } @@ -234,7 +234,6 @@ test "Get BigInt5 from base address with missing member should fail" { defer vm.segments.memory.deinitData(std.testing.allocator); - try std.testing.expectError(HintError.IdentifierHasNoMember, BigInt5.fromBaseAddr(Relocatable{ .segment_index = 0, .offset = 0 }, &vm)); } @@ -331,7 +330,6 @@ test "BigIntUtils: get bigint5 from var name with missing member should fail" { var ids_data = try testing_utils.setupIdsForTestWithoutMemory(std.testing.allocator, &.{"x"}); defer ids_data.deinit(); - try std.testing.expectError(HintError.IdentifierHasNoMember, BigInt5.fromVarName("x", &vm, ids_data, .{})); } diff --git a/src/hint_processor/field_arithmetic.zig b/src/hint_processor/field_arithmetic.zig new file mode 100644 index 00000000..29a1a968 --- /dev/null +++ b/src/hint_processor/field_arithmetic.zig @@ -0,0 +1,507 @@ +const std = @import("std"); +const CairoVM = @import("../vm/core.zig").CairoVM; +const HintReference = @import("hint_processor_def.zig").HintReference; +const HintProcessor = @import("hint_processor_def.zig").CairoVMHintProcessor; +const Felt252 = @import("../math/fields/starknet.zig").Felt252; +const Relocatable = @import("../vm/memory/relocatable.zig").Relocatable; +const MaybeRelocatable = @import("../vm/memory/relocatable.zig").MaybeRelocatable; +const ApTracking = @import("../vm/types/programjson.zig").ApTracking; +const HintData = @import("hint_processor_def.zig").HintData; +const ExecutionScopes = @import("../vm/types/execution_scopes.zig").ExecutionScopes; + +const BigInt3 = @import("builtin_hint_processor/secp/bigint_utils.zig").BigInt3; +const Uint384 = @import("builtin_hint_processor/secp/bigint_utils.zig").Uint384; +const BigInt5 = @import("builtin_hint_processor/secp/bigint_utils.zig").BigInt5; +const BigIntN = @import("builtin_hint_processor/secp/bigint_utils.zig").BigIntN; + +const MathError = @import("../vm/error.zig").MathError; +const HintError = @import("../vm/error.zig").HintError; +const CairoVMError = @import("../vm/error.zig").CairoVMError; + +const Int = @import("std").math.big.int.Managed; +const BASE = @import("../math/fields/constants.zig").BASE; + +const hint_codes = @import("builtin_hint_codes.zig"); +const hint_utils = @import("hint_utils.zig"); +const testing_utils = @import("testing_utils.zig"); +const field_helper = @import("../math/fields/helper.zig"); +const safeDivBigInt = @import("../math/fields/helper.zig").safeDivBigInt; + +// Implements Hint: +// %{ +// from starkware.python.math_utils import is_quad_residue, sqrt + +// def split(num: int, num_bits_shift: int = 128, length: int = 3): +// a = [] +// for _ in range(length): +// a.append( num & ((1 << num_bits_shift) - 1) ) +// num = num >> num_bits_shift +// return tuple(a) + +// def pack(z, num_bits_shift: int = 128) -> int: +// limbs = (z.d0, z.d1, z.d2) +// return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs)) + +// generator = pack(ids.generator) +// x = pack(ids.x) +// p = pack(ids.p) + +// success_x = is_quad_residue(x, p) +// root_x = sqrt(x, p) if success_x else None + +// success_gx = is_quad_residue(generator*x, p) +// root_gx = sqrt(generator*x, p) if success_gx else None + +// # Check that one is 0 and the other is 1 +// if x != 0: +// assert success_x + success_gx ==1 + +// # `None` means that no root was found, but we need to transform these into a felt no matter what +// if root_x == None: +// root_x = 0 +// if root_gx == None: +// root_gx = 0 +// ids.success_x = int(success_x) +// ids.success_gx = int(success_gx) +// split_root_x = split(root_x) +// split_root_gx = split(root_gx) +// ids.sqrt_x.d0 = split_root_x[0] +// ids.sqrt_x.d1 = split_root_x[1] +// ids.sqrt_x.d2 = split_root_x[2] +// ids.sqrt_gx.d0 = split_root_gx[0] +// ids.sqrt_gx.d1 = split_root_gx[1] +// ids.sqrt_gx.d2 = split_root_gx[2] +// %} +pub fn u384GetSquareRoot( + allocator: std.mem.Allocator, + vm: *CairoVM, + ids_data: std.StringHashMap(HintReference), + ap_tracking: ApTracking, +) !void { + try bigIntIntGetSquareRoot(allocator, vm, ids_data, ap_tracking, 3); +} + +// Implements Hint: +// %{ +// from starkware.python.math_utils import is_quad_residue, sqrt + +// def split(a: int): +// return (a & ((1 << 128) - 1), a >> 128) + +// def pack(z) -> int: +// return z.low + (z.high << 128) + +// generator = pack(ids.generator) +// x = pack(ids.x) +// p = pack(ids.p) + +// success_x = is_quad_residue(x, p) +// root_x = sqrt(x, p) if success_x else None +// success_gx = is_quad_residue(generator*x, p) +// root_gx = sqrt(generator*x, p) if success_gx else None + +// # Check that one is 0 and the other is 1 +// if x != 0: +// assert success_x + success_gx == 1 + +// # `None` means that no root was found, but we need to transform these into a felt no matter what +// if root_x == None: +// root_x = 0 +// if root_gx == None: +// root_gx = 0 +// ids.success_x = int(success_x) +// ids.success_gx = int(success_gx) +// split_root_x = split(root_x) +// # print('split root x', split_root_x) +// split_root_gx = split(root_gx) +// ids.sqrt_x.low = split_root_x[0] +// ids.sqrt_x.high = split_root_x[1] +// ids.sqrt_gx.low = split_root_gx[0] +// ids.sqrt_gx.high = split_root_gx[1] +// %} +pub fn u256GetSquareRoot( + allocator: std.mem.Allocator, + vm: *CairoVM, + ids_data: std.StringHashMap(HintReference), + ap_tracking: ApTracking, +) !void { + try bigIntIntGetSquareRoot(allocator, vm, ids_data, ap_tracking, 2); +} + +pub fn bigIntIntGetSquareRoot( + allocator: std.mem.Allocator, + vm: *CairoVM, + ids_data: std.StringHashMap(HintReference), + ap_tracking: ApTracking, + comptime NUM_LIMBS: usize, +) !void { + var generator = try (try BigIntN(NUM_LIMBS).fromVarName("generator", vm, ids_data, ap_tracking)).pack(allocator); + defer generator.deinit(); + + var x = try (try BigIntN(NUM_LIMBS).fromVarName("x", vm, ids_data, ap_tracking)).pack(allocator); + defer x.deinit(); + var p = try (try BigIntN(NUM_LIMBS).fromVarName("p", vm, ids_data, ap_tracking)).pack(allocator); + defer p.deinit(); + + const success_x = try field_helper.isQuadResidue(allocator, x, p); + + var root_x = if (success_x) + (try field_helper.sqrtPrimePower(allocator, x, p)) orelse try Int.initSet(allocator, 0) + else + try Int.initSet(allocator, 0); + defer root_x.deinit(); + + var gx = try Int.init(allocator); + defer gx.deinit(); + + try gx.mul(&generator, &x); + + const success_gx = try field_helper.isQuadResidue(allocator, gx, p); + + var root_gx = if (success_gx) + (try field_helper.sqrtPrimePower(allocator, gx, p)) orelse try Int.initSet(allocator, 0) + else + try Int.initSet(allocator, 0); + defer root_gx.deinit(); + + if (!x.eqlZero() and success_x == success_gx) + return HintError.AssertionFailed; + + try hint_utils.insertValueFromVarName( + allocator, + "success_x", + MaybeRelocatable.fromInt(u8, if (success_x) 1 else 0), + vm, + ids_data, + ap_tracking, + ); + try hint_utils.insertValueFromVarName( + allocator, + "success_gx", + MaybeRelocatable.fromInt(u8, if (success_gx) 1 else 0), + vm, + ids_data, + ap_tracking, + ); + + try (try BigIntN(NUM_LIMBS).split(allocator, root_x)).insertFromVarName(allocator, "sqrt_x", vm, ids_data, ap_tracking); + try (try BigIntN(NUM_LIMBS).split(allocator, root_gx)).insertFromVarName(allocator, "sqrt_gx", vm, ids_data, ap_tracking); +} + +// Implements Hint: +// %{ +// from starkware.python.math_utils import div_mod + +// def split(num: int, num_bits_shift: int, length: int): +// a = [] +// for _ in range(length): +// a.append( num & ((1 << num_bits_shift) - 1) ) +// num = num >> num_bits_shift +// return tuple(a) + +// def pack(z, num_bits_shift: int) -> int: +// limbs = (z.d0, z.d1, z.d2) +// return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs)) + +// a = pack(ids.a, num_bits_shift = 128) +// b = pack(ids.b, num_bits_shift = 128) +// p = pack(ids.p, num_bits_shift = 128) +// # For python3.8 and above the modular inverse can be computed as follows: +// # b_inverse_mod_p = pow(b, -1, p) +// # Instead we use the python3.7-friendly function div_mod from starkware.python.math_utils +// b_inverse_mod_p = div_mod(1, b, p) + +// b_inverse_mod_p_split = split(b_inverse_mod_p, num_bits_shift=128, length=3) + +// ids.b_inverse_mod_p.d0 = b_inverse_mod_p_split[0] +// ids.b_inverse_mod_p.d1 = b_inverse_mod_p_split[1] +// ids.b_inverse_mod_p.d2 = b_inverse_mod_p_split[2] +// %} +pub fn uint384Div( + allocator: std.mem.Allocator, + vm: *CairoVM, + ids_data: std.StringHashMap(HintReference), + ap_tracking: ApTracking, +) !void { + // Note: ids.a is not used here, nor is it used by following hints, so we dont need to extract it. + var b = try (try Uint384.fromVarName("b", vm, ids_data, ap_tracking)).pack(allocator); + defer b.deinit(); + var p = try (try Uint384.fromVarName("p", vm, ids_data, ap_tracking)).pack(allocator); + defer p.deinit(); + + if (b.eqlZero()) + return MathError.DividedByZero; + + var tmp = try Int.init(allocator); + defer tmp.deinit(); + + var b_inverse_mod_p = try field_helper.mulInv(allocator, b, p); + defer b_inverse_mod_p.deinit(); + + try tmp.divFloor(&b_inverse_mod_p, &b_inverse_mod_p, &p); + + b_inverse_mod_p.abs(); + + const b_inverse_mod_p_split = try Uint384.split(allocator, b_inverse_mod_p); + try b_inverse_mod_p_split.insertFromVarName(allocator, "b_inverse_mod_p", vm, ids_data, ap_tracking); +} + +test "FieldArithmetic: run u384 getSquareOk goldilocks prime" { + var vm = try testing_utils.initVMWithRangeCheck(std.testing.allocator); + defer vm.deinit(); + defer vm.segments.memory.deinitData(std.testing.allocator); + + //Initialize fp + vm.run_context.fp.* = 14; + //Create hint_data + var ids_data = try testing_utils.setupIdsNonContinuousIdsData(std.testing.allocator, &.{ + .{ "p", -14 }, + .{ "x", -11 }, + .{ "generator", -8 }, + .{ "sqrt_x", -5 }, + .{ "sqrt_gx", -2 }, + .{ "success_x", 1 }, + .{ "success_gx", 2 }, + }); + defer ids_data.deinit(); + //Insert ids into memory + + try vm.segments.memory.setUpMemory(std.testing.allocator, .{ + // p + .{ .{ 1, 0 }, .{18446744069414584321} }, + .{ .{ 1, 1 }, .{0} }, + .{ .{ 1, 2 }, .{0} }, + + //x + .{ .{ 1, 3 }, .{25} }, + .{ .{ 1, 4 }, .{0} }, + .{ .{ 1, 5 }, .{0} }, + + //generator + .{ .{ 1, 6 }, .{7} }, + .{ .{ 1, 7 }, .{0} }, + .{ .{ 1, 8 }, .{0} }, + }); + + //Execute the hint + try testing_utils.runHint(std.testing.allocator, &vm, ids_data, hint_codes.UINT384_GET_SQUARE_ROOT, undefined, undefined); + //Check hint memory inserts + try testing_utils.checkMemory(vm.segments.memory, .{ + // sqrt_x + .{ .{ 1, 9 }, .{5} }, + .{ .{ 1, 10 }, .{0} }, + .{ .{ 1, 11 }, .{0} }, + // sqrt_gx + .{ .{ 1, 12 }, .{0} }, + .{ .{ 1, 13 }, .{0} }, + .{ .{ 1, 14 }, .{0} }, + // success_x + .{ .{ 1, 15 }, .{1} }, + // success_gx + .{ .{ 1, 16 }, .{0} }, + }); +} + +test "FieldArithmetic: run u384 getSquareOk success gx" { + var vm = try testing_utils.initVMWithRangeCheck(std.testing.allocator); + defer vm.deinit(); + defer vm.segments.memory.deinitData(std.testing.allocator); + + //Initialize fp + vm.run_context.fp.* = 14; + //Create hint_data + var ids_data = try testing_utils.setupIdsNonContinuousIdsData(std.testing.allocator, &.{ + .{ "p", -14 }, + .{ "x", -11 }, + .{ "generator", -8 }, + .{ "sqrt_x", -5 }, + .{ "sqrt_gx", -2 }, + .{ "success_x", 1 }, + .{ "success_gx", 2 }, + }); + defer ids_data.deinit(); + //Insert ids into memory + + try vm.segments.memory.setUpMemory(std.testing.allocator, .{ + // p + .{ .{ 1, 0 }, .{3} }, + .{ .{ 1, 1 }, .{0} }, + .{ .{ 1, 2 }, .{0} }, + + //x + .{ .{ 1, 3 }, .{17} }, + .{ .{ 1, 4 }, .{0} }, + .{ .{ 1, 5 }, .{0} }, + + //generator + .{ .{ 1, 6 }, .{71} }, + .{ .{ 1, 7 }, .{0} }, + .{ .{ 1, 8 }, .{0} }, + }); + + //Execute the hint + try testing_utils.runHint(std.testing.allocator, &vm, ids_data, hint_codes.UINT384_GET_SQUARE_ROOT, undefined, undefined); + //Check hint memory inserts + try testing_utils.checkMemory(vm.segments.memory, .{ + // sqrt_x + .{ .{ 1, 9 }, .{0} }, + .{ .{ 1, 10 }, .{0} }, + .{ .{ 1, 11 }, .{0} }, + // sqrt_gx + .{ .{ 1, 12 }, .{1} }, + .{ .{ 1, 13 }, .{0} }, + .{ .{ 1, 14 }, .{0} }, + // success_x + .{ .{ 1, 15 }, .{0} }, + // success_gx + .{ .{ 1, 16 }, .{1} }, + }); +} + +test "FieldArithmetic: run u384 getSquareOk no successes" { + var vm = try testing_utils.initVMWithRangeCheck(std.testing.allocator); + defer vm.deinit(); + defer vm.segments.memory.deinitData(std.testing.allocator); + + //Initialize fp + vm.run_context.fp.* = 14; + //Create hint_data + var ids_data = try testing_utils.setupIdsNonContinuousIdsData(std.testing.allocator, &.{ + .{ "p", -14 }, + .{ "x", -11 }, + .{ "generator", -8 }, + .{ "sqrt_x", -5 }, + .{ "sqrt_gx", -2 }, + .{ "success_x", 1 }, + .{ "success_gx", 2 }, + }); + defer ids_data.deinit(); + //Insert ids into memory + + try vm.segments.memory.setUpMemory(std.testing.allocator, .{ + // p + .{ .{ 1, 0 }, .{3} }, + .{ .{ 1, 1 }, .{0} }, + .{ .{ 1, 2 }, .{0} }, + + //x + .{ .{ 1, 3 }, .{17} }, + .{ .{ 1, 4 }, .{0} }, + .{ .{ 1, 5 }, .{0} }, + + //generator + .{ .{ 1, 6 }, .{1} }, + .{ .{ 1, 7 }, .{0} }, + .{ .{ 1, 8 }, .{0} }, + }); + + //Execute the hint + try std.testing.expectError( + HintError.AssertionFailed, + testing_utils.runHint( + std.testing.allocator, + &vm, + ids_data, + hint_codes.UINT256_GET_SQUARE_ROOT, + undefined, + undefined, + ), + ); +} + +test "FieldArithmetic: run u384 div ok" { + var vm = try testing_utils.initVMWithRangeCheck(std.testing.allocator); + defer vm.deinit(); + defer vm.segments.memory.deinitData(std.testing.allocator); + + //Initialize fp + vm.run_context.fp.* = 11; + //Create hint_data + var ids_data = try testing_utils.setupIdsNonContinuousIdsData(std.testing.allocator, &.{ + .{ "a", -11 }, + .{ "b", -8 }, + .{ "p", -5 }, + .{ "b_inverse_mod_p", -2 }, + }); + defer ids_data.deinit(); + //Insert ids into memory + + try vm.segments.memory.setUpMemory(std.testing.allocator, .{ + // a + .{ .{ 1, 0 }, .{25} }, + .{ .{ 1, 1 }, .{0} }, + .{ .{ 1, 2 }, .{0} }, + + //b + .{ .{ 1, 3 }, .{5} }, + .{ .{ 1, 4 }, .{0} }, + .{ .{ 1, 5 }, .{0} }, + + //p + .{ .{ 1, 6 }, .{31} }, + .{ .{ 1, 7 }, .{0} }, + .{ .{ 1, 8 }, .{0} }, + }); + + //Execute the hint + try testing_utils.runHint( + std.testing.allocator, + &vm, + ids_data, + hint_codes.UINT384_DIV, + undefined, + undefined, + ); + //Check hint memory inserts + try testing_utils.checkMemory(vm.segments.memory, .{ + // b_inverse_mod_p + .{ .{ 1, 9 }, .{25} }, + .{ .{ 1, 10 }, .{0} }, + .{ .{ 1, 11 }, .{0} }, + }); +} + +test "FieldArithmetic: run u384 div b is zero" { + var vm = try testing_utils.initVMWithRangeCheck(std.testing.allocator); + defer vm.deinit(); + defer vm.segments.memory.deinitData(std.testing.allocator); + + //Initialize fp + vm.run_context.fp.* = 11; + //Create hint_data + var ids_data = try testing_utils.setupIdsNonContinuousIdsData(std.testing.allocator, &.{ + .{ "a", -11 }, + .{ "b", -8 }, + .{ "p", -5 }, + .{ "b_inverse_mod_p", -2 }, + }); + defer ids_data.deinit(); + //Insert ids into memory + + try vm.segments.memory.setUpMemory(std.testing.allocator, .{ + // a + .{ .{ 1, 0 }, .{25} }, + .{ .{ 1, 1 }, .{0} }, + .{ .{ 1, 2 }, .{0} }, + + //b + .{ .{ 1, 3 }, .{0} }, + .{ .{ 1, 4 }, .{0} }, + .{ .{ 1, 5 }, .{0} }, + + //p + .{ .{ 1, 6 }, .{31} }, + .{ .{ 1, 7 }, .{0} }, + .{ .{ 1, 8 }, .{0} }, + }); + + //Execute the hint + try std.testing.expectError(MathError.DividedByZero, testing_utils.runHint( + std.testing.allocator, + &vm, + ids_data, + hint_codes.UINT384_DIV, + undefined, + undefined, + )); +} diff --git a/src/hint_processor/hint_processor_def.zig b/src/hint_processor/hint_processor_def.zig index b4d79ce4..de9f379c 100644 --- a/src/hint_processor/hint_processor_def.zig +++ b/src/hint_processor/hint_processor_def.zig @@ -56,6 +56,8 @@ const print_utils = @import("./print.zig"); const testing_utils = @import("testing_utils.zig"); const blake2s_utils = @import("blake2s_utils.zig"); +const field_arithmetic = @import("field_arithmetic.zig"); + const HintError = @import("../vm/error.zig").HintError; const expect = std.testing.expect; @@ -850,6 +852,27 @@ pub const CairoVMHintProcessor = struct { hint_data.ids_data, hint_data.ap_tracking, ); + } else if (std.mem.eql(u8, hint_codes.UINT384_GET_SQUARE_ROOT, hint_data.code)) { + try field_arithmetic.u384GetSquareRoot( + allocator, + vm, + hint_data.ids_data, + hint_data.ap_tracking, + ); + } else if (std.mem.eql(u8, hint_codes.UINT256_GET_SQUARE_ROOT, hint_data.code)) { + try field_arithmetic.u256GetSquareRoot( + allocator, + vm, + hint_data.ids_data, + hint_data.ap_tracking, + ); + } else if (std.mem.eql(u8, hint_codes.UINT384_DIV, hint_data.code)) { + try field_arithmetic.uint384Div( + allocator, + vm, + hint_data.ids_data, + hint_data.ap_tracking, + ); } else { std.log.err("not implemented: {s}\n", .{hint_data.code}); return HintError.HintNotImplemented; diff --git a/src/integration_tests.zig b/src/integration_tests.zig index c3eb9aac..1b999894 100644 --- a/src/integration_tests.zig +++ b/src/integration_tests.zig @@ -59,10 +59,8 @@ pub fn main() !void { .{ .pathname = "cairo_programs/ec_op.json", .layout = "all_cairo" }, .{ .pathname = "cairo_programs/ec_recover.json", .layout = "all_cairo" }, .{ .pathname = "cairo_programs/ed25519_ec.json", .layout = "all_cairo" }, - // TODO: HintNotImplemented error field arithmetic - // .{ .pathname = "cairo_programs/ed25519_field.json", .layout = "all_cairo" }, - // TODO: field arithmetic - // .{ .pathname = "cairo_programs/efficient_secp256r1_ec.json", .layout = "all_cairo" }, + .{ .pathname = "cairo_programs/ed25519_field.json", .layout = "all_cairo" }, + .{ .pathname = "cairo_programs/efficient_secp256r1_ec.json", .layout = "all_cairo" }, // TODO: sha256 // .{ .pathname = "cairo_programs/example_blake2s.json", .layout = "all_cairo" }, .{ .pathname = "cairo_programs/example_program.json", .layout = "all_cairo" }, @@ -72,8 +70,7 @@ pub fn main() !void { .{ .pathname = "cairo_programs/fast_ec_add_v3.json", .layout = "all_cairo" }, .{ .pathname = "cairo_programs/fibonacci.json", .layout = "plain" }, - // TODO: field arithmetic - // .{ .pathname = "cairo_programs/field_arithmetic.json", .layout = "all_cairo" }, + .{ .pathname = "cairo_programs/field_arithmetic.json", .layout = "all_cairo" }, .{ .pathname = "cairo_programs/finalize_blake2s.json", .layout = "all_cairo" }, .{ .pathname = "cairo_programs/finalize_blake2s_v2_hint.json", .layout = "all_cairo" }, .{ .pathname = "cairo_programs/find_element.json", .layout = "all_cairo" }, @@ -174,7 +171,6 @@ pub fn main() !void { // .{ .pathname = "cairo_programs/sha256_test.json", .layout = "all_cairo" }, // .{ .pathname = "cairo_programs/sha256.json", .layout = "all_cairo" }, - // TODO: error .{ .pathname = "cairo_programs/signature.json", .layout = "all_cairo" }, .{ .pathname = "cairo_programs/signed_div_rem.json", .layout = "all_cairo" }, .{ .pathname = "cairo_programs/simple_print.json", .layout = "all_cairo" }, diff --git a/src/lib.zig b/src/lib.zig index 84756412..c4fa0ccf 100644 --- a/src/lib.zig +++ b/src/lib.zig @@ -89,6 +89,8 @@ pub const hint_processor = struct { pub usingnamespace @import("hint_processor/vrf/fq.zig"); pub usingnamespace @import("hint_processor/vrf/pack.zig"); pub usingnamespace @import("hint_processor/builtin_hint_processor/secp/signature.zig"); + + pub usingnamespace @import("hint_processor/field_arithmetic.zig"); }; pub const parser = struct { diff --git a/src/math/fields/helper.zig b/src/math/fields/helper.zig index 3f340d11..b4ce0277 100644 --- a/src/math/fields/helper.zig +++ b/src/math/fields/helper.zig @@ -26,29 +26,32 @@ pub fn multiplyModulus(a: u512, b: u512, modulus: u512) u512 { } pub fn multiplyModulusBigInt(allocator: std.mem.Allocator, a: Int, b: Int, modulus: Int) !Int { - var value = try Int.init(allocator); - errdefer value.deinit(); - - var tmp = try Int.init(allocator); - defer tmp.deinit(); - - try tmp.mul(&a, &b); + var result = try Int.init(allocator); + errdefer result.deinit(); - try tmp.divFloor(&value, &tmp, &modulus); + try multiplyModulusBigIntWithPtr(allocator, a, b, modulus, &result); - return value; + return result; } -pub fn multiplyModulusBigIntWithPtr(allocator: std.mem.Allocator, a: Int, b: Int, modulus: Int, value: *Int) !void { +pub fn multiplyModulusBigIntWithPtr(allocator: std.mem.Allocator, a: Int, b: Int, modulus: Int, result: *Int) !void { var tmp = try Int.init(allocator); defer tmp.deinit(); try tmp.mul(&a, &b); - try tmp.divFloor(value, &tmp, &modulus); + try tmp.divFloor(result, &tmp, &modulus); } pub fn powModulusBigInt(allocator: std.mem.Allocator, b: Int, e: Int, modulus: Int) !Int { + var result = try Int.initSet(allocator, 0); + errdefer result.deinit(); + + try powModulusBigIntWithPtr(allocator, b, e, modulus, &result); + return result; +} + +pub fn powModulusBigIntWithPtr(allocator: std.mem.Allocator, b: Int, e: Int, modulus: Int, result: *Int) !void { var base = try b.clone(); defer base.deinit(); @@ -61,11 +64,8 @@ pub fn powModulusBigInt(allocator: std.mem.Allocator, b: Int, e: Int, modulus: I var tmp2 = try Int.initSet(allocator, 1); defer tmp2.deinit(); - var result = try Int.initSet(allocator, 0); - errdefer result.deinit(); - if (modulus.eql(tmp)) - return result; + return; try tmp.divFloor(&base, &base, &modulus); @@ -76,15 +76,13 @@ pub fn powModulusBigInt(allocator: std.mem.Allocator, b: Int, e: Int, modulus: I try tmp.bitAnd(&exponent, &tmp); if (tmp.eql(tmp2)) { - try multiplyModulusBigIntWithPtr(allocator, result, base, modulus, &result); + try multiplyModulusBigIntWithPtr(allocator, result.*, base, modulus, result); } try multiplyModulusBigIntWithPtr(allocator, base, base, modulus, &base); try exponent.shiftRight(&exponent, 1); } - - return result; } pub fn powModulus(b: u512, e: u512, modulus: u512) u512 { @@ -115,6 +113,150 @@ pub fn legendre(a: u512, p: u512) u512 { return powModulus(a, (p - 1) / 2, p); } +pub fn legendreBigIntWithPtr(allocator: std.mem.Allocator, a: Int, p: Int, result: *Int) !void { + var tmp = try p.clone(); + defer tmp.deinit(); + var tmp2 = try Int.initSet(allocator, 2); + defer tmp2.deinit(); + + try tmp.addScalar(&tmp, -1); + try tmp.divFloor(&tmp2, &tmp, &tmp2); + + try powModulusBigIntWithPtr(allocator, a, tmp, p, result); +} + +pub fn legendreBigInt(allocator: std.mem.Allocator, a: Int, p: Int) !Int { + var tmp = try p.clone(); + defer tmp.deinit(); + var tmp2 = try Int.initSet(allocator, 2); + + try tmp.addScalar(&tmp, -1); + try tmp.divFloor(&tmp2, &tmp, &tmp2); + + return powModulusBigInt(allocator, a, tmp, p); +} + +pub fn tonelliShanksBigInt(allocator: std.mem.Allocator, n: Int, p: Int) !struct { Int, Int, bool } { + var arena = std.heap.ArenaAllocator.init(allocator); + defer arena.deinit(); + + var result: struct { Int, Int, bool } = undefined; + + inline for (0..2) |i| { + result[i] = try Int.init(allocator); + errdefer { + inline for (0..i) |j| result[j].deinit(); + } + } + errdefer { + inline for (0..2) |i| result[i].deinit(); + } + result[2] = false; + + var tmp = try legendreBigInt(arena.allocator(), n, p); + + var tmp2 = try Int.initSet(arena.allocator(), 1); + + if (!tmp.eql(tmp2)) + return result; + + result[2] = true; + + // Factor out powers of 2 from p - 1 + var q = try p.cloneWithDifferentAllocator(arena.allocator()); + + try q.addScalar(&p, -1); + + var s = try Int.initSet(arena.allocator(), 0); + + try tmp2.set(2); + while (q.isEven()) { + try q.divFloor(&tmp, &q, &tmp2); + try s.addScalar(&s, 1); + } + + try tmp2.set(1); + + var tmp3 = try Int.init(arena.allocator()); + + if (s.eql(tmp2)) { + try tmp3.set(4); + try tmp2.addScalar(&p, 1); + try tmp2.divFloor( + &tmp, + &tmp2, + &tmp3, + ); + const res = try powModulusBigInt(arena.allocator(), n, tmp2, p); + + try result[0].copy(res.toConst()); + try result[1].sub(&p, &res); + + result[2] = true; + return result; + } + + // Find a non-square z such as ( z | p ) = -1 + var z = try Int.initSet(arena.allocator(), 2); + + try legendreBigIntWithPtr(allocator, z, p, &tmp); + try tmp2.addScalar(&p, -1); + while (!tmp.eql(tmp2)) { + try z.addScalar(&z, 1); + + try legendreBigIntWithPtr(allocator, z, p, &tmp); + } + + var c = try powModulusBigInt(arena.allocator(), z, q, p); + var t = try powModulusBigInt(arena.allocator(), n, q, p); + var m = try s.clone(); + + try tmp.addScalar(&q, 1); + + try tmp.shiftRight(&tmp, 1); + + var res = try powModulusBigInt(arena.allocator(), n, tmp, p); + + try tmp3.set(1); + + var i = try Int.initSet(arena.allocator(), 1); + var b = try Int.initSet(arena.allocator(), 0); + + while (!t.eql(tmp3)) { + try i.set(1); + + try multiplyModulusBigIntWithPtr(arena.allocator(), t, t, p, &z); + + try tmp2.addScalar(&m, -1); + + while (!z.eql(tmp3) and i.order(tmp2).compare(.lt)) { + try i.addScalar(&i, 1); + + try multiplyModulusBigIntWithPtr(arena.allocator(), z, z, p, &z); + } + + try tmp2.set(1); + try tmp.sub(&m, &i); + try tmp.addScalar(&tmp, -1); + + try b.shiftLeft(&tmp2, try tmp.to(usize)); + + try powModulusBigIntWithPtr(arena.allocator(), c, b, p, &b); + try multiplyModulusBigIntWithPtr(arena.allocator(), b, b, p, &c); + try multiplyModulusBigIntWithPtr(arena.allocator(), t, c, p, &t); + + try m.copy(i.toConst()); + + try multiplyModulusBigIntWithPtr(arena.allocator(), res, b, p, &res); + } + + try result[0].copy(res.toConst()); + try result[1].sub(&p, &res); + result[2] = true; + + return result; +} + pub fn tonelliShanks(n: u512, p: u512) struct { u512, u512, bool } { if (legendre(n, p) != 1) { return .{ 0, 0, false }; @@ -325,7 +467,7 @@ pub fn safeDivBigInt(x: i512, y: i512) !i512 { return result[0]; } -pub fn isPrime(allocator: std.mem.Allocator, n: Int) !bool { +pub fn isPrimeU64(allocator: std.mem.Allocator, n: Int) !bool { var n_c = try n.clone(); defer n_c.deinit(); @@ -354,6 +496,75 @@ pub fn isPrime(allocator: std.mem.Allocator, n: Int) !bool { return false; } +pub fn isPrime(allocator: std.mem.Allocator, n: Int) !bool { + var tmp = try Int.initSet(allocator, 2); + defer tmp.deinit(); + var tmp1 = try Int.initSet(allocator, 3); + defer tmp1.deinit(); + var tmp2 = try Int.initSet(allocator, 5); + defer tmp2.deinit(); + + if (n.order(tmp).compare(.lt)) return false; + + if (n.eql(tmp) or n.eql(tmp1) or n.eql(tmp2)) return true; + + try tmp.divFloor(&tmp1, &n, &tmp); + if (tmp1.eqlZero()) + return false; + + var n_sub = try Int.init(allocator); + defer n_sub.deinit(); + + try n_sub.addScalar(&n, -1); + + var exponent = try n_sub.clone(); + defer exponent.deinit(); + + const trials = try trailingZeroesBigInt(exponent); + try exponent.shiftRight(&exponent, trials); + + const buf = try n.toString(allocator, 10, .lower); + defer allocator.free(buf); + + for (1..(buf.len + 2)) |i| { + try tmp.set(2); + try tmp.addScalar(&tmp, i); + try powModulusBigIntWithPtr(allocator, tmp, exponent, n, &tmp); + + try tmp1.set(1); + if (tmp.eql(tmp1) or tmp.eql(n_sub)) continue; + + var flag = false; + for (1..trials) |_| { + try tmp.mul(&tmp, &tmp); + try tmp1.divFloor(&tmp, &tmp, &n); + + try tmp1.set(1); + if (tmp.eql(tmp1)) return false; + + if (tmp.eql(n_sub)) { + flag = true; + break; + } + } + + if (flag) continue; + + return false; + } + + return true; +} + +pub fn trailingZeroesBigInt(n: Int) !usize { + const i: usize = for (0.., n.limbs) |i, digit| { + if (digit != 0) break i; + } else 0; + + const zeros = @ctz(n.limbs[i]); + return i * @bitSizeOf(std.math.big.Limb) + zeros; +} + // Ported from sympy implementation // Simplified as a & p are nonnegative // Asumes p is a prime number @@ -369,7 +580,9 @@ pub fn isQuadResidue(allocator: std.mem.Allocator, a: Int, p: Int) !bool { errdefer a_c.deinit(); try tmp.divFloor(&a_c, &a, &p); break :blk a_c; - } else try a.clone(); + } else blk: { + break :blk try a.clone(); + }; defer a_new.deinit(); var tmp2 = try Int.initSet(allocator, 3); @@ -383,12 +596,188 @@ pub fn isQuadResidue(allocator: std.mem.Allocator, a: Int, p: Int) !bool { try tmp.addScalar(&p, -1); try tmp2.set(2); - var result = try powModulusBigInt(allocator, tmp, tmp2, p); + try tmp.divFloor(&tmp2, &tmp, &tmp2); + + try powModulusBigIntWithPtr(allocator, a_new, tmp, p, &tmp); + + try tmp2.set(1); + + return tmp.eql(tmp2); +} + +// Adapted from sympy _sqrt_prime_power with k == 1 +pub fn sqrtPrimePower(allocator: std.mem.Allocator, a: Int, p: Int) !?Int { + if (p.eqlZero() or !(try isPrime(allocator, p))) { + return null; + } + + var result = try Int.init(allocator); + errdefer result.deinit(); + + var tmp = try Int.init(allocator); + defer tmp.deinit(); + + var tmp1 = try Int.init(allocator); + defer tmp1.deinit(); + var tmp2 = try Int.init(allocator); + defer tmp2.deinit(); + + var two = try Int.initSet(allocator, 2); + defer two.deinit(); + + try tmp.divFloor(&result, &a, &p); + if (p.eql(two)) + return result; + + try tmp.addScalar(&p, -1); + try tmp2.divFloor(&tmp1, &tmp, &two); + + try powModulusBigIntWithPtr(allocator, result, tmp2, p, &tmp); + try tmp2.set(1); + + if (!(a.order(two).compare(.lt) or tmp.eql(tmp2))) { + result.deinit(); + return null; + } + + try tmp1.set(4); + + try tmp.divFloor(&tmp2, &p, &tmp1); + try tmp.set(3); + + if (tmp2.eql(tmp)) { + try tmp.addScalar(&p, 1); + try tmp2.set(4); + try tmp1.divFloor(&tmp2, &tmp, &tmp2); + try powModulusBigIntWithPtr(allocator, result, tmp1, p, &result); + try tmp.sub(&p, &result); + + if (result.order(tmp).compare(.gt)) { + try result.copy(tmp.toConst()); + } + + return result; + } + + try tmp2.set(8); + try tmp.divFloor(&tmp1, &p, &tmp2); + try tmp2.set(5); + + if (tmp1.eql(tmp2)) { + try tmp.addScalar(&p, -1); + try tmp1.set(4); + try tmp.divFloor(&tmp1, &tmp, &tmp1); + + try powModulusBigIntWithPtr(allocator, result, tmp, p, &tmp1); + + try tmp.set(1); + + // tmp1 is sign + if (tmp1.eql(tmp)) { + try tmp.addScalar(&p, 3); + try tmp1.set(8); + + try tmp.divFloor(&tmp1, &tmp, &tmp1); + + try powModulusBigIntWithPtr(allocator, result, tmp, p, &result); + + try tmp.sub(&p, &result); + + if (result.order(tmp).compare(.gt)) { + try result.copy(tmp.toConst()); + } + + return result; + } else { + try tmp1.addScalar(&p, -5); + try tmp.set(8); + try tmp.divFloor(&tmp1, &tmp1, &tmp); + + try tmp2.set(4); + try tmp2.mul(&tmp2, &result); + + try powModulusBigIntWithPtr(allocator, tmp2, tmp, p, &tmp); + + // b==tmp + try tmp1.mul(&result, &tmp); + try tmp2.set(2); + try tmp1.mul(&tmp1, &tmp2); + try tmp.divFloor(&tmp2, &tmp1, &p); + + // x==tmp2 + try powModulusBigIntWithPtr(allocator, tmp2, two, p, &tmp); + if (tmp.eql(result)) { + try result.copy(tmp2.toConst()); + return result; + } + } + } defer result.deinit(); - try tmp.set(1); + var val1, var val2, const succ = try tonelliShanksBigInt(allocator, result, p); + if (!succ) { + return null; + } + + if (val1.order(val2).compare(.lt)) { + val2.deinit(); + return val1; + } - return result.eql(tmp); + val1.deinit(); + return val2; +} + +///Returns num_a^-1 mod p +pub fn mulInv(allocator: std.mem.Allocator, num_a: Int, p: Int) !Int { + var result = try Int.initSet(allocator, 0); + errdefer result.deinit(); + + if (num_a.eqlZero()) + return result; + + var a = try num_a.clone(); + defer a.deinit(); + a.abs(); + + var x_sign = blk: { + var res = try Int.initSet(allocator, 0); + errdefer res.deinit(); + if (!num_a.eqlZero()) if (num_a.isPositive()) try res.set(1) else try res.set(-1); + break :blk res; + }; + defer x_sign.deinit(); + + var b = try p.clone(); + defer b.deinit(); + b.abs(); + + var x = try Int.initSet(allocator, 1); + defer x.deinit(); + var r = try Int.initSet(allocator, 0); + defer r.deinit(); + + var c = try Int.initSet(allocator, 0); + defer c.deinit(); + var q = try Int.initSet(allocator, 0); + defer q.deinit(); + + var tmp = try Int.init(allocator); + defer tmp.deinit(); + + while (!b.eqlZero()) { + try q.divFloor(&c, &a, &b); + + try result.mul(&q, &r); + try x.sub(&x, &result); + std.mem.swap(Int, &r, &x); + try tmp.copy(b.toConst()); + try b.copy(c.toConst()); + try a.copy(tmp.toConst()); + } + + try result.mul(&x, &x_sign); + return result; } test "Helper: extendedGCD big" { @@ -411,3 +800,66 @@ test "Helper: extendedGCD big" { try std.testing.expectEqual(result.x, try res2.x.to(i512)); try std.testing.expectEqual(result.y, try res2.y.to(i512)); } + +test "Helper: tonelli-shanks ok" { + const val = tonelliShanks(2, 113); + + var n = try Int.initSet(std.testing.allocator, 2); + defer n.deinit(); + var p = try Int.initSet(std.testing.allocator, 113); + defer p.deinit(); + + var val2 = try tonelliShanksBigInt(std.testing.allocator, n, p); + defer { + inline for (0..2) |i| val2[i].deinit(); + } + + try std.testing.expectEqual(val[0], try val2[0].to(u512)); + try std.testing.expectEqual(val[1], try val2[1].to(u512)); + try std.testing.expectEqual(val[2], val2[2]); +} + +test "Helper: SqrtPrimePower" { + var n = try Int.initSet(std.testing.allocator, 25); + defer n.deinit(); + var p = try Int.initSet(std.testing.allocator, 18446744069414584321); + defer p.deinit(); + + var result = (try sqrtPrimePower(std.testing.allocator, n, p)).?; + defer result.deinit(); + + try std.testing.expect(try result.to(u8) == 5); +} + +test "Helper: SqrtPrimePower p is zero" { + var n = try Int.initSet(std.testing.allocator, 1); + defer n.deinit(); + var p = try Int.initSet(std.testing.allocator, 0); + defer p.deinit(); + + try std.testing.expect(try sqrtPrimePower(std.testing.allocator, n, p) == null); +} + +test "Helper: SqrtPrimePower mod 8 is 5 sign not one" { + var n = try Int.initSet(std.testing.allocator, 676); + defer n.deinit(); + var p = try Int.initSet(std.testing.allocator, 9956234341095173); + defer p.deinit(); + + var result = (try sqrtPrimePower(std.testing.allocator, n, p)).?; + defer result.deinit(); + + try std.testing.expectEqual(try result.to(u64), 9956234341095147); +} + +test "Helper: SqrtPrimePower mod 8 is 5 sign is one" { + var n = try Int.initSet(std.testing.allocator, 130283432663); + defer n.deinit(); + var p = try Int.initSet(std.testing.allocator, 743900351477); + defer p.deinit(); + + var result = (try sqrtPrimePower(std.testing.allocator, n, p)).?; + defer result.deinit(); + + try std.testing.expectEqual(try result.to(u64), 123538694848); +}