Skip to content

Commit

Permalink
msm/pippenger: use single digit decomposition
Browse files Browse the repository at this point in the history
Signed-off-by: Ignacio Hagopian <[email protected]>
  • Loading branch information
jsign committed Sep 29, 2023
1 parent 2bc6ebb commit 97f793e
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 27 deletions.
18 changes: 13 additions & 5 deletions src/bandersnatch/points/extended.zig
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ pub const ExtendedPointMSM = struct {
};
}

pub fn neg(p: ExtendedPointMSM) ExtendedPointMSM {
return ExtendedPointMSM{
.x = p.x.neg(),
.y = p.y,
.t = p.t.neg(),
};
}

pub fn fromExtendedPoint(p: ExtendedPoint) ExtendedPointMSM {
const z_inv = p.z.inv().?;
const x = p.x.mul(z_inv);
Expand Down Expand Up @@ -73,12 +81,12 @@ pub const ExtendedPoint = struct {
return comptime initUnsafe(gen.x, gen.y);
}

pub fn neg(self: ExtendedPoint) ExtendedPoint {
pub fn neg(p: ExtendedPoint) ExtendedPoint {
return ExtendedPoint{
.x = self.x.neg(),
.y = self.y,
.t = self.t.neg(),
.z = self.z,
.x = p.x.neg(),
.y = p.y,
.t = p.t.neg(),
.z = p.z,
};
}

Expand Down
5 changes: 5 additions & 0 deletions src/banderwagon/banderwagon.zig
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ test "two torsion" {
try std.testing.expect(result.equal(gen));
}

// TODO: rename and reference methods too.
pub const ElementNormalized = struct {
point: ExtendedPointMSM,

Expand Down Expand Up @@ -252,6 +253,10 @@ pub const ElementNormalized = struct {
return Element.fromElementNormalized(self).toBytes();
}

pub fn neg(self: ElementNormalized) ElementNormalized {
return ElementNormalized{ .point = ExtendedPointMSM.neg(self.point) };
}

// TODO: move this.
pub fn fromElements(result: []ElementNormalized, points: []const Element) void {
var accumulator = Fp.one();
Expand Down
2 changes: 1 addition & 1 deletion src/bench.zig
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ fn benchMultiproofs() !void {

const mproof = try multiproof.MultiProof.init(xcrs);
for (openings) |num_openings| {
std.debug.print("\tBenchmarking {} openings...", .{num_openings});
std.debug.print("\tBenchmarking {} openings... ", .{num_openings});

var accum_proving: i64 = 0;
var accum_verifying: i64 = 0;
Expand Down
2 changes: 1 addition & 1 deletion src/main.zig
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ test "crs" {
}

test "msm" {
// _ = @import("msm/precomp.zig");
_ = @import("msm/precomp.zig");
_ = @import("msm/pippenger.zig");
}

Expand Down
58 changes: 39 additions & 19 deletions src/msm/pippenger.zig
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@ const Fr = banderwagon.Fr;

pub fn Pippenger(comptime c: comptime_int) type {
return struct {
const window_mask = (1 << c) - 1;
const num_windows = std.math.divCeil(u8, Fr.BitSize, c) catch unreachable;
const num_buckets = (1 << c) - 1;
const num_buckets = 1 << (c - 1);

pub fn msm(base_allocator: Allocator, basis: []const ElementNormalized, scalars_mont: []const Fr) !Element {
std.debug.assert(basis.len >= scalars_mont.len);
Expand All @@ -18,33 +17,31 @@ pub fn Pippenger(comptime c: comptime_int) type {
defer arena.deinit();
var allocator = arena.allocator();

var scalars = try allocator.alloc(u256, scalars_mont.len);
for (0..scalars.len) |i| {
scalars[i] = scalars_mont[i].toInteger();
}
var scalars_windows = try signedDigitDecomposition(allocator, scalars_mont);

var result: ?Element = null;
var buckets: [num_buckets]?Element = std.mem.zeroes([num_buckets]?Element);
var scalar_windows = try allocator.alloc(u16, scalars.len);
for (0..num_windows) |w| {
// Partition scalars.
const w_idx = num_windows - w - 1;
for (0..scalars.len) |i| {
scalar_windows[i] = @as(u16, @intCast((scalars[i] >> @as(u8, @intCast(w_idx * c))) & window_mask));
}

// Accumulate in buckets.
for (0..buckets.len) |i| {
buckets[i] = null;
}
for (0..scalar_windows.len) |i| {
if (scalar_windows[i] == 0) {
for (0..scalars_mont.len) |i| {
var scalar_window = scalars_windows[i + w * scalars_mont.len];
if (scalar_window == 0) {
continue;
}
if (buckets[scalar_windows[i] - 1] == null) {
buckets[scalar_windows[i] - 1] = Element.identity();

var adj_basis: ElementNormalized = basis[i];
if (scalar_window < 0) {
adj_basis = ElementNormalized.neg(basis[i]);
scalar_window = -scalar_window;
}
buckets[scalar_windows[i] - 1] = Element.mixedMsmAdd(buckets[scalar_windows[i] - 1].?, basis[i]);
const bucket_idx = @as(usize, @intCast(scalar_window)) - 1;
if (buckets[bucket_idx] == null) {
buckets[bucket_idx] = Element.identity();
}
buckets[bucket_idx] = Element.mixedMsmAdd(buckets[bucket_idx].?, adj_basis);
}

// Aggregate buckets.
Expand Down Expand Up @@ -82,6 +79,29 @@ pub fn Pippenger(comptime c: comptime_int) type {

return result orelse Element.identity();
}

fn signedDigitDecomposition(arena: Allocator, scalars_mont: []const Fr) ![]i16 {
const window_mask = (1 << c) - 1;
var scalars_windows = try arena.alloc(i16, scalars_mont.len * num_windows);

for (0..scalars_mont.len) |i| {
const scalar = scalars_mont[i].toInteger();
var carry: u1 = 0;
for (0..num_windows) |j| {
const curr_window = @as(u16, @intCast((scalar >> @as(u8, @intCast(j * c))) & window_mask)) + carry;
carry = 0;
if (curr_window >= 1 << (c - 1)) {
std.debug.assert(j != num_windows - 1);
scalars_windows[(num_windows - 1 - j) * scalars_mont.len + i] = @as(i16, @intCast(curr_window)) - (1 << c);
carry = 1;
} else {
scalars_windows[(num_windows - 1 - j) * scalars_mont.len + i] = @as(i16, @intCast(curr_window));
}
}
}

return scalars_windows;
}
};
}

Expand All @@ -95,7 +115,7 @@ test "correctness" {
scalars[i] = Fr.fromInteger((i + 0x93434) *% 0x424242);
}

inline for (2..8) |c| {
inline for (3..8) |c| {
const pippenger = Pippenger(c);

for (1..crs.DomainSize) |msm_length| {
Expand Down
4 changes: 4 additions & 0 deletions src/msm/precomp.zig
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ const Element = banderwagon.Element;
const ElementNormalized = banderwagon.ElementNormalized;
const Fr = banderwagon.Fr;

// This implementation is based on:
// Faster Montgomery multiplication andMulti-Scalar-Multiplication for SNARKs
// https://tches.iacr.org/index.php/TCHES/article/view/10972/10279
// plus some extra tricks from Ignacio Hagopian.
pub fn PrecompMSM(
comptime _t: comptime_int,
comptime _b: comptime_int,
Expand Down
3 changes: 2 additions & 1 deletion src/multiproof/multiproof.zig
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,8 @@ pub const MultiProof = struct {
Cs[i] = query.C;
E_coefficients[i] = Fr.mul(powers_of_r[i], helper_scalar_den[queries[i].z]);
}
const E = try pippenger.Pippenger(8).msm(allocator, Cs, E_coefficients);
// TODO: make the window size be dynamically calculated.
const E = try pippenger.Pippenger(11).msm(allocator, Cs, E_coefficients);
transcript.appendPoint(E, "E");

// Check IPA proof.
Expand Down

0 comments on commit 97f793e

Please sign in to comment.