Skip to content

Commit

Permalink
use more descriptive script and variable names (#21)
Browse files Browse the repository at this point in the history
* rename file

Signed-off-by: fred@pegasus <[email protected]>

* update file name

Signed-off-by: fred@pegasus <[email protected]>

* update file and model name

Signed-off-by: fred@pegasus <[email protected]>
  • Loading branch information
fredzzhang authored Dec 5, 2021
1 parent 41e58fb commit 38cb045
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
12 changes: 6 additions & 6 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from mpl_toolkits.axes_grid1 import make_axes_locatable

from utils import DataFactory
from detector import build_detector
from upt import build_detector

warnings.filterwarnings("ignore")

Expand Down Expand Up @@ -133,27 +133,27 @@ def main(args):
actions = dataset.dataset.verbs if args.dataset == 'hicodet' else \
dataset.dataset.actions

detector = build_detector(args, conversion)
detector.eval()
upt = build_detector(args, conversion)
upt.eval()

if os.path.exists(args.resume):
print(f"=> Continue from saved checkpoint {args.resume}")
checkpoint = torch.load(args.resume, map_location='cpu')
detector.load_state_dict(checkpoint['model_state_dict'])
upt.load_state_dict(checkpoint['model_state_dict'])
else:
print(f"=> Start from a randomly initialised model")

if args.image_path is None:
image, _ = dataset[args.index]
output = detector([image])
output = upt([image])
image = dataset.dataset.load_image(
os.path.join(dataset.dataset._root,
dataset.dataset.filename(args.index)
))
else:
image = dataset.dataset.load_image(args.image_path)
image_tensor, _ = dataset.transforms(image, None)
output = detector([image_tensor])
output = upt([image_tensor])

visualise_entire_image(image, output[0], actions, args.action, args.action_score_thresh)

Expand Down
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch.multiprocessing as mp
from torch.utils.data import DataLoader, DistributedSampler

from detector import build_detector
from upt import build_detector
from utils import custom_collate, CustomisedDLE, DataFactory

warnings.filterwarnings("ignore")
Expand Down
2 changes: 1 addition & 1 deletion detector.py → upt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Human-object interaction detector
Unary-pairwise transformer for human-object interaction detection
Fred Zhang <[email protected]>
Expand Down

0 comments on commit 38cb045

Please sign in to comment.