Skip to content

Commit

Permalink
feat: construct sumcheck proofs (PROOF-913) (#198)
Browse files Browse the repository at this point in the history
construct sumcheck proofs
  • Loading branch information
rnburn authored Oct 31, 2024
1 parent 31f5800 commit f32d350
Show file tree
Hide file tree
Showing 4 changed files with 282 additions and 0 deletions.
26 changes: 26 additions & 0 deletions sxt/proof/sumcheck/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,32 @@ sxt_cc_component(
],
)

sxt_cc_component(
name = "proof_computation",
impl_deps = [
":driver",
":transcript_utility",
"//sxt/execution/async:future",
"//sxt/base/error:assert",
"//sxt/base/num:ceil_log2",
"//sxt/execution/async:coroutine",
"//sxt/scalar25/type:element",
],
test_deps = [
":cpu_driver",
"//sxt/base/test:unit_test",
"//sxt/execution/async:future",
"//sxt/execution/schedule:scheduler",
"//sxt/proof/transcript",
"//sxt/scalar25/operation:overload",
"//sxt/scalar25/type:element",
],
deps = [
"//sxt/base/container:span",
"//sxt/execution/async:future_fwd",
],
)

sxt_cc_component(
name = "verification",
impl_deps = [
Expand Down
67 changes: 67 additions & 0 deletions sxt/proof/sumcheck/proof_computation.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
*
* Copyright 2024-present Space and Time Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "sxt/proof/sumcheck/proof_computation.h"

#include "sxt/base/error/assert.h"
#include "sxt/base/num/ceil_log2.h"
#include "sxt/execution/async/coroutine.h"
#include "sxt/execution/async/future.h"
#include "sxt/proof/sumcheck/driver.h"
#include "sxt/proof/sumcheck/transcript_utility.h"
#include "sxt/scalar25/type/element.h"

namespace sxt::prfsk {
//--------------------------------------------------------------------------------------------------
// prove_sum
//--------------------------------------------------------------------------------------------------
xena::future<> prove_sum(basct::span<s25t::element> polynomials,
basct::span<s25t::element> evaluation_point, prft::transcript& transcript,
const driver& drv, basct::cspan<s25t::element> mles,
basct::cspan<std::pair<s25t::element, unsigned>> product_table,
basct::cspan<unsigned> product_terms, unsigned n) noexcept {
SXT_RELEASE_ASSERT(0 < n);
auto num_variables = std::max(basn::ceil_log2(n), 1);
auto polynomial_length = polynomials.size() / num_variables;
auto num_mles = mles.size() / n;
SXT_RELEASE_ASSERT(
// clang-format off
polynomial_length > 1 &&
polynomials.size() == num_variables * polynomial_length &&
mles.size() == n * num_mles
// clang-format on
);

init_transcript(transcript, num_variables, polynomial_length - 1);

auto ws = co_await drv.make_workspace(mles, product_table, product_terms, n);

for (unsigned round_index = 0; round_index < num_variables; ++round_index) {
auto polynomial = polynomials.subspan(round_index * polynomial_length, polynomial_length);

// compute the round polynomial
co_await drv.sum(polynomial, *ws);

// draw the next random challenge
s25t::element r;
round_challenge(r, transcript, polynomial);
evaluation_point[round_index] = r;

// fold the polynomial
co_await drv.fold(*ws, r);
}
}
} // namespace sxt::prfsk
42 changes: 42 additions & 0 deletions sxt/proof/sumcheck/proof_computation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
*
* Copyright 2024-present Space and Time Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <utility>

#include "sxt/base/container/span.h"
#include "sxt/execution/async/future_fwd.h"

namespace sxt::prft {
class transcript;
}
namespace sxt::s25t {
class element;
}

namespace sxt::prfsk {
class driver;

//--------------------------------------------------------------------------------------------------
// prove_sum
//--------------------------------------------------------------------------------------------------
xena::future<> prove_sum(basct::span<s25t::element> polynomials,
basct::span<s25t::element> evaluation_point, prft::transcript& transcript,
const driver& drv, basct::cspan<s25t::element> mles,
basct::cspan<std::pair<s25t::element, unsigned>> product_table,
basct::cspan<unsigned> product_terms, unsigned n) noexcept;
} // namespace sxt::prfsk
147 changes: 147 additions & 0 deletions sxt/proof/sumcheck/proof_computation.t.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
*
* Copyright 2024-present Space and Time Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "sxt/proof/sumcheck/proof_computation.h"

#include <utility>
#include <vector>

#include "sxt/base/test/unit_test.h"
#include "sxt/execution/async/future.h"
#include "sxt/execution/schedule/scheduler.h"
#include "sxt/proof/sumcheck/cpu_driver.h"
#include "sxt/proof/transcript/transcript.h"
#include "sxt/scalar25/operation/overload.h"
#include "sxt/scalar25/type/element.h"
#include "sxt/scalar25/type/literal.h"

using namespace sxt;
using namespace sxt::prfsk;
using s25t::operator""_s25;

TEST_CASE("we can create a sumcheck proof") {
prft::transcript transcript{"abc"};
cpu_driver drv;
std::vector<s25t::element> polynomials(2);
std::vector<s25t::element> evaluation_point(1);
std::vector<s25t::element> mles = {
0x8_s25,
0x3_s25,
};
std::vector<std::pair<s25t::element, unsigned>> product_table = {
{0x1_s25, 1},
};
std::vector<unsigned> product_terms = {0};

SECTION("we can prove a sum with n=1") {
auto fut = prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table,
product_terms, 1);
xens::get_scheduler().run();
REQUIRE(fut.ready());
REQUIRE(polynomials[0] == mles[0]);
REQUIRE(polynomials[1] == -mles[0]);
}

SECTION("we can prove a sum with a single variable") {
auto fut = prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table,
product_terms, 2);
xens::get_scheduler().run();
REQUIRE(fut.ready());
REQUIRE(polynomials[0] == mles[0]);
REQUIRE(polynomials[1] == mles[1] - mles[0]);
}

SECTION("we can prove a sum degree greater than 1") {
product_table = {
{0x1_s25, 2},
};
product_terms = {0, 0};
polynomials.resize(3);
auto fut = prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table,
product_terms, 2);
xens::get_scheduler().run();
REQUIRE(fut.ready());
REQUIRE(polynomials[0] == mles[0] * mles[0]);
REQUIRE(polynomials[1] == 0x2_s25 * (mles[1] - mles[0]) * mles[0]);
REQUIRE(polynomials[2] == (mles[1] - mles[0]) * (mles[1] - mles[0]));
}

