Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More conservative caching in the CommutationChecker #13600

Merged
merged 4 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 29 additions & 19 deletions crates/accelerate/src/commutation_checker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,29 @@ use qiskit_circuit::circuit_instruction::{ExtraInstructionAttributes, OperationF
use qiskit_circuit::dag_node::DAGOpNode;
use qiskit_circuit::imports::QI_OPERATOR;
use qiskit_circuit::operations::OperationRef::{Gate as PyGateType, Operation as PyOperationType};
use qiskit_circuit::operations::{Operation, OperationRef, Param, StandardGate};
use qiskit_circuit::operations::{
get_standard_gate_names, Operation, OperationRef, Param, StandardGate,
};
use qiskit_circuit::{BitType, Clbit, Qubit};

use crate::unitary_compose;
use crate::QiskitError;

const TWOPI: f64 = 2.0 * std::f64::consts::PI;

// These gates do not commute with other gates, we do not check them.
static SKIPPED_NAMES: [&str; 4] = ["measure", "reset", "delay", "initialize"];
static NO_CACHE_NAMES: [&str; 2] = ["annotated", "linear_function"];

// We keep a hash-set of operations eligible for commutation checking. This is because checking
// eligibility is not for free.
static SUPPORTED_OP: Lazy<HashSet<&str>> = Lazy::new(|| {
HashSet::from([
"rxx", "ryy", "rzz", "rzx", "h", "x", "y", "z", "sx", "sxdg", "t", "tdg", "s", "sdg", "cx",
"cy", "cz", "swap", "iswap", "ecr", "ccx", "cswap",
])
});

const TWOPI: f64 = 2.0 * std::f64::consts::PI;

// map rotation gates to their generators, or to ``None`` if we cannot currently efficiently
// Map rotation gates to their generators, or to ``None`` if we cannot currently efficiently
// represent the generator in Rust and store the commutation relation in the commutation dictionary
static SUPPORTED_ROTATIONS: Lazy<HashMap<&str, Option<OperationRef>>> = Lazy::new(|| {
HashMap::from([
Expand Down Expand Up @@ -322,15 +327,17 @@ impl CommutationChecker {
(qargs1, qargs2)
};

let skip_cache: bool = NO_CACHE_NAMES.contains(&first_op.name()) ||
NO_CACHE_NAMES.contains(&second_op.name()) ||
// Skip params that do not evaluate to floats for caching and commutation library
first_params.iter().any(|p| !matches!(p, Param::Float(_))) ||
second_params.iter().any(|p| !matches!(p, Param::Float(_)))
&& !SUPPORTED_OP.contains(op1.name())
&& !SUPPORTED_OP.contains(op2.name());

if skip_cache {
// For our cache to work correctly, we require the gate's definition to only depend on the
// ``params`` attribute. This cannot be guaranteed for custom gates, so we only check
// the cache for our standard gates, which we know are defined by the ``params`` AND
// that the ``params`` are float-only at this point.
let whitelist = get_standard_gate_names();
let check_cache = whitelist.contains(&first_op.name())
&& whitelist.contains(&second_op.name())
&& first_params.iter().all(|p| matches!(p, Param::Float(_)))
&& second_params.iter().all(|p| matches!(p, Param::Float(_)));

if !check_cache {
return self.commute_matmul(
py,
first_op,
Expand Down Expand Up @@ -630,21 +637,24 @@ fn map_rotation<'a>(
) -> (&'a OperationRef<'a>, &'a [Param], bool) {
let name = op.name();
if let Some(generator) = SUPPORTED_ROTATIONS.get(name) {
// if the rotation angle is below the tolerance, the gate is assumed to
// If the rotation angle is below the tolerance, the gate is assumed to
// commute with everything, and we simply return the operation with the flag that
// it commutes trivially
// it commutes trivially.
if let Param::Float(angle) = params[0] {
if (angle % TWOPI).abs() < tol {
return (op, params, true);
};
};

// otherwise, we check if a generator is given -- if not, we'll just return the operation
// itself (e.g. RXX does not have a generator and is just stored in the commutations
// dictionary)
// Otherwise we need to cover two cases -- either a generator is given, in which case
// we return it, or we don't have a generator yet, but we know we have the operation
// stored in the commutation library. For example, RXX does not have a generator in Rust
// yet (PauliGate is not in Rust currently), but it is stored in the library, so we
// can strip the parameters and just return the gate.
if let Some(gate) = generator {
return (gate, &[], false);
};
return (op, &[], false);
}
(op, params, false)
}
Expand Down
5 changes: 5 additions & 0 deletions crates/circuit/src/operations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,11 @@ static STANDARD_GATE_NAME: [&str; STANDARD_GATE_SIZE] = [
"rcccx", // 51 ("rc3x")
];

/// Get a slice of all standard gate names.
pub fn get_standard_gate_names() -> &'static [&'static str] {
&STANDARD_GATE_NAME
}

impl StandardGate {
pub fn create_py_op(
&self,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
---
fixes:
- |
Commutation relations of :class:`~.circuit.Instruction`\ s with float-only ``params``
were eagerly cached by the :class:`.CommutationChecker`, using the ``params`` as key to
query the relation. This could lead to faulty results, if the instruction's definition
depended on additional information that just the :attr:`~.circuit.Instruction.params`
attribute, such as e.g. the case for :class:`.PauliEvolutionGate`.
This behavior is now fixed, and the commutation checker only conservatively caches
commutations for Qiskit-native standard gates. This can incur a performance cost if you were
relying on your custom gates being cached, however, we cannot guarantee safe caching for
custom gates, as they might rely on information beyond :attr:`~.circuit.Instruction.params`.
- |
Fixed a bug in the :class:`.CommmutationChecker`, where checking commutation of instruction
with non-numeric values in the :attr:`~.circuit.Instruction.params` attribute (such as the
:class:`.PauliGate`) could raise an error.
Fixed `#13570 <https://github.com/Qiskit/qiskit/issues/13570>`__.
77 changes: 65 additions & 12 deletions test/python/circuit/test_commutation_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
Parameter,
QuantumRegister,
Qubit,
QuantumCircuit,
)
from qiskit.circuit.commutation_library import SessionCommutationChecker as scc
from qiskit.circuit.library import (
Expand All @@ -37,9 +38,11 @@
CRYGate,
CRZGate,
CXGate,
CUGate,
LinearFunction,
MCXGate,
Measure,
PauliGate,
PhaseGate,
Reset,
RXGate,
Expand Down Expand Up @@ -82,6 +85,22 @@ def to_matrix(self):
return np.array([[1, 0, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0], [0, 1, 0, 0]], dtype=complex)


class MyEvilRXGate(Gate):
"""A RX gate designed to annoy the caching mechanism (but a realistic gate nevertheless)."""

def __init__(self, evil_input_not_in_param: float):
"""
Args:
evil_input_not_in_param: The RX rotation angle.
"""
self.value = evil_input_not_in_param
super().__init__("<evil laugh here>", 1, [])
Comment on lines +91 to +97
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I love this. This is what I meant earlier by saying that you are taking the art of writing tests to a whole new level 😄.


def _define(self):
self.definition = QuantumCircuit(1)
self.definition.rx(self.value, 0)


@ddt
class TestCommutationChecker(QiskitTestCase):
"""Test CommutationChecker class."""
Expand Down Expand Up @@ -137,7 +156,7 @@ def test_standard_gates_commutations(self):
def test_caching_positive_results(self):
"""Check that hashing positive results in commutativity checker works as expected."""
scc.clear_cached_commutations()
self.assertTrue(scc.commute(ZGate(), [0], [], NewGateCX(), [0, 1], []))
self.assertTrue(scc.commute(ZGate(), [0], [], CUGate(1, 2, 3, 0), [0, 1], []))
self.assertGreater(scc.num_cached_entries(), 0)

def test_caching_lookup_with_non_overlapping_qubits(self):
Expand All @@ -150,27 +169,29 @@ def test_caching_lookup_with_non_overlapping_qubits(self):
def test_caching_store_and_lookup_with_non_overlapping_qubits(self):
"""Check that commutations storing and lookup with non-overlapping qubits works as expected."""
scc_lenm = scc.num_cached_entries()
self.assertTrue(scc.commute(NewGateCX(), [0, 2], [], CXGate(), [0, 1], []))
self.assertFalse(scc.commute(NewGateCX(), [0, 1], [], CXGate(), [1, 2], []))
self.assertTrue(scc.commute(NewGateCX(), [1, 4], [], CXGate(), [1, 6], []))
self.assertFalse(scc.commute(NewGateCX(), [5, 3], [], CXGate(), [3, 1], []))
cx_like = CUGate(np.pi, 0, np.pi, 0)
self.assertTrue(scc.commute(cx_like, [0, 2], [], CXGate(), [0, 1], []))
self.assertFalse(scc.commute(cx_like, [0, 1], [], CXGate(), [1, 2], []))
self.assertTrue(scc.commute(cx_like, [1, 4], [], CXGate(), [1, 6], []))
self.assertFalse(scc.commute(cx_like, [5, 3], [], CXGate(), [3, 1], []))
Comment on lines -153 to +176
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please tell me if I understand correctly. NewGateCX is not a standard gate, so with this PR we will no longer cache its commutation relations. And cx_like is a standard gate (a CUGate). This is why it makes sense to update the tests to reason about cx_like instead of NewGateCX, correct? However, the scc.commute method should still be able to handle NewGateCX. So maybe it makes sense to leave a test that deals with custom gates. Do we still have such a test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct to all of the above 👍🏻

We don't have a test using NewGateCX concretely, only with the EvilRXGate, which should test the same thing. I can add a test with NewGateCX though just to be safe.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, having an additional test for a benign custom gate makes sense.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in 44457f7 👍🏻

self.assertEqual(scc.num_cached_entries(), scc_lenm + 2)

def test_caching_negative_results(self):
"""Check that hashing negative results in commutativity checker works as expected."""
scc.clear_cached_commutations()
self.assertFalse(scc.commute(XGate(), [0], [], NewGateCX(), [0, 1], []))
self.assertFalse(scc.commute(XGate(), [0], [], CUGate(1, 2, 3, 0), [0, 1], []))
self.assertGreater(scc.num_cached_entries(), 0)

def test_caching_different_qubit_sets(self):
"""Check that hashing same commutativity results over different qubit sets works as expected."""
scc.clear_cached_commutations()
# All the following should be cached in the same way
# though each relation gets cached twice: (A, B) and (B, A)
scc.commute(XGate(), [0], [], NewGateCX(), [0, 1], [])
scc.commute(XGate(), [10], [], NewGateCX(), [10, 20], [])
scc.commute(XGate(), [10], [], NewGateCX(), [10, 5], [])
scc.commute(XGate(), [5], [], NewGateCX(), [5, 7], [])
cx_like = CUGate(np.pi, 0, np.pi, 0)
scc.commute(XGate(), [0], [], cx_like, [0, 1], [])
scc.commute(XGate(), [10], [], cx_like, [10, 20], [])
scc.commute(XGate(), [10], [], cx_like, [10, 5], [])
scc.commute(XGate(), [5], [], cx_like, [5, 7], [])
self.assertEqual(scc.num_cached_entries(), 1)

def test_zero_rotations(self):
Expand Down Expand Up @@ -377,12 +398,14 @@ def test_serialization(self):
"""Test that the commutation checker is correctly serialized"""
import pickle

cx_like = CUGate(np.pi, 0, np.pi, 0)

scc.clear_cached_commutations()
self.assertTrue(scc.commute(ZGate(), [0], [], NewGateCX(), [0, 1], []))
self.assertTrue(scc.commute(ZGate(), [0], [], cx_like, [0, 1], []))
cc2 = pickle.loads(pickle.dumps(scc))
self.assertEqual(cc2.num_cached_entries(), 1)
dop1 = DAGOpNode(ZGate(), qargs=[0], cargs=[])
dop2 = DAGOpNode(NewGateCX(), qargs=[0, 1], cargs=[])
dop2 = DAGOpNode(cx_like, qargs=[0, 1], cargs=[])
cc2.commute_nodes(dop1, dop2)
dop1 = DAGOpNode(ZGate(), qargs=[0], cargs=[])
dop2 = DAGOpNode(CXGate(), qargs=[0, 1], cargs=[])
Expand Down Expand Up @@ -430,6 +453,36 @@ def test_rotation_mod_2pi(self, gate_cls):
scc.commute(generic_gate, [0], [], gate, list(range(gate.num_qubits)), [])
)

def test_custom_gate(self):
"""Test a custom gate."""
my_cx = NewGateCX()

self.assertTrue(scc.commute(my_cx, [0, 1], [], XGate(), [1], []))
self.assertFalse(scc.commute(my_cx, [0, 1], [], XGate(), [0], []))
self.assertTrue(scc.commute(my_cx, [0, 1], [], ZGate(), [0], []))

self.assertFalse(scc.commute(my_cx, [0, 1], [], my_cx, [1, 0], []))
self.assertTrue(scc.commute(my_cx, [0, 1], [], my_cx, [0, 1], []))

def test_custom_gate_caching(self):
"""Test a custom gate is correctly handled on consecutive runs."""

all_commuter = MyEvilRXGate(0) # this will commute with anything
some_rx = MyEvilRXGate(1.6192) # this should not commute with H

# the order here is important: we're testing whether the gate that commutes with
# everything is used after the first commutation check, regardless of the internal
# gate parameters
self.assertTrue(scc.commute(all_commuter, [0], [], HGate(), [0], []))
self.assertFalse(scc.commute(some_rx, [0], [], HGate(), [0], []))

def test_nonfloat_param(self):
"""Test commutation-checking on a gate that has non-float ``params``."""
pauli_gate = PauliGate("XX")
rx_gate_theta = RXGate(Parameter("Theta"))
self.assertTrue(scc.commute(pauli_gate, [0, 1], [], rx_gate_theta, [0], []))
self.assertTrue(scc.commute(rx_gate_theta, [0], [], pauli_gate, [0, 1], []))

alexanderivrii marked this conversation as resolved.
Show resolved Hide resolved

if __name__ == "__main__":
unittest.main()
Loading