Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved Action Space incl. Layouting and BQSKit Integration #204

Merged
merged 20 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
- name: Install MQT Predictor
run: pip install .[coverage]
- name: Generate Report
run: pytest -v --cov --cov-config=pyproject.toml --cov-report=xml --ignore=tests/compilation/test_pretrained_models.py
run: pytest -v --cov --cov-config=pyproject.toml --cov-report=xml --ignore=tests/test_pretrained_models.py
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Upload coverage to Codecov
Expand Down
9 changes: 5 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ dynamic = ["version"]

dependencies = [
"mqt.bench @ git+https://github.com/cda-tum/mqt-bench.git",
"sb3_contrib>=2.0.0",
"scikit-learn>=1.3.0, <1.3.3",
"sb3_contrib>=2.0.0, <2.2.2",
"scikit-learn>=1.4.0,<1.4.2",
"importlib_metadata>=4.4; python_version < '3.10'",
"importlib_resources>=5.0; python_version < '3.10'",
"tensorboard>=2.11.0"
"tensorboard>=2.11.0, <2.16.3",
"bqskit>=1.1.0, <1.1.2",
]

classifiers = [
Expand Down Expand Up @@ -99,7 +100,7 @@ implicit_reexport = true
# recent versions of `gym` are typed, but stable-baselines3 pins a very old version of gym.
# qiskit is not yet marked as typed, but is typed mostly.
# the other libraries do not have type stubs.
module = ["qiskit.*", "joblib.*", "sklearn.*", "matplotlib.*", "gymnasium.*", "mqt.bench.*", "sb3_contrib.*"]
module = ["qiskit.*", "joblib.*", "sklearn.*", "matplotlib.*", "gymnasium.*", "mqt.bench.*", "sb3_contrib.*", "bqskit.*"]
ignore_missing_imports = true

[tool.ruff]
Expand Down
2 changes: 1 addition & 1 deletion src/mqt/predictor/ml/Predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ def generate_eval_all_datapoints(
tmp_res = scores_filtered_sorted_accordingly[i]
max_score = max(tmp_res)
for j in range(len(tmp_res)):
plt.plot(i, tmp_res[j] / max_score, "b.", alpha=1.0, markersize=1.7, color=color_all)
plt.plot(i, tmp_res[j] / max_score, alpha=1.0, markersize=1.7, color=color_all)

plt.plot(
i,
Expand Down
9 changes: 3 additions & 6 deletions src/mqt/predictor/rl/Predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, figure_of_merit: reward.figure_of_merit, device_name: str, lo
logger.setLevel(logger_level)

self.model = None
self.env = rl.PredictorEnv(figure_of_merit, device_name)
self.env = rl.PredictorEnv(reward_function=figure_of_merit, device_name=device_name)
self.device_name = device_name
self.figure_of_merit = figure_of_merit

Expand All @@ -50,7 +50,7 @@ def compile_as_predicted(
raise RuntimeError(msg) from e

assert self.model
obs, _ = self.env.reset(qc) # type: ignore[unreachable]
obs, _ = self.env.reset(qc, seed=0) # type: ignore[unreachable]

used_compilation_passes = []
terminated = False
Expand All @@ -62,7 +62,6 @@ def compile_as_predicted(
action_item = self.env.action_set[action]
used_compilation_passes.append(action_item["name"])
obs, reward_val, terminated, truncated, info = self.env.step(action)
self.env.state._layout = self.env.layout # noqa: SLF001

if not self.env.error_occured:
return self.env.state, used_compilation_passes
Expand Down Expand Up @@ -94,11 +93,9 @@ def train_model(
progress_bar = True

logger.debug("Start training for: " + self.figure_of_merit + " on " + self.device_name)
env = rl.PredictorEnv(reward_function=self.figure_of_merit, device_name=self.device_name)

model = MaskablePPO(
MaskableMultiInputActorCriticPolicy,
env,
self.env,
verbose=verbose,
tensorboard_log="./" + model_name + "_" + self.figure_of_merit + "_" + self.device_name,
gamma=0.98,
Expand Down
140 changes: 116 additions & 24 deletions src/mqt/predictor/rl/PredictorEnv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
from pathlib import Path

import numpy as np
from bqskit.ext import bqskit_to_qiskit, qiskit_to_bqskit
from gymnasium import Env
from gymnasium.spaces import Box, Dict, Discrete
from pytket.circuit import Qubit
from pytket.extensions.qiskit import qiskit_to_tk, tk_to_qiskit
from qiskit import QuantumCircuit
from qiskit.transpiler import CouplingMap, PassManager
from qiskit.transpiler import CouplingMap, PassManager, TranspileLayout
from qiskit.transpiler.passes import CheckMap, GatesInBasis
from qiskit.transpiler.runningpassmanager import TranspileLayout

from mqt.bench.devices import get_device_by_name
from mqt.predictor import reward, rl
Expand All @@ -35,6 +36,7 @@
self.actions_routing_indices = []
self.actions_mapping_indices = []
self.actions_opt_indices = []
self.actions_final_optimization_indices = []
self.used_actions: list[str] = []
self.device = get_device_by_name(device_name)

Expand All @@ -60,14 +62,21 @@
self.action_set[index] = elem
self.actions_mapping_indices.append(index)
index += 1
for elem in rl.helper.get_actions_final_optimization():
self.action_set[index] = elem
self.actions_final_optimization_indices.append(index)
index += 1

self.action_set[index] = rl.helper.get_action_terminate()
self.action_terminate_index = index

self.reward_function = reward_function
self.action_space = Discrete(len(self.action_set.keys()))
self.num_steps = 0
self.layout = None
self.layout: TranspileLayout | None = None
self.num_qubits_uncompiled_circuit = 0

self.has_parametrized_gates = False

spaces = {
"num_qubits": Discrete(128),
Expand Down Expand Up @@ -113,6 +122,7 @@
if self.state.count_ops().get("unitary"):
self.state = self.state.decompose(gates_to_decompose="unitary")

self.state._layout = self.layout # noqa: SLF001
obs = rl.helper.create_feature_dict(self.state)
return obs, reward_val, done, False, {}

Expand Down Expand Up @@ -160,33 +170,55 @@
self.valid_actions = self.actions_opt_indices + self.actions_synthesis_indices

self.error_occured = False

self.num_qubits_uncompiled_circuit = self.state.num_qubits
self.has_parametrized_gates = len(self.state.parameters) > 0
return rl.helper.create_feature_dict(self.state), {}

def action_masks(self) -> list[bool]:
"""Returns a list of valid actions for the current state."""
return [action in self.valid_actions for action in self.action_set]
action_mask = [action in self.valid_actions for action in self.action_set]

# it is not clear how tket will handle the layout, so we remove all actions that are from "origin"=="tket" if a layout is set
if self.layout is not None:
action_mask = [
action_mask[i] and self.action_set[i].get("origin") != "tket" for i in range(len(action_mask))
]

if self.has_parametrized_gates or self.layout is not None:
# remove all actions that are from "origin"=="bqskit" because they are not supported for parametrized gates
# or after layout since using BQSKit after a layout is set may result in an error
action_mask = [
action_mask[i] and self.action_set[i].get("origin") != "bqskit" for i in range(len(action_mask))
]

# only allow VF2PostLayout if "ibm" is in the device name
if "ibm" not in self.device.name:
action_mask = [
action_mask[i] and self.action_set[i].get("name") != "VF2PostLayout" for i in range(len(action_mask))
]
return action_mask

def apply_action(self, action_index: int) -> QuantumCircuit | None:
"""Applies the given action to the current state and returns the altered state."""
if action_index in self.action_set:
action = self.action_set[action_index]
if action["name"] == "terminate":
return self.state
if (
action_index
in self.actions_layout_indices + self.actions_routing_indices + self.actions_mapping_indices
):
transpile_pass = action["transpile_pass"](self.device.coupling_map)
elif action_index in self.actions_synthesis_indices:
transpile_pass = action["transpile_pass"](self.device.basis_gates)
else:
if action_index in self.actions_opt_indices:
transpile_pass = action["transpile_pass"]
else:
transpile_pass = action["transpile_pass"](self.device)

if action["origin"] == "qiskit":
try:
if action["name"] == "QiskitO3":
pm = PassManager()
pm.append(
action["transpile_pass"](self.device.basis_gates, CouplingMap(self.device.coupling_map)),
action["transpile_pass"](
self.device.basis_gates,
CouplingMap(self.device.coupling_map) if self.layout is not None else None,
),
do_while=action["do_while"],
)
else:
Expand All @@ -201,20 +233,55 @@

self.error_occured = True
return None
if action_index in self.actions_layout_indices + self.actions_mapping_indices:
assert pm.property_set["layout"]
self.layout = TranspileLayout(
initial_layout=pm.property_set["layout"],
input_qubit_mapping=pm.property_set["original_qubit_indices"],
final_layout=pm.property_set["final_layout"],
)
if (
action_index
in self.actions_layout_indices
+ self.actions_mapping_indices
+ self.actions_final_optimization_indices
):
if action["name"] == "VF2Layout":
if pm.property_set["layout"]:
altered_qc, pm = rl.helper.postprocess_VF2Layout(
altered_qc,
pm.property_set["layout"],
pm.property_set["original_qubit_indices"],
pm.property_set["final_layout"],
self.device,
)
elif action["name"] == "VF2PostLayout":
assert pm.property_set["VF2PostLayout_stop_reason"] is not None
post_layout = pm.property_set["post_layout"]
if post_layout:
altered_qc, pm = rl.helper.postprocess_VF2PostLayout(altered_qc, post_layout, self.layout)

Check warning on line 255 in src/mqt/predictor/rl/PredictorEnv.py

View check run for this annotation

Codecov / codecov/patch

src/mqt/predictor/rl/PredictorEnv.py#L252-L255

Added lines #L252 - L255 were not covered by tests
else:
assert pm.property_set["layout"]

if pm.property_set["layout"]:
self.layout = TranspileLayout(
initial_layout=pm.property_set["layout"],
input_qubit_mapping=pm.property_set["original_qubit_indices"],
final_layout=pm.property_set["final_layout"],
_output_qubit_list=altered_qc.qubits,
_input_qubit_count=self.num_qubits_uncompiled_circuit,
)

elif action_index in self.actions_routing_indices:
assert self.layout is not None
self.layout.final_layout = pm.property_set["final_layout"]

elif action["origin"] == "tket":
try:
tket_qc = qiskit_to_tk(self.state)
tket_qc = qiskit_to_tk(self.state, preserve_param_uuid=True)
for elem in transpile_pass:
elem.apply(tket_qc)
qbs = tket_qc.qubits
qubit_map = {qbs[i]: Qubit("q", i) for i in range(len(qbs))}
tket_qc.rename_units(qubit_map) # type: ignore[arg-type]
altered_qc = tk_to_qiskit(tket_qc)
if action_index in self.actions_routing_indices:
assert self.layout is not None
self.layout.final_layout = rl.helper.final_layout_pytket_to_qiskit(tket_qc, altered_qc)

Check warning on line 283 in src/mqt/predictor/rl/PredictorEnv.py

View check run for this annotation

Codecov / codecov/patch

src/mqt/predictor/rl/PredictorEnv.py#L282-L283

Added lines #L282 - L283 were not covered by tests

except Exception:
logger.exception(
"Error in executing TKET transpile pass for {action} at step {i} for {filename}".format(
Expand All @@ -224,6 +291,28 @@
self.error_occured = True
return None

elif action["origin"] == "bqskit":
try:
bqskit_qc = qiskit_to_bqskit(self.state)
if action_index in self.actions_opt_indices + self.actions_synthesis_indices:
bqskit_compiled_qc = transpile_pass(bqskit_qc)
altered_qc = bqskit_to_qiskit(bqskit_compiled_qc)
elif action_index in self.actions_mapping_indices:
bqskit_compiled_qc, initial_layout, final_layout = transpile_pass(bqskit_qc)
altered_qc = bqskit_to_qiskit(bqskit_compiled_qc)
layout = rl.helper.final_layout_bqskit_to_qiskit(
initial_layout, final_layout, altered_qc, self.state
)
self.layout = layout
except Exception:
logger.exception(

Check warning on line 308 in src/mqt/predictor/rl/PredictorEnv.py

View check run for this annotation

Codecov / codecov/patch

src/mqt/predictor/rl/PredictorEnv.py#L307-L308

Added lines #L307 - L308 were not covered by tests
"Error in executing BQSKit transpile pass for {action} at step {i} for {filename}".format(
action=action["name"], i=self.num_steps, filename=self.filename
)
)
self.error_occured = True
return None

Check warning on line 314 in src/mqt/predictor/rl/PredictorEnv.py

View check run for this annotation

Codecov / codecov/patch

src/mqt/predictor/rl/PredictorEnv.py#L313-L314

Added lines #L313 - L314 were not covered by tests

else:
error_msg = f"Origin {action['origin']} not supported."
raise ValueError(error_msg)
Expand All @@ -241,16 +330,19 @@
only_nat_gates = check_nat_gates.property_set["all_gates_in_basis"]

if not only_nat_gates:
return self.actions_synthesis_indices + self.actions_opt_indices
actions = self.actions_synthesis_indices + self.actions_opt_indices
if self.layout is not None:
actions += self.actions_routing_indices
return actions

check_mapping = CheckMap(coupling_map=CouplingMap(self.device.coupling_map))
check_mapping(self.state)
mapped = check_mapping.property_set["is_swap_mapped"]

if mapped and self.layout is not None:
return [self.action_terminate_index, *self.actions_opt_indices] # type: ignore[unreachable]
return [self.action_terminate_index, *self.actions_opt_indices]

if self.state._layout is not None: # noqa: SLF001
if self.layout is not None:
return self.actions_routing_indices

# No layout applied yet
Expand Down
Loading
Loading