Skip to content

Commit

Permalink
added test
Browse files Browse the repository at this point in the history
  • Loading branch information
Coerulatus committed Dec 18, 2024
1 parent 6a0e4ec commit 78f5ed2
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 10 deletions.
131 changes: 131 additions & 0 deletions test/data/batching/test_neighbor_cells_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import os
import shutil
import rootutils
from hydra import compose
import torch

from topobenchmark.data.preprocessor import PreProcessor
from topobenchmark.data.utils.utils import load_manual_graph
from topobenchmark.data.batching import NeighborCellsLoader
from topobenchmark.run import initialize_hydra

initialize_hydra()

path = "/temp/graph2simplicial_lifting/"
if os.path.isdir(path):
shutil.rmtree(path)
cfg = compose(config_name="run.yaml",
overrides=["dataset=graph/manual_dataset", "model=simplicial/san"],
return_hydra_config=True)

data = load_manual_graph()
preprocessed_dataset = PreProcessor(data, path, cfg['transforms'])
data = preprocessed_dataset[0]

batch_size=2

rank = 0
n_cells = data[f'x_{rank}'].shape[0]
train_prop = 0.5
n_train = int(train_prop * n_cells)
train_mask = torch.zeros(n_cells, dtype=torch.bool)
train_mask[:n_train] = 1

y = torch.zeros(n_cells, dtype=torch.long)
data.y = y

loader = NeighborCellsLoader(data,
rank=rank,
num_neighbors=[-1],
input_nodes=train_mask,
batch_size=batch_size,
shuffle=False)
train_nodes = []
for batch in loader:
train_nodes += [n for n in batch.n_id[:batch_size]]
for i in range(n_train):
assert i in train_nodes

rank = 1
n_cells = data[f'x_{rank}'].shape[0]
train_prop = 0.5
n_train = int(train_prop * n_cells)
train_mask = torch.zeros(n_cells, dtype=torch.bool)
train_mask[:n_train] = 1

y = torch.zeros(n_cells, dtype=torch.long)
data.y = y

loader = NeighborCellsLoader(data,
rank=rank,
num_neighbors=[-1,-1],
input_nodes=train_mask,
batch_size=batch_size,
shuffle=False)

train_nodes = []
for batch in loader:
train_nodes += [n for n in batch.n_id[:batch_size]]
for i in range(n_train):
assert i in train_nodes
shutil.rmtree(path)


path = "/temp/graph2hypergraph_lifting/"
if os.path.isdir(path):
shutil.rmtree(path)
cfg = compose(config_name="run.yaml",
overrides=["dataset=graph/manual_dataset", "model=hypergraph/allsettransformer"],
return_hydra_config=True)

data = load_manual_graph()
preprocessed_dataset = PreProcessor(data, path, cfg['transforms'])
data = preprocessed_dataset[0]

batch_size=2

rank = 0
n_cells = data[f'x_0'].shape[0]
train_prop = 0.5
n_train = int(train_prop * n_cells)
train_mask = torch.zeros(n_cells, dtype=torch.bool)
train_mask[:n_train] = 1

y = torch.zeros(n_cells, dtype=torch.long)
data.y = y

loader = NeighborCellsLoader(data,
rank=rank,
num_neighbors=[-1],
input_nodes=train_mask,
batch_size=batch_size,
shuffle=False)
train_nodes = []
for batch in loader:
train_nodes += [n for n in batch.n_id[:batch_size]]
for i in range(n_train):
assert i in train_nodes

rank = 1
n_cells = data[f'x_hyperedges'].shape[0]
train_prop = 0.5
n_train = int(train_prop * n_cells)
train_mask = torch.zeros(n_cells, dtype=torch.bool)
train_mask[:n_train] = 1

y = torch.zeros(n_cells, dtype=torch.long)
data.y = y

loader = NeighborCellsLoader(data,
rank=rank,
num_neighbors=[-1,-1],
input_nodes=train_mask,
batch_size=batch_size,
shuffle=False)

train_nodes = []
for batch in loader:
train_nodes += [n for n in batch.n_id[:batch_size]]
for i in range(n_train):
assert i in train_nodes
shutil.rmtree(path)
7 changes: 7 additions & 0 deletions topobenchmark/data/batching/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
""" Init file for batching module. """

from .neighbor_cells_loader import NeighborCellsLoader

