From 988d6042ca60837b9fc933d2f2b7f23cd7b7c3c0 Mon Sep 17 00:00:00 2001 From: Ben Howe Date: Fri, 11 Oct 2024 17:22:28 +0000 Subject: [PATCH] Update Stim backend to support conditionals and mid-circuit measurements Signed-off-by: Ben Howe --- python/tests/kernel/test_kernel_features.py | 10 +- runtime/nvqir/stim/StimCircuitSimulator.cpp | 145 ++++++++++++++---- targettests/execution/qir_cond_for_break.cpp | 1 + targettests/execution/qir_cond_for_loop-1.cpp | 1 + targettests/execution/qir_cond_for_loop-2.cpp | 1 + targettests/execution/qir_cond_for_loop-3.cpp | 1 + targettests/execution/qir_cond_for_loop-4.cpp | 1 + targettests/execution/qir_cond_for_loop-5.cpp | 1 + targettests/execution/qir_cond_for_loop-6.cpp | 1 + targettests/execution/qir_simple_cond-1.cpp | 1 + unittests/integration/builder_tester.cpp | 24 ++- 11 files changed, 156 insertions(+), 31 deletions(-) diff --git a/python/tests/kernel/test_kernel_features.py b/python/tests/kernel/test_kernel_features.py index b6267cb75d1..ae6d12e20c5 100644 --- a/python/tests/kernel/test_kernel_features.py +++ b/python/tests/kernel/test_kernel_features.py @@ -263,10 +263,15 @@ def kernel(theta: float): assert np.isclose(want_exp, -1.13, atol=1e-2) -def test_dynamic_circuit(): +@pytest.mark.parametrize('target', ['default', 'stim']) +def test_dynamic_circuit(target): """Test that we correctly sample circuits with mid-circuit measurements and conditionals.""" + if target == 'stim': + save_target = cudaq.get_target() + cudaq.set_target('stim') + @cudaq.kernel def simple(): q = cudaq.qvector(2) @@ -297,6 +302,9 @@ def simple(): assert '0' in c0 and '1' in c0 assert '00' in counts and '11' in counts + if target == 'stim': + cudaq.set_target(save_target) + def test_teleport(): diff --git a/runtime/nvqir/stim/StimCircuitSimulator.cpp b/runtime/nvqir/stim/StimCircuitSimulator.cpp index f95cc41cf94..47265554c07 100644 --- a/runtime/nvqir/stim/StimCircuitSimulator.cpp +++ b/runtime/nvqir/stim/StimCircuitSimulator.cpp @@ -24,12 +24,35 @@ namespace nvqir { /// https://github.com/quantumlib/Stim. class StimCircuitSimulator : public nvqir::CircuitSimulatorBase { protected: - stim::Circuit stimCircuit; + // Follow Stim naming convention (W) for bit width (required for templates). + static constexpr std::size_t W = stim::MAX_BITWORD_WIDTH; + + /// @brief Number of measurements performed so far. + std::size_t num_measurements = 0; + + /// @brief Top-level random engine. Stim simulator RNGs are based off of this + /// engine. std::mt19937_64 randomEngine; + /// @brief Stim Tableau simulator (noiseless) + std::unique_ptr> tableau; + + /// @brief Stim Frame/Flip simulator (used to generate multiple shots) + std::unique_ptr> sampleSim; + /// @brief Grow the state vector by one qubit. void addQubitToState() override { addQubitsToState(1); } + /// @brief Get the batch size to use for the Stim sample simulator. + std::size_t getBatchSize() { + // Default to single shot + std::size_t batch_size = 1; + if (getExecutionContext() && getExecutionContext()->name == "sample" && + !getExecutionContext()->hasConditionalsOnMeasureResults) + batch_size = getExecutionContext()->shots; + return batch_size; + } + /// @brief Override the default sized allocation of qubits /// here to be a bit more efficient than the default implementation void addQubitsToState(std::size_t qubitCount, @@ -37,11 +60,52 @@ class StimCircuitSimulator : public nvqir::CircuitSimulatorBase { if (stateDataIn) throw std::runtime_error("The Stim simulator does not support " "initialization of qubits from state data."); - return; + + if (!tableau) { + cudaq::info("Creating new Stim Tableau simulator"); + // Bump the randomEngine before cloning and giving to the Tableau + // simulator. + randomEngine.discard( + std::uniform_int_distribution(1, 30)(randomEngine)); + tableau = std::make_unique>( + std::mt19937_64(randomEngine), /*num_qubits=*/0, /*sign_bias=*/+0); + } + if (!sampleSim) { + auto batch_size = getBatchSize(); + cudaq::info("Creating new Stim frame simulator with batch size {}", + batch_size); + // Bump the randomEngine before cloning and giving to the sample + // simulator. + randomEngine.discard( + std::uniform_int_distribution(1, 30)(randomEngine)); + sampleSim = std::make_unique>( + stim::CircuitStats(), + stim::FrameSimulatorMode::STORE_MEASUREMENTS_TO_MEMORY, batch_size, + std::mt19937_64(randomEngine)); + sampleSim->reset_all(); + } } /// @brief Reset the qubit state. - void deallocateStateImpl() override { stimCircuit.clear(); } + void deallocateStateImpl() override { + tableau.reset(); + // Update the randomEngine so that future invocations will use the updated + // RNG state. + if (sampleSim) + randomEngine = std::move(sampleSim->rng); + sampleSim.reset(); + num_measurements = 0; + } + + /// @brief Apply operation to all Stim simulators. + void applyOpToSims(const std::string &gate_name, + const std::vector &targets) { + stim::Circuit tempCircuit; + cudaq::info("Calling applyOpToSims {} - {}", gate_name, targets); + tempCircuit.safe_append_u(gate_name, targets); + tableau->safe_do_circuit(tempCircuit); + sampleSim->safe_do_circuit(tempCircuit); + } /// @brief Apply the noise channel on \p qubits void applyNoiseChannel(const std::string_view gateName, @@ -78,19 +142,21 @@ class StimCircuitSimulator : public nvqir::CircuitSimulatorBase { cudaq::info("Applying {} kraus channels to qubits {}", krausChannels.size(), stimTargets); + stim::Circuit noiseOps; for (auto &channel : krausChannels) { if (channel.noise_type == cudaq::noise_model_type::bit_flip_channel) - stimCircuit.safe_append_ua("X_ERROR", stimTargets, - channel.parameters[0]); + noiseOps.safe_append_ua("X_ERROR", stimTargets, channel.parameters[0]); else if (channel.noise_type == cudaq::noise_model_type::phase_flip_channel) - stimCircuit.safe_append_ua("Z_ERROR", stimTargets, - channel.parameters[0]); + noiseOps.safe_append_ua("Z_ERROR", stimTargets, channel.parameters[0]); else if (channel.noise_type == cudaq::noise_model_type::depolarization_channel) - stimCircuit.safe_append_ua("DEPOLARIZE1", stimTargets, - channel.parameters[0]); + noiseOps.safe_append_ua("DEPOLARIZE1", stimTargets, + channel.parameters[0]); } + // Only apply the noise operations to the sample simulator (not the Tableau + // simulator). + sampleSim->safe_do_circuit(noiseOps); } void applyGate(const GateApplicationTask &task) override { @@ -119,7 +185,7 @@ class StimCircuitSimulator : public nvqir::CircuitSimulatorBase { for (auto t : task.targets) stimTargets.push_back(t); try { - stimCircuit.safe_append_u(gateName, stimTargets); + applyOpToSims(gateName, stimTargets); } catch (std::out_of_range &e) { throw std::runtime_error( fmt::format("Gate not supported by Stim simulator: {}. Note that " @@ -137,14 +203,31 @@ class StimCircuitSimulator : public nvqir::CircuitSimulatorBase { return 0; } - /// @brief Measure the qubit and return the result. Collapse the - /// state vector. - bool measureQubit(const std::size_t index) override { return false; } + /// @brief Measure the qubit and return the result. + bool measureQubit(const std::size_t index) override { + // Perform measurement + applyOpToSims( + "M", std::vector{static_cast(index)}); + num_measurements++; + + // Get the tableau bit that was just generated. + const std::vector &v = tableau->measurement_record.storage; + const bool tableauBit = *v.crbegin(); + + // Get the mid-circuit sample to be XOR-ed with tableauBit. + bool sampleSimBit = + sampleSim->m_record.storage[num_measurements - 1][/*shot=*/0]; + + // Calculate the result. + bool result = tableauBit ^ sampleSimBit; + + return result; + } QubitOrdering getQubitOrdering() const override { return QubitOrdering::msb; } public: - StimCircuitSimulator() { + StimCircuitSimulator() : randomEngine(std::random_device{}()) { // Populate the correct name so it is printed correctly during // deconstructor. summaryData.name = name(); @@ -161,26 +244,33 @@ class StimCircuitSimulator : public nvqir::CircuitSimulatorBase { /// @param index 0-based index of qubit to reset void resetQubit(const std::size_t index) override { flushGateQueue(); - stimCircuit.safe_append_u( + applyOpToSims( "R", std::vector{static_cast(index)}); } /// @brief Sample the multi-qubit state. cudaq::ExecutionResult sample(const std::vector &qubits, const int shots) override { + assert(shots <= sampleSim->batch_size); std::vector stimTargetQubits(qubits.begin(), qubits.end()); - stimCircuit.safe_append_u("M", stimTargetQubits); - if (false) { - std::stringstream ss; - ss << stimCircuit << '\n'; - cudaq::log("Stim circuit is\n{}", ss.str()); + applyOpToSims("M", stimTargetQubits); + num_measurements += stimTargetQubits.size(); + + // Generate a reference sample + const std::vector &v = tableau->measurement_record.storage; + stim::simd_bits ref(v.size()); + for (size_t k = 0; k < v.size(); k++) + ref[k] ^= v[k]; + + // Now XOR results on a per-shot basis + stim::simd_bit_table sample = sampleSim->m_record.storage; + auto nShots = sampleSim->batch_size; + if (ref.not_zero()) { + sample = stim::transposed_vs_ref(nShots, sample, ref); + sample = sample.transposed(); } - auto ref_sample = stim::TableauSimulator< - stim::MAX_BITWORD_WIDTH>::reference_sample_circuit(stimCircuit); - stim::simd_bit_table sample = - stim::sample_batch_measurements(stimCircuit, ref_sample, shots, - randomEngine, false); - size_t bits_per_sample = stimCircuit.count_measurements(); + + size_t bits_per_sample = num_measurements; std::vector sequentialData; // Only retain the final "qubits.size()" measurements. All other // measurements were mid-circuit measurements that have been previously @@ -190,9 +280,8 @@ class StimCircuitSimulator : public nvqir::CircuitSimulatorBase { CountsDictionary counts; for (std::size_t shot = 0; shot < shots; shot++) { std::string aShot(qubits.size(), '0'); - for (std::size_t b = first_bit_to_save; b < bits_per_sample; b++) { + for (std::size_t b = first_bit_to_save; b < bits_per_sample; b++) aShot[b - first_bit_to_save] = sample[b][shot] ? '1' : '0'; - } counts[aShot]++; sequentialData.push_back(std::move(aShot)); } diff --git a/targettests/execution/qir_cond_for_break.cpp b/targettests/execution/qir_cond_for_break.cpp index e89650e8dde..fa84f358c8c 100644 --- a/targettests/execution/qir_cond_for_break.cpp +++ b/targettests/execution/qir_cond_for_break.cpp @@ -7,6 +7,7 @@ ******************************************************************************/ // RUN: nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s +// RUN: CUDAQ_DEFAULT_SIMULATOR=stim nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s // RUN: nvq++ -std=c++17 --enable-mlir %s -o %t #include diff --git a/targettests/execution/qir_cond_for_loop-1.cpp b/targettests/execution/qir_cond_for_loop-1.cpp index f2751647738..3d98172db5e 100644 --- a/targettests/execution/qir_cond_for_loop-1.cpp +++ b/targettests/execution/qir_cond_for_loop-1.cpp @@ -8,6 +8,7 @@ // clang-format off // RUN: nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s +// RUN: CUDAQ_DEFAULT_SIMULATOR=stim nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s // RUN: nvq++ -std=c++17 --enable-mlir %s -o %t // clang-format on diff --git a/targettests/execution/qir_cond_for_loop-2.cpp b/targettests/execution/qir_cond_for_loop-2.cpp index be76c1f98ce..a8b393a58f1 100644 --- a/targettests/execution/qir_cond_for_loop-2.cpp +++ b/targettests/execution/qir_cond_for_loop-2.cpp @@ -8,6 +8,7 @@ // clang-format off // RUN: nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s +// RUN: CUDAQ_DEFAULT_SIMULATOR=stim nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s // RUN: nvq++ -std=c++17 --enable-mlir %s -o %t // clang-format on diff --git a/targettests/execution/qir_cond_for_loop-3.cpp b/targettests/execution/qir_cond_for_loop-3.cpp index a1a0b4d44f4..78a47ceb370 100644 --- a/targettests/execution/qir_cond_for_loop-3.cpp +++ b/targettests/execution/qir_cond_for_loop-3.cpp @@ -8,6 +8,7 @@ // clang-format off // RUN: nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s +// RUN: CUDAQ_DEFAULT_SIMULATOR=stim nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s // RUN: nvq++ -std=c++17 --enable-mlir %s -o %t // clang-format on diff --git a/targettests/execution/qir_cond_for_loop-4.cpp b/targettests/execution/qir_cond_for_loop-4.cpp index bec087d1fdf..4eb04970ed4 100644 --- a/targettests/execution/qir_cond_for_loop-4.cpp +++ b/targettests/execution/qir_cond_for_loop-4.cpp @@ -8,6 +8,7 @@ // clang-format off // RUN: nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s +// RUN: CUDAQ_DEFAULT_SIMULATOR=stim nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s // RUN: nvq++ -std=c++17 --enable-mlir %s -o %t // clang-format on diff --git a/targettests/execution/qir_cond_for_loop-5.cpp b/targettests/execution/qir_cond_for_loop-5.cpp index 641bcbcb9d6..91c441029b3 100644 --- a/targettests/execution/qir_cond_for_loop-5.cpp +++ b/targettests/execution/qir_cond_for_loop-5.cpp @@ -8,6 +8,7 @@ // clang-format off // RUN: nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s +// RUN: CUDAQ_DEFAULT_SIMULATOR=stim nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s // RUN: nvq++ -std=c++17 --enable-mlir %s -o %t // clang-format on diff --git a/targettests/execution/qir_cond_for_loop-6.cpp b/targettests/execution/qir_cond_for_loop-6.cpp index 0a7ec8ea953..d1c8325c5b2 100644 --- a/targettests/execution/qir_cond_for_loop-6.cpp +++ b/targettests/execution/qir_cond_for_loop-6.cpp @@ -8,6 +8,7 @@ // clang-format off // RUN: nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s +// RUN: CUDAQ_DEFAULT_SIMULATOR=stim nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s // RUN: nvq++ -std=c++17 --enable-mlir %s -o %t // clang-format on diff --git a/targettests/execution/qir_simple_cond-1.cpp b/targettests/execution/qir_simple_cond-1.cpp index 79e8d5000c3..f08cfb52dc0 100644 --- a/targettests/execution/qir_simple_cond-1.cpp +++ b/targettests/execution/qir_simple_cond-1.cpp @@ -8,6 +8,7 @@ // clang-format off // RUN: nvq++ %cpp_std --target quantinuum --emulate %s -o %t && %t | FileCheck %s +// RUN: nvq++ %cpp_std --target stim --enable-mlir %s -o %t && %t | FileCheck %s // RUN: nvq++ -std=c++17 --enable-mlir %s -o %t // clang-format on diff --git a/unittests/integration/builder_tester.cpp b/unittests/integration/builder_tester.cpp index d4053441b88..6bf29634193 100644 --- a/unittests/integration/builder_tester.cpp +++ b/unittests/integration/builder_tester.cpp @@ -575,7 +575,7 @@ CUDAQ_TEST(BuilderTester, checkSwap) { // Conditional execution on the tensornet backend is slow for a large number of // shots. -#if !defined(CUDAQ_BACKEND_TENSORNET) && !defined(CUDAQ_BACKEND_STIM) +#if !defined(CUDAQ_BACKEND_TENSORNET) CUDAQ_TEST(BuilderTester, checkConditional) { { cudaq::set_random_seed(13); @@ -985,7 +985,6 @@ CUDAQ_TEST(BuilderTester, checkMidCircuitMeasure) { EXPECT_EQ(counts.count("0", "c1"), 1000); EXPECT_EQ(counts.count("1", "c0"), 1000); - return; } { @@ -1005,6 +1004,27 @@ CUDAQ_TEST(BuilderTester, checkMidCircuitMeasure) { EXPECT_EQ(counts.count("1", "hello2"), 0); EXPECT_EQ(counts.count("0", "hello2"), 1000); } + + { + // Force conditional sample + auto entryPoint = cudaq::make_kernel(); + auto q = entryPoint.qalloc(2); + entryPoint.h(q[0]); + auto mres = entryPoint.mz(q[0], "res0"); + entryPoint.c_if(mres, [&]() { entryPoint.x(q[1]); }); + entryPoint.mz(q, "final"); + + printf("%s\n", entryPoint.to_quake().c_str()); + auto counts = cudaq::sample(entryPoint); + counts.dump(); + + EXPECT_GT(counts.count("0", "res0"), 0); + EXPECT_GT(counts.count("1", "res0"), 0); + EXPECT_GT(counts.count("00", "final"), 0); + EXPECT_EQ(counts.count("01", "final"), 0); + EXPECT_EQ(counts.count("10", "final"), 0); + EXPECT_GT(counts.count("11", "final"), 0); + } } #endif