Skip to content

Commit

Permalink
Add: can save labels as json
Browse files Browse the repository at this point in the history
  • Loading branch information
nik1806 committed Jul 17, 2021
1 parent 45b73a4 commit d24b6b2
Showing 1 changed file with 28 additions and 5 deletions.
33 changes: 28 additions & 5 deletions eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@
from torchvision.utils import save_image
from datetime import datetime

import json

# loader = torch.utils.data.DataLoader(data, batch_size=500)

noisy_labels = [] # global list

def save_images(imgs, pred, save_dir):
for i in range(len(imgs)):
img = imgs[i]
Expand All @@ -21,8 +25,15 @@ def save_images(imgs, pred, save_dir):
n = now.strftime("%H_%M_%S_%f")
save_image(img, save_dir+'/' + l+'/'+n+'.png') # may be scaling is a problem

def add_to_json(pred):
# save noisy labels in .json file
for i in range(len(pred)):
l = int(pred[i].cpu()[0].numpy())
global noisy_labels
noisy_labels.append(l)


def get_accuracy(model, dataloader, save_image=False, save_dir=None):
def get_accuracy(model, dataloader, save_image=False, save_json=False, save_dir=None):

correct, total = 0, 0
for xs, ts in dataloader:
Expand All @@ -33,10 +44,20 @@ def get_accuracy(model, dataloader, save_image=False, save_dir=None):

if save_image:
save_images(xs, pred, save_dir)

if save_json: # adding to the list
add_to_json(pred)

correct += pred.eq(ts.view_as(pred)).sum().item()
total += int(ts.shape[0])

# dumping to json file
global noisy_labels
jsonlabels = json.dumps(noisy_labels)
jsonFile = open(save_dir+'.json', "w")
jsonFile.write(jsonlabels)
jsonFile.close()

return correct / total


Expand Down Expand Up @@ -74,8 +95,10 @@ def get_accuracy(model, dataloader, save_image=False, save_dir=None):
parser.add_argument('--source', type=str, default='svhn', metavar='N',
help='source dataset')
parser.add_argument('--target', type=str, default='mnist', metavar='N', help='target dataset')
parser.add_argument('--save_infer', action='store_true', default=False,
help='save dataset or not after inference with new labels')
parser.add_argument('--save_img', action='store_true', default=False,
help='save dataset after inference with new labels')
parser.add_argument('--save_json', action='store_true', default=False,
help='save labels after inference in json file')
# parser.add_argument('--use_abs_diff', action='store_true', default=False,
# help='use absolute difference value as a measurement')

Expand Down Expand Up @@ -116,7 +139,7 @@ def get_accuracy(model, dataloader, save_image=False, save_dir=None):


# create directory to save dataset after inference
if args.save_infer:
if args.save_img:
print("Creating new labeled data directory!!")
if not os.path.exists(source+target):
os.mkdir(source+target)
Expand All @@ -137,6 +160,6 @@ def get_accuracy(model, dataloader, save_image=False, save_dir=None):
#
model.load_state_dict(torch.load(f'{args.checkpoint_dir}/{source+target}_epoch_{args.load_epoch}.pth'))

accuracy = get_accuracy(model.cuda(), train_loader, args.save_infer, source+target)
accuracy = get_accuracy(model.cuda(), train_loader, args.save_img, args.save_json, source+target)

print(f"Accuracy (on {target})= {accuracy}")

0 comments on commit d24b6b2

Please sign in to comment.