Skip to content
This repository has been archived by the owner on Aug 1, 2023. It is now read-only.

Commit

Permalink
more general state init/passing
Browse files Browse the repository at this point in the history
Summary: A more general mechanism for passing states during incremental decoding, which in particular makes requirements for ONNX export more explicit.

Differential Revision: D9599067

fbshipit-source-id: 806a0d6ba213fb531f8b44bbc9bc2fb089066b4e
  • Loading branch information
jhcross authored and facebook-github-bot committed Aug 31, 2018
1 parent adb281f commit b5fc2a5
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 71 deletions.
93 changes: 25 additions & 68 deletions pytorch_translate/ensemble_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import os
import tempfile
from collections import OrderedDict

import numpy as np
import onnx
Expand Down Expand Up @@ -155,8 +156,8 @@ def forward(self, src_tokens, src_lengths):
encoder_outputs = encoder_out[0]
outputs.append(encoder_outputs)
output_names.append(f"encoder_output_{i}")
if hasattr(model.decoder, "_init_prev_states"):
states.extend(model.decoder._init_prev_states(encoder_out))
if hasattr(model.decoder, "get_init_prev_states"):
states.extend(model.decoder.get_init_prev_states(encoder_out))

# underlying assumption is each model has same vocab_reduction_module
vocab_reduction_module = self.models[0].decoder.vocab_reduction_module
Expand Down Expand Up @@ -272,9 +273,6 @@ def forward(self, input_tokens, prev_scores, timestep, *inputs):

next_state_input = len(self.models)

# size of "batch" dimension of input as tensor
batch_size = torch.onnx.operators.shape_as_tensor(input_tokens)[0]

# underlying assumption is each model has same vocab_reduction_module
vocab_reduction_module = self.models[0].decoder.vocab_reduction_module
if vocab_reduction_module is not None:
Expand All @@ -285,20 +283,6 @@ def forward(self, input_tokens, prev_scores, timestep, *inputs):

for i, model in enumerate(self.models):
encoder_output = inputs[i]
prev_hiddens = []
prev_cells = []

for _ in range(len(model.decoder.layers)):
prev_hiddens.append(inputs[next_state_input])
prev_cells.append(inputs[next_state_input + 1])
next_state_input += 2

# ensure previous attention context has batch dimension
input_feed_shape = torch.cat((batch_size.view(1), torch.LongTensor([-1])))
prev_input_feed = torch.onnx.operators.reshape_from_tensor_shape(
inputs[next_state_input], input_feed_shape
)
next_state_input += 1

# no batching, we only care about care about "max" length
src_length_int = int(encoder_output.size()[0])
Expand All @@ -310,8 +294,8 @@ def forward(self, input_tokens, prev_scores, timestep, *inputs):

