Gradient estimators
One of the most important element of performing optimization algorithms is computing the gradient of a physical observable's gradient with respect to a set of circuit parameters:
In this tutorial, we introduce the gradient estimators provided by QURI Parts. They are:
- Numerical gradient estimator: A gradient estimator that estimates the gradient based on finite difference method.
- Parameter shift gradient estimator: A gradient estimator that estimates the gradient based on the parameter shift method.
Prerequisite
QURI Parts modules used in this tutorial: quri-parts-circuit
, quri-parts-core
, and quri-parts-qulacs
. You can install them as follows:
!pip install "quri-parts[qulacs]"
Interface
A gradient estimator is represented by the GradientEstimator
interface. It represents a function that estimates gradient values of an expectation value of a given Operator
for a given parametric state with given parameter values (the third argument). It's function signature is
from typing import Callable, Sequence, Union
from typing_extensions import TypeAlias, TypeVar
from quri_parts.core.estimator import Estimatable, Estimates
from quri_parts.core.state import ParametricCircuitQuantumState, ParametricQuantumStateVector
# Generic type of parametric states
_ParametricStateT = TypeVar(
"_ParametricStateT",
bound=Union[ParametricCircuitQuantumState, ParametricQuantumStateVector],
)
# Function signature of a `GradientEstimator` defined in QURI Parts.
GradientEstimator: TypeAlias = Callable[
[Estimatable, _ParametricStateT, Sequence[float]],
Estimates[complex],
]
You may create a GradientEstimator
from a generating function. They are often named as create_..._gradient_estimator
. To create a GradientEstimator
, you need to pass in a ConcurrentParametricQuantumEstimator
to the generating function. Here, we use the one provided by quri_parts.qulacs
from quri_parts.qulacs.estimator import create_qulacs_vector_concurrent_parametric_estimator
concurrent_parametric_estimator = create_qulacs_vector_concurrent_parametric_estimator()
Preparation
Let's prepare the operator and the parametric state we use through out this tutorial.
from quri_parts.core.operator import Operator, pauli_label
operator = Operator({
pauli_label("X0 Y1"): 0.5,
pauli_label("Z0 X1"): 0.2,
})
The linear mapping of the parametric circuit is slightly different from previous sections. Here, the circuit parameter and gate parameters are related via:
for aesthetical reason when we discuss the details of the parameter shift rule later.
import numpy as np
from quri_parts.circuit import LinearMappedUnboundParametricQuantumCircuit, CONST
from quri_parts.core.state import quantum_state
n_qubits = 2
linear_param_circuit = LinearMappedUnboundParametricQuantumCircuit(n_qubits)
theta, phi = linear_param_circuit.add_parameters("theta", "phi")
linear_param_circuit.add_H_gate(0)
linear_param_circuit.add_CNOT_gate(0, 1)
linear_param_circuit.add_ParametricRX_gate(0, {theta: np.pi/2, phi: np.pi/3, CONST: np.pi/2})
linear_param_circuit.add_ParametricRY_gate(0, {theta: -np.pi/2, phi: np.pi/3})
linear_param_circuit.add_ParametricRZ_gate(1, {theta: np.pi/3, phi: -np.pi/2, CONST: -np.pi/2})
param_state = quantum_state(n_qubits, circuit=linear_param_circuit)
Numerical gradient estimator
The numerical gradient estimator computes the gradient according to the finite difference method, i.e.
with being a small number we can freely set. Thus, to create a numerical gradient estimator, we need to pass in along with the concurrent parametric estimator.
from quri_parts.core.estimator.gradient import create_numerical_gradient_estimator
numerical_gradient_estimator = create_numerical_gradient_estimator(
concurrent_parametric_estimator,
delta=1e-10
)
Now, we may estimate the gradient of the parametric state on .
numerical_gradient_estimator(operator, param_state, [0.1, 0.2]).values
[(-0.3508326962275987+0j), (0.5306499684110122+0j)]
Parameter shift gradient estimator
The parameter shift rule was introduced in the cited paper below 1. As a very quick review, we may write the parameter shift rule as: