Skip to content

Commit

Permalink
feat: add c bindings for variable length multi-exponentiations (PROOF…
Browse files Browse the repository at this point in the history
…-896) (#167)

* add variable length multiexponentiation

* add stub for c api

* add c api stub

* fill in c bridge

* reformat

* fill in tests

* fill in backend

* fill in cpu backend

* refactor

* fill in cbindings

* add no output tests

* fill in tests

* doc

* minor changes
  • Loading branch information
rnburn authored Aug 16, 2024
1 parent 69974db commit 31ee6a6
Show file tree
Hide file tree
Showing 14 changed files with 304 additions and 0 deletions.
29 changes: 29 additions & 0 deletions cbindings/blitzar_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,35 @@ void sxt_fixed_packed_multiexponentiation(void* res, const struct sxt_multiexp_h
const unsigned* output_bit_table, unsigned num_outputs,
unsigned n, const uint8_t* scalars);

/**
* Compute a varying lengthing multiexponentiation of scalars in packed format using a handle to
* pre-specified generators.
*
* On completion `res` contains an array of size `num_outputs` for the multiexponentiation
* of the given `scalars` array.
*
* An entry output_bit_table[output_index] specifies the number of scalar bits used for
* output_index and output_lengths[output_index] specifies the length used for output_index.
*
* Note: output_lengths must be sorted in ascending order
*
* Put
* bit_sum = sum_{output_index} output_bit_table[output_index]
* and let num_bytes denote the smallest integer greater than or equal to bit_sum that is a
* multiple of 8.
*
* Let n denote the length of the longest output. Then `scalars` specifies a contiguous
* multi-dimension `num_bytes` by `n` array laid out in a packed column-major order as specified by
* output_bit_table. A given row determines the scalar exponents for generator g_i with the output
* scalars packed contiguously and padded with zeros.
*
* Note: `res` must match the generator type of the curve. See `sxt_multiexp_handle_new` for
* the types.
*/
void sxt_fixed_vlen_multiexponentiation(void* res, const struct sxt_multiexp_handle* handle,
const unsigned* output_bit_table,
const unsigned* output_lengths, unsigned num_outputs,
const uint8_t* scalars);
#ifdef __cplusplus
} // extern "C"
#endif
13 changes: 13 additions & 0 deletions cbindings/fixed_pedersen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,16 @@ void sxt_fixed_packed_multiexponentiation(void* res, const struct sxt_multiexp_h
backend->fixed_multiexponentiation(res, h->curve_id, *h->partition_table_accessor,
output_bit_table, num_outputs, n, scalars);
}

//--------------------------------------------------------------------------------------------------
// sxt_fixed_vlen_multiexponentiation
//--------------------------------------------------------------------------------------------------
void sxt_fixed_vlen_multiexponentiation(void* res, const struct sxt_multiexp_handle* handle,
const unsigned* output_bit_table,
const unsigned* output_lengths, unsigned num_outputs,
const uint8_t* scalars) {
auto backend = cbn::get_backend();
auto h = reinterpret_cast<const cbnb::multiexp_handle*>(handle);
backend->fixed_multiexponentiation(res, h->curve_id, *h->partition_table_accessor,
output_bit_table, output_lengths, num_outputs, scalars);
}
34 changes: 34 additions & 0 deletions cbindings/fixed_pedersen.t.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,40 @@ TEST_CASE("we can compute multi-exponentiations with a fixed set of generators")
REQUIRE(res[1] == generators[0]);
}

SECTION("we can compute a multiexponentiation of varying length") {
cbn::reset_backend_for_testing();
const sxt_config config = {SXT_GPU_BACKEND, 0};
REQUIRE(sxt_init(&config) == 0);

wrapped_handle h{generators.data(), 2};
REQUIRE(h.h != nullptr);

uint8_t scalars[] = {0b1011, 0b1101};
unsigned bit_table[] = {3, 1};
unsigned lengths[] = {1, 2};
c21t::element_p3 res[2];
sxt_fixed_vlen_multiexponentiation(res, h.h, bit_table, lengths, 2, scalars);
REQUIRE(res[0] == 3 * generators[0]);
REQUIRE(res[1] == generators[0] + generators[1]);
}

