Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

[20220319] ValueChoice in Oneshot Lightning

Yuge Zhang edited this page Mar 23, 2022 · 2 revisions

This is the design of ValueChoice in Oneshot Lightning.

Pull request: https://github.com/microsoft/nni/pull/4602

Background

  • NAS = search space + evaluator + strategy.
  • We are making one-shot algorithms purely strategies.
  • But training and searching in one-shot algorithms are coupled. How to disaggregate them?
    • If the training part (evaluator) is written in PyTorch-Lightning, we can break the LightningModule into smaller parts, such as configure_optimizer, training_step, validation_step, trainer_loop. We can plug in searching recipes between these smaller parts, and reassemble them into a LightningModule.

The following image is from Lightning official:

Thus, the flow of an one-shot strategy:

  1. Strategy wraps the model (base model) with evaluator, to make it a LightningModule.
  2. Strategy sends the wrapped LightingModule to algorithm implementation, which converts it into another LightningModule, but now search + training are all ready.
  3. Strategy trains the converted LightningModule with a trainer, which can be found in evaluator.

In the second step, the "algorithm implementation", i.e., XXXAlgoModule, will replace the mutation primitives (e.g., LayerChoice) with algo-specific module (e.g., DifferentiableLayerChoice, or, PathSamplingLayerChoice).

Mutation hooks

This PR basically includes two changes. Firstly, it introduces a more formal "module replacement", called "mutation hooks". The responsibility of mutation hooks is to replace the mutation primitives with specific super-net modules defined by the algorithm.

We call those "specific super-net modules" BaseSuperNetModule. A super-net module defines the construction and forward-backward propagation logic of a mutation primitive. It also has a class method called mutate, which defines on which modules the replacement should be triggered.

The underlying assumption of this refactor is that, one super-net module must correspond to one or more search dimensions. Note that this assumption is non-trivial, and already relaxed, because without mixed operation (described next), the assumption is that one super-net module corresponds to exactly one search dimension, and thus the design makes them coupled. Since mixed operation can have multiple search dimensions (e.g., a convolution can search for channels and kernel sizes at the same time, but it must be a single super-net module), this refactor is necessary to make them work.

To make things more straightforward, I also designed resample, export, search_space_spec for BaseSuperNetModule. Having these methods make the algorithm easier to deal with different types of super-net modules but same kinds of underlying search dimensions (e.g., discrete, categorical, choices). All these methods have a memo, which enables sharing sampling results among search dimensions with the same label.

After introducing this change, I moved the original super-net module implementations (e.g., DartsLayerChoice) to the new implementation. This is basically a cut&paste of code, with changes of base class.

Mixed operation

On the other hand, this PR introduces mixed operation, which is an implementation of a series of weight sharing tricks, commonly known as super-kernel, channel search, or weight entanglement.

One design principle of mixed operation is that mixed operation is considered agnostic to search algorithms. They can be used for path sampling algorithms like SPOS, ENAS, and also differentiable ones like DARTS. One potential benefits of this design is that potentially we can have many different implementations of many different mixed operations (e.g., Conv2d, Conv3d, BatchNorm, LayerNorm, MultiheadAttention...), but we only need to do them once. In other words, we don't need to rewrite the mixed operation over again for a new supported algorithm.

To this end, I designed MixedOperation and MixedOperationSamplingStrategy, which handles operation-related parts and algorithm-related parts respectively. In forward propagation, the mixed operation (which is a BaseSupernetModule) first asks the sampling strategy to give a sample (which could be algorithm-specific), then it interprets the sample and executes it based on the characteristics of the operation itself. Similarly, in export, it also delegates the control to sampling strategy. Note that the sampling strategy here is NOT the strategy in Retiarii. It's only a local concept that serves the need of mixed operation.

To actually implements a mixed operation, it needs to implement super_init_argument, which controls how to initialize the super-kernel. Usually using the max channels, max features, max kernel sizes should work, but there are also other cases. It also needs to implement forward_with_args, which controls how to run the forward propagation when the arguments get resampled. This method is often a copy-and-modification of PyTorch source code, with some prefix slices of operation's parameters. It could be nasty to maintain. Therefore I suggest to keep the supported operations limited for now (maybe until we figure out a better way).

Some useful tools to implement mixed operations are in _valuechoice_utils and _operation_utils. Basically, they are to handle complex ValueChoice cases (e.g., ValueChoice([3, 5, 7]) // 2 + ValueChoice([2, 4, 6])), and complex slices. In some cases (e.g., differentiable mixed operation), the prefix slices could be weighted.

UML diagram

ValueChoiceClass