From 31ee6a6f93359268b1e7636bb82805b07a5921f3 Mon Sep 17 00:00:00 2001 From: Ryan Burn Date: Fri, 16 Aug 2024 10:54:53 -0700 Subject: [PATCH] feat: add c bindings for variable length multi-exponentiations (PROOF-896) (#167) * add variable length multiexponentiation * add stub for c api * add c api stub * fill in c bridge * reformat * fill in tests * fill in backend * fill in cpu backend * refactor * fill in cbindings * add no output tests * fill in tests * doc * minor changes --- cbindings/blitzar_api.h | 29 +++++++++++ cbindings/fixed_pedersen.cc | 13 +++++ cbindings/fixed_pedersen.t.cc | 34 +++++++++++++ sxt/cbindings/backend/BUILD | 18 +++++++ sxt/cbindings/backend/computational_backend.h | 6 +++ .../backend/computational_backend_utility.cc | 51 +++++++++++++++++++ .../backend/computational_backend_utility.h | 30 +++++++++++ .../computational_backend_utility.t.cc | 46 +++++++++++++++++ sxt/cbindings/backend/cpu_backend.cc | 19 +++++++ sxt/cbindings/backend/cpu_backend.h | 6 +++ sxt/cbindings/backend/gpu_backend.cc | 20 ++++++++ sxt/cbindings/backend/gpu_backend.h | 6 +++ .../variable_length_multiexponentiation.h | 8 +++ .../variable_length_multiexponentiation.t.cc | 18 +++++++ 14 files changed, 304 insertions(+) create mode 100644 sxt/cbindings/backend/computational_backend_utility.cc create mode 100644 sxt/cbindings/backend/computational_backend_utility.h create mode 100644 sxt/cbindings/backend/computational_backend_utility.t.cc diff --git a/cbindings/blitzar_api.h b/cbindings/blitzar_api.h index 7fef6e0d..7e7f6acf 100644 --- a/cbindings/blitzar_api.h +++ b/cbindings/blitzar_api.h @@ -642,6 +642,35 @@ void sxt_fixed_packed_multiexponentiation(void* res, const struct sxt_multiexp_h const unsigned* output_bit_table, unsigned num_outputs, unsigned n, const uint8_t* scalars); +/** + * Compute a varying lengthing multiexponentiation of scalars in packed format using a handle to + * pre-specified generators. + * + * On completion `res` contains an array of size `num_outputs` for the multiexponentiation + * of the given `scalars` array. + * + * An entry output_bit_table[output_index] specifies the number of scalar bits used for + * output_index and output_lengths[output_index] specifies the length used for output_index. + * + * Note: output_lengths must be sorted in ascending order + * + * Put + * bit_sum = sum_{output_index} output_bit_table[output_index] + * and let num_bytes denote the smallest integer greater than or equal to bit_sum that is a + * multiple of 8. + * + * Let n denote the length of the longest output. Then `scalars` specifies a contiguous + * multi-dimension `num_bytes` by `n` array laid out in a packed column-major order as specified by + * output_bit_table. A given row determines the scalar exponents for generator g_i with the output + * scalars packed contiguously and padded with zeros. + * + * Note: `res` must match the generator type of the curve. See `sxt_multiexp_handle_new` for + * the types. + */ +void sxt_fixed_vlen_multiexponentiation(void* res, const struct sxt_multiexp_handle* handle, + const unsigned* output_bit_table, + const unsigned* output_lengths, unsigned num_outputs, + const uint8_t* scalars); #ifdef __cplusplus } // extern "C" #endif diff --git a/cbindings/fixed_pedersen.cc b/cbindings/fixed_pedersen.cc index 0ea93411..2ddb172a 100644 --- a/cbindings/fixed_pedersen.cc +++ b/cbindings/fixed_pedersen.cc @@ -66,3 +66,16 @@ void sxt_fixed_packed_multiexponentiation(void* res, const struct sxt_multiexp_h backend->fixed_multiexponentiation(res, h->curve_id, *h->partition_table_accessor, output_bit_table, num_outputs, n, scalars); } + +//-------------------------------------------------------------------------------------------------- +// sxt_fixed_vlen_multiexponentiation +//-------------------------------------------------------------------------------------------------- +void sxt_fixed_vlen_multiexponentiation(void* res, const struct sxt_multiexp_handle* handle, + const unsigned* output_bit_table, + const unsigned* output_lengths, unsigned num_outputs, + const uint8_t* scalars) { + auto backend = cbn::get_backend(); + auto h = reinterpret_cast(handle); + backend->fixed_multiexponentiation(res, h->curve_id, *h->partition_table_accessor, + output_bit_table, output_lengths, num_outputs, scalars); +} diff --git a/cbindings/fixed_pedersen.t.cc b/cbindings/fixed_pedersen.t.cc index 11cdc0ec..f28b898a 100644 --- a/cbindings/fixed_pedersen.t.cc +++ b/cbindings/fixed_pedersen.t.cc @@ -91,6 +91,40 @@ TEST_CASE("we can compute multi-exponentiations with a fixed set of generators") REQUIRE(res[1] == generators[0]); } + SECTION("we can compute a multiexponentiation of varying length") { + cbn::reset_backend_for_testing(); + const sxt_config config = {SXT_GPU_BACKEND, 0}; + REQUIRE(sxt_init(&config) == 0); + + wrapped_handle h{generators.data(), 2}; + REQUIRE(h.h != nullptr); + + uint8_t scalars[] = {0b1011, 0b1101}; + unsigned bit_table[] = {3, 1}; + unsigned lengths[] = {1, 2}; + c21t::element_p3 res[2]; + sxt_fixed_vlen_multiexponentiation(res, h.h, bit_table, lengths, 2, scalars); + REQUIRE(res[0] == 3 * generators[0]); + REQUIRE(res[1] == generators[0] + generators[1]); + } + + SECTION("we can compute a multiexponentiation of varying length on the host") { + cbn::reset_backend_for_testing(); + const sxt_config config = {SXT_CPU_BACKEND, 0}; + REQUIRE(sxt_init(&config) == 0); + + wrapped_handle h{generators.data(), 2}; + REQUIRE(h.h != nullptr); + + uint8_t scalars[] = {0b1011, 0b1101}; + unsigned bit_table[] = {3, 1}; + unsigned lengths[] = {1, 2}; + c21t::element_p3 res[2]; + sxt_fixed_vlen_multiexponentiation(res, h.h, bit_table, lengths, 2, scalars); + REQUIRE(res[0] == 3 * generators[0]); + REQUIRE(res[1] == generators[0] + generators[1]); + } + SECTION("we can compute a multiexponentiation in packed form with three generators") { cbn::reset_backend_for_testing(); const sxt_config config = {SXT_GPU_BACKEND, 0}; diff --git a/sxt/cbindings/backend/BUILD b/sxt/cbindings/backend/BUILD index 93e2bda6..088dd24b 100644 --- a/sxt/cbindings/backend/BUILD +++ b/sxt/cbindings/backend/BUILD @@ -16,9 +16,24 @@ sxt_cc_component( ], ) +sxt_cc_component( + name = "computational_backend_utility", + impl_deps = [ + "//sxt/base/error:assert", + "//sxt/base/num:divide_up", + ], + test_deps = [ + "//sxt/base/test:unit_test", + ], + deps = [ + "//sxt/base/container:span", + ], +) + sxt_cc_component( name = "gpu_backend", impl_deps = [ + ":computational_backend_utility", "//sxt/base/error:assert", "//sxt/base/num:divide_up", "//sxt/proof/transcript:transcript", @@ -56,6 +71,7 @@ sxt_cc_component( "//sxt/multiexp/curve:multiexponentiation", "//sxt/multiexp/pippenger2:in_memory_partition_table_accessor_utility", "//sxt/multiexp/pippenger2:multiexponentiation", + "//sxt/multiexp/pippenger2:variable_length_multiexponentiation", "//sxt/seqcommit/generator:precomputed_generators", "//sxt/proof/inner_product:proof_descriptor", "//sxt/proof/inner_product:proof_computation", @@ -71,6 +87,7 @@ sxt_cc_component( sxt_cc_component( name = "cpu_backend", impl_deps = [ + ":computational_backend_utility", "//sxt/base/error:panic", "//sxt/base/num:round_up", "//sxt/cbindings/base:curve_id_utility", @@ -103,6 +120,7 @@ sxt_cc_component( "//sxt/memory/management:managed_array", "//sxt/multiexp/pippenger2:in_memory_partition_table_accessor_utility", "//sxt/multiexp/pippenger2:multiexponentiation", + "//sxt/multiexp/pippenger2:variable_length_multiexponentiation", "//sxt/ristretto/type:compressed_element", "//sxt/ristretto/operation:compression", "//sxt/multiexp/base:exponent_sequence", diff --git a/sxt/cbindings/backend/computational_backend.h b/sxt/cbindings/backend/computational_backend.h index 3668b07b..844f1f55 100644 --- a/sxt/cbindings/backend/computational_backend.h +++ b/sxt/cbindings/backend/computational_backend.h @@ -117,5 +117,11 @@ class computational_backend { const mtxpp2::partition_table_accessor_base& accessor, const unsigned* output_bit_table, unsigned num_outputs, unsigned n, const uint8_t* scalars) const noexcept = 0; + + virtual void fixed_multiexponentiation(void* res, cbnb::curve_id_t curve_id, + const mtxpp2::partition_table_accessor_base& accessor, + const unsigned* output_bit_table, + const unsigned* output_lengths, unsigned num_outputs, + const uint8_t* scalars) const noexcept = 0; }; } // namespace sxt::cbnbck diff --git a/sxt/cbindings/backend/computational_backend_utility.cc b/sxt/cbindings/backend/computational_backend_utility.cc new file mode 100644 index 00000000..ab19e08f --- /dev/null +++ b/sxt/cbindings/backend/computational_backend_utility.cc @@ -0,0 +1,51 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2024-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "sxt/cbindings/backend/computational_backend_utility.h" + +#include + +#include "sxt/base/error/assert.h" +#include "sxt/base/num/divide_up.h" + +namespace sxt::cbnbck { +//-------------------------------------------------------------------------------------------------- +// make_scalars_span +//-------------------------------------------------------------------------------------------------- +basct::cspan make_scalars_span(const uint8_t* data, + basct::cspan output_bit_table, + basct::cspan output_lengths) noexcept { + auto num_outputs = output_bit_table.size(); + SXT_DEBUG_ASSERT(output_lengths.size() == num_outputs); + + unsigned output_bit_sum = 0; + unsigned n = 0; + unsigned prev_len = 0; + for (unsigned output_index = 0; output_index < num_outputs; ++output_index) { + auto width = output_bit_table[output_index]; + SXT_RELEASE_ASSERT(width > 0, "output bit width must be positive"); + auto len = output_lengths[output_index]; + SXT_RELEASE_ASSERT(len >= prev_len, "output lengths must be sorted in ascending order"); + + output_bit_sum += width; + n = std::max(n, len); + prev_len = len; + } + + auto output_num_bytes = basn::divide_up(output_bit_sum, 8u); + return basct::cspan{data, output_num_bytes * n}; +} +} // namespace sxt::cbnbck diff --git a/sxt/cbindings/backend/computational_backend_utility.h b/sxt/cbindings/backend/computational_backend_utility.h new file mode 100644 index 00000000..865ca45a --- /dev/null +++ b/sxt/cbindings/backend/computational_backend_utility.h @@ -0,0 +1,30 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2024-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include "sxt/base/container/span.h" + +namespace sxt::cbnbck { +//-------------------------------------------------------------------------------------------------- +// make_scalars_span +//-------------------------------------------------------------------------------------------------- +basct::cspan make_scalars_span(const uint8_t* data, + basct::cspan output_bit_table, + basct::cspan output_lengths) noexcept; +} // namespace sxt::cbnbck diff --git a/sxt/cbindings/backend/computational_backend_utility.t.cc b/sxt/cbindings/backend/computational_backend_utility.t.cc new file mode 100644 index 00000000..c60d1823 --- /dev/null +++ b/sxt/cbindings/backend/computational_backend_utility.t.cc @@ -0,0 +1,46 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2024-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "sxt/cbindings/backend/computational_backend_utility.h" + +#include + +#include "sxt/base/test/unit_test.h" + +using namespace sxt; +using namespace sxt::cbnbck; + +TEST_CASE("we can make a span for the referenced scalars") { + uint8_t data[16]; + + std::vector output_bit_table, output_lengths; + + SECTION("we handle an output of length 1") { + output_bit_table = {1}; + output_lengths = {1}; + auto span = make_scalars_span(data, output_bit_table, output_lengths); + REQUIRE(span.size() == 1); + REQUIRE(span.data() == data); + } + + SECTION("we handle multiple outputs") { + output_bit_table = {1, 8}; + output_lengths = {1, 2}; + auto span = make_scalars_span(data, output_bit_table, output_lengths); + REQUIRE(span.size() == 4); + REQUIRE(span.data() == data); + } +} diff --git a/sxt/cbindings/backend/cpu_backend.cc b/sxt/cbindings/backend/cpu_backend.cc index 6c834169..3284f546 100644 --- a/sxt/cbindings/backend/cpu_backend.cc +++ b/sxt/cbindings/backend/cpu_backend.cc @@ -23,6 +23,7 @@ #include "sxt/base/error/assert.h" #include "sxt/base/error/panic.h" #include "sxt/base/num/divide_up.h" +#include "sxt/cbindings/backend/computational_backend_utility.h" #include "sxt/cbindings/base/curve_id_utility.h" #include "sxt/curve21/operation/add.h" #include "sxt/curve21/operation/double.h" @@ -51,6 +52,7 @@ #include "sxt/multiexp/curve/multiexponentiation.h" #include "sxt/multiexp/pippenger2/in_memory_partition_table_accessor_utility.h" #include "sxt/multiexp/pippenger2/multiexponentiation.h" +#include "sxt/multiexp/pippenger2/variable_length_multiexponentiation.h" #include "sxt/proof/inner_product/cpu_driver.h" #include "sxt/proof/inner_product/proof_computation.h" #include "sxt/proof/inner_product/proof_descriptor.h" @@ -194,6 +196,23 @@ void cpu_backend::fixed_multiexponentiation(void* res, cbnb::curve_id_t curve_id }); } +void cpu_backend::fixed_multiexponentiation(void* res, cbnb::curve_id_t curve_id, + const mtxpp2::partition_table_accessor_base& accessor, + const unsigned* output_bit_table, + const unsigned* output_lengths, unsigned num_outputs, + const uint8_t* scalars) const noexcept { + cbnb::switch_curve_type(curve_id, [&](std::type_identity, + std::type_identity) noexcept { + basct::span res_span{static_cast(res), num_outputs}; + basct::cspan output_bit_table_span{output_bit_table, num_outputs}; + basct::cspan output_lengths_span{output_lengths, num_outputs}; + auto scalars_span = make_scalars_span(scalars, output_bit_table_span, output_lengths_span); + mtxpp2::multiexponentiate(res_span, + static_cast&>(accessor), + output_bit_table_span, output_lengths_span, scalars_span); + }); +} + //-------------------------------------------------------------------------------------------------- // get_cpu_backend //-------------------------------------------------------------------------------------------------- diff --git a/sxt/cbindings/backend/cpu_backend.h b/sxt/cbindings/backend/cpu_backend.h index bd96882f..1993bc11 100644 --- a/sxt/cbindings/backend/cpu_backend.h +++ b/sxt/cbindings/backend/cpu_backend.h @@ -72,6 +72,12 @@ class cpu_backend final : public computational_backend { const mtxpp2::partition_table_accessor_base& accessor, const unsigned* output_bit_table, unsigned num_outputs, unsigned n, const uint8_t* scalars) const noexcept override; + + void fixed_multiexponentiation(void* res, cbnb::curve_id_t curve_id, + const mtxpp2::partition_table_accessor_base& accessor, + const unsigned* output_bit_table, const unsigned* output_lengths, + unsigned num_outputs, + const uint8_t* scalars) const noexcept override; }; //-------------------------------------------------------------------------------------------------- diff --git a/sxt/cbindings/backend/gpu_backend.cc b/sxt/cbindings/backend/gpu_backend.cc index eab008bd..dbd5bfc6 100644 --- a/sxt/cbindings/backend/gpu_backend.cc +++ b/sxt/cbindings/backend/gpu_backend.cc @@ -21,6 +21,7 @@ #include "sxt/base/error/assert.h" #include "sxt/base/num/divide_up.h" +#include "sxt/cbindings/backend/computational_backend_utility.h" #include "sxt/cbindings/base/curve_id_utility.h" #include "sxt/curve21/operation/add.h" #include "sxt/curve21/operation/double.h" @@ -51,6 +52,7 @@ #include "sxt/multiexp/curve/multiexponentiation.h" #include "sxt/multiexp/pippenger2/in_memory_partition_table_accessor_utility.h" #include "sxt/multiexp/pippenger2/multiexponentiation.h" +#include "sxt/multiexp/pippenger2/variable_length_multiexponentiation.h" #include "sxt/proof/inner_product/gpu_driver.h" #include "sxt/proof/inner_product/proof_computation.h" #include "sxt/proof/inner_product/proof_descriptor.h" @@ -235,6 +237,24 @@ void gpu_backend::fixed_multiexponentiation(void* res, cbnb::curve_id_t curve_id }); } +void gpu_backend::fixed_multiexponentiation(void* res, cbnb::curve_id_t curve_id, + const mtxpp2::partition_table_accessor_base& accessor, + const unsigned* output_bit_table, + const unsigned* output_lengths, unsigned num_outputs, + const uint8_t* scalars) const noexcept { + cbnb::switch_curve_type( + curve_id, [&](std::type_identity, std::type_identity) noexcept { + basct::span res_span{static_cast(res), num_outputs}; + basct::cspan output_bit_table_span{output_bit_table, num_outputs}; + basct::cspan output_lengths_span{output_lengths, num_outputs}; + auto scalars_span = make_scalars_span(scalars, output_bit_table_span, output_lengths_span); + auto fut = mtxpp2::async_multiexponentiate( + res_span, static_cast&>(accessor), + output_bit_table_span, output_lengths_span, scalars_span); + xens::get_scheduler().run(); + }); +} + //-------------------------------------------------------------------------------------------------- // get_gpu_backend //-------------------------------------------------------------------------------------------------- diff --git a/sxt/cbindings/backend/gpu_backend.h b/sxt/cbindings/backend/gpu_backend.h index 45f26717..b5da9a47 100644 --- a/sxt/cbindings/backend/gpu_backend.h +++ b/sxt/cbindings/backend/gpu_backend.h @@ -74,6 +74,12 @@ class gpu_backend final : public computational_backend { const mtxpp2::partition_table_accessor_base& accessor, const unsigned* output_bit_table, unsigned num_outputs, unsigned n, const uint8_t* scalars) const noexcept override; + + void fixed_multiexponentiation(void* res, cbnb::curve_id_t curve_id, + const mtxpp2::partition_table_accessor_base& accessor, + const unsigned* output_bit_table, const unsigned* output_lengths, + unsigned num_outputs, + const uint8_t* scalars) const noexcept override; }; //-------------------------------------------------------------------------------------------------- diff --git a/sxt/multiexp/pippenger2/variable_length_multiexponentiation.h b/sxt/multiexp/pippenger2/variable_length_multiexponentiation.h index 00c8dd68..69643b8b 100644 --- a/sxt/multiexp/pippenger2/variable_length_multiexponentiation.h +++ b/sxt/multiexp/pippenger2/variable_length_multiexponentiation.h @@ -160,11 +160,16 @@ multiexponentiate_impl(basct::span res, const partition_table_accessor& ac auto num_outputs = res.size(); auto num_products = std::accumulate(output_bit_table.begin(), output_bit_table.end(), 0u); auto num_output_bytes = basn::divide_up(num_products, 8); + if (num_outputs == 0) { + co_return; + } SXT_DEBUG_ASSERT( // clang-format off scalars.size() % num_output_bytes == 0 // clang-format on ); + + // compute products basdv::stream stream; memr::async_device_resource resource{stream}; memmg::managed_array products{num_products, &resource}; @@ -221,6 +226,9 @@ void multiexponentiate(basct::span res, const partition_table_accessor& ac auto num_outputs = res.size(); auto num_products = std::accumulate(output_bit_table.begin(), output_bit_table.end(), 0u); auto num_output_bytes = basn::divide_up(num_products, 8); + if (num_outputs == 0) { + return; + } auto n = scalars.size() / num_output_bytes; SXT_DEBUG_ASSERT( // clang-format off diff --git a/sxt/multiexp/pippenger2/variable_length_multiexponentiation.t.cc b/sxt/multiexp/pippenger2/variable_length_multiexponentiation.t.cc index 4c4f3a1e..da4abf96 100644 --- a/sxt/multiexp/pippenger2/variable_length_multiexponentiation.t.cc +++ b/sxt/multiexp/pippenger2/variable_length_multiexponentiation.t.cc @@ -44,6 +44,24 @@ TEST_CASE("we can compute multiexponentiations with varying lengths") { std::vector output_bit_table(1); std::vector output_lengths(1); + SECTION("we handle no outputs") { + res.clear(); + output_bit_table.clear(); + output_lengths.clear(); + scalars.clear(); + auto fut = + async_multiexponentiate(res, *accessor, output_bit_table, output_lengths, scalars); + REQUIRE(fut.ready()); + } + + SECTION("we handle no outputs on the host") { + res.clear(); + output_bit_table.clear(); + output_lengths.clear(); + scalars.clear(); + multiexponentiate(res, *accessor, output_bit_table, output_lengths, scalars); + } + SECTION("we can compute a multiexponentiation of length zero") { output_bit_table[0] = 1; output_lengths[0] = 0;