-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PureT verbose, add some extra files, such as multi-gpus training, ens…
…emble testing, online testing
- Loading branch information
Showing
11 changed files
with
2,492 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import os | ||
import sys | ||
import pprint | ||
import random | ||
import time | ||
import tqdm | ||
import logging | ||
import argparse | ||
import numpy as np | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.multiprocessing as mp | ||
import torch.distributed as dist | ||
|
||
import losses | ||
import models | ||
import datasets | ||
import lib.utils as utils | ||
from lib.utils import AverageMeter | ||
from optimizer.optimizer import Optimizer | ||
from evaluation.evaler import Evaler | ||
from scorer.scorer import Scorer | ||
from lib.config import cfg, cfg_from_file | ||
|
||
class Tester(object): | ||
def __init__(self, args): | ||
super(Tester, self).__init__() | ||
self.args = args | ||
self.device = torch.device("cuda") | ||
|
||
self.setup_logging() | ||
self.setup_network() | ||
self.evaler = Evaler( | ||
eval_ids = cfg.DATA_LOADER.TEST_ID, | ||
gv_feat = cfg.DATA_LOADER.TEST_GV_FEAT, | ||
att_feats = cfg.DATA_LOADER.TEST_ATT_FEATS, | ||
eval_annfile = cfg.INFERENCE.TEST_ANNFILE | ||
) | ||
|
||
def setup_logging(self): | ||
self.logger = logging.getLogger(cfg.LOGGER_NAME) | ||
self.logger.setLevel(logging.INFO) | ||
|
||
ch = logging.StreamHandler(stream=sys.stdout) | ||
ch.setLevel(logging.INFO) | ||
formatter = logging.Formatter("[%(levelname)s: %(asctime)s] %(message)s") | ||
ch.setFormatter(formatter) | ||
self.logger.addHandler(ch) | ||
|
||
if not os.path.exists(cfg.ROOT_DIR): | ||
os.makedirs(cfg.ROOT_DIR) | ||
|
||
fh = logging.FileHandler(os.path.join(cfg.ROOT_DIR, 'OfflineTest_' + cfg.LOGGER_NAME + '.txt')) | ||
fh.setLevel(logging.INFO) | ||
fh.setFormatter(formatter) | ||
self.logger.addHandler(fh) | ||
|
||
def setup_network(self): | ||
model = models.create(cfg.MODEL.TYPE) | ||
print(model) | ||
self.model = torch.nn.DataParallel(model).cuda() | ||
if self.args.resume > 0: | ||
self.model.load_state_dict( | ||
torch.load(self.snapshot_path("caption_model", self.args.resume), | ||
map_location=lambda storage, loc: storage) | ||
) | ||
|
||
def eval(self, epoch): | ||
""" | ||
res = self.evaler(self.model, 'test_' + str(epoch)) | ||
self.logger.info('######## Epoch ' + str(epoch) + ' ########') | ||
self.logger.info(str(res)) | ||
""" | ||
print(self.model.module.flops()) | ||
|
||
def snapshot_path(self, name, epoch): | ||
snapshot_folder = os.path.join(cfg.ROOT_DIR, 'snapshot') | ||
return os.path.join(snapshot_folder, name + "_" + str(epoch) + ".pth") | ||
|
||
def parse_args(): | ||
""" | ||
Parse input arguments | ||
""" | ||
parser = argparse.ArgumentParser(description='Image Captioning') | ||
parser.add_argument('--folder', dest='folder', default=None, type=str) | ||
parser.add_argument("--resume", type=int, default=-1) | ||
|
||
if len(sys.argv) == 1: | ||
parser.print_help() | ||
sys.exit(1) | ||
|
||
args = parser.parse_args() | ||
return args | ||
|
||
if __name__ == '__main__': | ||
args = parse_args() | ||
print('Called with args:') | ||
print(args) | ||
|
||
if args.folder is not None: | ||
cfg_from_file(os.path.join(args.folder, 'config.yml')) | ||
cfg.ROOT_DIR = args.folder | ||
|
||
tester = Tester(args) | ||
tester.eval(args.resume) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.