Skip to content

Latest commit

 

History

History
21 lines (15 loc) · 2.54 KB

2023-07.md

File metadata and controls

21 lines (15 loc) · 2.54 KB

July 2023

Lessons learnt from rewriting an AAE to PyTorch

Project: https://github.com/johncf/mnist-style

Regularization

  • Using one or more regularizers significantly improves the training stability, producing more consistent training outcomes despite (small) variations in initialization or training data.
  • One of the most basic regularization, L2-regularization, is equivalent to having a "weight-decay" when using stochastic gradient descent (SGD) without momentum.
  • However, when using any momentum-based optimizer (SGD w/ Momentum, RMSProp, Adam etc.), weight-decay is not equivalent to L2-regularization (source).
    • In PyTorch, setting Adam's weight_decay parameter (zero by default), seems to implement L2-regularization (and not exactly "weight-decay").
    • In contrast, AdamW is the correct implementation of weight-decay as described in the paper.
    • There's a review paper on AdamW investigating its merits over Adam+L2.

Computational Graphs in PyTorch

  • When we do a forward-pass on a model (i.e., from input to output), a computational graph is constructed for backpropagation (by default), containing nodes representing partial derivatives for every operation in the forward pass. (source)
  • Calling Tensor.backward() (usually on the loss tensor) does backpropagation, calculating (and accumulating) gradients using the above graph to all nodes reachable from it.
  • However, we may want to avoid backpropagating gradients past a certain layer (or through a certain tensor) when training different parts of the model. For this, PyTorch provides Tensor.detach() to detach a tensor from the compuatational graph.
    • For example, when training the discriminator part of a Generative Adversarial Network, we don't want gradients from the discriminator loss to be backpropagated to the generator part. To accomplish this, when optimizing the discriminator, during its forward pass, instead of feeding the output tensor from the generator as is (which is linked to the compuational graph of generator's forward pass), we must "detach" it first. (example)