diff --git a/omniglot/task_generator.py b/omniglot/task_generator.py index f0945d0d..82b89886 100644 --- a/omniglot/task_generator.py +++ b/omniglot/task_generator.py @@ -138,8 +138,8 @@ def __len__(self): def get_data_loader(task, num_per_class=1, split='train',shuffle=True,rotation=0): # NOTE: batch size here is # instances PER CLASS - normalize = transforms.Normalize(mean=[0.92206, 0.92206, 0.92206], std=[0.08426, 0.08426, 0.08426]) - + # normalize = transforms.Normalize(mean=[0.92206, 0.92206, 0.92206], std=[0.08426, 0.08426, 0.08426]) + normalize=transforms.Normalize(mean=[0.92206],std=[0.8426]) dataset = Omniglot(task,split=split,transform=transforms.Compose([Rotate(rotation),transforms.ToTensor(),normalize])) if split == 'train':