SECTION("we can compute a multiexponentiation of varying length on the host") {
cbn::reset_backend_for_testing();
const sxt_config config = {SXT_CPU_BACKEND, 0};
REQUIRE(sxt_init(&config) == 0);

wrapped_handle h{generators.data(), 2};
REQUIRE(h.h != nullptr);

uint8_t scalars[] = {0b1011, 0b1101};
unsigned bit_table[] = {3, 1};
unsigned lengths[] = {1, 2};
c21t::element_p3 res[2];
sxt_fixed_vlen_multiexponentiation(res, h.h, bit_table, lengths, 2, scalars);
REQUIRE(res[0] == 3 * generators[0]);
REQUIRE(res[1] == generators[0] + generators[1]);
}

SECTION("we can compute a multiexponentiation in packed form with three generators") {
cbn::reset_backend_for_testing();
const sxt_config config = {SXT_GPU_BACKEND, 0};
Expand Down
18 changes: 18 additions & 0 deletions sxt/cbindings/backend/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,24 @@ sxt_cc_component(
],
)

sxt_cc_component(
name = "computational_backend_utility",
impl_deps = [
"//sxt/base/error:assert",
"//sxt/base/num:divide_up",
],
test_deps = [
"//sxt/base/test:unit_test",
],
deps = [
"//sxt/base/container:span",
],
)

