-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathmain.py
58 lines (42 loc) · 1.79 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from script.dataset import Dataset
from script.utils import load_config
from script.pipeline import Pipeline
def train_and_test(model_name, data, dataset_root, config):
"""模型训练和测试
使用给定数据和配置训练并测试模型
Inputs:
-------
data: string, 使用的数据集名称, ['DD', 'NCI1', 'PROTEINS']
dataset_root: string, 数据集保存根文件夹路径
config: dict, 参数配置
"""
# 数据获取和预处理
dataset = Dataset(data, dataset_root, **config[data])
# 训练模型
pipeline = Pipeline(model_name, **config[data])
pipeline.train(dataset)
# 测试集准确率
test_loss, test_acc = pipeline.predict(dataset, 'test')
print('[{}]-[{}]-[TestLoss:{:.4f}]-[TestAcc:{:.3f}]\n'.format(
model_name, data.upper(), test_loss, test_acc))
return
if __name__ == '__main__':
# 数据集根目录
dataset_root = '../../Dataset'
# 加载全局配置
config = load_config(config_file='config.yaml')
# 使用DD数据集训练和测试模型
train_and_test('SAGPoolG', 'DD', dataset_root, config)
# [SAGPoolG] DD Test Accuracy: 0.723
train_and_test('SAGPoolH', 'DD', dataset_root, config)
# [SAGPoolH] DD Test Accuracy: 0.745
# 使用NCI1数据集训练和测试模型
train_and_test('SAGPoolG', 'NCI1', dataset_root, config)
# [SAGPoolG] NCI1 Test Accuracy: 0.763
train_and_test('SAGPoolH', 'NCI1', dataset_root, config)
# [SAGPoolH] NCI1 Test Accuracy: 0.648
# 使用PROTEINS数据集训练和测试模型
train_and_test('SAGPoolG', 'PROTEINS', dataset_root, config)
# [SAGPoolG] PROTEINS Test Accuracy: 0.757
train_and_test('SAGPoolH', 'PROTEINS', dataset_root, config)
# [SAGPoolH] PROTEINS Test Accuracy: 0.743