Skip to content

Commit

Permalink
feat: add sumcheck operations for cpu (PROOF-913) (#197)
Browse files Browse the repository at this point in the history
add sumcheck operations for cpu
  • Loading branch information
rnburn authored Oct 30, 2024
1 parent 1e0df4a commit 31f5800
Show file tree
Hide file tree
Showing 11 changed files with 518 additions and 0 deletions.
42 changes: 42 additions & 0 deletions sxt/proof/sumcheck/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,48 @@ load(
"sxt_cc_component",
)

sxt_cc_component(
name = "workspace",
with_test = False,
)

sxt_cc_component(
name = "driver",
with_test = False,
deps = [
":workspace",
"//sxt/base/container:span",
"//sxt/execution/async:future_fwd",
],
)

sxt_cc_component(
name = "cpu_driver",
impl_deps = [
":polynomial_utility",
"//sxt/base/container:stack_array",
"//sxt/base/error:panic",
"//sxt/base/num:ceil_log2",
"//sxt/execution/async:future",
"//sxt/memory/management:managed_array",
"//sxt/scalar25/operation:mul",
"//sxt/scalar25/operation:sub",
"//sxt/scalar25/operation:muladd",
"//sxt/scalar25/type:element",
"//sxt/scalar25/type:literal",
],
test_deps = [
"//sxt/execution/async:future",
"//sxt/scalar25/operation:overload",
"//sxt/scalar25/type:element",
"//sxt/scalar25/type:literal",
],
deps = [
":driver",
":workspace",
],
)

sxt_cc_component(
name = "transcript_utility",
impl_deps = [
Expand Down
159 changes: 159 additions & 0 deletions sxt/proof/sumcheck/cpu_driver.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
*
* Copyright 2024-present Space and Time Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "sxt/proof/sumcheck/cpu_driver.h"

#include <algorithm>
#include <iostream>

#include "sxt/base/container/stack_array.h"
#include "sxt/base/error/panic.h"
#include "sxt/base/num/ceil_log2.h"
#include "sxt/execution/async/future.h"
#include "sxt/memory/management/managed_array.h"
#include "sxt/proof/sumcheck/polynomial_utility.h"
#include "sxt/scalar25/operation/mul.h"
#include "sxt/scalar25/operation/muladd.h"
#include "sxt/scalar25/operation/sub.h"
#include "sxt/scalar25/type/element.h"
#include "sxt/scalar25/type/literal.h"

namespace sxt::prfsk {
//--------------------------------------------------------------------------------------------------
// cpu_workspace
//--------------------------------------------------------------------------------------------------
namespace {
struct cpu_workspace final : public workspace {
memmg::managed_array<s25t::element> mles;
basct::cspan<std::pair<s25t::element, unsigned>> product_table;
basct::cspan<unsigned> product_terms;
unsigned n;
unsigned num_variables;
};
} // namespace

//--------------------------------------------------------------------------------------------------
// make_workspace
//--------------------------------------------------------------------------------------------------
xena::future<std::unique_ptr<workspace>>
cpu_driver::make_workspace(basct::cspan<s25t::element> mles,
basct::cspan<std::pair<s25t::element, unsigned>> product_table,
basct::cspan<unsigned> product_terms, unsigned n) const noexcept {
auto res = std::make_unique<cpu_workspace>();
res->mles = memmg::managed_array<s25t::element>{mles.begin(), mles.end()};
res->product_table = product_table;
res->product_terms = product_terms;
res->n = n;
res->num_variables = std::max(basn::ceil_log2(n), 1);
return xena::make_ready_future<std::unique_ptr<workspace>>(std::move(res));
}

//--------------------------------------------------------------------------------------------------
// sum
//--------------------------------------------------------------------------------------------------
xena::future<> cpu_driver::sum(basct::span<s25t::element> polynomial,
workspace& ws) const noexcept {
auto& work = static_cast<cpu_workspace&>(ws);
auto n = work.n;
auto mid = 1u << (work.num_variables - 1u);
SXT_RELEASE_ASSERT(work.n >= mid);

auto mles = work.mles.data();
auto product_table = work.product_table;
auto product_terms = work.product_terms;

for (auto& val : polynomial) {
val = {};
}

// expand paired terms
auto n1 = work.n - mid;
for (unsigned i = 0; i < n1; ++i) {
unsigned term_first = 0;
for (auto [mult, num_terms] : product_table) {
SXT_RELEASE_ASSERT(num_terms < polynomial.size());
auto terms = product_terms.subspan(term_first, num_terms);
SXT_STACK_ARRAY(p, num_terms + 1u, s25t::element);
expand_products(p, mles + i, n, mid, terms);
for (unsigned term_index = 0; term_index < p.size(); ++term_index) {
s25o::muladd(polynomial[term_index], mult, p[term_index], polynomial[term_index]);
}
term_first += num_terms;
}
}

// expand terms where the corresponding pair is zero (i.e. n is not a power of 2)
for (unsigned i = n1; i < mid; ++i) {
unsigned term_first = 0;
for (auto [mult, num_terms] : product_table) {
auto terms = product_terms.subspan(term_first, num_terms);
SXT_STACK_ARRAY(p, num_terms + 1u, s25t::element);
partial_expand_products(p, mles + i, n, terms);
for (unsigned term_index = 0; term_index < p.size(); ++term_index) {
s25o::muladd(polynomial[term_index], mult, p[term_index], polynomial[term_index]);
}
term_first += num_terms;
}
}

return xena::make_ready_future();
}

//--------------------------------------------------------------------------------------------------
// fold
//--------------------------------------------------------------------------------------------------
xena::future<> cpu_driver::fold(workspace& ws, const s25t::element& r) const noexcept {
using s25t::operator""_s25;

auto& work = static_cast<cpu_workspace&>(ws);
auto n = work.n;
auto mid = 1u << (work.num_variables - 1u);
auto num_mles = work.mles.size() / n;
SXT_RELEASE_ASSERT(
// clang-format off
work.n >= mid && work.mles.size() % n == 0
// clang-format on
);

auto mles = work.mles.data();
s25t::element one_m_r = 0x1_s25;
s25o::sub(one_m_r, one_m_r, r);
auto n1 = work.n - mid;
for (auto mle_index = 0; mle_index < num_mles; ++mle_index) {
auto data = mles + n * mle_index;

// fold paired terms
for (unsigned i = 0; i < n1; ++i) {
auto val = data[i];
s25o::mul(val, val, one_m_r);
s25o::muladd(val, r, data[mid + i], val);
data[i] = val;
}

// fold terms paired with zero
for (unsigned i = n1; i < mid; ++i) {
auto val = data[i];
s25o::mul(val, val, one_m_r);
data[i] = val;
}
}

work.n = mid;
--work.num_variables;
work.mles.shrink(num_mles * mid);
return xena::make_ready_future();
}
} // namespace sxt::prfsk
37 changes: 37 additions & 0 deletions sxt/proof/sumcheck/cpu_driver.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
*
* Copyright 2024-present Space and Time Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include "sxt/proof/sumcheck/driver.h"

namespace sxt::prfsk {
//--------------------------------------------------------------------------------------------------
// cpu_driver
//--------------------------------------------------------------------------------------------------
class cpu_driver final : public driver {
public:
// driver
xena::future<std::unique_ptr<workspace>>
make_workspace(basct::cspan<s25t::element> mles,
basct::cspan<std::pair<s25t::element, unsigned>> product_table,
basct::cspan<unsigned> product_terms, unsigned n) const noexcept override;

xena::future<> sum(basct::span<s25t::element> polynomial, workspace& ws) const noexcept override;

xena::future<> fold(workspace& ws, const s25t::element& r) const noexcept override;
};
} // namespace sxt::prfsk
116 changes: 116 additions & 0 deletions sxt/proof/sumcheck/cpu_driver.t.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
*
* Copyright 2024-present Space and Time Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "sxt/proof/sumcheck/cpu_driver.h"

#include <vector>

#include "sxt/base/test/unit_test.h"
#include "sxt/execution/async/future.h"
#include "sxt/proof/sumcheck/workspace.h"
#include "sxt/scalar25/operation/overload.h"
#include "sxt/scalar25/type/element.h"
#include "sxt/scalar25/type/literal.h"

using namespace sxt;
using namespace sxt::prfsk;
using s25t::operator""_s25;

TEST_CASE("we can perform the primitive operations for sumcheck proofs") {
std::vector<s25t::element> mles;
std::vector<std::pair<s25t::element, unsigned>> product_table{
{0x1_s25, 1},
};
std::vector<unsigned> product_terms = {0};

std::vector<s25t::element> p(2);
cpu_driver drv;

SECTION("we can sum a polynomial with n = 1") {
std::vector<s25t::element> mles = {0x123_s25};
auto ws = drv.make_workspace(mles, product_table, product_terms, 1).value();
auto fut = drv.sum(p, *ws);
REQUIRE(fut.ready());
REQUIRE(p[0] == mles[0]);
REQUIRE(p[1] == -mles[0]);
}

SECTION("we can sum a polynomial with a non-unity multiplier") {
std::vector<s25t::element> mles = {0x123_s25};
product_table[0].first = 0x2_s25;
auto ws = drv.make_workspace(mles, product_table, product_terms, 1).value();
auto fut = drv.sum(p, *ws);
REQUIRE(fut.ready());
REQUIRE(p[0] == 0x2_s25 * mles[0]);
REQUIRE(p[1] == -0x2_s25 * mles[0]);
}

SECTION("we can sum a polynomial with n = 2") {
std::vector<s25t::element> mles = {0x123_s25, 0x456_s25};
auto ws = drv.make_workspace(mles, product_table, product_terms, 2).value();
auto fut = drv.sum(p, *ws);
REQUIRE(fut.ready());
REQUIRE(p[0] == mles[0]);
REQUIRE(p[1] == mles[1] - mles[0]);
}

SECTION("we can sum a polynomial with two MLEs added together") {
std::vector<s25t::element> mles = {0x123_s25, 0x456_s25};
std::vector<std::pair<s25t::element, unsigned>> product_table{
{0x1_s25, 1},
{0x1_s25, 1},
};
std::vector<unsigned> product_terms = {0, 1};

auto ws = drv.make_workspace(mles, product_table, product_terms, 1).value();
auto fut = drv.sum(p, *ws);
REQUIRE(fut.ready());
REQUIRE(p[0] == mles[0] + mles[1]);
REQUIRE(p[1] == -mles[0] - mles[1]);
}

SECTION("we can sum a polynomial with two MLEs multiplied together") {
std::vector<s25t::element> mles = {0x123_s25, 0x456_s25};
std::vector<std::pair<s25t::element, unsigned>> product_table{
{0x1_s25, 2},
};
std::vector<unsigned> product_terms = {0, 1};
p.resize(3);

auto ws = drv.make_workspace(mles, product_table, product_terms, 1).value();
auto fut = drv.sum(p, *ws);
REQUIRE(fut.ready());
REQUIRE(p[0] == mles[0] * mles[1]);
REQUIRE(p[1] == -mles[0] * mles[1] - mles[1] * mles[0]);
REQUIRE(p[2] == mles[0] * mles[1]);
}

SECTION("we can fold mles") {
std::vector<s25t::element> mles = {0x123_s25, 0x456_s25, 0x789_s25};
auto ws = drv.make_workspace(mles, product_table, product_terms, 3).value();
auto r = 0xabc123_s25;
auto fut = drv.fold(*ws, r);
REQUIRE(fut.ready());
fut = drv.sum(p, *ws);
REQUIRE(fut.ready());

mles[0] = (0x1_s25 - r) * mles[0] + r * mles[2];
mles[1] = (0x1_s25 - r) * mles[1];

REQUIRE(p[0] == mles[0]);
REQUIRE(p[1] == mles[1] - mles[0]);
}
}
17 changes: 17 additions & 0 deletions sxt/proof/sumcheck/driver.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
*
* Copyright 2024-present Space and Time Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "sxt/proof/sumcheck/driver.h"
Loading

0 comments on commit 31f5800

Please sign in to comment.