Skip to content

Commit

Permalink
More conservative caching in the CommutationChecker (#13600)
Browse files Browse the repository at this point in the history
* conservative commutation check

* tests and reno

* reno in the right location

* more tests for custom gates

(cherry picked from commit 93d796f)
  • Loading branch information
Cryoris authored and mergify[bot] committed Jan 16, 2025
1 parent 5c6cfe1 commit 7ea1f39
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 31 deletions.
48 changes: 29 additions & 19 deletions crates/accelerate/src/commutation_checker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,29 @@ use qiskit_circuit::circuit_instruction::{ExtraInstructionAttributes, OperationF
use qiskit_circuit::dag_node::DAGOpNode;
use qiskit_circuit::imports::QI_OPERATOR;
use qiskit_circuit::operations::OperationRef::{Gate as PyGateType, Operation as PyOperationType};
use qiskit_circuit::operations::{Operation, OperationRef, Param, StandardGate};
use qiskit_circuit::operations::{
get_standard_gate_names, Operation, OperationRef, Param, StandardGate,
};
use qiskit_circuit::{BitType, Clbit, Qubit};

use crate::unitary_compose;
use crate::QiskitError;

const TWOPI: f64 = 2.0 * std::f64::consts::PI;

// These gates do not commute with other gates, we do not check them.
static SKIPPED_NAMES: [&str; 4] = ["measure", "reset", "delay", "initialize"];
static NO_CACHE_NAMES: [&str; 2] = ["annotated", "linear_function"];

// We keep a hash-set of operations eligible for commutation checking. This is because checking
// eligibility is not for free.
static SUPPORTED_OP: Lazy<HashSet<&str>> = Lazy::new(|| {
HashSet::from([
"rxx", "ryy", "rzz", "rzx", "h", "x", "y", "z", "sx", "sxdg", "t", "tdg", "s", "sdg", "cx",
"cy", "cz", "swap", "iswap", "ecr", "ccx", "cswap",
])
});

const TWOPI: f64 = 2.0 * std::f64::consts::PI;

// map rotation gates to their generators, or to ``None`` if we cannot currently efficiently
// Map rotation gates to their generators, or to ``None`` if we cannot currently efficiently
// represent the generator in Rust and store the commutation relation in the commutation dictionary
static SUPPORTED_ROTATIONS: Lazy<HashMap<&str, Option<OperationRef>>> = Lazy::new(|| {
HashMap::from([
Expand Down Expand Up @@ -322,15 +327,17 @@ impl CommutationChecker {
(qargs1, qargs2)
};

let skip_cache: bool = NO_CACHE_NAMES.contains(&first_op.name()) ||
NO_CACHE_NAMES.contains(&second_op.name()) ||
// Skip params that do not evaluate to floats for caching and commutation library
first_params.iter().any(|p| !matches!(p, Param::Float(_))) ||
second_params.iter().any(|p| !matches!(p, Param::Float(_)))
&& !SUPPORTED_OP.contains(op1.name())
&& !SUPPORTED_OP.contains(op2.name());

if skip_cache {
// For our cache to work correctly, we require the gate's definition to only depend on the
// ``params`` attribute. This cannot be guaranteed for custom gates, so we only check
// the cache for our standard gates, which we know are defined by the ``params`` AND
// that the ``params`` are float-only at this point.
let whitelist = get_standard_gate_names();
let check_cache = whitelist.contains(&first_op.name())
&& whitelist.contains(&second_op.name())
&& first_params.iter().all(|p| matches!(p, Param::Float(_)))
&& second_params.iter().all(|p| matches!(p, Param::Float(_)));

if !check_cache {
return self.commute_matmul(
py,
first_op,
Expand Down Expand Up @@ -630,21 +637,24 @@ fn map_rotation<'a>(
) -> (&'a OperationRef<'a>, &'a [Param], bool) {
let name = op.name();
if let Some(generator) = SUPPORTED_ROTATIONS.get(name) {
// if the rotation angle is below the tolerance, the gate is assumed to
// If the rotation angle is below the tolerance, the gate is assumed to
// commute with everything, and we simply return the operation with the flag that
// it commutes trivially
// it commutes trivially.
if let Param::Float(angle) = params[0] {
if (angle % TWOPI).abs() < tol {
return (op, params, true);
};
};

// otherwise, we check if a generator is given -- if not, we'll just return the operation
// itself (e.g. RXX does not have a generator and is just stored in the commutations
// dictionary)
// Otherwise we need to cover two cases -- either a generator is given, in which case
// we return it, or we don't have a generator yet, but we know we have the operation
// stored in the commutation library. For example, RXX does not have a generator in Rust
// yet (PauliGate is not in Rust currently), but it is stored in the library, so we
// can strip the parameters and just return the gate.
if let Some(gate) = generator {
return (gate, &[], false);
};
return (op, &[], false);
}
(op, params, false)
}
Expand Down
5 changes: 5 additions & 0 deletions crates/circuit/src/operations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,11 @@ static STANDARD_GATE_NAME: [&str; STANDARD_GATE_SIZE] = [
"rcccx", // 51 ("rc3x")
];

/// Get a slice of all standard gate names.
pub fn get_standard_gate_names() -> &'static [&'static str] {
&STANDARD_GATE_NAME
}

impl StandardGate {
pub fn create_py_op(
&self,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
---
fixes:
- |
Commutation relations of :class:`~.circuit.Instruction`\ s with float-only ``params``
were eagerly cached by the :class:`.CommutationChecker`, using the ``params`` as key to
query the relation. This could lead to faulty results, if the instruction's definition
depended on additional information that just the :attr:`~.circuit.Instruction.params`
attribute, such as e.g. the case for :class:`.PauliEvolutionGate`.
This behavior is now fixed, and the commutation checker only conservatively caches
commutations for Qiskit-native standard gates. This can incur a performance cost if you were
relying on your custom gates being cached, however, we cannot guarantee safe caching for
custom gates, as they might rely on information beyond :attr:`~.circuit.Instruction.params`.
- |
Fixed a bug in the :class:`.CommmutationChecker`, where checking commutation of instruction
with non-numeric values in the :attr:`~.circuit.Instruction.params` attribute (such as the
:class:`.PauliGate`) could raise an error.
Fixed `#13570 <https://github.com/Qiskit/qiskit/issues/13570>`__.
77 changes: 65 additions & 12 deletions test/python/circuit/test_commutation_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
Parameter,
QuantumRegister,
Qubit,
QuantumCircuit,
)
from qiskit.circuit.commutation_library import SessionCommutationChecker as scc
from qiskit.circuit.library import (
Expand All @@ -37,9 +38,11 @@
CRYGate,
CRZGate,
CXGate,
CUGate,
LinearFunction,
MCXGate,
Measure,
PauliGate,
PhaseGate,
Reset,
RXGate,
Expand Down Expand Up @@ -82,6 +85,22 @@ def to_matrix(self):
return np.array([[1, 0, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0], [0, 1, 0, 0]], dtype=complex)


class MyEvilRXGate(Gate):
"""A RX gate designed to annoy the caching mechanism (but a realistic gate nevertheless)."""

def __init__(self, evil_input_not_in_param: float):
"""
Args:
evil_input_not_in_param: The RX rotation angle.
"""
self.value = evil_input_not_in_param
super().__init__("<evil laugh here>", 1, [])

def _define(self):
self.definition = QuantumCircuit(1)
self.definition.rx(self.value, 0)


@ddt
class TestCommutationChecker(QiskitTestCase):
"""Test CommutationChecker class."""
Expand Down Expand Up @@ -137,7 +156,7 @@ def test_standard_gates_commutations(self):
def test_caching_positive_results(self):
"""Check that hashing positive results in commutativity checker works as expected."""
scc.clear_cached_commutations()
self.assertTrue(scc.commute(ZGate(), [0], [], NewGateCX(), [0, 1], []))
self.assertTrue(scc.commute(ZGate(), [0], [], CUGate(1, 2, 3, 0), [0, 1], []))
self.assertGreater(scc.num_cached_entries(), 0)

def test_caching_lookup_with_non_overlapping_qubits(self):
Expand All @@ -150,27 +169,29 @@ def test_caching_lookup_with_non_overlapping_qubits(self):
def test_caching_store_and_lookup_with_non_overlapping_qubits(self):
"""Check that commutations storing and lookup with non-overlapping qubits works as expected."""
scc_lenm = scc.num_cached_entries()
self.assertTrue(scc.commute(NewGateCX(), [0, 2], [], CXGate(), [0, 1], []))
self.assertFalse(scc.commute(NewGateCX(), [0, 1], [], CXGate(), [1, 2], []))
self.assertTrue(scc.commute(NewGateCX(), [1, 4], [], CXGate(), [1, 6], []))
self.assertFalse(scc.commute(NewGateCX(), [5, 3], [], CXGate(), [3, 1], []))
cx_like = CUGate(np.pi, 0, np.pi, 0)
self.assertTrue(scc.commute(cx_like, [0, 2], [], CXGate(), [0, 1], []))
self.assertFalse(scc.commute(cx_like, [0, 1], [], CXGate(), [1, 2], []))
self.assertTrue(scc.commute(cx_like, [1, 4], [], CXGate(), [1, 6], []))
self.assertFalse(scc.commute(cx_like, [5, 3], [], CXGate(), [3, 1], []))
self.assertEqual(scc.num_cached_entries(), scc_lenm + 2)

def test_caching_negative_results(self):
"""Check that hashing negative results in commutativity checker works as expected."""
scc.clear_cached_commutations()
self.assertFalse(scc.commute(XGate(), [0], [], NewGateCX(), [0, 1], []))
self.assertFalse(scc.commute(XGate(), [0], [], CUGate(1, 2, 3, 0), [0, 1], []))
self.assertGreater(scc.num_cached_entries(), 0)

def test_caching_different_qubit_sets(self):
"""Check that hashing same commutativity results over different qubit sets works as expected."""
scc.clear_cached_commutations()
# All the following should be cached in the same way
# though each relation gets cached twice: (A, B) and (B, A)
scc.commute(XGate(), [0], [], NewGateCX(), [0, 1], [])
scc.commute(XGate(), [10], [], NewGateCX(), [10, 20], [])
scc.commute(XGate(), [10], [], NewGateCX(), [10, 5], [])
scc.commute(XGate(), [5], [], NewGateCX(), [5, 7], [])
cx_like = CUGate(np.pi, 0, np.pi, 0)
scc.commute(XGate(), [0], [], cx_like, [0, 1], [])
scc.commute(XGate(), [10], [], cx_like, [10, 20], [])
scc.commute(XGate(), [10], [], cx_like, [10, 5], [])
scc.commute(XGate(), [5], [], cx_like, [5, 7], [])
self.assertEqual(scc.num_cached_entries(), 1)

def test_zero_rotations(self):
Expand Down Expand Up @@ -377,12 +398,14 @@ def test_serialization(self):
"""Test that the commutation checker is correctly serialized"""
import pickle

cx_like = CUGate(np.pi, 0, np.pi, 0)

scc.clear_cached_commutations()
self.assertTrue(scc.commute(ZGate(), [0], [], NewGateCX(), [0, 1], []))
self.assertTrue(scc.commute(ZGate(), [0], [], cx_like, [0, 1], []))
cc2 = pickle.loads(pickle.dumps(scc))
self.assertEqual(cc2.num_cached_entries(), 1)
dop1 = DAGOpNode(ZGate(), qargs=[0], cargs=[])
dop2 = DAGOpNode(NewGateCX(), qargs=[0, 1], cargs=[])
dop2 = DAGOpNode(cx_like, qargs=[0, 1], cargs=[])
cc2.commute_nodes(dop1, dop2)
dop1 = DAGOpNode(ZGate(), qargs=[0], cargs=[])
dop2 = DAGOpNode(CXGate(), qargs=[0, 1], cargs=[])
Expand Down Expand Up @@ -430,6 +453,36 @@ def test_rotation_mod_2pi(self, gate_cls):
scc.commute(generic_gate, [0], [], gate, list(range(gate.num_qubits)), [])
)

def test_custom_gate(self):
"""Test a custom gate."""
my_cx = NewGateCX()

self.assertTrue(scc.commute(my_cx, [0, 1], [], XGate(), [1], []))
self.assertFalse(scc.commute(my_cx, [0, 1], [], XGate(), [0], []))
self.assertTrue(scc.commute(my_cx, [0, 1], [], ZGate(), [0], []))

self.assertFalse(scc.commute(my_cx, [0, 1], [], my_cx, [1, 0], []))
self.assertTrue(scc.commute(my_cx, [0, 1], [], my_cx, [0, 1], []))

def test_custom_gate_caching(self):
"""Test a custom gate is correctly handled on consecutive runs."""

all_commuter = MyEvilRXGate(0) # this will commute with anything
some_rx = MyEvilRXGate(1.6192) # this should not commute with H

# the order here is important: we're testing whether the gate that commutes with
# everything is used after the first commutation check, regardless of the internal
# gate parameters
self.assertTrue(scc.commute(all_commuter, [0], [], HGate(), [0], []))
self.assertFalse(scc.commute(some_rx, [0], [], HGate(), [0], []))

def test_nonfloat_param(self):
"""Test commutation-checking on a gate that has non-float ``params``."""
pauli_gate = PauliGate("XX")
rx_gate_theta = RXGate(Parameter("Theta"))
self.assertTrue(scc.commute(pauli_gate, [0, 1], [], rx_gate_theta, [0], []))
self.assertTrue(scc.commute(rx_gate_theta, [0], [], pauli_gate, [0, 1], []))


if __name__ == "__main__":
unittest.main()

0 comments on commit 7ea1f39

Please sign in to comment.