SECTION("we can prove a sum with multiple MLEs") {
product_table = {
{0x1_s25, 2},
};
product_terms = {0, 1};
polynomials.resize(3);
mles.push_back(0x7_s25);
mles.push_back(0x10_s25);
auto fut = prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table,
product_terms, 2);
xens::get_scheduler().run();
REQUIRE(fut.ready());
REQUIRE(polynomials[0] == mles[0] * mles[2]);
REQUIRE(polynomials[1] == (mles[1] - mles[0]) * mles[2] + (mles[3] - mles[2]) * mles[0]);
REQUIRE(polynomials[2] == (mles[1] - mles[0]) * (mles[3] - mles[2]));
}

SECTION("we can prove a sum where the term multiplier is different from one") {
product_table[0].first = 0x2_s25;
auto fut = prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table,
product_terms, 2);
xens::get_scheduler().run();
REQUIRE(fut.ready());
REQUIRE(polynomials[0] == 0x2_s25 * mles[0]);
REQUIRE(polynomials[1] == 0x2_s25 * (mles[1] - mles[0]));
}

SECTION("we can prove a sum with two variables") {
mles.push_back(0x4_s25);
mles.push_back(0x7_s25);
polynomials.resize(4);
evaluation_point.resize(2);
auto fut = prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table,
product_terms, 4);
xens::get_scheduler().run();
REQUIRE(fut.ready());
REQUIRE(polynomials[0] == mles[0] + mles[1]);
REQUIRE(polynomials[1] == (mles[2] - mles[0]) + (mles[3] - mles[1]));

auto r = evaluation_point[0];
mles[0] = mles[0] * (0x1_s25 - r) + mles[2] * r;
mles[1] = mles[1] * (0x1_s25 - r) + mles[3] * r;

REQUIRE(polynomials[2] == mles[0]);
REQUIRE(polynomials[3] == mles[1] - mles[0]);
}

SECTION("we can prove a sum with n=3") {
mles.push_back(0x4_s25);
polynomials.resize(4);
evaluation_point.resize(2);
auto fut = prove_sum(polynomials, evaluation_point, transcript, drv, mles, product_table,
product_terms, 3);
xens::get_scheduler().run();
REQUIRE(fut.ready());
REQUIRE(polynomials[0] == mles[0] + mles[1]);
REQUIRE(polynomials[1] == (mles[2] - mles[0]) - mles[1]);

auto r = evaluation_point[0];
mles[0] = mles[0] * (0x1_s25 - r) + mles[2] * r;
mles[1] = mles[1] * (0x1_s25 - r);

REQUIRE(polynomials[2] == mles[0]);
REQUIRE(polynomials[3] == mles[1] - mles[0]);
}
}

0 comments on commit f32d350

Please sign in to comment.