From 2cd03bd0431c758243a251c624e9aea176a6d5e0 Mon Sep 17 00:00:00 2001 From: Wanchao Date: Fri, 22 Feb 2019 16:21:37 -0800 Subject: [PATCH] Add cudnn layernorm lowerbound benchmark (#53) * Add cudnn layernorm lowerbound benchmark * backward=None --- rnns/fastrnns/bench.py | 3 ++- rnns/fastrnns/factory.py | 37 +++++++++++++++++++++++++++++++++++++ rnns/fastrnns/runner.py | 1 + 3 files changed, 40 insertions(+), 1 deletion(-) diff --git a/rnns/fastrnns/bench.py b/rnns/fastrnns/bench.py index 6d34943394..8668f49524 100644 --- a/rnns/fastrnns/bench.py +++ b/rnns/fastrnns/bench.py @@ -150,7 +150,8 @@ def bench(rnn_runners, group_name, print_json=False, sep=' ', **params): rnns = args.rnns or ['cudnn', 'aten', 'jit', 'jit_premul', 'jit_simple', 'jit_multilayer', 'py'] # TODO: Maybe add a separate section for the layernorm/dropout lstms - # 'jit_layernorm', 'jit_layernom_decom', 'jit', 'jit_dropout', 'cudnn_dropout' + # 'cudnn_layernorm', jit_layernorm', 'jit_layernom_decom', + # 'jit', 'jit_dropout', 'cudnn_dropout' vlrnns = ['vl_cudnn', 'vl_jit', 'vl_py'] cnns = ['resnet18', 'resnet18_jit', 'resnet50', 'resnet50_jit'] if args.print_json: diff --git a/rnns/fastrnns/factory.py b/rnns/fastrnns/factory.py index 7c32825d47..90f49bcf14 100644 --- a/rnns/fastrnns/factory.py +++ b/rnns/fastrnns/factory.py @@ -267,6 +267,43 @@ def varlen_lstm_creator(script=False, **kwargs): backward=simple_backward) +# cudnn_layernorm_lstm: since cudnn does not have Layernorm LSTM, we cannot benchmark +# the lowerbound directly. Instead, we only benchmark the foward pass by mimicing the +# computation of a cudnn lstm + seq_len * 3 layernorm computation. This should serve +# as a perf lowerbound for the Layernorm LSTM forward pass(given that Layernorm itself +# is invariant), the lowerbound of backward pass is hard to get since we lose the +# intermediate results, we can still optimize the layernorm implementation to make +# a faster foward lowerbound though. +def layernorm_pytorch_lstm_creator(**kwargs): + input, hidden, _, module = lstm_inputs(return_module=True, **kwargs) + batch_size = kwargs['miniBatch'] + hidden_size = kwargs['hiddenSize'] + ln_i = torch.nn.LayerNorm(4 * hidden_size).cuda() + ln_h = torch.nn.LayerNorm(4 * hidden_size).cuda() + ln_c = torch.nn.LayerNorm(hidden_size).cuda() + ln_input1 = torch.randn(batch_size, 4 * hidden_size, device='cuda') + + def forward(input, hidden): + out, new_hidden = module(input, hidden) + # plus (seq_len * three laynorm cell computation) to mimic the lower bound of + # Layernorm cudnn LSTM in the forward pass + seq_len = len(input.unbind(0)) + hy, cy = new_hidden + for i in range(seq_len): + ln_i_output = ln_i(ln_input1) + ln_h_output = ln_h(ln_input1) + cy = ln_c(cy) + + return out, (hy, cy) + + return ModelDef( + inputs=[input, hidden], + params=flatten_list(module.all_weights), + forward=forward, + backward_setup=lstm_backward_setup, + backward=None) + + # input: lstm.all_weights format (wih, whh, bih, bhh = lstm.all_weights[layer]) # output: packed_weights with format # packed_weights[0] is wih with size (layer, 4*hiddenSize, inputSize) diff --git a/rnns/fastrnns/runner.py b/rnns/fastrnns/runner.py index 5d5583a1ea..b4c1b16114 100644 --- a/rnns/fastrnns/runner.py +++ b/rnns/fastrnns/runner.py @@ -45,6 +45,7 @@ def get_rnn_runners(*names): rnn_runners = { 'cudnn': RNNRunner('cudnn', pytorch_lstm_creator, DummyContext), 'cudnn_dropout': RNNRunner('cudnn_dropout', partial(pytorch_lstm_creator, dropout=0.4), DummyContext), + 'cudnn_layernorm': RNNRunner('cudnn_layernorm', layernorm_pytorch_lstm_creator, DummyContext), 'vl_cudnn': RNNRunner('vl_cudnn', varlen_pytorch_lstm_creator, DummyContext), 'vl_jit': RNNRunner('vl_jit', partial(varlen_lstm_creator, script=True), DummyContext), 'vl_py': RNNRunner('vl_py', varlen_lstm_creator, DummyContext),