Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
asdf committed Feb 14, 2024
1 parent 2571dcb commit 2c6e2e7
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 55 deletions.
8 changes: 4 additions & 4 deletions src/mqt/predictor/ml/GNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class Net(nn.Module): # type: ignore[misc]
beta: bool = False
bias: bool = True
root_weight: bool = True
model: str = "TransformerConv"
model: str = "GCN"
jk: str = "last"
v2: bool = True

Expand All @@ -46,10 +46,10 @@ def __init__(self, **kwargs: object) -> None:
if self.edge_embedding_dim and self.edge_embedding_dim > 1:
self.edge_embedding = nn.Embedding(self.num_edge_categories, self.edge_embedding_dim)

if self.node_embedding_dim and self.node_embedding_dim == 1:
if self.node_embedding_dim and self.node_embedding_dim == 1: # one-hot encoding
self.node_embedding = lambda x: F.one_hot(x, num_classes=self.num_node_categories).float()
self.node_embedding_dim = self.num_node_categories
if self.edge_embedding_dim and self.edge_embedding_dim == 1:
if self.edge_embedding_dim and self.edge_embedding_dim == 1: # one-hot encoding
self.edge_embedding = lambda x: F.one_hot(x, num_classes=self.num_edge_categories).float()
self.edge_embedding_dim = self.num_edge_categories

Expand Down Expand Up @@ -115,7 +115,7 @@ def __init__(self, **kwargs: object) -> None:
act=self.activation_func,
norm=self.batch_norm_layer if self.batch_norm else None,
edge_dim=self.edge_embedding_dim if self.edge_embedding_dim else 1,
v2=True,
v2=self.v2,
)
]
last_hidden_dim = self.output_dim
Expand Down
9 changes: 3 additions & 6 deletions src/mqt/predictor/ml/GNN_zx.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, **kwargs: object) -> None:
if self.node_embedding_dim and self.node_embedding_dim > 1:
self.node_embedding = nn.Embedding(self.num_node_categories, self.node_embedding_dim)

if self.node_embedding_dim and self.node_embedding_dim == 1:
if self.node_embedding_dim and self.node_embedding_dim == 1: # one-hot encoding
self.node_embedding = lambda x: F.one_hot(x, num_classes=self.num_node_categories).float()
self.node_embedding_dim = self.num_node_categories

Expand Down Expand Up @@ -141,13 +141,10 @@ def __init__(self, **kwargs: object) -> None:
elif self.readout == "max":
self.pooling = Sequential("x, batch", [(self.out_nn, "x -> x"), (global_max_pool, "x, batch -> x")])

self.mu_layer = nn.Linear(last_hidden_dim, self.output_dim)
self.logstd_layer = nn.Linear(last_hidden_dim, self.output_dim)