sxt_cc_component(
name = "gpu_backend",
impl_deps = [
":computational_backend_utility",
"//sxt/base/error:assert",
"//sxt/base/num:divide_up",
"//sxt/proof/transcript:transcript",
Expand Down Expand Up @@ -56,6 +71,7 @@ sxt_cc_component(
"//sxt/multiexp/curve:multiexponentiation",
"//sxt/multiexp/pippenger2:in_memory_partition_table_accessor_utility",
"//sxt/multiexp/pippenger2:multiexponentiation",
"//sxt/multiexp/pippenger2:variable_length_multiexponentiation",
"//sxt/seqcommit/generator:precomputed_generators",
"//sxt/proof/inner_product:proof_descriptor",
"//sxt/proof/inner_product:proof_computation",
Expand All @@ -71,6 +87,7 @@ sxt_cc_component(
sxt_cc_component(
name = "cpu_backend",
impl_deps = [
":computational_backend_utility",
"//sxt/base/error:panic",
"//sxt/base/num:round_up",
"//sxt/cbindings/base:curve_id_utility",
Expand Down Expand Up @@ -103,6 +120,7 @@ sxt_cc_component(
"//sxt/memory/management:managed_array",
"//sxt/multiexp/pippenger2:in_memory_partition_table_accessor_utility",
"//sxt/multiexp/pippenger2:multiexponentiation",
"//sxt/multiexp/pippenger2:variable_length_multiexponentiation",
"//sxt/ristretto/type:compressed_element",
"//sxt/ristretto/operation:compression",
"//sxt/multiexp/base:exponent_sequence",
Expand Down
6 changes: 6 additions & 0 deletions sxt/cbindings/backend/computational_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,5 +117,11 @@ class computational_backend {
const mtxpp2::partition_table_accessor_base& accessor,
const unsigned* output_bit_table, unsigned num_outputs,
unsigned n, const uint8_t* scalars) const noexcept = 0;

virtual void fixed_multiexponentiation(void* res, cbnb::curve_id_t curve_id,
const mtxpp2::partition_table_accessor_base& accessor,
const unsigned* output_bit_table,
const unsigned* output_lengths, unsigned num_outputs,
const uint8_t* scalars) const noexcept = 0;
};
} // namespace sxt::cbnbck
51 changes: 51 additions & 0 deletions sxt/cbindings/backend/computational_backend_utility.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
*
* Copyright 2024-present Space and Time Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "sxt/cbindings/backend/computational_backend_utility.h"

#include <algorithm>

#include "sxt/base/error/assert.h"
#include "sxt/base/num/divide_up.h"

namespace sxt::cbnbck {
//--------------------------------------------------------------------------------------------------
// make_scalars_span
//--------------------------------------------------------------------------------------------------
basct::cspan<uint8_t> make_scalars_span(const uint8_t* data,
basct::cspan<unsigned> output_bit_table,
basct::cspan<unsigned> output_lengths) noexcept {
auto num_outputs = output_bit_table.size();
SXT_DEBUG_ASSERT(output_lengths.size() == num_outputs);

unsigned output_bit_sum = 0;
unsigned n = 0;
unsigned prev_len = 0;
for (unsigned output_index = 0; output_index < num_outputs; ++output_index) {
auto width = output_bit_table[output_index];
SXT_RELEASE_ASSERT(width > 0, "output bit width must be positive");
auto len = output_lengths[output_index];
SXT_RELEASE_ASSERT(len >= prev_len, "output lengths must be sorted in ascending order");

output_bit_sum += width;
n = std::max(n, len);
prev_len = len;
}

auto output_num_bytes = basn::divide_up(output_bit_sum, 8u);
return basct::cspan<uint8_t>{data, output_num_bytes * n};
}
} // namespace sxt::cbnbck
30 changes: 30 additions & 0 deletions sxt/cbindings/backend/computational_backend_utility.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
*
* Copyright 2024-present Space and Time Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <cstdint>

#include "sxt/base/container/span.h"

namespace sxt::cbnbck {
//--------------------------------------------------------------------------------------------------
// make_scalars_span
//--------------------------------------------------------------------------------------------------
basct::cspan<uint8_t> make_scalars_span(const uint8_t* data,
basct::cspan<unsigned> output_bit_table,
basct::cspan<unsigned> output_lengths) noexcept;
} // namespace sxt::cbnbck
46 changes: 46 additions & 0 deletions sxt/cbindings/backend/computational_backend_utility.t.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
*
* Copyright 2024-present Space and Time Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "sxt/cbindings/backend/computational_backend_utility.h"

#include <vector>

#include "sxt/base/test/unit_test.h"

using namespace sxt;
using namespace sxt::cbnbck;

TEST_CASE("we can make a span for the referenced scalars") {
uint8_t data[16];

std::vector<unsigned> output_bit_table, output_lengths;

SECTION("we handle an output of length 1") {
output_bit_table = {1};
output_lengths = {1};
auto span = make_scalars_span(data, output_bit_table, output_lengths);
REQUIRE(span.size() == 1);
REQUIRE(span.data() == data);
}

SECTION("we handle multiple outputs") {
output_bit_table = {1, 8};
output_lengths = {1, 2};
auto span = make_scalars_span(data, output_bit_table, output_lengths);
REQUIRE(span.size() == 4);
REQUIRE(span.data() == data);
}
}
19 changes: 19 additions & 0 deletions sxt/cbindings/backend/cpu_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "sxt/base/error/assert.h"
#include "sxt/base/error/panic.h"
#include "sxt/base/num/divide_up.h"
#include "sxt/cbindings/backend/computational_backend_utility.h"
#include "sxt/cbindings/base/curve_id_utility.h"
#include "sxt/curve21/operation/add.h"
#include "sxt/curve21/operation/double.h"
Expand Down Expand Up @@ -51,6 +52,7 @@
#include "sxt/multiexp/curve/multiexponentiation.h"
#include "sxt/multiexp/pippenger2/in_memory_partition_table_accessor_utility.h"
#include "sxt/multiexp/pippenger2/multiexponentiation.h"
#include "sxt/multiexp/pippenger2/variable_length_multiexponentiation.h"
#include "sxt/proof/inner_product/cpu_driver.h"
#include "sxt/proof/inner_product/proof_computation.h"
#include "sxt/proof/inner_product/proof_descriptor.h"
Expand Down Expand Up @@ -194,6 +196,23 @@ void cpu_backend::fixed_multiexponentiation(void* res, cbnb::curve_id_t curve_id
});
}

void cpu_backend::fixed_multiexponentiation(void* res, cbnb::curve_id_t curve_id,
const mtxpp2::partition_table_accessor_base& accessor,
const unsigned* output_bit_table,
const unsigned* output_lengths, unsigned num_outputs,
const uint8_t* scalars) const noexcept {
cbnb::switch_curve_type(curve_id, [&]<class U, class T>(std::type_identity<U>,
std::type_identity<T>) noexcept {
basct::span<T> res_span{static_cast<T*>(res), num_outputs};
basct::cspan<unsigned> output_bit_table_span{output_bit_table, num_outputs};
basct::cspan<unsigned> output_lengths_span{output_lengths, num_outputs};
auto scalars_span = make_scalars_span(scalars, output_bit_table_span, output_lengths_span);
mtxpp2::multiexponentiate<T>(res_span,
static_cast<const mtxpp2::partition_table_accessor<U>&>(accessor),
output_bit_table_span, output_lengths_span, scalars_span);
});
}

//--------------------------------------------------------------------------------------------------
// get_cpu_backend
//--------------------------------------------------------------------------------------------------
Expand Down
6 changes: 6 additions & 0 deletions sxt/cbindings/backend/cpu_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ class cpu_backend final : public computational_backend {
const mtxpp2::partition_table_accessor_base& accessor,
const unsigned* output_bit_table, unsigned num_outputs, unsigned n,
const uint8_t* scalars) const noexcept override;

void fixed_multiexponentiation(void* res, cbnb::curve_id_t curve_id,
const mtxpp2::partition_table_accessor_base& accessor,
const unsigned* output_bit_table, const unsigned* output_lengths,
unsigned num_outputs,
const uint8_t* scalars) const noexcept override;
};

//--------------------------------------------------------------------------------------------------
Expand Down
20 changes: 20 additions & 0 deletions sxt/cbindings/backend/gpu_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include "sxt/base/error/assert.h"
#include "sxt/base/num/divide_up.h"
#include "sxt/cbindings/backend/computational_backend_utility.h"
#include "sxt/cbindings/base/curve_id_utility.h"
#include "sxt/curve21/operation/add.h"
#include "sxt/curve21/operation/double.h"
Expand Down Expand Up @@ -51,6 +52,7 @@
#include "sxt/multiexp/curve/multiexponentiation.h"
#include "sxt/multiexp/pippenger2/in_memory_partition_table_accessor_utility.h"
#include "sxt/multiexp/pippenger2/multiexponentiation.h"
#include "sxt/multiexp/pippenger2/variable_length_multiexponentiation.h"
#include "sxt/proof/inner_product/gpu_driver.h"
#include "sxt/proof/inner_product/proof_computation.h"
#include "sxt/proof/inner_product/proof_descriptor.h"
Expand Down Expand Up @@ -235,6 +237,24 @@ void gpu_backend::fixed_multiexponentiation(void* res, cbnb::curve_id_t curve_id
});
}

