Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes Wires objects as wire labels bug #6933

Merged
merged 15 commits into from
Feb 11, 2025
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,9 @@

<h3>Internal changes ⚙️</h3>

* Fix `qml.wires.Wires` initialization to disallow `Wires` objects as wires labels.
JerryChen97 marked this conversation as resolved.
Show resolved Hide resolved
[(#6933)](https://github.com/PennyLaneAI/pennylane/pull/6933)

* Remove `QNode.get_gradient_fn` from source code.
[(#6898)](https://github.com/PennyLaneAI/pennylane/pull/6898)

Expand Down
12 changes: 10 additions & 2 deletions pennylane/wires.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _process(wires):
if len(set_of_wires) != len(tuple_of_wires):
raise WireError(f"Wires must be unique; got {wires}.")

return tuple_of_wires
return tuple(itertools.chain(*(_flatten_wires_object(x) for x in tuple_of_wires)))
PietropaoloFrisoni marked this conversation as resolved.
Show resolved Hide resolved


class Wires(Sequence):
Expand All @@ -120,7 +120,7 @@ class Wires(Sequence):
"""

def _flatten(self):
"""Serialize Wires into a flattened representation according to the PyTree convension."""
"""Serialize Wires into a flattened representation according to the PyTree convention."""
return self._labels, ()

@classmethod
Expand Down Expand Up @@ -731,5 +731,13 @@ def __rxor__(self, other):

WiresLike = Union[Wires, Iterable[Hashable], Hashable]


def _flatten_wires_object(wire_label):
"""Converts the input to a tuple of wire labels."""
if isinstance(wire_label, Wires):
return wire_label.labels
return [wire_label]


# Register Wires as a PyTree-serializable class
register_pytree(Wires, Wires._flatten, Wires._unflatten) # pylint: disable=protected-access
11 changes: 8 additions & 3 deletions tests/test_wires.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@
class TestWires:
"""Tests for the ``Wires`` class."""

def test_wires_object_as_label(self):
"""Tests that a Wires object can be used as a label for another Wires object."""
assert Wires([0, 1]) == Wires([Wires([0]), Wires([1])])
assert Wires(["a", "b", 1]) == Wires([Wires(["a", "b"]), Wires([1])])

def test_error_if_wires_none(self):
"""Tests that a TypeError is raised if None is given as wires."""
with pytest.raises(TypeError, match="Must specify a set of wires."):
Expand Down Expand Up @@ -74,7 +79,7 @@ def test_creation_from_wires_lists(self):
"""Tests that a Wires object can be created from a list of Wires."""

wires = Wires([Wires([0]), Wires([1]), Wires([2])])
assert wires.labels == (Wires([0]), Wires([1]), Wires([2]))
assert wires.labels == (0, 1, 2)

@pytest.mark.parametrize(
"iterable", [[1, 0, 4], ["a", "b", "c"], [0, 1, None], ["a", 1, "ancilla"]]
Expand Down Expand Up @@ -148,7 +153,7 @@ def test_contains(
wires = Wires([0, 1, 2, 3, Wires([4, 5]), None])

assert 0 in wires
assert Wires([4, 5]) in wires
assert Wires([4, 5]) not in wires
assert None in wires
assert Wires([1]) not in wires
assert Wires([0, 3]) not in wires
Expand All @@ -170,7 +175,7 @@ def test_contains_wires(

assert not wires.contains_wires(0) # wrong type
assert not wires.contains_wires([0, 1]) # wrong type
assert not wires.contains_wires(
assert wires.contains_wires(
Wires([4, 5])
) # looks up 4 and 5 in wires, which are not present

Expand Down