__all__ = [
"NeighborCellsLoader",
]
2 changes: 1 addition & 1 deletion topobenchmark/data/batching/cell_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def filter_fn(
out = self.transform_sampler_output(out)

if isinstance(out, SamplerOutput) and isinstance(self.data, Data):
data = filter_data( #
data = filter_data(
self.data, out.node, self.rank)
else:
raise TypeError(f"'{self.__class__.__name__}'' found invalid "
Expand Down
17 changes: 12 additions & 5 deletions topobenchmark/data/batching/neighbor_cells_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from topobenchmark.data.batching.cell_loader import CellLoader
from topobenchmark.data.batching.utils import get_sampled_neighborhood
from topobenchmark.dataloader import DataloadDataset

from torch_geometric.data import Data, FeatureStore, GraphStore, HeteroData

Expand Down Expand Up @@ -121,24 +122,30 @@ def __init__(
is_sorted: bool = False,
filter_per_worker: Optional[bool] = None,
neighbor_sampler: Optional[NeighborSampler] = None,
directed: bool = True, # Deprecated.
directed: bool = True,
**kwargs,
):
if input_time is not None and time_attr is None:
raise ValueError("Received conflicting 'input_time' and "
"'time_attr' arguments: 'input_time' is set "
"while 'time_attr' is not set.")

is_hypergraph = hasattr(data, 'incidence_hyperedges')
data_obj = Data()
if isinstance(data, DataloadDataset):
for tensor, name in zip(data[0][0], data[0][1]):
setattr(data_obj, name, tensor)
else:
data_obj = data
is_hypergraph = hasattr(data_obj, 'incidence_hyperedges')
n_hops = len(num_neighbors)
data = get_sampled_neighborhood(data, rank, n_hops, is_hypergraph)
data_obj = get_sampled_neighborhood(data_obj, rank, n_hops, is_hypergraph)
self.rank = rank
if self.rank != 0:
# When rank is different than 0 get_sampled_neighborhood connects cells that are up to n_hops away, meaning that the NeighborhoodSampler needs to consider only one hop.
num_neighbors = [num_neighbors[0]]
if neighbor_sampler is None:
neighbor_sampler = NeighborSampler(
data,
data_obj,
num_neighbors=num_neighbors,
replace=replace,
subgraph_type=subgraph_type,
Expand All @@ -152,7 +159,7 @@ def __init__(
)

super().__init__(
data=data,
data=data_obj,
cell_sampler=neighbor_sampler,
input_cells=input_nodes,
input_time=input_time,
Expand Down
14 changes: 10 additions & 4 deletions topobenchmark/data/batching/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ def reduce_higher_ranks_incidences(batch, cells_ids, rank, max_rank, is_hypergra
incidence = torch.index_select(incidence, 0, cells_ids[i-1])
cells_ids[i] = torch.where(torch.sum(incidence, dim=0).to_dense() > 1)[0]
incidence = torch.index_select(incidence, 1, cells_ids[i])
batch[f"incidence_{i}"] = incidence
if is_hypergraph:
batch.incidence_hyperedges = incidence
else:
batch[f"incidence_{i}"] = incidence

return batch, cells_ids

Expand Down Expand Up @@ -76,7 +79,10 @@ def reduce_lower_ranks_incidences(batch, cells_ids, rank, is_hypergraph=False):
incidence = torch.index_select(incidence, 1, cells_ids[i])
cells_ids[i-1] = torch.where(torch.sum(incidence, dim=1).to_dense() > 0)[0]
incidence = torch.index_select(incidence, 0, cells_ids[i-1])
batch[f"incidence_{i}"] = incidence
if is_hypergraph:
batch.incidence_hyperedges = incidence
else:
batch[f"incidence_{i}"] = incidence

if not is_hypergraph:
incidence = batch[f"incidence_0"]
Expand Down Expand Up @@ -275,8 +281,8 @@ def get_sampled_neighborhood(data, rank=0, n_hops=1, is_hypergraph=False):
# P = torch.sparse.mm(data[f"incidence_{i}"], P)
# Q = torch.sparse.mm(P.T,P)
# edges = torch.cat((edges, Q.indices()), dim=1)
edges = A_sum.coalesce().indices()

edges = A_sum.coalesce().indices()
# Remove self edges
mask = edges[0, :] != edges[1, :]
edges = edges[:, mask]
Expand Down

0 comments on commit 78f5ed2

Please sign in to comment.