void gpu_backend::fixed_multiexponentiation(void* res, cbnb::curve_id_t curve_id,
const mtxpp2::partition_table_accessor_base& accessor,
const unsigned* output_bit_table,
const unsigned* output_lengths, unsigned num_outputs,
const uint8_t* scalars) const noexcept {
cbnb::switch_curve_type(
curve_id, [&]<class U, class T>(std::type_identity<U>, std::type_identity<T>) noexcept {
basct::span<T> res_span{static_cast<T*>(res), num_outputs};
basct::cspan<unsigned> output_bit_table_span{output_bit_table, num_outputs};
basct::cspan<unsigned> output_lengths_span{output_lengths, num_outputs};
auto scalars_span = make_scalars_span(scalars, output_bit_table_span, output_lengths_span);
auto fut = mtxpp2::async_multiexponentiate<T>(
res_span, static_cast<const mtxpp2::partition_table_accessor<U>&>(accessor),
output_bit_table_span, output_lengths_span, scalars_span);
xens::get_scheduler().run();
});
}

//--------------------------------------------------------------------------------------------------
// get_gpu_backend
//--------------------------------------------------------------------------------------------------
Expand Down
6 changes: 6 additions & 0 deletions sxt/cbindings/backend/gpu_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ class gpu_backend final : public computational_backend {
const mtxpp2::partition_table_accessor_base& accessor,
const unsigned* output_bit_table, unsigned num_outputs, unsigned n,
const uint8_t* scalars) const noexcept override;

void fixed_multiexponentiation(void* res, cbnb::curve_id_t curve_id,
const mtxpp2::partition_table_accessor_base& accessor,
const unsigned* output_bit_table, const unsigned* output_lengths,
unsigned num_outputs,
const uint8_t* scalars) const noexcept override;
};

//--------------------------------------------------------------------------------------------------
Expand Down
Loading

0 comments on commit 31ee6a6

Please sign in to comment.