-
Notifications
You must be signed in to change notification settings - Fork 115
/
twitter-train-test-splits.py
31 lines (24 loc) · 1.04 KB
/
twitter-train-test-splits.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import networkx as nx
import pandas as pd
import pickle
import numpy as np
from gae.preprocessing import mask_test_edges_directed
RANDOM_SEED = 0
twitter_adj = pickle.load(open('./twitter/twitter-combined-adj.pkl', 'rb'))
FRAC_EDGES_HIDDEN = [0.25, 0.5, 0.75]
TRAIN_TEST_SPLIT_DIR = './train-test-splits/'
# Generate 1 train/test split for each frac_edges_hidden setting
for frac_hidden in FRAC_EDGES_HIDDEN:
val_frac = 0.1
test_frac = frac_hidden - val_frac
# Set random seed
np.random.seed(RANDOM_SEED)
# Generate train_test_split:
# (adj_train, train_edges, train_edges_false, val_edges, val_edges_false, test_edges, test_edges_false)
train_test_split = mask_test_edges_directed(twitter_adj,
test_frac=test_frac, val_frac=val_frac,
verbose=True, prevent_disconnect=False, false_edge_sampling='random')
# Save split
file_name = TRAIN_TEST_SPLIT_DIR + 'twitter-combined-{}-hidden.pkl'.format(frac_hidden)
with open(file_name, 'wb') as f:
pickle.dump(train_test_split, f, protocol=2)