diff --git a/demonstrations/tutorial_eqnn_force_field.metadata.json b/demonstrations/tutorial_eqnn_force_field.metadata.json index 99dd1bcc61..21043e02ae 100644 --- a/demonstrations/tutorial_eqnn_force_field.metadata.json +++ b/demonstrations/tutorial_eqnn_force_field.metadata.json @@ -9,7 +9,7 @@ } ], "dateOfPublication": "2024-03-12T00:00:00+00:00", - "dateOfLastModification": "2024-11-06T00:00:00+00:00", + "dateOfLastModification": "2024-11-25T00:00:00+00:00", "categories": [ "Quantum Machine Learning", "Quantum Chemistry" diff --git a/demonstrations/tutorial_eqnn_force_field.py b/demonstrations/tutorial_eqnn_force_field.py index 45373bf0b4..6f58512d73 100644 --- a/demonstrations/tutorial_eqnn_force_field.py +++ b/demonstrations/tutorial_eqnn_force_field.py @@ -143,6 +143,10 @@ import matplotlib.pyplot as plt import sklearn +###################################################################### +# To speed up the computation, we also import `Catalyst `_, a jit compiler for PennyLane quantum programs. +import catalyst + ###################################################################### # Let us construct Pauli matrices, which are used to build the Hamiltonian. X = np.array([[0, 1], [1, 0]]) @@ -301,10 +305,12 @@ def noise_layer(epsilon, wires): ################################# -dev = qml.device("default.qubit", wires=num_qubits) +###################################################################### +# We will be running our program using `lightning.qubit`, our performant state-vector simulator. +dev = qml.device("lightning.qubit", wires=num_qubits) -@qml.qnode(dev, interface="jax") +@qml.qnode(dev) def vqlm(data, params): weights = params["params"]["weights"] @@ -396,25 +402,27 @@ def vqlm(data, params): ) ################################# -# We will know define the cost function and how to train the model using Jax. We will use the mean-square-error loss function. -# To speed up the computation, we use the decorator ``@jax.jit`` to do just-in-time compilation for this execution. This means the first execution will typically take a little longer with the -# benefit that all following executions will be significantly faster, see the `Jax docs on jitting `_. +# We will now define the cost function and how to train the model using Jax. We will use the mean-square-error loss function. +# We use the decorator ``@catalyst.qjit`` to do just-in-time compilation for this execution. This means the first execution will typically take a little longer with the +# benefit that all following executions will be significantly faster (see the `Catalyst documentation `_). ################################# from jax.example_libraries import optimizers # We vectorize the model over the data points -vec_vqlm = jax.vmap(vqlm, (0, None), 0) +vec_vqlm = catalyst.vmap( + vqlm, + in_axes=(0, {"params": {"alphas": None, "epsilon": None, "weights": None}}), + out_axes=0, +) # Mean-squared-error loss function -@jax.jit def mse_loss(predictions, targets): return jnp.mean(0.5 * (predictions - targets) ** 2) # Make prediction and compute the loss -@jax.jit def cost(weights, loss_data): data, E_target, F_target = loss_data E_pred = vec_vqlm(data, weights) @@ -424,17 +432,19 @@ def cost(weights, loss_data): # Perform one training step -@jax.jit +# This function will be repeatedly called, so we qjit it to exploit the saved runtime from many runs. +@catalyst.qjit def train_step(step_i, opt_state, loss_data): net_params = get_params(opt_state) - loss, grads = jax.value_and_grad(cost, argnums=0)(net_params, loss_data) - + loss = cost(net_params, loss_data) + grads = catalyst.grad(cost, method="fd", h=1e-13, argnums=0)(net_params, loss_data) return loss, opt_update(step_i, grads, opt_state) # Return prediction and loss at inference times, e.g. for testing -@jax.jit +# This function is also repeatedly called, so qjit it. +@catalyst.qjit def inference(loss_data, opt_state): data, E_target, F_target = loss_data @@ -475,11 +485,12 @@ def inference(loss_data, opt_state): # We train our VQLM using stochastic gradient descent. -num_batches = 5000 # number of optimization steps -batch_size = 256 # number of training data per batch +num_batches = 200 # 5000 # number of optimization steps +batch_size = 5 # 256 # number of training data per batch for ibatch in range(num_batches): + #print(ibatch) # select a batch of training points batch = np.random.choice(np.arange(np.shape(data_train)[0]), batch_size, replace=False)