encoder_out = (
encoder_output,
prev_hiddens,
prev_cells,
None,
None,
src_length,
src_tokens,
src_embeddings,
Expand All @@ -321,16 +305,12 @@ def forward(self, input_tokens, prev_scores, timestep, *inputs):
model.decoder._is_incremental_eval = True
model.eval()

# placeholder
incremental_state = {}

# cache previous state inputs
utils.set_incremental_state(
model.decoder,
incremental_state,
"cached_state",
(prev_hiddens, prev_cells, prev_input_feed),
)
# pass state inputs via incremental_state
num_states = model.decoder.get_num_states()
prev_states = inputs[next_state_input : next_state_input + num_states]
next_state_input += num_states
incremental_state = OrderedDict()
model.decoder.populate_incremental_state(incremental_state, prev_states)

decoder_output = model.decoder(
input_tokens,
Expand All @@ -345,13 +325,8 @@ def forward(self, input_tokens, prev_scores, timestep, *inputs):
log_probs_per_model.append(log_probs)
attn_weights_per_model.append(attn_scores)

(next_hiddens, next_cells, next_input_feed) = utils.get_incremental_state(
model.decoder, incremental_state, "cached_state"
)

for h, c in zip(next_hiddens, next_cells):
state_outputs.extend([h, c])
state_outputs.append(next_input_feed)
next_states = model.decoder.serialize_incremental_state(incremental_state)
state_outputs.extend(next_states)

average_log_probs = torch.mean(
torch.cat(log_probs_per_model, dim=1), dim=1, keepdim=True
Expand Down Expand Up @@ -735,15 +710,6 @@ def forward(self, input_token, target_token, timestep, *inputs):

for i, model in enumerate(self.models):
encoder_output = inputs[i]
prev_hiddens = []
prev_cells = []

for _ in range(len(model.decoder.layers)):
prev_hiddens.append(inputs[next_state_input])
prev_cells.append(inputs[next_state_input + 1])
next_state_input += 2
prev_input_feed = inputs[next_state_input].view(1, -1)
next_state_input += 1

# no batching, we only care about care about "max" length
src_length_int = int(encoder_output.size()[0])
Expand All @@ -755,8 +721,8 @@ def forward(self, input_token, target_token, timestep, *inputs):

encoder_out = (
encoder_output,
prev_hiddens,
prev_cells,
None,
None,
src_length,
src_tokens,
src_embeddings,
Expand All @@ -766,16 +732,12 @@ def forward(self, input_token, target_token, timestep, *inputs):
model.decoder._is_incremental_eval = True
model.eval()

# placeholder
incremental_state = {}

# cache previous state inputs
utils.set_incremental_state(
model.decoder,
incremental_state,
"cached_state",
(prev_hiddens, prev_cells, prev_input_feed),
)
# pass state inputs via incremental_state
num_states = model.decoder.get_num_states()
prev_states = inputs[next_state_input : next_state_input + num_states]
next_state_input += num_states
incremental_state = OrderedDict()
model.decoder.populate_incremental_state(incremental_state, prev_states)

decoder_output = model.decoder(
input_token.view(1, 1),
Expand All @@ -789,13 +751,8 @@ def forward(self, input_token, target_token, timestep, *inputs):

log_probs_per_model.append(log_probs)

(next_hiddens, next_cells, next_input_feed) = utils.get_incremental_state(
model.decoder, incremental_state, "cached_state"
)

for h, c in zip(next_hiddens, next_cells):
state_outputs.extend([h, c])
state_outputs.append(next_input_feed)
next_states = model.decoder.serialize_incremental_state(incremental_state)
state_outputs.extend(next_states)

average_log_probs = torch.mean(
torch.cat(log_probs_per_model, dim=0), dim=0, keepdim=True
Expand Down Expand Up @@ -1020,8 +977,8 @@ def forward(self, src_tokens, src_lengths, char_inds, word_lengths):
outputs.append(encoder_outputs)
output_names.append(f"encoder_output_{i}")

if hasattr(model.decoder, "_init_prev_states"):
states.extend(model.decoder._init_prev_states(encoder_out))
if hasattr(model.decoder, "get_init_prev_states"):
states.extend(model.decoder.get_init_prev_states(encoder_out))

# underlying assumption is each model has same vocab_reduction_module
vocab_reduction_module = self.models[0].decoder.vocab_reduction_module
Expand Down
45 changes: 42 additions & 3 deletions pytorch_translate/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,7 +1139,7 @@ def forward_unprojected(self, input_tokens, encoder_out, incremental_state=None)
prev_hiddens, prev_cells, input_feed = cached_state
else:
# first time step, initialize previous states
init_prev_states = self._init_prev_states(encoder_out)
init_prev_states = self.get_init_prev_states(encoder_out)
prev_hiddens = []
prev_cells = []

Expand Down Expand Up @@ -1247,7 +1247,13 @@ def max_positions(self):
"""Maximum output length supported by the decoder."""
return int(1e5) # an arbitrary large number

def _init_prev_states(self, encoder_out):
def get_num_states(self):
num_states = 2 * len(self.layers)
if self.attention.context_dim:
num_states += 1
return num_states

def get_init_prev_states(self, encoder_out):
(
encoder_output,
final_hiddens,
Expand All @@ -1274,10 +1280,43 @@ def _init_prev_states(self, encoder_out):
for h, c in zip(prev_hiddens, prev_cells):
prev_states.extend([h, c])
if self.attention.context_dim:
prev_states.append(self.initial_attn_context)
prev_states.append(self.initial_attn_context.view(1, -1))

return prev_states

def populate_incremental_state(self, incremental_state, states):
"""
From output of previous step outputs, for ONNX tracing.
"""
prev_hiddens = []
prev_cells = []

for i in range(len(self.layers)):
prev_hiddens.append(states[2 * i])
prev_cells.append(states[2 * i + 1])

input_feed = states[-1]

# cache previous state inputs
utils.set_incremental_state(
self,
incremental_state,
"cached_state",
(prev_hiddens, prev_cells, input_feed),
)

def serialize_incremental_state(self, incremental_state):
state_outputs = []
(hiddens, cells, input_feed) = utils.get_incremental_state(
self, incremental_state, "cached_state"
)

for h, c in zip(hiddens, cells):
state_outputs.extend([h, c])
state_outputs.append(input_feed)

return state_outputs


@register_model_architecture("rnn", "rnn")
def base_architecture(args):
Expand Down

0 comments on commit b5fc2a5

Please sign in to comment.