-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
214 lines (178 loc) · 7.8 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import tools.Data as td
import tools.process as tp
import torch
import torch.optim as optim
import torch.nn as nn
import pandas as pd
from model.args import Args
from torch.utils.data import DataLoader
from model.BERT import Seq2SeqModel as BART
from model.fusion import FusionModel
from sklearn.model_selection import train_test_split
from tools.fit import Fit, predict
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tools.early import EarlyStopping
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if __name__ == "__main__":
"""----------1 读取数据----------"""
train_file = "data/train.txt"
test_file = "data/test_without_label.txt"
data_dir = "data/data"
train_data, test_data = td.load_data(train_file, test_file, data_dir, 1)
# 划分数据集
tra_guids = list(train_data.keys())
test_guids = list(test_data.keys())
tra_guids, val_guids = train_test_split(tra_guids, test_size=0.2, random_state=42)
tra_data = {guid: train_data.get(guid) for guid in tra_guids}
val_data = {guid: train_data.get(guid) for guid in val_guids}
# 获取数据
train_texts = []
train_labels = []
train_imgs = []
for guid, data in train_data.items():
if data["text"]:
train_texts.append(data["text"])
train_labels.append(data["tag"])
if data["image"]:
train_imgs.append(data["image"])
val_texts = []
val_labels = []
val_imgs = []
for guid, data in val_data.items():
if data["text"]:
val_texts.append(data["text"])
val_labels.append(data["tag"])
if data["image"]:
val_imgs.append(data["image"])
test_texts = []
test_imgs = []
for guid, data in test_data.items():
if "text" in data:
test_texts.append(data["text"])
if "image" in data:
test_imgs.append(data["image"])
# 获取参数
args = Args()
if args.fusion:
mode = 'fusion'
elif args.image_only:
mode = 'image_only'
else:
mode = 'text_only'
"""----------2 处理数据----------"""
print("data processing...")
# 文本数据
train_text_dataset = tp.TextDataset(train_texts, train_labels, args)
val_text_dataset = tp.TextDataset(val_texts, val_labels, args)
test_text_dataset = tp.TextDataset(test_texts, [None] * len(test_texts), args)
train_text_loader = DataLoader(
train_text_dataset, batch_size=args.batch_size, shuffle=True
)
val_text_loader = DataLoader(
val_text_dataset, batch_size=args.batch_size, shuffle=False
)
test_text_loader = DataLoader(
test_text_dataset, batch_size=args.batch_size, shuffle=False
)
# sample = train_text_dataset[0]
# print("Text:", sample["encoder_input"])
# print("Attention Mask:", sample["attention_mask"]) # 注意力掩码
# print("Label:", sample["label"]) # 标签
# print("Text Length:", sample["decoder_len"]) # 文本长度
# 图片数据
train_img_dataset = tp.ImageDataset(train_imgs, train_labels, args)
val_img_dataset = tp.ImageDataset(val_imgs, val_labels, args)
test_img_dataset = tp.ImageDataset(test_imgs, [None] * len(test_imgs), args)
train_img_loader = DataLoader(
train_img_dataset, batch_size=args.batch_size, shuffle=True
)
val_img_loader = DataLoader(
val_img_dataset, batch_size=args.batch_size, shuffle=False
)
test_img_loader = DataLoader(
test_img_dataset, batch_size=args.batch_size, shuffle=False
)
# import numpy as np
# # 保存所有文本张量为 .txt 文件
# with open("all_text_tensors.txt", "w") as f:
# for i in range(len(train_text_dataset)):
# sample = train_text_dataset[i]
# text_tensor = sample["encoder_input"] # 假设文本张量的键是 "encoder_input"
# text_tensor_np = text_tensor.numpy() # 转换为 NumPy 数组
# f.write(f"Sample {i}:\n")
# np.savetxt(f, text_tensor_np, fmt="%.6f") # 保存为文本文件,保留 6 位小数
# f.write("\n") # 添加空行分隔样本
# # 保存所有图片张量为 .txt 文件
# with open("all_img_tensors.txt", "w") as f:
# for i in range(int(len(train_img_dataset)*0.1)):
# sample = train_img_dataset[i]
# img_tensor = sample["image"] # 假设图片张量的键是 "image"
# img_tensor_np = img_tensor.numpy() # 转换为 NumPy 数组
# f.write(f"Sample {i}:\n")
# np.savetxt(f, img_tensor_np.reshape(-1, img_tensor_np.shape[-1]), fmt="%.6f") # 保存为文本文件,保留 6 位小数
# f.write("\n") # 添加空行分隔样本
# # 保存所有标签为 .txt 文件
# with open("all_labels.txt", "w") as f:
# for i in range(len(train_text_dataset)):
# sample = train_text_dataset[i]
# label = sample["label"] # 假设标签的键是 "label"
# f.write(f"Sample {i}: Label = {label}\n")
"""----------3 训练模型----------"""
from torchvision import models
# resnet_model = ResNet().to(device)
resnet_model = models.resnet18(pretrained=True)
resnet_model.fc = nn.Identity()
bart_model = BART(args).to(device)
fusion_model = FusionModel(resnet_model, bart_model, num_classes=3).to(device)
criterion = nn.CrossEntropyLoss() # 对于分类任务,使用交叉熵损失
# 打印 fusion_model 的参数
# for name, param in fusion_model.named_parameters():
# print(name, param.requires_grad)
optimizer = optim.Adam(fusion_model.parameters(), lr=args.lr_rate) # Adam优化器
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3) # 学习率衰减
early_stopping = EarlyStopping(patience=3, verbose=True) # 早停机制
results = Fit(
fusion_model,
train_img_loader,
train_text_loader,
val_img_loader,
val_text_loader,
criterion,
optimizer,
device,
args.epoch,
scheduler,
early_stopping,
mode=mode
)
import json
# 保存训练结果到 JSON 文件
def save_results_to_json(results, filename="./results/training_results.json"):
with open(filename, "w") as f:
json.dump(results, f, indent=4)
print(f"Results saved to {filename}")
save_results_to_json(results)
"""----------4 预测数据----------"""
if args.fusion:
# 预测(多模态融合)
test_predictions = predict(fusion_model, test_img_loader, test_text_loader, device, mode="fusion")
# val_predictions = predict(fusion_model, val_img_loader, val_text_loader, device, mode="fusion")
# print(val_labels)
# print(val_predictions)
test_guids = list(test_data.keys())
results = pd.DataFrame({"guid": test_guids, "tag": test_predictions})
results.to_csv("./results/predictions_fusion.csv", index=False)
print("Predictions saved to predictions_fusion.csv")
if args.image_only:
# 预测(仅图片)
test_predictions_img = predict(fusion_model, test_img_loader, test_text_loader, device, mode="image_only")
results_img = pd.DataFrame({"guid": test_guids, "tag": test_predictions_img})
results_img.to_csv("./results/predictions_image_only.csv", index=False)
print("Predictions saved to predictions_image_only.csv")
if args.text_only:
# 预测(仅文本)
test_predictions_text = predict(fusion_model, test_img_loader, test_text_loader, device, mode="text_only")
results_text = pd.DataFrame({"guid": test_guids, "tag": test_predictions_text})
results_text.to_csv("./results/predictions_text_only.csv", index=False)
print("Predictions saved to predictions_text_only.csv")