-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
More conservative caching in the CommutationChecker
#13600
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
--- | ||
fixes: | ||
- | | ||
Commutation relations of :class:`~.circuit.Instruction`\ s with float-only ``params`` | ||
were eagerly cached by the :class:`.CommutationChecker`, using the ``params`` as key to | ||
query the relation. This could lead to faulty results, if the instruction's definition | ||
depended on additional information that just the :attr:`~.circuit.Instruction.params` | ||
attribute, such as e.g. the case for :class:`.PauliEvolutionGate`. | ||
This behavior is now fixed, and the commutation checker only conservatively caches | ||
commutations for Qiskit-native standard gates. This can incur a performance cost if you were | ||
relying on your custom gates being cached, however, we cannot guarantee safe caching for | ||
custom gates, as they might rely on information beyond :attr:`~.circuit.Instruction.params`. | ||
- | | ||
Fixed a bug in the :class:`.CommmutationChecker`, where checking commutation of instruction | ||
with non-numeric values in the :attr:`~.circuit.Instruction.params` attribute (such as the | ||
:class:`.PauliGate`) could raise an error. | ||
Fixed `#13570 <https://github.com/Qiskit/qiskit/issues/13570>`__. | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,6 +27,7 @@ | |
Parameter, | ||
QuantumRegister, | ||
Qubit, | ||
QuantumCircuit, | ||
) | ||
from qiskit.circuit.commutation_library import SessionCommutationChecker as scc | ||
from qiskit.circuit.library import ( | ||
|
@@ -37,9 +38,11 @@ | |
CRYGate, | ||
CRZGate, | ||
CXGate, | ||
CUGate, | ||
LinearFunction, | ||
MCXGate, | ||
Measure, | ||
PauliGate, | ||
PhaseGate, | ||
Reset, | ||
RXGate, | ||
|
@@ -82,6 +85,22 @@ def to_matrix(self): | |
return np.array([[1, 0, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0], [0, 1, 0, 0]], dtype=complex) | ||
|
||
|
||
class MyEvilRXGate(Gate): | ||
"""A RX gate designed to annoy the caching mechanism (but a realistic gate nevertheless).""" | ||
|
||
def __init__(self, evil_input_not_in_param: float): | ||
""" | ||
Args: | ||
evil_input_not_in_param: The RX rotation angle. | ||
""" | ||
self.value = evil_input_not_in_param | ||
super().__init__("<evil laugh here>", 1, []) | ||
|
||
def _define(self): | ||
self.definition = QuantumCircuit(1) | ||
self.definition.rx(self.value, 0) | ||
|
||
|
||
@ddt | ||
class TestCommutationChecker(QiskitTestCase): | ||
"""Test CommutationChecker class.""" | ||
|
@@ -137,7 +156,7 @@ def test_standard_gates_commutations(self): | |
def test_caching_positive_results(self): | ||
"""Check that hashing positive results in commutativity checker works as expected.""" | ||
scc.clear_cached_commutations() | ||
self.assertTrue(scc.commute(ZGate(), [0], [], NewGateCX(), [0, 1], [])) | ||
self.assertTrue(scc.commute(ZGate(), [0], [], CUGate(1, 2, 3, 0), [0, 1], [])) | ||
self.assertGreater(scc.num_cached_entries(), 0) | ||
|
||
def test_caching_lookup_with_non_overlapping_qubits(self): | ||
|
@@ -150,27 +169,29 @@ def test_caching_lookup_with_non_overlapping_qubits(self): | |
def test_caching_store_and_lookup_with_non_overlapping_qubits(self): | ||
"""Check that commutations storing and lookup with non-overlapping qubits works as expected.""" | ||
scc_lenm = scc.num_cached_entries() | ||
self.assertTrue(scc.commute(NewGateCX(), [0, 2], [], CXGate(), [0, 1], [])) | ||
self.assertFalse(scc.commute(NewGateCX(), [0, 1], [], CXGate(), [1, 2], [])) | ||
self.assertTrue(scc.commute(NewGateCX(), [1, 4], [], CXGate(), [1, 6], [])) | ||
self.assertFalse(scc.commute(NewGateCX(), [5, 3], [], CXGate(), [3, 1], [])) | ||
cx_like = CUGate(np.pi, 0, np.pi, 0) | ||
self.assertTrue(scc.commute(cx_like, [0, 2], [], CXGate(), [0, 1], [])) | ||
self.assertFalse(scc.commute(cx_like, [0, 1], [], CXGate(), [1, 2], [])) | ||
self.assertTrue(scc.commute(cx_like, [1, 4], [], CXGate(), [1, 6], [])) | ||
self.assertFalse(scc.commute(cx_like, [5, 3], [], CXGate(), [3, 1], [])) | ||
Comment on lines
-153
to
+176
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please tell me if I understand correctly. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correct to all of the above 👍🏻 We don't have a test using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, having an additional test for a benign custom gate makes sense. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added in 44457f7 👍🏻 |
||
self.assertEqual(scc.num_cached_entries(), scc_lenm + 2) | ||
|
||
def test_caching_negative_results(self): | ||
"""Check that hashing negative results in commutativity checker works as expected.""" | ||
scc.clear_cached_commutations() | ||
self.assertFalse(scc.commute(XGate(), [0], [], NewGateCX(), [0, 1], [])) | ||
self.assertFalse(scc.commute(XGate(), [0], [], CUGate(1, 2, 3, 0), [0, 1], [])) | ||
self.assertGreater(scc.num_cached_entries(), 0) | ||
|
||
def test_caching_different_qubit_sets(self): | ||
"""Check that hashing same commutativity results over different qubit sets works as expected.""" | ||
scc.clear_cached_commutations() | ||
# All the following should be cached in the same way | ||
# though each relation gets cached twice: (A, B) and (B, A) | ||
scc.commute(XGate(), [0], [], NewGateCX(), [0, 1], []) | ||
scc.commute(XGate(), [10], [], NewGateCX(), [10, 20], []) | ||
scc.commute(XGate(), [10], [], NewGateCX(), [10, 5], []) | ||
scc.commute(XGate(), [5], [], NewGateCX(), [5, 7], []) | ||
cx_like = CUGate(np.pi, 0, np.pi, 0) | ||
scc.commute(XGate(), [0], [], cx_like, [0, 1], []) | ||
scc.commute(XGate(), [10], [], cx_like, [10, 20], []) | ||
scc.commute(XGate(), [10], [], cx_like, [10, 5], []) | ||
scc.commute(XGate(), [5], [], cx_like, [5, 7], []) | ||
self.assertEqual(scc.num_cached_entries(), 1) | ||
|
||
def test_zero_rotations(self): | ||
|
@@ -377,12 +398,14 @@ def test_serialization(self): | |
"""Test that the commutation checker is correctly serialized""" | ||
import pickle | ||
|
||
cx_like = CUGate(np.pi, 0, np.pi, 0) | ||
|
||
scc.clear_cached_commutations() | ||
self.assertTrue(scc.commute(ZGate(), [0], [], NewGateCX(), [0, 1], [])) | ||
self.assertTrue(scc.commute(ZGate(), [0], [], cx_like, [0, 1], [])) | ||
cc2 = pickle.loads(pickle.dumps(scc)) | ||
self.assertEqual(cc2.num_cached_entries(), 1) | ||
dop1 = DAGOpNode(ZGate(), qargs=[0], cargs=[]) | ||
dop2 = DAGOpNode(NewGateCX(), qargs=[0, 1], cargs=[]) | ||
dop2 = DAGOpNode(cx_like, qargs=[0, 1], cargs=[]) | ||
cc2.commute_nodes(dop1, dop2) | ||
dop1 = DAGOpNode(ZGate(), qargs=[0], cargs=[]) | ||
dop2 = DAGOpNode(CXGate(), qargs=[0, 1], cargs=[]) | ||
|
@@ -430,6 +453,36 @@ def test_rotation_mod_2pi(self, gate_cls): | |
scc.commute(generic_gate, [0], [], gate, list(range(gate.num_qubits)), []) | ||
) | ||
|
||
def test_custom_gate(self): | ||
"""Test a custom gate.""" | ||
my_cx = NewGateCX() | ||
|
||
self.assertTrue(scc.commute(my_cx, [0, 1], [], XGate(), [1], [])) | ||
self.assertFalse(scc.commute(my_cx, [0, 1], [], XGate(), [0], [])) | ||
self.assertTrue(scc.commute(my_cx, [0, 1], [], ZGate(), [0], [])) | ||
|
||
self.assertFalse(scc.commute(my_cx, [0, 1], [], my_cx, [1, 0], [])) | ||
self.assertTrue(scc.commute(my_cx, [0, 1], [], my_cx, [0, 1], [])) | ||
|
||
def test_custom_gate_caching(self): | ||
"""Test a custom gate is correctly handled on consecutive runs.""" | ||
|
||
all_commuter = MyEvilRXGate(0) # this will commute with anything | ||
some_rx = MyEvilRXGate(1.6192) # this should not commute with H | ||
|
||
# the order here is important: we're testing whether the gate that commutes with | ||
# everything is used after the first commutation check, regardless of the internal | ||
# gate parameters | ||
self.assertTrue(scc.commute(all_commuter, [0], [], HGate(), [0], [])) | ||
self.assertFalse(scc.commute(some_rx, [0], [], HGate(), [0], [])) | ||
|
||
def test_nonfloat_param(self): | ||
"""Test commutation-checking on a gate that has non-float ``params``.""" | ||
pauli_gate = PauliGate("XX") | ||
rx_gate_theta = RXGate(Parameter("Theta")) | ||
self.assertTrue(scc.commute(pauli_gate, [0, 1], [], rx_gate_theta, [0], [])) | ||
self.assertTrue(scc.commute(rx_gate_theta, [0], [], pauli_gate, [0, 1], [])) | ||
|
||
alexanderivrii marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I love this. This is what I meant earlier by saying that you are taking the art of writing tests to a whole new level 😄.