diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 23b1f938555..a6af21f7fb6 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -189,6 +189,9 @@ * The `qml.clifford_t_decomposition` has been improved to use less gates when decomposing `qml.PhaseShift`. [(#6842)](https://github.com/PennyLaneAI/pennylane/pull/6842) + +* `qml.qchem.taper` now handles wire ordering for the tapered observables more robustly. + [(#6954)](https://github.com/PennyLaneAI/pennylane/pull/6954) * A `ParametrizedMidMeasure` class is added to represent a mid-circuit measurement in an arbitrary measurement basis in the XY, YZ or ZX plane. diff --git a/pennylane/qchem/tapering.py b/pennylane/qchem/tapering.py index 3029de2b555..d9df538874c 100644 --- a/pennylane/qchem/tapering.py +++ b/pennylane/qchem/tapering.py @@ -293,44 +293,40 @@ def _taper_pauli_sentence(ps_h, generators, paulixops, paulix_sector): for ps in _split_pauli_sentence(ps_h, max_size=PAULI_SENTENCE_MEMORY_SPLITTING_SIZE): ts_ps += ps_u @ ps @ ps_u # helps restrict the peak memory usage for u @ h @ u - wireset = ps_u.wires + ps_h.wires + wireset = ps_h.wires + ps_u.wires wiremap = dict(zip(list(wireset.toset()), range(len(wireset) + 1))) paulix_wires = [x.wires[0] for x in paulixops] - o = [] - val = np.ones(len(ts_ps)) - - wires_tap = [i for i in ts_ps.wires if i not in paulix_wires] - wiremap_tap = dict(zip(wires_tap, range(len(wires_tap) + 1))) - - for i, pw_coeff in enumerate(ts_ps.items()): - pw, _ = pw_coeff + wires_tap = [i for i in wiremap.keys() if i not in paulix_wires] + wires_ord = list(range(len(wires_tap))) + wiremap_tap = dict(zip(wires_tap, wires_ord)) + obs, val = [], qml.math.ones(len(ts_ps)) + for i, pw in enumerate(ts_ps.keys()): for idx, w in enumerate(paulix_wires): if pw[w] == "X": val[i] *= paulix_sector[idx] - o.append( - qml.pauli.string_to_pauli_word( - "".join([pw[wiremap[i]] for i in wires_tap]), - wire_map=wiremap_tap, + obs.append( + qml.pauli.PauliWord({wiremap_tap[wire]: pw[wire] for wire in wires_tap}).operation( + wire_order=wires_ord ) ) - c = qml.math.stack(qml.math.multiply(val * complex(1.0), list(ts_ps.values()))) - - tapered_ham = qml.simplify(qml.dot(c, o)) - # If simplified Hamiltonian is missing wires, then add wires manually for consistency - if set(wires_tap) != tapered_ham.wires.toset(): - identity_op = functools.reduce( - lambda i, j: i @ j, - [ - qml.Identity(wire) - for wire in Wires.unique_wires([tapered_ham.wires, Wires(wires_tap)]) - ], - ) + interface = qml.math.get_deep_interface(list(ps_h.values())) + coeffs = qml.math.multiply(val, qml.math.array(list(ts_ps.values()), like=interface)) + + if interface == "jax" and qml.math.is_abstract(coeffs): + tapered_ham = qml.sum(*(qml.s_prod(coeff, op) for coeff, op in zip(coeffs, obs))) + else: + if qml.math.all(qml.math.abs(qml.math.imag(coeffs)) <= 1e-8): + coeffs = qml.math.real(coeffs) + tapered_ham = qml.simplify(0.0 * qml.Identity(wires=wires_ord) + qml.dot(coeffs, obs)) - return tapered_ham + (0.0 * identity_op) + # If simplified Hamiltonian is missing wires due to simplification, + # then add wires manually for consistency + if set(wires_ord) != tapered_ham.wires.toset(): + return 0.0 * qml.Identity(wires=wires_ord) + tapered_ham return tapered_ham @@ -481,7 +477,7 @@ def taper_hf(generators, paulixops, paulix_sector, num_electrons, num_wires): # taper the HF observable using the symmetries obtained from the molecular hamiltonian fermop_taper = _taper_pauli_sentence(ferm_ps, generators, paulixops, paulix_sector) fermop_ps = pauli_sentence(fermop_taper) - fermop_mat = _binary_matrix_from_pws(list(fermop_ps), len(fermop_taper.wires)) + fermop_mat = _binary_matrix_from_pws(list(fermop_ps), len(fermop_ps.wires)) # build a wireset to match wires with that of the tapered Hamiltonian gen_wires = Wires.all_wires([generator.wires for generator in generators]) @@ -611,7 +607,8 @@ def _build_generator(operation, wire_order, op_gen=None): return op_gen -# pylint: disable=too-many-branches, too-many-arguments, inconsistent-return-statements, no-member +# pylint: disable=too-many-branches, too-many-arguments, too-many-positional-arguments +# pylint: disable=inconsistent-return-statements, no-member def taper_operation( operation, generators, paulixops, paulix_sector, wire_order, op_wires=None, op_gen=None ): diff --git a/tests/qchem/test_tapering.py b/tests/qchem/test_tapering.py index 05388b27060..512bd969f8f 100644 --- a/tests/qchem/test_tapering.py +++ b/tests/qchem/test_tapering.py @@ -989,3 +989,62 @@ def test_split_pauli_sentence(ps_size, max_size): split_sentence = {**split_sentence, **ps} assert sentence == qml.pauli.PauliSentence(split_sentence) + + +@pytest.mark.parametrize( + ("symbols", "geometry"), + [(["Li", "H"], np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 3.13]]))], +) +def test_taper_wire_order(symbols, geometry): + r"""Test that a tapering workflow results in correct order of wires.""" + + molecule = qml.qchem.Molecule(symbols, geometry) + hamiltonian, num_wires = qml.qchem.molecular_hamiltonian(molecule) + + generators = qml.symmetry_generators(hamiltonian) + paulixops = qml.paulix_ops(generators, num_wires) + paulix_sector = optimal_sector(hamiltonian, generators, molecule.n_electrons) + + tapered_ham = qml.taper(hamiltonian, generators, paulixops, paulix_sector) + assert tapered_ham.wires.tolist() == list(sorted(tapered_ham.wires)) + + +@pytest.mark.jax +@pytest.mark.parametrize( + ("symbols", "geometry", "charge"), + [ + ( + ["H", "H"], + np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 1.40104295]], requires_grad=True), + 0, + ), + ( + ["He", "H"], + np.array( + [[0.0, 0.0, 0.0], [0.0, 0.0, 1.4588684632]], + requires_grad=True, + ), + 1, + ), + ], +) +def test_taper_jax_jit(symbols, geometry, charge): + r"""Test that an observable can be tapred within a jax-jit workflow.""" + + import jax + + molecule = qml.qchem.Molecule(symbols, jax.numpy.array(geometry), charge) + hamiltonian, num_wires = qml.qchem.molecular_hamiltonian(molecule) + + generators = qml.symmetry_generators(hamiltonian) + paulixops = qml.paulix_ops(generators, num_wires) + paulix_sector = tuple(optimal_sector(hamiltonian, generators, molecule.n_electrons)) + + tapered_ham1 = qml.simplify(qml.taper(hamiltonian, generators, paulixops, paulix_sector)) + tapered_ham2 = qml.simplify( + jax.jit(qml.taper, static_argnums=[3])(hamiltonian, generators, paulixops, paulix_sector) + ) + + assert qml.math.get_deep_interface(tapered_ham1.terms()[0]) == "jax" + assert qml.math.get_deep_interface(tapered_ham2.terms()[0]) == "jax" + qml.assert_equal(tapered_ham1, tapered_ham2)