def forward(self, data: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
x, edge_index, _edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
x, edge_index, batch = data.x, data.edge_index, data.batch

# Apply the node and edge embeddings
# Apply the node embedding
if self.node_embedding_dim:
x_0 = self.node_embedding(x[:, 0].long()).squeeze()
x_1 = x[:, 1].float().unsqueeze(1)
Expand Down
7 changes: 4 additions & 3 deletions src/mqt/predictor/ml/GNNclassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,9 @@ def fit(self, dataset: Dataset) -> None:
for batch in loader:
self.optim.zero_grad()
out = self.gnn.forward(batch)
target = batch.y.view(-1, self.output_dim)
loss = torch.nn.MSELoss()(out, target) # compute the MSE loss
target = batch.y.view(-1, len(self.output_mask))
masked_target = target[:, self.output_mask]
loss = torch.nn.MSELoss()(out, masked_target) # compute the MSE loss
loss.backward()
self.optim.step()
return
Expand All @@ -111,7 +112,7 @@ def predict(self, dataset: Dataset) -> torch.Tensor:

def score(self, dataset: Dataset) -> float:
pred = self.predict(dataset)
labels = torch.stack([data.y for data in dataset]).argmax(dim=1)
labels = torch.stack([data.y for data in dataset])[:, self.output_mask].argmax(dim=1)
correct = pred.eq(labels).sum().item()
total = len(dataset)
return int(correct) / total
Expand Down
53 changes: 20 additions & 33 deletions src/mqt/predictor/ml/Predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,11 +313,7 @@ def generate_training_sample(
if num_not_empty_entries == 0:
logger.warning("no compiled circuits found for:" + str(file))

try:
feature_vec = ml.helper.create_feature_dict(str(path_uncompiled_circuit / file))
except Exception as e:
print(e, file)
return (([], -1), "", scores)
feature_vec = ml.helper.create_feature_dict(str(path_uncompiled_circuit / file))

training_sample = (list(feature_vec.values()), np.argmax(scores))
circuit_name = str(file).split(".")[0]
Expand Down Expand Up @@ -392,41 +388,32 @@ def get_prepared_training_data(
scores_list: list[list[float]] = [[] for _ in range(len(raw_scores_list))]
X_raw = list(unzipped_training_data_X)
X_list: list[list[float]] = [[] for _ in range(len(X_raw))]
if graph_only:
X_graph: list[Data] = [None for _ in range(len(X_raw))]
X_graph: list[list[Data]] = [[] for _ in range(len(X_raw))]
y_list = list(unzipped_training_data_Y)
for i in range(len(X_raw)):
if not X_raw[i]:
continue
X_list[i] = list(X_raw[i][:-2]) # all but graphs
if graph_only:
X_graph[i] = X_raw[i][-1] # graph feature
X_graph[i] = list(X_raw[i][:2]) # only graphs
else:
X_list[i] = list(X_raw[i][2:]) # all but graphs
scores_list[i] = list(raw_scores_list[i])

# remove all empty (erroneous) files
X_list = [x for x, raw in zip(X_list, X_raw) if raw]
y_list = [y for y, raw in zip(y_list, X_raw) if raw]
names_list = [name for name, raw in zip(names_list, X_raw) if raw]
scores_list = [score for score, raw in zip(scores_list, X_raw) if raw]
if graph_only:
X_graph = [x for x, raw in zip(X_graph, X_raw) if raw]
y, indices = np.array(y_list), np.array(range(len(y_list)))

X, y, indices = np.array(X_list), np.array(y_list), np.array(range(len(y_list)))

# Store all non zero feature indices
non_zero_indices = [i for i in range(len(X[0])) if sum(X[:, i]) > 0]
X = X[:, non_zero_indices]

if save_non_zero_indices:
data = np.asarray(non_zero_indices)
np.save(
ml.helper.get_path_trained_model(figure_of_merit, return_non_zero_indices=True),
data,
)

# Overwrite X with graph features only
if graph_only:
X = X_graph # type: ignore[assignment]
X = X_graph
else:
X = np.array(X_list)
# Store all non zero feature indices
non_zero_indices = [i for i in range(len(X[0])) if sum(X[:, i]) > 0]
X = X[:, non_zero_indices]

if save_non_zero_indices:
data = np.asarray(non_zero_indices)
np.save(
ml.helper.get_path_trained_model(figure_of_merit, return_non_zero_indices=True),
data,
)

(
X_train,
X_test,
Expand Down
26 changes: 17 additions & 9 deletions src/mqt/predictor/ml/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,15 +199,25 @@ def create_feature_dict(qc: str | QuantumCircuit) -> dict[str, Any]:

ops_list = qc.count_ops()
ops_list_dict = dict_to_featurevector(ops_list)
feature_dict = {}

# operations/gates encoding for graph feature creation
ops_list_encoding = ops_list_dict.copy()
ops_list_encoding["measure"] = len(ops_list_encoding) # add extra gate
# unique number for each gate {'measure': 0, 'cx': 1, ...}
for i, key in enumerate(ops_list_dict):
ops_list_encoding[key] = i
try:
# operations/gates encoding for graph feature creation
ops_list_encoding = ops_list_dict.copy()
ops_list_encoding["measure"] = len(ops_list_encoding) # add extra gate
# unique number for each gate {'measure': 0, 'cx': 1, ...}
for i, key in enumerate(ops_list_dict):
ops_list_encoding[key] = i

feature_dict["graph"] = circuit_to_graph(qc, ops_list_encoding)
except Exception:
feature_dict["graph"] = None

try:
feature_dict["zx_graph"] = qasm_to_zx(qc.qasm())
except Exception: # e.g. zx-calculus not supported for all circuits
feature_dict["zx_graph"] = None

feature_dict = {}
for key in ops_list_dict:
feature_dict[key] = float(ops_list_dict[key])

Expand All @@ -220,8 +230,6 @@ def create_feature_dict(qc: str | QuantumCircuit) -> dict[str, Any]:
feature_dict["entanglement_ratio"] = supermarq_features.entanglement_ratio
feature_dict["parallelism"] = supermarq_features.parallelism
feature_dict["liveness"] = supermarq_features.liveness
feature_dict["graph"] = circuit_to_graph(qc, ops_list_encoding)
feature_dict["zx_graph"] = qasm_to_zx(qc.qasm())
return feature_dict


Expand Down

0 comments on commit 2c6e2e7

Please sign in to comment.