diff --git a/pytorch_translate/ensemble_export.py b/pytorch_translate/ensemble_export.py index a32f30d6..52d4ff7f 100644 --- a/pytorch_translate/ensemble_export.py +++ b/pytorch_translate/ensemble_export.py @@ -3,6 +3,7 @@ import logging import os import tempfile +from collections import OrderedDict import numpy as np import onnx @@ -155,8 +156,8 @@ def forward(self, src_tokens, src_lengths): encoder_outputs = encoder_out[0] outputs.append(encoder_outputs) output_names.append(f"encoder_output_{i}") - if hasattr(model.decoder, "_init_prev_states"): - states.extend(model.decoder._init_prev_states(encoder_out)) + if hasattr(model.decoder, "get_init_prev_states"): + states.extend(model.decoder.get_init_prev_states(encoder_out)) # underlying assumption is each model has same vocab_reduction_module vocab_reduction_module = self.models[0].decoder.vocab_reduction_module @@ -272,9 +273,6 @@ def forward(self, input_tokens, prev_scores, timestep, *inputs): next_state_input = len(self.models) - # size of "batch" dimension of input as tensor - batch_size = torch.onnx.operators.shape_as_tensor(input_tokens)[0] - # underlying assumption is each model has same vocab_reduction_module vocab_reduction_module = self.models[0].decoder.vocab_reduction_module if vocab_reduction_module is not None: @@ -285,20 +283,6 @@ def forward(self, input_tokens, prev_scores, timestep, *inputs): for i, model in enumerate(self.models): encoder_output = inputs[i] - prev_hiddens = [] - prev_cells = [] - - for _ in range(len(model.decoder.layers)): - prev_hiddens.append(inputs[next_state_input]) - prev_cells.append(inputs[next_state_input + 1]) - next_state_input += 2 - - # ensure previous attention context has batch dimension - input_feed_shape = torch.cat((batch_size.view(1), torch.LongTensor([-1]))) - prev_input_feed = torch.onnx.operators.reshape_from_tensor_shape( - inputs[next_state_input], input_feed_shape - ) - next_state_input += 1 # no batching, we only care about care about "max" length src_length_int = int(encoder_output.size()[0]) @@ -310,8 +294,8 @@ def forward(self, input_tokens, prev_scores, timestep, *inputs): encoder_out = ( encoder_output, - prev_hiddens, - prev_cells, + None, + None, src_length, src_tokens, src_embeddings, @@ -321,16 +305,12 @@ def forward(self, input_tokens, prev_scores, timestep, *inputs): model.decoder._is_incremental_eval = True model.eval() - # placeholder - incremental_state = {} - - # cache previous state inputs - utils.set_incremental_state( - model.decoder, - incremental_state, - "cached_state", - (prev_hiddens, prev_cells, prev_input_feed), - ) + # pass state inputs via incremental_state + num_states = model.decoder.get_num_states() + prev_states = inputs[next_state_input : next_state_input + num_states] + next_state_input += num_states + incremental_state = OrderedDict() + model.decoder.populate_incremental_state(incremental_state, prev_states) decoder_output = model.decoder( input_tokens, @@ -345,13 +325,8 @@ def forward(self, input_tokens, prev_scores, timestep, *inputs): log_probs_per_model.append(log_probs) attn_weights_per_model.append(attn_scores) - (next_hiddens, next_cells, next_input_feed) = utils.get_incremental_state( - model.decoder, incremental_state, "cached_state" - ) - - for h, c in zip(next_hiddens, next_cells): - state_outputs.extend([h, c]) - state_outputs.append(next_input_feed) + next_states = model.decoder.serialize_incremental_state(incremental_state) + state_outputs.extend(next_states) average_log_probs = torch.mean( torch.cat(log_probs_per_model, dim=1), dim=1, keepdim=True @@ -735,15 +710,6 @@ def forward(self, input_token, target_token, timestep, *inputs): for i, model in enumerate(self.models): encoder_output = inputs[i] - prev_hiddens = [] - prev_cells = [] - - for _ in range(len(model.decoder.layers)): - prev_hiddens.append(inputs[next_state_input]) - prev_cells.append(inputs[next_state_input + 1]) - next_state_input += 2 - prev_input_feed = inputs[next_state_input].view(1, -1) - next_state_input += 1 # no batching, we only care about care about "max" length src_length_int = int(encoder_output.size()[0]) @@ -755,8 +721,8 @@ def forward(self, input_token, target_token, timestep, *inputs): encoder_out = ( encoder_output, - prev_hiddens, - prev_cells, + None, + None, src_length, src_tokens, src_embeddings, @@ -766,16 +732,12 @@ def forward(self, input_token, target_token, timestep, *inputs): model.decoder._is_incremental_eval = True model.eval() - # placeholder - incremental_state = {} - - # cache previous state inputs - utils.set_incremental_state( - model.decoder, - incremental_state, - "cached_state", - (prev_hiddens, prev_cells, prev_input_feed), - ) + # pass state inputs via incremental_state + num_states = model.decoder.get_num_states() + prev_states = inputs[next_state_input : next_state_input + num_states] + next_state_input += num_states + incremental_state = OrderedDict() + model.decoder.populate_incremental_state(incremental_state, prev_states) decoder_output = model.decoder( input_token.view(1, 1), @@ -789,13 +751,8 @@ def forward(self, input_token, target_token, timestep, *inputs): log_probs_per_model.append(log_probs) - (next_hiddens, next_cells, next_input_feed) = utils.get_incremental_state( - model.decoder, incremental_state, "cached_state" - ) - - for h, c in zip(next_hiddens, next_cells): - state_outputs.extend([h, c]) - state_outputs.append(next_input_feed) + next_states = model.decoder.serialize_incremental_state(incremental_state) + state_outputs.extend(next_states) average_log_probs = torch.mean( torch.cat(log_probs_per_model, dim=0), dim=0, keepdim=True @@ -1020,8 +977,8 @@ def forward(self, src_tokens, src_lengths, char_inds, word_lengths): outputs.append(encoder_outputs) output_names.append(f"encoder_output_{i}") - if hasattr(model.decoder, "_init_prev_states"): - states.extend(model.decoder._init_prev_states(encoder_out)) + if hasattr(model.decoder, "get_init_prev_states"): + states.extend(model.decoder.get_init_prev_states(encoder_out)) # underlying assumption is each model has same vocab_reduction_module vocab_reduction_module = self.models[0].decoder.vocab_reduction_module diff --git a/pytorch_translate/rnn.py b/pytorch_translate/rnn.py index dd5529ae..d7eb0c74 100644 --- a/pytorch_translate/rnn.py +++ b/pytorch_translate/rnn.py @@ -1139,7 +1139,7 @@ def forward_unprojected(self, input_tokens, encoder_out, incremental_state=None) prev_hiddens, prev_cells, input_feed = cached_state else: # first time step, initialize previous states - init_prev_states = self._init_prev_states(encoder_out) + init_prev_states = self.get_init_prev_states(encoder_out) prev_hiddens = [] prev_cells = [] @@ -1247,7 +1247,13 @@ def max_positions(self): """Maximum output length supported by the decoder.""" return int(1e5) # an arbitrary large number - def _init_prev_states(self, encoder_out): + def get_num_states(self): + num_states = 2 * len(self.layers) + if self.attention.context_dim: + num_states += 1 + return num_states + + def get_init_prev_states(self, encoder_out): ( encoder_output, final_hiddens, @@ -1274,10 +1280,43 @@ def _init_prev_states(self, encoder_out): for h, c in zip(prev_hiddens, prev_cells): prev_states.extend([h, c]) if self.attention.context_dim: - prev_states.append(self.initial_attn_context) + prev_states.append(self.initial_attn_context.view(1, -1)) return prev_states + def populate_incremental_state(self, incremental_state, states): + """ + From output of previous step outputs, for ONNX tracing. + """ + prev_hiddens = [] + prev_cells = [] + + for i in range(len(self.layers)): + prev_hiddens.append(states[2 * i]) + prev_cells.append(states[2 * i + 1]) + + input_feed = states[-1] + + # cache previous state inputs + utils.set_incremental_state( + self, + incremental_state, + "cached_state", + (prev_hiddens, prev_cells, input_feed), + ) + + def serialize_incremental_state(self, incremental_state): + state_outputs = [] + (hiddens, cells, input_feed) = utils.get_incremental_state( + self, incremental_state, "cached_state" + ) + + for h, c in zip(hiddens, cells): + state_outputs.extend([h, c]) + state_outputs.append(input_feed) + + return state_outputs + @register_model_architecture("rnn", "rnn") def base_architecture(args):