You must first install diffq
, then the requirements for this example. To do so, run from the root of the repository:
pip install .
cd examples/cifar
pip install -r requirements.txt
In order to train a model you can run
./{DATASET} model={MODEL}
with DATASET either cifar10
or cifar100
and model one of
(ResNet 18), mobilenet
(MobileNet), or w_resnet
(Wide ResNet).
The datasets will be automatically downloaded in the ./data
folder, and
the checkpoints stored in the ./outputs
To train with qat,
./{DATASET} model={MODEL} quant.bits={BITS} quant.qat=True
for instance
./ model=mobilenet quant.bits=3 quant.qat=True
To train with diffq, with a given model size penalty and group size.
./{DATASET} model={MODEL} quant.penalty={PENALTY} quant.group_size={GROUP_SIZE}
for instance
./ model=w_resnet quant.penalty=5 quant.group_size=16
See the Supplementary Material, Section A.4, and table B.2 for more information on the hyper-parameter used.
In order to run experiments with LSQ, you will first need to train a baseline model
./ model=resnet
Then you can fine tune with LSQ with
./ model=resnet dummy=ft \
'continue_from=",model=resnet"' continue_best=true \
quant.lsq=true quant.bits=4 lr=0.01
To run experiments with a Resnet-20 on CIFAR-10:
./ preset=res20 quant.penalty=10 quant.group_size=16
To run experiments with a Vision Transformer on CIFAR-10:
./ model=vit quant.penalty=5 quant.group_size=16
./ model=vit_timm continue_best=true \
quant.lsq=true quant.bits=4 lr=0.01
See the file ../../LICENSE for more details.
The files src/
and src/
are taken from kuangliu/pytorch-cifar, released as MIT.
The file src/
is taken from meliketoy/wide-resnet, released as MIT.
The file src/
is taken from akamaster/pytorch_resnet_cifar10, released as BSD 2-Clause "Simplified".