Skip to content

Commit

Permalink
Merge branch 'rllm-team:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
JianwuZheng413 authored Sep 5, 2024
2 parents cfef218 + 72d231c commit 3d31b39
Show file tree
Hide file tree
Showing 3 changed files with 362 additions and 11 deletions.
33 changes: 22 additions & 11 deletions examples/rect.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# The RECT method from the
# "Network Embedding with Completely-imbalanced Labels" paper.
# ArXiv: https://arxiv.org/abs/2007.03545

# RECT focuses on the zero-shot setting,
# where only parts of classes have labeled samples.
# As such, "unseen" classes are first removed from the training set.
# Then we train a RECT (or more specifically its supervised part RECT-L) model.
# Lastly, we use the Logistic Regression model to evaluate the performance
# of the resulted embeddings based on the original balanced labels.

# Datasets Citeseer | Cora | Pubmed
# Unseen Classes [1, 2, 5] [3, 4] | [1, 2, 3] [3, 4, 6] | [2]
# RECT-L 66.50 68.40 | 74.80 72.20 | 75.30

import argparse
import copy
import os.path as osp
Expand All @@ -11,27 +26,23 @@
from rllm.datasets.planetoid import PlanetoidDataset
from rllm.nn.models.rect import RECT_L

# RECT focuses on the zero-shot setting,
# where only parts of classes have labeled samples.
# As such, "unseen" classes are first removed from the training set.
# Then we train a RECT (or more specifically its supervised part RECT-L) model.
# Lastly, we use the Logistic Regression model to evaluate the performance
# of the resulted embeddings based on the original balanced labels.

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='Cora',
choices=['Cora', 'CiteSeer', 'PubMed'])
parser.add_argument('--unseen-classes', type=int, nargs='*', default=[1, 2, 3])
parser.add_argument('--epochs', type=int, default=200, help="Training epochs")
args = parser.parse_args()

transform = T.Compose([
T.NormalizeFeatures('l2'),
T.SVDFeatureReduction(200),
T.GCNNorm()
T.GDC()
])

path = osp.join(osp.dirname(osp.realpath(__file__)), '../data')
dataset = PlanetoidDataset(path, args.dataset, transform=transform, force_reload=True)
dataset = PlanetoidDataset(
path, args.dataset, transform=transform, force_reload=True)
data = dataset[0]

zs_data = T.RemoveTrainingClasses(args.unseen_classes)(copy.deepcopy(data))
Expand All @@ -48,18 +59,18 @@

model, zs_data = model.to(device), zs_data.to(device)

criterion = torch.nn.MSELoss(reduction='mean')
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)

model.train()
st = time.time()
for epoch in range(1, 51):
for epoch in range(1, args.epochs+1):
optimizer.zero_grad()
out = model(zs_data.x, zs_data.adj)
loss = criterion(out[zs_data.train_mask], zs_data.y)
loss.backward()
optimizer.step()
# print(f'Epoch {epoch:03d}, Loss {loss:.4f}')
print(f'Epoch {epoch:03d}, Loss {loss:.4f}')
et = time.time()
model.eval()
with torch.no_grad():
Expand Down
3 changes: 3 additions & 0 deletions rllm/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from .knn_graph import KNNGraph # noqa
from .gcn_norm import GCNNorm # noqa
from .build_homo_graph import build_homo_graph # noqa
from .gdc import GDC


general_transforms = [
'BaseTransform',
Expand All @@ -23,6 +25,7 @@
'RemoveSelfLoops',
'KNNGraph',
'GCNNorm',
'GDC',
]

graph_builders = [
Expand Down
Loading

0 comments on commit 3d31b39

Please sign in to comment.