- Here we will train a simple diffusion model to generate handwriting number
- The dataset is MNIST, it will be downloaded under the folder
dataset
using torchvision, the dataset folder structure looks like this:
dataset
├── mnist
│ └── MNIST
│ │ └── raw
│ │ ├── t10k-images-idx3-ubyte
│ │ ├── t10k-images-idx3-ubyte.gz
│ │ ├── t10k-labels-idx1-ubyte
│ │ ├── t10k-labels-idx1-ubyte.gz
│ │ ├── train-images-idx3-ubyte
│ │ ├── train-images-idx3-ubyte.gz
│ │ ├── train-labels-idx1-ubyte
│ │ └── train-labels-idx1-ubyte.gz
- Actually I try to use a simplest model which is
MLP
to do this task, but I find that doesn't work - For this task, we build a simple U-Net which contains convolution and residual connection
- Here I use a NVIDIA GeForce RTX 3090 to train, each epoch will cost about 15 seconds
- If you want to train from scratch, you don't have to modify anything. If you finish training and want to generate number picture, modify
mode
, simply run program and wait for your generated numbers
python run.py
- Of course, you can modify the model architecture or try some other hyper-parameters, do anything you want
- In fact this dataset is very easy, so you'll find the result is pretty good after only 100 epochs' training
- First of all, we will use random Gaussian Noise to sample some images, here are 256 examples