From 97f793e115b6267cddc94229a6f14342144830fb Mon Sep 17 00:00:00 2001 From: Ignacio Hagopian Date: Fri, 29 Sep 2023 16:04:18 -0300 Subject: [PATCH] msm/pippenger: use single digit decomposition Signed-off-by: Ignacio Hagopian --- src/bandersnatch/points/extended.zig | 18 ++++++--- src/banderwagon/banderwagon.zig | 5 +++ src/bench.zig | 2 +- src/main.zig | 2 +- src/msm/pippenger.zig | 58 +++++++++++++++++++--------- src/msm/precomp.zig | 4 ++ src/multiproof/multiproof.zig | 3 +- 7 files changed, 65 insertions(+), 27 deletions(-) diff --git a/src/bandersnatch/points/extended.zig b/src/bandersnatch/points/extended.zig index abc109c..696844f 100644 --- a/src/bandersnatch/points/extended.zig +++ b/src/bandersnatch/points/extended.zig @@ -25,6 +25,14 @@ pub const ExtendedPointMSM = struct { }; } + pub fn neg(p: ExtendedPointMSM) ExtendedPointMSM { + return ExtendedPointMSM{ + .x = p.x.neg(), + .y = p.y, + .t = p.t.neg(), + }; + } + pub fn fromExtendedPoint(p: ExtendedPoint) ExtendedPointMSM { const z_inv = p.z.inv().?; const x = p.x.mul(z_inv); @@ -73,12 +81,12 @@ pub const ExtendedPoint = struct { return comptime initUnsafe(gen.x, gen.y); } - pub fn neg(self: ExtendedPoint) ExtendedPoint { + pub fn neg(p: ExtendedPoint) ExtendedPoint { return ExtendedPoint{ - .x = self.x.neg(), - .y = self.y, - .t = self.t.neg(), - .z = self.z, + .x = p.x.neg(), + .y = p.y, + .t = p.t.neg(), + .z = p.z, }; } diff --git a/src/banderwagon/banderwagon.zig b/src/banderwagon/banderwagon.zig index 07c3808..8db4844 100644 --- a/src/banderwagon/banderwagon.zig +++ b/src/banderwagon/banderwagon.zig @@ -216,6 +216,7 @@ test "two torsion" { try std.testing.expect(result.equal(gen)); } +// TODO: rename and reference methods too. pub const ElementNormalized = struct { point: ExtendedPointMSM, @@ -252,6 +253,10 @@ pub const ElementNormalized = struct { return Element.fromElementNormalized(self).toBytes(); } + pub fn neg(self: ElementNormalized) ElementNormalized { + return ElementNormalized{ .point = ExtendedPointMSM.neg(self.point) }; + } + // TODO: move this. pub fn fromElements(result: []ElementNormalized, points: []const Element) void { var accumulator = Fp.one(); diff --git a/src/bench.zig b/src/bench.zig index 794f6b7..6974947 100644 --- a/src/bench.zig +++ b/src/bench.zig @@ -201,7 +201,7 @@ fn benchMultiproofs() !void { const mproof = try multiproof.MultiProof.init(xcrs); for (openings) |num_openings| { - std.debug.print("\tBenchmarking {} openings...", .{num_openings}); + std.debug.print("\tBenchmarking {} openings... ", .{num_openings}); var accum_proving: i64 = 0; var accum_verifying: i64 = 0; diff --git a/src/main.zig b/src/main.zig index 6e25b37..48c6dda 100644 --- a/src/main.zig +++ b/src/main.zig @@ -15,7 +15,7 @@ test "crs" { } test "msm" { - // _ = @import("msm/precomp.zig"); + _ = @import("msm/precomp.zig"); _ = @import("msm/pippenger.zig"); } diff --git a/src/msm/pippenger.zig b/src/msm/pippenger.zig index b9e7c78..5dae8c4 100644 --- a/src/msm/pippenger.zig +++ b/src/msm/pippenger.zig @@ -7,9 +7,8 @@ const Fr = banderwagon.Fr; pub fn Pippenger(comptime c: comptime_int) type { return struct { - const window_mask = (1 << c) - 1; const num_windows = std.math.divCeil(u8, Fr.BitSize, c) catch unreachable; - const num_buckets = (1 << c) - 1; + const num_buckets = 1 << (c - 1); pub fn msm(base_allocator: Allocator, basis: []const ElementNormalized, scalars_mont: []const Fr) !Element { std.debug.assert(basis.len >= scalars_mont.len); @@ -18,33 +17,31 @@ pub fn Pippenger(comptime c: comptime_int) type { defer arena.deinit(); var allocator = arena.allocator(); - var scalars = try allocator.alloc(u256, scalars_mont.len); - for (0..scalars.len) |i| { - scalars[i] = scalars_mont[i].toInteger(); - } + var scalars_windows = try signedDigitDecomposition(allocator, scalars_mont); var result: ?Element = null; var buckets: [num_buckets]?Element = std.mem.zeroes([num_buckets]?Element); - var scalar_windows = try allocator.alloc(u16, scalars.len); for (0..num_windows) |w| { - // Partition scalars. - const w_idx = num_windows - w - 1; - for (0..scalars.len) |i| { - scalar_windows[i] = @as(u16, @intCast((scalars[i] >> @as(u8, @intCast(w_idx * c))) & window_mask)); - } - // Accumulate in buckets. for (0..buckets.len) |i| { buckets[i] = null; } - for (0..scalar_windows.len) |i| { - if (scalar_windows[i] == 0) { + for (0..scalars_mont.len) |i| { + var scalar_window = scalars_windows[i + w * scalars_mont.len]; + if (scalar_window == 0) { continue; } - if (buckets[scalar_windows[i] - 1] == null) { - buckets[scalar_windows[i] - 1] = Element.identity(); + + var adj_basis: ElementNormalized = basis[i]; + if (scalar_window < 0) { + adj_basis = ElementNormalized.neg(basis[i]); + scalar_window = -scalar_window; } - buckets[scalar_windows[i] - 1] = Element.mixedMsmAdd(buckets[scalar_windows[i] - 1].?, basis[i]); + const bucket_idx = @as(usize, @intCast(scalar_window)) - 1; + if (buckets[bucket_idx] == null) { + buckets[bucket_idx] = Element.identity(); + } + buckets[bucket_idx] = Element.mixedMsmAdd(buckets[bucket_idx].?, adj_basis); } // Aggregate buckets. @@ -82,6 +79,29 @@ pub fn Pippenger(comptime c: comptime_int) type { return result orelse Element.identity(); } + + fn signedDigitDecomposition(arena: Allocator, scalars_mont: []const Fr) ![]i16 { + const window_mask = (1 << c) - 1; + var scalars_windows = try arena.alloc(i16, scalars_mont.len * num_windows); + + for (0..scalars_mont.len) |i| { + const scalar = scalars_mont[i].toInteger(); + var carry: u1 = 0; + for (0..num_windows) |j| { + const curr_window = @as(u16, @intCast((scalar >> @as(u8, @intCast(j * c))) & window_mask)) + carry; + carry = 0; + if (curr_window >= 1 << (c - 1)) { + std.debug.assert(j != num_windows - 1); + scalars_windows[(num_windows - 1 - j) * scalars_mont.len + i] = @as(i16, @intCast(curr_window)) - (1 << c); + carry = 1; + } else { + scalars_windows[(num_windows - 1 - j) * scalars_mont.len + i] = @as(i16, @intCast(curr_window)); + } + } + } + + return scalars_windows; + } }; } @@ -95,7 +115,7 @@ test "correctness" { scalars[i] = Fr.fromInteger((i + 0x93434) *% 0x424242); } - inline for (2..8) |c| { + inline for (3..8) |c| { const pippenger = Pippenger(c); for (1..crs.DomainSize) |msm_length| { diff --git a/src/msm/precomp.zig b/src/msm/precomp.zig index 70cfd0a..829d822 100644 --- a/src/msm/precomp.zig +++ b/src/msm/precomp.zig @@ -5,6 +5,10 @@ const Element = banderwagon.Element; const ElementNormalized = banderwagon.ElementNormalized; const Fr = banderwagon.Fr; +// This implementation is based on: +// Faster Montgomery multiplication andMulti-Scalar-Multiplication for SNARKs +// https://tches.iacr.org/index.php/TCHES/article/view/10972/10279 +// plus some extra tricks from Ignacio Hagopian. pub fn PrecompMSM( comptime _t: comptime_int, comptime _b: comptime_int, diff --git a/src/multiproof/multiproof.zig b/src/multiproof/multiproof.zig index 7b56585..54f6dbb 100644 --- a/src/multiproof/multiproof.zig +++ b/src/multiproof/multiproof.zig @@ -205,7 +205,8 @@ pub const MultiProof = struct { Cs[i] = query.C; E_coefficients[i] = Fr.mul(powers_of_r[i], helper_scalar_den[queries[i].z]); } - const E = try pippenger.Pippenger(8).msm(allocator, Cs, E_coefficients); + // TODO: make the window size be dynamically calculated. + const E = try pippenger.Pippenger(11).msm(allocator, Cs, E_coefficients); transcript.appendPoint(E, "E"); // Check IPA proof.