-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'pytorch' into dsprites
Conflicts: notebooks/visualize_model_weights.ipynb train_model.py utils/dataset_utils.py
- Loading branch information
Showing
62 changed files
with
6,458 additions
and
1,764 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,6 @@ | ||
*.npz | ||
*.pkl | ||
|
||
# Memristor data | ||
memristor_data/ | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
import os | ||
import sys | ||
|
||
ROOT_DIR = os.getcwd() | ||
while 'DeepSparseCoding' in ROOT_DIR: | ||
ROOT_DIR = os.path.dirname(ROOT_DIR) | ||
if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) | ||
|
||
import numpy as np | ||
import proplot as plot | ||
import torch | ||
|
||
from DeepSparseCoding.utils.file_utils import Logger | ||
import DeepSparseCoding.utils.loaders as loaders | ||
import DeepSparseCoding.utils.run_utils as run_utils | ||
import DeepSparseCoding.utils.dataset_utils as dataset_utils | ||
import DeepSparseCoding.utils.run_utils as ru | ||
import DeepSparseCoding.utils.plot_functions as pf | ||
|
||
import eagerpy as ep | ||
from foolbox import PyTorchModel, accuracy, samples | ||
import foolbox.attacks as fa | ||
|
||
|
||
log_files = [ | ||
os.path.join(*[ROOT_DIR, 'Torch_projects', 'mlp_768_mnist', 'logfiles', 'mlp_768_mnist_v0.log']), | ||
os.path.join(*[ROOT_DIR, 'Torch_projects', 'lca_768_mlp_mnist', 'logfiles', 'lca_768_mlp_mnist_v0.log']) | ||
] | ||
|
||
cp_latest_filenames = [ | ||
os.path.join(*[ROOT_DIR,'Torch_projects', 'mlp_768_mnist', 'checkpoints', 'mlp_768_mnist_latest_checkpoint_v0.pt']), | ||
os.path.join(*[ROOT_DIR, 'Torch_projects', 'lca_768_mlp_mnist', 'checkpoints', 'lca_768_mlp_mnist_latest_checkpoint_v0.pt']) | ||
] | ||
|
||
attack_params = { | ||
'linfPGD': { | ||
'abs_stepsize':0.01, | ||
'steps':5000 | ||
} | ||
} | ||
|
||
attacks = [ | ||
#fa.FGSM(), | ||
fa.LinfPGD(**attack_params['linfPGD']), | ||
#fa.LinfBasicIterativeAttack(), | ||
#fa.LinfAdditiveUniformNoiseAttack(), | ||
#fa.LinfDeepFoolAttack(), | ||
] | ||
|
||
epsilons = [ # allowed perturbation size | ||
0.0, | ||
0.05, | ||
0.1, | ||
0.15, | ||
0.2, | ||
0.25, | ||
0.3, | ||
0.35, | ||
#0.4, | ||
0.5, | ||
#0.8, | ||
1.0 | ||
] | ||
|
||
num_models = len(log_files) | ||
for model_index in range(num_models): | ||
logger = Logger(log_files[model_index], overwrite=False) | ||
log_text = logger.load_file() | ||
params = logger.read_params(log_text)[-1] | ||
params.cp_latest_filename = cp_latest_filenames[model_index] | ||
train_loader, val_loader, test_loader, data_params = dataset_utils.load_dataset(params) | ||
for key, value in data_params.items(): | ||
setattr(params, key, value) | ||
model = loaders.load_model(params.model_type) | ||
model.setup(params, logger) | ||
model.params.analysis_out_dir = os.path.join( | ||
*[model.params.model_out_dir, 'analysis', model.params.version]) | ||
model.params.analysis_save_dir = os.path.join(model.params.analysis_out_dir, 'savefiles') | ||
if not os.path.exists(model.params.analysis_save_dir): | ||
os.makedirs(model.params.analysis_save_dir) | ||
model.to(params.device) | ||
model.load_checkpoint() | ||
fmodel = PyTorchModel(model.eval(), bounds=(0, 1)) | ||
print('\n', '~' * 79) | ||
num_batches = len(test_loader.dataset) // model.params.batch_size | ||
attack_success = np.zeros( | ||
(len(attacks), len(epsilons), num_batches, model.params.batch_size), dtype=np.bool) | ||
for batch_index, (data, target) in enumerate(test_loader): | ||
data = model.preprocess_data(data.to(model.params.device)) | ||
target = target.to(model.params.device) | ||
images, labels = ep.astensors(*(data, target)) | ||
del data; del target | ||
print(f'Model type: {model.params.model_type} [{model_index+1} out of {len(log_files)}]') | ||
print(f'Batch {batch_index+1} out of {num_batches}') | ||
print(f'accuracy {accuracy(fmodel, images, labels)}') | ||
for attack_index, attack in enumerate(attacks): | ||
advs, inputs, success = attack(fmodel, images, labels, epsilons=epsilons) | ||
assert success.shape == (len(epsilons), len(images)) | ||
success_ = success.numpy() | ||
assert success_.dtype == np.bool | ||
attack_success[attack_index, :, batch_index, :] = success_ | ||
print('\n', attack) | ||
print(' ', 1.0 - success_.mean(axis=-1).round(2)) | ||
np.savez('tmp_perturbations.npz', data=advs[0].numpy()) | ||
np.savez('tmp_images.npz', data=images.numpy()) | ||
np.savez('tmp_inputs.npz', data=inputs[0].numpy()) | ||
import IPython; IPython.embed(); raise SystemExit | ||
robust_accuracy = 1.0 - attack_success[:, :, batch_index, :].max(axis=0).mean(axis=-1) | ||
print('\n', '-' * 79, '\n') | ||
print('worst case (best attack per-sample)') | ||
print(' ', robust_accuracy.round(2)) | ||
print('-' * 79) | ||
attack_success = attack_success.reshape( | ||
(len(attacks), len(epsilons), num_batches*model.params.batch_size)) | ||
attack_types = [str(type(attack)).split('.')[-1][:-2] for attack in attacks] | ||
output_filename = os.path.join(model.params.analysis_save_dir, | ||
f'linf_adversarial_analysis.npz') | ||
out_dict = { | ||
'adversarial_analysis':attack_success, | ||
'attack_types':attack_types, | ||
'epsilons':epsilons, | ||
'attack_params':attack_params} | ||
np.savez(output_filename, data=out_dict) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import os | ||
import sys | ||
|
||
import numpy as np | ||
from scipy.stats import norm | ||
from PIL import Image | ||
import torch | ||
import torchvision | ||
|
||
ROOT_DIR = os.path.dirname(os.getcwd()) | ||
if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) | ||
|
||
import DeepSparseCoding.utils.data_processing as dp | ||
|
||
class SyntheticImages(torchvision.datasets.vision.VisionDataset): | ||
"""Synthetic dataset of square images with pixel values drawn from a specified distribution | ||
Inputs: | ||
epoch_size [int] Number of datapoints in the dataset | ||
data_edge_size [int] Length of the edge of a square datapoint. | ||
dist_type [str] one of {'gaussian', 'laplacian', 'hierarchical_sparse'} | ||
rand_state [np.random.RandomState()] a numpy random state to generate data from | ||
num_classes [int, optional] number of classes for random supervised labels | ||
transform [callable, optional] A function/transform that takes in an PIL image | ||
and returns a transformed version. E.g, ``transforms.RandomCrop`` | ||
target_transform (callable, optional): A function/transform that takes in the | ||
target and transforms it. | ||
""" | ||
|
||
@property | ||
def train_labels(self): | ||
warnings.warn("train_labels has been renamed targets") | ||
return self.targets | ||
|
||
@property | ||
def test_labels(self): | ||
warnings.warn("test_labels has been renamed targets") | ||
return self.targets | ||
|
||
@property | ||
def train_data(self): | ||
warnings.warn("train_data has been renamed data") | ||
return self.data | ||
|
||
@property | ||
def test_data(self): | ||
warnings.warn("test_data has been renamed data") | ||
return self.data | ||
|
||
def __init__(self, epoch_size, data_edge_size, dist_type, rand_state, num_classes=None, | ||
transform=None, target_transform=None): | ||
root = './' # no need for a root directory because the data is never on disc | ||
if(target_transform): | ||
assert num_classes is not None, ( | ||
'Num classes must be specified if target_transform is not None.') | ||
super(SyntheticImages, self).__init__(root, transform=transform, | ||
target_transform=target_transform) # transforms get set to member variables | ||
self.data = torch.tensor( | ||
self.generate_synthetic_data(epoch_size, data_edge_size, dist_type, rand_state)) | ||
if(num_classes): | ||
self.targets = self.generate_labels(epoch_size, num_classes, rand_state) | ||
else: | ||
self.targets = -1 * torch.ones(len(self.data)) | ||
|
||
def __getitem__(self, index): | ||
""" | ||
Inputs: | ||
index (int): Index | ||
Outputs: | ||
tuple: (image, target) where target is index of the target class. | ||
""" | ||
img = self.data[index] | ||
target = self.targets[index] | ||
# doing this so that it is consistent with all other datasets | ||
# to return a PIL Image | ||
img = Image.fromarray(np.squeeze(img.numpy()), mode='L') | ||
if self.transform is not None: | ||
img = self.transform(img) | ||
if self.target_transform is not None: | ||
target = self.target_transform(target) | ||
return img, target | ||
|
||
def __len__(self): | ||
return len(self.data) | ||
|
||
def generate_synthetic_data(self, epoch_size, data_edge_size, dist_type, rand_state): | ||
""" | ||
Function for generating synthetic data of shape [epoch_size, num_edge, num_edge] | ||
Inputs: | ||
dist_type [str] one of {'gaussian', 'laplacian'}, | ||
otherwise returns zeros | ||
epoch_size [int] number of datapoints in an epoch | ||
data_edge_size [int] size of the edge of the square synthetic image | ||
""" | ||
data_shape = (epoch_size, data_edge_size, data_edge_size, 1) | ||
if dist_type == 'gaussian': | ||
data = rand_state.normal(loc=0.0, scale=1.0, size=data_shape) | ||
elif dist_type == 'laplacian': | ||
data = rand_state.laplace(loc=0.0, scale=1.0, size=data_shape) | ||
else: | ||
assert False, (f'Data dist_type must be "gaussian" or "laplace", not {dist_type}') | ||
return data | ||
|
||
def generate_labels(self, epoch_size, num_classes, rand_state): | ||
labels = torch.tensor(rand_state.randint(num_classes, size=epoch_size)) | ||
#labels = dp.dense_to_one_hot(labels, num_classes) | ||
return labels |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.