Skip to content

Commit

Permalink
PureT verbose, add some extra files, such as multi-gpus training, ens…
Browse files Browse the repository at this point in the history
…emble testing, online testing
  • Loading branch information
232525 committed May 30, 2024
1 parent 0c57f4a commit cb3e16d
Show file tree
Hide file tree
Showing 11 changed files with 2,492 additions and 1 deletion.
106 changes: 106 additions & 0 deletions cal_flops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import os
import sys
import pprint
import random
import time
import tqdm
import logging
import argparse
import numpy as np

import torch
import torch.nn as nn
import torch.multiprocessing as mp
import torch.distributed as dist

import losses
import models
import datasets
import lib.utils as utils
from lib.utils import AverageMeter
from optimizer.optimizer import Optimizer
from evaluation.evaler import Evaler
from scorer.scorer import Scorer
from lib.config import cfg, cfg_from_file

class Tester(object):
def __init__(self, args):
super(Tester, self).__init__()
self.args = args
self.device = torch.device("cuda")

self.setup_logging()
self.setup_network()
self.evaler = Evaler(
eval_ids = cfg.DATA_LOADER.TEST_ID,
gv_feat = cfg.DATA_LOADER.TEST_GV_FEAT,
att_feats = cfg.DATA_LOADER.TEST_ATT_FEATS,
eval_annfile = cfg.INFERENCE.TEST_ANNFILE
)

def setup_logging(self):
self.logger = logging.getLogger(cfg.LOGGER_NAME)
self.logger.setLevel(logging.INFO)

ch = logging.StreamHandler(stream=sys.stdout)
ch.setLevel(logging.INFO)
formatter = logging.Formatter("[%(levelname)s: %(asctime)s] %(message)s")
ch.setFormatter(formatter)
self.logger.addHandler(ch)

if not os.path.exists(cfg.ROOT_DIR):
os.makedirs(cfg.ROOT_DIR)

fh = logging.FileHandler(os.path.join(cfg.ROOT_DIR, 'OfflineTest_' + cfg.LOGGER_NAME + '.txt'))
fh.setLevel(logging.INFO)
fh.setFormatter(formatter)
self.logger.addHandler(fh)

def setup_network(self):
model = models.create(cfg.MODEL.TYPE)
print(model)
self.model = torch.nn.DataParallel(model).cuda()
if self.args.resume > 0:
self.model.load_state_dict(
torch.load(self.snapshot_path("caption_model", self.args.resume),
map_location=lambda storage, loc: storage)
)

def eval(self, epoch):
"""
res = self.evaler(self.model, 'test_' + str(epoch))
self.logger.info('######## Epoch ' + str(epoch) + ' ########')
self.logger.info(str(res))
"""
print(self.model.module.flops())

def snapshot_path(self, name, epoch):
snapshot_folder = os.path.join(cfg.ROOT_DIR, 'snapshot')
return os.path.join(snapshot_folder, name + "_" + str(epoch) + ".pth")

def parse_args():
"""
Parse input arguments
"""
parser = argparse.ArgumentParser(description='Image Captioning')
parser.add_argument('--folder', dest='folder', default=None, type=str)
parser.add_argument("--resume", type=int, default=-1)

if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)

args = parser.parse_args()
return args

if __name__ == '__main__':
args = parse_args()
print('Called with args:')
print(args)

if args.folder is not None:
cfg_from_file(os.path.join(args.folder, 'config.yml'))
cfg.ROOT_DIR = args.folder

tester = Tester(args)
tester.eval(args.resume)
11 changes: 10 additions & 1 deletion evaluation/online_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,15 @@ def __init__(
# './mscoco/txt/coco_val_image_id.txt' Karpathy验证集 5K张图像
# './mscoco/txt/coco_test_image_id.txt' Karpathy测试集 5K张图像
# './mscoco/txt/coco_test4w_image_id.txt' MSCOCO在线测试集 4W张图像
self.eval_ids = np.array(utils.load_ids(eval_ids))

# 读取txt文件,读取的为image_ids的list
# self.eval_ids = np.array(utils.load_ids(eval_ids))

# 端到端训练时,直接读取annotation的json文件,其中包含了图像id和路径
# 读取json文件,读取的为{image_id: image_path}的dict
with open(eval_ids, 'r') as f:
self.ids2path = json.load(f) # dict {image_id: image_path}
self.eval_ids = np.array(list(self.ids2path.keys())) # array of str
self.eval_loader = data_loader.load_val(eval_ids, gv_feat, att_feats)

def make_kwargs(self, indices, ids, gv_feat, att_feats, att_mask):
Expand Down Expand Up @@ -55,6 +63,7 @@ def __call__(self, model, rname):
# 构造模型验证结果 {'image_id': ***, 'caption': 'word1 word2 word3 ...'}
result = {cfg.INFERENCE.ID_KEY: int(ids[sid]), cfg.INFERENCE.CAP_KEY: sent}
results.append(result)
print(result)

# 在线测试不需要评估,直接保存模型输出结果即可
result_folder = os.path.join(cfg.ROOT_DIR, 'result')
Expand Down
Loading

0 comments on commit cb3e16d

Please sign in to comment.