diff --git a/.gitignore b/.gitignore index ee0c297d..96e7c120 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,3 @@ **/__pycache__ data -checkpoint runs \ No newline at end of file diff --git a/README.md b/README.md index bcbbfdb4..5b02ff57 100644 --- a/README.md +++ b/README.md @@ -1,42 +1,40 @@ # Pytorch-cifar100 -practice on cifar100 using pytorch +In this work, I use a novel approach to address GPU memory constraints in the implementation of the ResNet50 model for image classification on the CIFAR-100 dataset. By employing the Group Pruner technique, I selectively prune redundant filters in ResNet50, significantly reducing its memory usage while preserving comparable accuracy. Additionally, I leverage knowledge distillation to transfer the knowledge from the base model (ResNet50) to a smaller ResNet18 model, enabling us to obtain a memory-efficient architecture without sacrificing performance. My experimental results demonstrate that the combination of Group Pruner and knowledge distillation provides an effective solution to the challenges of deep CNN architecture deployment on resource-limited platforms, making it feasible to utilize sophisticated models like ResNet50 in real-world applications with memory constraints. -## Requirements +## Results +The result I can get from a certain model, since I use the same hyperparameters to train all the networks, some networks might not get the best result from these hyperparameters, you could try yourself by finetuning the hyperparameters to get +better result. -This is my experiment eviroument -- python3.6 -- pytorch1.6.0+cu101 -- tensorboard 2.2.2(optional) +|Dataset|Network|Params|Inference time (ms)|Runtime Memory on CPU (ms)||Runtime Memory on CUDA (ms)|Top-1 Error|Top-5 Error|Checkpoint path| +|:-----:|:-----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:| +|Cifar100|Resnet50 (base model)|23705252|167|15.933|6.481|21.24|5.33|?| +|Cifar100|Resnet50 (pruned model)|14959153|146|15.538|5.456|27.04|8.05|?| +|Cifar100|Resnet18 (KD from Resnet50)|11220132|61|7.019|2.735|23.95|7.01|?| +## Requirements + +This is my experiment eviroment +- Ubuntu 20.04.6 LTS +- Python 3.8 +- PyTorch 1.13.1+cu117 +- CuDNN 8500 + ## Usage -### 1. enter directory +### 1. Enter directory ```bash $ cd pytorch-cifar100 ``` -### 2. dataset -I will use cifar100 dataset from torchvision since it's more convenient, but I also -kept the sample code for writing your own dataset module in dataset folder, as an -example for people don't know how to write it. +### 2. Dataset +I will use cifar100 dataset from torchvision since it's more convenient, but I also kept the sample code for writing your own dataset module in dataset folder, as an example for people don't know how to write it. -### 3. run tensorbard(optional) -Install tensorboard -```bash -$ pip install tensorboard -$ mkdir runs -Run tensorboard -$ tensorboard --logdir='runs' --port=6006 --host='localhost' -``` - -### 4. train the model +### 3. Train the model You need to specify the net you want to train using arg -net - ```bash -# use gpu to train vgg16 -$ python train.py -net vgg16 -gpu +$ python train.py -net resnet50 -gpu ``` sometimes, you might want to use warmup training by set ```-warm``` to 1 or 2, to prevent network @@ -91,94 +89,44 @@ stochasticdepth101 Normally, the weights file with the best accuracy would be written to the disk with name suffix 'best'(default in checkpoint folder). -### 5. test the model +### 4. Test the model Test the model using test.py ```bash -$ python test.py -net vgg16 -weights path_to_vgg16_weights_file +$ python test.py -net resnet50 -weights path_to_resnet50_weights_file +``` +For this work, you can easily test the model by the trained base resnet50 model by using +```bash +$ python test.py -net resnet50 -weights checkpoint/resnet50/Saturday_29_July_2023_05h_48m_46s/resnet50-200-regular.pth ``` -## Implementated NetWork - -- vgg [Very Deep Convolutional Networks for Large-Scale Image Recognition](https://arxiv.org/abs/1409.1556v6) -- googlenet [Going Deeper with Convolutions](https://arxiv.org/abs/1409.4842v1) -- inceptionv3 [Rethinking the Inception Architecture for Computer Vision](https://arxiv.org/abs/1512.00567v3) -- inceptionv4, inception_resnet_v2 [Inception-v4, Inception-ResNet and the Impact of Residual Connections on Learning](https://arxiv.org/abs/1602.07261) -- xception [Xception: Deep Learning with Depthwise Separable Convolutions](https://arxiv.org/abs/1610.02357) -- resnet [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385v1) -- resnext [Aggregated Residual Transformations for Deep Neural Networks](https://arxiv.org/abs/1611.05431v2) -- resnet in resnet [Resnet in Resnet: Generalizing Residual Architectures](https://arxiv.org/abs/1603.08029v1) -- densenet [Densely Connected Convolutional Networks](https://arxiv.org/abs/1608.06993v5) -- shufflenet [ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices](https://arxiv.org/abs/1707.01083v2) -- shufflenetv2 [ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design](https://arxiv.org/abs/1807.11164v1) -- mobilenet [MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications](https://arxiv.org/abs/1704.04861) -- mobilenetv2 [MobileNetV2: Inverted Residuals and Linear Bottlenecks](https://arxiv.org/abs/1801.04381) -- residual attention network [Residual Attention Network for Image Classification](https://arxiv.org/abs/1704.06904) -- senet [Squeeze-and-Excitation Networks](https://arxiv.org/abs/1709.01507) -- squeezenet [SqueezeNet: AlexNet-level accuracy with 50x fewer parameters and <0.5MB model size](https://arxiv.org/abs/1602.07360v4) -- nasnet [Learning Transferable Architectures for Scalable Image Recognition](https://arxiv.org/abs/1707.07012v4) -- wide residual network[Wide Residual Networks](https://arxiv.org/abs/1605.07146) -- stochastic depth networks[Deep Networks with Stochastic Depth](https://arxiv.org/abs/1603.09382) - -## Training Details -I didn't use any training tricks to improve accuray, if you want to learn more about training tricks, -please refer to my another [repo](https://github.com/weiaicunzai/Bag_of_Tricks_for_Image_Classification_with_Convolutional_Neural_Networks), contains -various common training tricks and their pytorch implementations. - - -I follow the hyperparameter settings in paper [Improved Regularization of Convolutional Neural Networks with Cutout](https://arxiv.org/abs/1708.04552v2), which is init lr = 0.1 divide by 5 at 60th, 120th, 160th epochs, train for 200 -epochs with batchsize 128 and weight decay 5e-4, Nesterov momentum of 0.9. You could also use the hyperparameters from paper [Regularizing Neural Networks by Penalizing Confident Output Distributions](https://arxiv.org/abs/1701.06548v1) and [Random Erasing Data Augmentation](https://arxiv.org/abs/1708.04896v2), which is initial lr = 0.1, lr divied by 10 at 150th and 225th epochs, and training for 300 epochs with batchsize 128, this is more commonly used. You could decrese the batchsize to 64 or whatever suits you, if you dont have enough gpu memory. - -You can choose whether to use TensorBoard to visualize your training procedure -## Results -The result I can get from a certain model, since I use the same hyperparameters to train all the networks, some networks might not get the best result from these hyperparameters, you could try yourself by finetuning the hyperparameters to get -better result. - -|dataset|network|params|top1 err|top5 err|epoch(lr = 0.1)|epoch(lr = 0.02)|epoch(lr = 0.004)|epoch(lr = 0.0008)|total epoch| -|:-----:|:-----:|:----:|:------:|:------:|:-------------:|:--------------:|:---------------:|:----------------:|:---------:| -|cifar100|mobilenet|3.3M|34.02|10.56|60|60|40|40|200| -|cifar100|mobilenetv2|2.36M|31.92|09.02|60|60|40|40|200| -|cifar100|squeezenet|0.78M|30.59|8.36|60|60|40|40|200| -|cifar100|shufflenet|1.0M|29.94|8.35|60|60|40|40|200| -|cifar100|shufflenetv2|1.3M|30.49|8.49|60|60|40|40|200| -|cifar100|vgg11_bn|28.5M|31.36|11.85|60|60|40|40|200| -|cifar100|vgg13_bn|28.7M|28.00|9.71|60|60|40|40|200| -|cifar100|vgg16_bn|34.0M|27.07|8.84|60|60|40|40|200| -|cifar100|vgg19_bn|39.0M|27.77|8.84|60|60|40|40|200| -|cifar100|resnet18|11.2M|24.39|6.95|60|60|40|40|200| -|cifar100|resnet34|21.3M|23.24|6.63|60|60|40|40|200| -|cifar100|resnet50|23.7M|22.61|6.04|60|60|40|40|200| -|cifar100|resnet101|42.7M|22.22|5.61|60|60|40|40|200| -|cifar100|resnet152|58.3M|22.31|5.81|60|60|40|40|200| -|cifar100|preactresnet18|11.3M|27.08|8.53|60|60|40|40|200| -|cifar100|preactresnet34|21.5M|24.79|7.68|60|60|40|40|200| -|cifar100|preactresnet50|23.9M|25.73|8.15|60|60|40|40|200| -|cifar100|preactresnet101|42.9M|24.84|7.83|60|60|40|40|200| -|cifar100|preactresnet152|58.6M|22.71|6.62|60|60|40|40|200| -|cifar100|resnext50|14.8M|22.23|6.00|60|60|40|40|200| -|cifar100|resnext101|25.3M|22.22|5.99|60|60|40|40|200| -|cifar100|resnext152|33.3M|22.40|5.58|60|60|40|40|200| -|cifar100|attention59|55.7M|33.75|12.90|60|60|40|40|200| -|cifar100|attention92|102.5M|36.52|11.47|60|60|40|40|200| -|cifar100|densenet121|7.0M|22.99|6.45|60|60|40|40|200| -|cifar100|densenet161|26M|21.56|6.04|60|60|60|40|200| -|cifar100|densenet201|18M|21.46|5.9|60|60|40|40|200| -|cifar100|googlenet|6.2M|21.97|5.94|60|60|40|40|200| -|cifar100|inceptionv3|22.3M|22.81|6.39|60|60|40|40|200| -|cifar100|inceptionv4|41.3M|24.14|6.90|60|60|40|40|200| -|cifar100|inceptionresnetv2|65.4M|27.51|9.11|60|60|40|40|200| -|cifar100|xception|21.0M|25.07|7.32|60|60|40|40|200| -|cifar100|seresnet18|11.4M|23.56|6.68|60|60|40|40|200| -|cifar100|seresnet34|21.6M|22.07|6.12|60|60|40|40|200| -|cifar100|seresnet50|26.5M|21.42|5.58|60|60|40|40|200| -|cifar100|seresnet101|47.7M|20.98|5.41|60|60|40|40|200| -|cifar100|seresnet152|66.2M|20.66|5.19|60|60|40|40|200| -|cifar100|nasnet|5.2M|22.71|5.91|60|60|40|40|200| -|cifar100|wideresnet-40-10|55.9M|21.25|5.77|60|60|40|40|200| -|cifar100|stochasticdepth18|11.22M|31.40|8.84|60|60|40|40|200| -|cifar100|stochasticdepth34|21.36M|27.72|7.32|60|60|40|40|200| -|cifar100|stochasticdepth50|23.71M|23.35|5.76|60|60|40|40|200| -|cifar100|stochasticdepth101|42.69M|21.28|5.39|60|60|40|40|200| +### 5. Train the pruned model +Train the pruned model using train_prune.py. Considering an exist resnet50_weights_file is sparse model +```bash +$ python train_prune.py -net resnet50 -sl_weights path_to_sparse_resnet50_weights_file +``` +For this work, you can train the model by the sparse resnet50 model by using +```bash +$ python train_prune.py -net resnet50 -sl_weights checkpoint/resnet50/Saturday_29_July_2023_05h_48m_46s/resnet50-200-regular.pth +``` +### 6. Test the pruned model +Train the pruned model using train.py. Considering an exist resnet50_weights_file is sparse model +```bash +$ python test_prune.py -net resnet50 -sl_weights path_to_sparse_resnet50_weights_file -weights path_to_optimized_resnet50_weights_file +``` +For this work, you can easily test the model by the trained pruned resnet50 model by using +```bash +$ python test_prune.py -net resnet50 -sl_weights checkpoint/resnet50/Saturday_29_July_2023_05h_48m_46s/resnet50-200-regular.pth -weights checkpoint/resnet50/Sunday_30_July_2023_11h_34m_52s/resnet50-200-regular.pth +``` +### 7. Train the KD model +Train the pruned model using train_prune.py. Considering an exist resnet50_weights_file is sparse model +```bash +$ python train_KD.py -net-teacher resnet50 -teacher-weights path_to_teacher_resnet50_weights_file -net-student resnet18 +``` +For this work, you can train the model by the teacher resnet50 model by using +```bash +$ python train_KD.py -net-teacher resnet50 -teacher-weights checkpoint/resnet50/Saturday_29_July_2023_05h_48m_46s/resnet50-200-regular.pth -net-student resnet18 +``` diff --git a/checkpoint/resnet18/Sunday_30_July_2023_14h_29m_06s/resnet18-200-regular.pth b/checkpoint/resnet18/Sunday_30_July_2023_14h_29m_06s/resnet18-200-regular.pth new file mode 100644 index 00000000..733a51e0 Binary files /dev/null and b/checkpoint/resnet18/Sunday_30_July_2023_14h_29m_06s/resnet18-200-regular.pth differ diff --git a/checkpoint/resnet50/Saturday_29_July_2023_05h_48m_46s/resnet50-200-regular.pth b/checkpoint/resnet50/Saturday_29_July_2023_05h_48m_46s/resnet50-200-regular.pth new file mode 100644 index 00000000..3787ee73 Binary files /dev/null and b/checkpoint/resnet50/Saturday_29_July_2023_05h_48m_46s/resnet50-200-regular.pth differ diff --git a/checkpoint/resnet50/Sunday_30_July_2023_11h_34m_52s/resnet50-200-regular.pth b/checkpoint/resnet50/Sunday_30_July_2023_11h_34m_52s/resnet50-200-regular.pth new file mode 100644 index 00000000..98ce5422 Binary files /dev/null and b/checkpoint/resnet50/Sunday_30_July_2023_11h_34m_52s/resnet50-200-regular.pth differ diff --git a/test.py b/test.py index dab61a05..66cd8672 100644 --- a/test.py +++ b/test.py @@ -7,7 +7,7 @@ author baiyu """ - +import numpy as np import argparse from matplotlib import pyplot as plt @@ -18,6 +18,8 @@ from conf import settings from utils import get_network, get_test_dataloader +from torch.profiler import profile, record_function, ProfilerActivity + if __name__ == '__main__': @@ -46,6 +48,10 @@ correct_5 = 0.0 total = 0 + starter, ender = starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) + repetitions = len(cifar100_test_loader) + timings=np.zeros((repetitions,1)) + with torch.no_grad(): for n_iter, (image, label) in enumerate(cifar100_test_loader): print("iteration: {}\ttotal {} iterations".format(n_iter + 1, len(cifar100_test_loader))) @@ -53,11 +59,17 @@ if args.gpu: image = image.cuda() label = label.cuda() - print('GPU INFO.....') - print(torch.cuda.memory_summary(), end='') - - - output = net(image) + # print('GPU INFO.....') + # print(torch.cuda.memory_summary(), end='') + + starter.record() + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof: + output = net(image) + ender.record() + # WAIT FOR GPU SYNC + torch.cuda.synchronize() + curr_time = starter.elapsed_time(ender) + timings[n_iter] = curr_time _, pred = output.topk(5, 1, largest=True, sorted=True) label = label.view(label.size(0), -1).expand_as(pred) @@ -74,6 +86,8 @@ print(torch.cuda.memory_summary(), end='') print() + print("Average inference time (ms)/image: ", np.sum(timings) / repetitions) print("Top 1 err: ", 1 - correct_1 / len(cifar100_test_loader.dataset)) print("Top 5 err: ", 1 - correct_5 / len(cifar100_test_loader.dataset)) print("Parameter numbers: {}".format(sum(p.numel() for p in net.parameters()))) + print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10)) \ No newline at end of file diff --git a/test_prune.py b/test_prune.py new file mode 100644 index 00000000..85cd4858 --- /dev/null +++ b/test_prune.py @@ -0,0 +1,140 @@ +#test.py +#!/usr/bin/env python3 + +""" test neuron network performace +print top1 and top5 err on test dataset +of a model + +author baiyu +""" +import numpy as np +import argparse + +from matplotlib import pyplot as plt + +import torch +import torchvision.transforms as transforms +from torch.utils.data import DataLoader + +from conf import settings +from utils import get_network, get_test_dataloader +from torch.profiler import profile, record_function, ProfilerActivity +import torch_pruning as tp +from train_prune import progressive_pruning +from functools import partial + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('-net', type=str, required=True, help='net type') + parser.add_argument('-sl_weights', type=str, required=True, help='the weights file of the pretrained sparsity model') + parser.add_argument('-weights', type=str, required=True, help='the weights file you want to test') + parser.add_argument('-gpu', action='store_true', default=False, help='use gpu or not') + parser.add_argument('-b', type=int, default=16, help='batch size for dataloader') + args = parser.parse_args() + + net = get_network(args) + + cifar100_test_loader = get_test_dataloader( + settings.CIFAR100_TRAIN_MEAN, + settings.CIFAR100_TRAIN_STD, + #settings.CIFAR100_PATH, + num_workers=4, + batch_size=args.b, + ) + net.load_state_dict(torch.load(args.sl_weights)) + net.eval() + + correct_1 = 0.0 + correct_5 = 0.0 + total = 0 + + starter, ender = starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) + repetitions = len(cifar100_test_loader) + timings=np.zeros((repetitions,1)) + + for inputs in cifar100_test_loader: + example_inputs, _= inputs + if args.gpu: + example_inputs = example_inputs.to('cuda') + break + + # Pruning + # 1. Pruning + ignored_layers = [] + for m in net.modules(): + if isinstance(m, torch.nn.Linear) and m.out_features == 100: + ignored_layers.append(m) + elif isinstance(m, torch.nn.modules.conv._ConvNd) and m.out_channels == 100: + ignored_layers.append(m) + + for inputs in cifar100_test_loader: + example_inputs, _= inputs + if args.gpu: + example_inputs = example_inputs.to('cuda') + break + + imp = tp.importance.GroupNormImportance(p=2) + pruner_entry = partial(tp.pruner.GroupNormPruner, reg=5e-4, global_pruning=True) + pruner = pruner_entry( + net, + example_inputs, + importance=imp, + iterative_steps=200, + ch_sparsity=1.0, + ch_sparsity_dict={}, + max_ch_sparsity=1.0, + ignored_layers=ignored_layers, + unwrapped_parameters=[], + ) + + print("Pruning...") + progressive_pruning(pruner, net, speed_up=2.11, example_inputs=example_inputs) + del pruner # remove reference + pruned_ops, pruned_size = tp.utils.count_ops_and_params(net, example_inputs=example_inputs) + + net.load_state_dict(torch.load(args.weights)) + print(net) + + with torch.no_grad(): + for n_iter, (image, label) in enumerate(cifar100_test_loader): + print("iteration: {}\ttotal {} iterations".format(n_iter + 1, len(cifar100_test_loader))) + + if args.gpu: + image = image.cuda() + label = label.cuda() + # print('GPU INFO.....') + # print(torch.cuda.memory_summary(), end='') + + starter.record() + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof: + output = net(image) + ender.record() + # WAIT FOR GPU SYNC + torch.cuda.synchronize() + curr_time = starter.elapsed_time(ender) + timings[n_iter] = curr_time + _, pred = output.topk(5, 1, largest=True, sorted=True) + + label = label.view(label.size(0), -1).expand_as(pred) + correct = pred.eq(label).float() + + #compute top 5 + correct_5 += correct[:, :5].sum() + + #compute top1 + correct_1 += correct[:, :1].sum() + + if args.gpu: + print('GPU INFO.....') + print(torch.cuda.memory_summary(), end='') + + print() + + print("Average inference time (ms)/image: ", np.sum(timings) / repetitions) + print("Top 1 err: ", 1 - correct_1 / len(cifar100_test_loader.dataset)) + print("Top 5 err: ", 1 - correct_5 / len(cifar100_test_loader.dataset)) + print("Parameter numbers: ", pruned_size) + print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10)) + diff --git a/train.py b/train.py index c5034606..acf110ba 100644 --- a/train.py +++ b/train.py @@ -20,7 +20,6 @@ import torchvision.transforms as transforms from torch.utils.data import DataLoader -from torch.utils.tensorboard import SummaryWriter from conf import settings from utils import get_network, get_training_dataloader, get_test_dataloader, WarmUpLR, \ @@ -45,11 +44,6 @@ def train(epoch): n_iter = (epoch - 1) * len(cifar100_training_loader) + batch_index + 1 last_layer = list(net.children())[-1] - for name, para in last_layer.named_parameters(): - if 'weight' in name: - writer.add_scalar('LastLayerGradients/grad_norm2_weights', para.grad.norm(), n_iter) - if 'bias' in name: - writer.add_scalar('LastLayerGradients/grad_norm2_bias', para.grad.norm(), n_iter) print('Training Epoch: {epoch} [{trained_samples}/{total_samples}]\tLoss: {:0.4f}\tLR: {:0.6f}'.format( loss.item(), @@ -59,16 +53,12 @@ def train(epoch): total_samples=len(cifar100_training_loader.dataset) )) - #update training loss for each iteration - writer.add_scalar('Train/loss', loss.item(), n_iter) - if epoch <= args.warm: warmup_scheduler.step() for name, param in net.named_parameters(): layer, attr = os.path.splitext(name) attr = attr[1:] - writer.add_histogram("{}/{}".format(layer, attr), param, epoch) finish = time.time() @@ -109,10 +99,6 @@ def eval_training(epoch=0, tb=True): )) print() - #add informations to tensorboard - if tb: - writer.add_scalar('Test/Average loss', test_loss / len(cifar100_test_loader.dataset), epoch) - writer.add_scalar('Test/Accuracy', correct.float() / len(cifar100_test_loader.dataset), epoch) return correct.float() / len(cifar100_test_loader.dataset) @@ -166,14 +152,10 @@ def eval_training(epoch=0, tb=True): if not os.path.exists(settings.LOG_DIR): os.mkdir(settings.LOG_DIR) - #since tensorboard can't overwrite old values - #so the only way is to create a new tensorboard log - writer = SummaryWriter(log_dir=os.path.join( - settings.LOG_DIR, args.net, settings.TIME_NOW)) + input_tensor = torch.Tensor(1, 3, 32, 32) if args.gpu: input_tensor = input_tensor.cuda() - writer.add_graph(net, input_tensor) #create checkpoint folder to save model if not os.path.exists(checkpoint_path): @@ -224,5 +206,3 @@ def eval_training(epoch=0, tb=True): weights_path = checkpoint_path.format(net=args.net, epoch=epoch, type='regular') print('saving weights file to {}'.format(weights_path)) torch.save(net.state_dict(), weights_path) - - writer.close() diff --git a/train_KD.py b/train_KD.py new file mode 100644 index 00000000..07c7579b --- /dev/null +++ b/train_KD.py @@ -0,0 +1,217 @@ +# train.py +#!/usr/bin/env python3 + +""" train network using pytorch + +author baiyu +""" + +import os +import sys +import argparse +import time +from datetime import datetime + +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +import torchvision +import torchvision.transforms as transforms + +from torch.utils.data import DataLoader + +from conf import settings +from utils import get_network, get_training_dataloader, get_test_dataloader, WarmUpLR, \ + most_recent_folder, most_recent_weights, last_epoch, best_acc_weights + +def train(epoch, alpha=0.5): + start = time.time() + net_student.train() + for batch_index, (images, labels) in enumerate(cifar100_training_loader): + + if args.gpu: + labels = labels.cuda() + images = images.cuda() + + optimizer_student.zero_grad() + outputs_student = net_student(images) + outputs_teacher = net_teacher(images).detach() # Detach teacher outputs from the computation graph + outputs_teacher = outputs_teacher.argmax(dim=1) + # Compute knowledge distillation loss + loss_student = loss_function(outputs_student, labels) + loss_teacher = loss_function(outputs_student, outputs_teacher) + loss = (1 - alpha) * loss_student + alpha * loss_teacher + + loss.backward() + optimizer_student.step() + + print('Training Epoch: {epoch} [{trained_samples}/{total_samples}]\tLoss: {:0.4f}\tLR: {:0.6f}'.format( + loss.item(), + optimizer_student.param_groups[0]['lr'], + epoch=epoch, + trained_samples=batch_index * args.b + len(images), + total_samples=len(cifar100_training_loader.dataset) + )) + + if epoch <= args.warm: + warmup_scheduler_student.step() + + finish = time.time() + print('epoch {} training time consumed: {:.2f}s'.format(epoch, finish - start)) + +@torch.no_grad() +def eval_training(epoch=0, tb=True): + + start = time.time() + net_student.eval() + + test_loss = 0.0 # cost function error + correct = 0.0 + + for (images, labels) in cifar100_test_loader: + + if args.gpu: + images = images.cuda() + labels = labels.cuda() + + + outputs = net_student(images) + loss = loss_function(outputs, labels) + + test_loss += loss.item() + _, preds = outputs.max(1) + correct += preds.eq(labels).sum() + + finish = time.time() + if args.gpu: + print('GPU INFO.....') + print(torch.cuda.memory_summary(), end='') + print('Evaluating Network.....') + print('Test set: Epoch: {}, Average loss: {:.4f}, Accuracy: {:.4f}, Time consumed:{:.2f}s'.format( + epoch, + test_loss / len(cifar100_test_loader.dataset), + correct.float() / len(cifar100_test_loader.dataset), + finish - start + )) + print() + + + return correct.float() / len(cifar100_test_loader.dataset) + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('-net-teacher', type=str, required=True, help='net type') + parser.add_argument('-net-student', type=str, required=True, help='net type') + parser.add_argument('-teacher-weights',type=str, required=True, help='the weights file of the trained model') + parser.add_argument('-gpu', action='store_true', default=False, help='use gpu or not') + parser.add_argument('-b', type=int, default=128, help='batch size for dataloader') + parser.add_argument('-warm', type=int, default=1, help='warm up training phase') + parser.add_argument('-lr', type=float, default=0.1, help='initial learning rate') + parser.add_argument('-resume', action='store_true', default=False, help='resume training') + args = parser.parse_args() + + # net = get_network(args) + if args.net_student == 'resnet18': + from models.resnet import resnet18 + net_student = resnet18() + + if args.net_teacher == 'resnet50': + from models.resnet import resnet50 + net_teacher = resnet50() + + #data preprocessing: + cifar100_training_loader = get_training_dataloader( + settings.CIFAR100_TRAIN_MEAN, + settings.CIFAR100_TRAIN_STD, + num_workers=4, + batch_size=args.b, + shuffle=True + ) + + cifar100_test_loader = get_test_dataloader( + settings.CIFAR100_TRAIN_MEAN, + settings.CIFAR100_TRAIN_STD, + num_workers=4, + batch_size=args.b, + shuffle=True + ) + device = torch.device("cuda" if args.gpu else "cpu") + net_teacher.load_state_dict(torch.load(args.teacher_weights, map_location=device)) + net_student.to(device) + net_teacher.to(device) + loss_function = nn.CrossEntropyLoss() + optimizer_student = optim.SGD(net_student.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) + train_scheduler_student = optim.lr_scheduler.MultiStepLR(optimizer_student, milestones=settings.MILESTONES, gamma=0.2) #learning rate decay + iter_per_epoch = len(cifar100_training_loader) + warmup_scheduler_student = WarmUpLR(optimizer_student, iter_per_epoch * args.warm) + + if args.resume: + recent_folder = most_recent_folder(os.path.join(settings.CHECKPOINT_PATH, args.net), fmt=settings.DATE_FORMAT) + if not recent_folder: + raise Exception('no recent folder were found') + + checkpoint_path = os.path.join(settings.CHECKPOINT_PATH, args.net_student, recent_folder) + + else: + checkpoint_path = os.path.join(settings.CHECKPOINT_PATH, args.net_student, settings.TIME_NOW) + + #use tensorboard + if not os.path.exists(settings.LOG_DIR): + os.mkdir(settings.LOG_DIR) + + + input_tensor = torch.Tensor(1, 3, 32, 32) + if args.gpu: + input_tensor = input_tensor.cuda() + + #create checkpoint folder to save model + if not os.path.exists(checkpoint_path): + os.makedirs(checkpoint_path) + checkpoint_path = os.path.join(checkpoint_path, '{net}-{epoch}-{type}.pth') + + best_acc = 0.0 + if args.resume: + best_weights = best_acc_weights(os.path.join(settings.CHECKPOINT_PATH, args.net_student, recent_folder)) + if best_weights: + weights_path = os.path.join(settings.CHECKPOINT_PATH, args.net_student, recent_folder, best_weights) + print('found best acc weights file:{}'.format(weights_path)) + print('load best training file to test acc...') + net_student.load_state_dict(torch.load(weights_path)) + best_acc = eval_training(tb=False) + print('best acc is {:0.2f}'.format(best_acc)) + + recent_weights_file = most_recent_weights(os.path.join(settings.CHECKPOINT_PATH, args.net_student, recent_folder)) + if not recent_weights_file: + raise Exception('no recent weights file were found') + weights_path = os.path.join(settings.CHECKPOINT_PATH, args.net_student, recent_folder, recent_weights_file) + print('loading weights file {} to resume training.....'.format(weights_path)) + net_student.load_state_dict(torch.load(weights_path)) + + resume_epoch = last_epoch(os.path.join(settings.CHECKPOINT_PATH, args.net_student, recent_folder)) + + + for epoch in range(1, settings.EPOCH + 1): + if epoch > args.warm: + train_scheduler_student.step(epoch) + + if args.resume: + if epoch <= resume_epoch: + continue + + train(epoch) + acc = eval_training(epoch) + + #start to save best performance model after learning rate decay to 0.01 + if epoch > settings.MILESTONES[1] and best_acc < acc: + weights_path = checkpoint_path.format(net=args.net_student, epoch=epoch, type='best') + print('saving weights file to {}'.format(weights_path)) + torch.save(net_student.state_dict(), weights_path) + best_acc = acc + continue + + if not epoch % settings.SAVE_EPOCH: + weights_path = checkpoint_path.format(net=args.net_student, epoch=epoch, type='regular') + print('saving weights file to {}'.format(weights_path)) + torch.save(net_student.state_dict(), weights_path) diff --git a/train_prune.py b/train_prune.py new file mode 100644 index 00000000..f03180fd --- /dev/null +++ b/train_prune.py @@ -0,0 +1,262 @@ +# train.py +#!/usr/bin/env python3 + +""" train network using pytorch + +author baiyu +""" + +import os +import sys +import argparse +import time +from datetime import datetime + +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +import torchvision +import torchvision.transforms as transforms +from functools import partial + + +from torch.utils.data import DataLoader +from conf import settings +from utils import get_network, get_training_dataloader, get_test_dataloader, WarmUpLR, \ + most_recent_folder, most_recent_weights, last_epoch, best_acc_weights +import torch_pruning as tp + +def progressive_pruning(pruner, model, speed_up, example_inputs): + model.eval() + base_ops, _ = tp.utils.count_ops_and_params(model, example_inputs=example_inputs) + current_speed_up = 1 + while current_speed_up < speed_up: + pruner.step() + pruned_ops, _ = tp.utils.count_ops_and_params(model, example_inputs=example_inputs) + current_speed_up = float(base_ops) / pruned_ops + if pruner.current_step == pruner.iterative_steps: + break + #print(current_speed_up) + return current_speed_up + +def train(epoch): + + start = time.time() + net.train() + for batch_index, (images, labels) in enumerate(cifar100_training_loader): + + if args.gpu: + labels = labels.cuda() + images = images.cuda() + + optimizer.zero_grad() + outputs = net(images) + loss = loss_function(outputs, labels) + loss.backward() + optimizer.step() + + n_iter = (epoch - 1) * len(cifar100_training_loader) + batch_index + 1 + + last_layer = list(net.children())[-1] + + print('Training Epoch: {epoch} [{trained_samples}/{total_samples}]\tLoss: {:0.4f}\tLR: {:0.6f}'.format( + loss.item(), + optimizer.param_groups[0]['lr'], + epoch=epoch, + trained_samples=batch_index * args.b + len(images), + total_samples=len(cifar100_training_loader.dataset) + )) + + + if epoch <= args.warm: + warmup_scheduler.step() + + for name, param in net.named_parameters(): + layer, attr = os.path.splitext(name) + attr = attr[1:] + + finish = time.time() + + print('epoch {} training time consumed: {:.2f}s'.format(epoch, finish - start)) + +@torch.no_grad() +def eval_training(epoch=0, tb=True): + + start = time.time() + net.eval() + + test_loss = 0.0 # cost function error + correct = 0.0 + + for (images, labels) in cifar100_test_loader: + + if args.gpu: + images = images.cuda() + labels = labels.cuda() + + outputs = net(images) + loss = loss_function(outputs, labels) + + test_loss += loss.item() + _, preds = outputs.max(1) + correct += preds.eq(labels).sum() + + finish = time.time() + if args.gpu: + print('GPU INFO.....') + print(torch.cuda.memory_summary(), end='') + print('Evaluating Network.....') + print('Test set: Epoch: {}, Average loss: {:.4f}, Accuracy: {:.4f}, Time consumed:{:.2f}s'.format( + epoch, + test_loss / len(cifar100_test_loader.dataset), + correct.float() / len(cifar100_test_loader.dataset), + finish - start + )) + print() + + return correct.float() / len(cifar100_test_loader.dataset) + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('-net', type=str, required=True, help='net type') + parser.add_argument('-gpu', action='store_true', default=False, help='use gpu or not') + parser.add_argument('-b', type=int, default=128, help='batch size for dataloader') + parser.add_argument('-warm', type=int, default=1, help='warm up training phase') + parser.add_argument('-lr', type=float, default=0.1, help='initial learning rate') + parser.add_argument('-resume', action='store_true', default=False, help='resume training') + parser.add_argument('-sl_weights', type=str, required=True, help='the weights file of the pretrained sparsity learning') + args = parser.parse_args() + + net = get_network(args) + + device = torch.device("cuda" if args.gpu else "cpu") + #data preprocessing: + cifar100_training_loader = get_training_dataloader( + settings.CIFAR100_TRAIN_MEAN, + settings.CIFAR100_TRAIN_STD, + num_workers=4, + batch_size=args.b, + shuffle=True + ) + + cifar100_test_loader = get_test_dataloader( + settings.CIFAR100_TRAIN_MEAN, + settings.CIFAR100_TRAIN_STD, + num_workers=4, + batch_size=args.b, + shuffle=True + ) + + loss_function = nn.CrossEntropyLoss() + optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) + train_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=settings.MILESTONES, gamma=0.2) #learning rate decay + iter_per_epoch = len(cifar100_training_loader) + warmup_scheduler = WarmUpLR(optimizer, iter_per_epoch * args.warm) + + if args.resume: + recent_folder = most_recent_folder(os.path.join(settings.CHECKPOINT_PATH, args.net), fmt=settings.DATE_FORMAT) + if not recent_folder: + raise Exception('no recent folder were found') + + checkpoint_path = os.path.join(settings.CHECKPOINT_PATH, args.net, recent_folder) + + else: + checkpoint_path = os.path.join(settings.CHECKPOINT_PATH, args.net, settings.TIME_NOW) + + #use tensorboard + if not os.path.exists(settings.LOG_DIR): + os.mkdir(settings.LOG_DIR) + + input_tensor = torch.Tensor(1, 3, 32, 32) + if args.gpu: + input_tensor = input_tensor.cuda() + #create checkpoint folder to save model + if not os.path.exists(checkpoint_path): + os.makedirs(checkpoint_path) + checkpoint_path = os.path.join(checkpoint_path, '{net}-{epoch}-{type}.pth') + best_acc = 0.0 + if args.resume: + best_weights = best_acc_weights(os.path.join(settings.CHECKPOINT_PATH, args.net, recent_folder)) + if best_weights: + weights_path = os.path.join(settings.CHECKPOINT_PATH, args.net, recent_folder, best_weights) + print('found best acc weights file:{}'.format(weights_path)) + print('load best training file to test acc...') + net.load_state_dict(torch.load(weights_path)) + best_acc = eval_training(tb=False) + print('best acc is {:0.2f}'.format(best_acc)) + + recent_weights_file = most_recent_weights(os.path.join(settings.CHECKPOINT_PATH, args.net, recent_folder)) + if not recent_weights_file: + raise Exception('no recent weights file were found') + weights_path = os.path.join(settings.CHECKPOINT_PATH, args.net, recent_folder, recent_weights_file) + print('loading weights file {} to resume training.....'.format(weights_path)) + net.load_state_dict(torch.load(weights_path)) + + resume_epoch = last_epoch(os.path.join(settings.CHECKPOINT_PATH, args.net, recent_folder)) + # 0. Load Sparse model + net.load_state_dict(torch.load(args.sl_weights, map_location=device)) + # 1. Pruning + ignored_layers = [] + for m in net.modules(): + if isinstance(m, torch.nn.Linear) and m.out_features == 100: + ignored_layers.append(m) + elif isinstance(m, torch.nn.modules.conv._ConvNd) and m.out_channels == 100: + ignored_layers.append(m) + + for inputs in cifar100_test_loader: + example_inputs, _= inputs + example_inputs = example_inputs.to(device) + break + + imp = tp.importance.GroupNormImportance(p=2) + pruner_entry = partial(tp.pruner.GroupNormPruner, reg=5e-4, global_pruning=True) + pruner = pruner_entry( + net, + example_inputs, + importance=imp, + iterative_steps=200, + ch_sparsity=1.0, + ch_sparsity_dict={}, + max_ch_sparsity=1.0, + ignored_layers=ignored_layers, + unwrapped_parameters=[], + ) + + ori_ops, ori_size = tp.utils.count_ops_and_params(net, example_inputs=example_inputs) + ori_acc = eval_training(net) + print('Accuracy befor pruning', ori_acc) + print('Model size before pruning', ori_size) + print("Pruning...") + progressive_pruning(pruner, net, speed_up=2.11, example_inputs=example_inputs) + del pruner # remove reference + pruned_ops, pruned_size = tp.utils.count_ops_and_params(net, example_inputs=example_inputs) + pruned_acc = eval_training(net) + print('Model size after pruning', pruned_size) + print('Accuracy after pruning', pruned_acc) + + #2. Finetune pruned model + for epoch in range(1, settings.EPOCH + 1): + if epoch > args.warm: + train_scheduler.step(epoch) + + if args.resume: + if epoch <= resume_epoch: + continue + + train(epoch) + acc = eval_training(epoch) + + #start to save best performance model after learning rate decay to 0.01 + if epoch > settings.MILESTONES[1] and best_acc < acc: + weights_path = checkpoint_path.format(net=args.net, epoch=epoch, type='best') + print('saving weights file to {}'.format(weights_path)) + torch.save(net.state_dict(), weights_path) + best_acc = acc + continue + + if not epoch % settings.SAVE_EPOCH: + weights_path = checkpoint_path.format(net=args.net, epoch=epoch, type='regular') + print('saving weights file to {}'.format(weights_path)) + torch.save(net.state_dict(), weights_path) \ No newline at end of file