Skip to content
This repository has been archived by the owner on Aug 1, 2023. It is now read-only.

more general state init/passing #212

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 25 additions & 68 deletions pytorch_translate/ensemble_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import os
import tempfile
from collections import OrderedDict

import numpy as np
import onnx
Expand Down Expand Up @@ -155,8 +156,8 @@ def forward(self, src_tokens, src_lengths):
encoder_outputs = encoder_out[0]
outputs.append(encoder_outputs)
output_names.append(f"encoder_output_{i}")
if hasattr(model.decoder, "_init_prev_states"):
states.extend(model.decoder._init_prev_states(encoder_out))
if hasattr(model.decoder, "get_init_prev_states"):
states.extend(model.decoder.get_init_prev_states(encoder_out))

# underlying assumption is each model has same vocab_reduction_module
vocab_reduction_module = self.models[0].decoder.vocab_reduction_module
Expand Down Expand Up @@ -272,9 +273,6 @@ def forward(self, input_tokens, prev_scores, timestep, *inputs):

next_state_input = len(self.models)

# size of "batch" dimension of input as tensor
batch_size = torch.onnx.operators.shape_as_tensor(input_tokens)[0]

# underlying assumption is each model has same vocab_reduction_module
vocab_reduction_module = self.models[0].decoder.vocab_reduction_module
if vocab_reduction_module is not None:
Expand All @@ -285,20 +283,6 @@ def forward(self, input_tokens, prev_scores, timestep, *inputs):

for i, model in enumerate(self.models):
encoder_output = inputs[i]
prev_hiddens = []
prev_cells = []

for _ in range(len(model.decoder.layers)):
prev_hiddens.append(inputs[next_state_input])
prev_cells.append(inputs[next_state_input + 1])
next_state_input += 2

# ensure previous attention context has batch dimension
input_feed_shape = torch.cat((batch_size.view(1), torch.LongTensor([-1])))
prev_input_feed = torch.onnx.operators.reshape_from_tensor_shape(
inputs[next_state_input], input_feed_shape
)
next_state_input += 1

# no batching, we only care about care about "max" length
src_length_int = int(encoder_output.size()[0])
Expand All @@ -310,8 +294,8 @@ def forward(self, input_tokens, prev_scores, timestep, *inputs):

encoder_out = (
encoder_output,
prev_hiddens,
prev_cells,
None,
None,
src_length,
src_tokens,
src_embeddings,
Expand All @@ -321,16 +305,12 @@ def forward(self, input_tokens, prev_scores, timestep, *inputs):
model.decoder._is_incremental_eval = True
model.eval()

# placeholder
incremental_state = {}

# cache previous state inputs
utils.set_incremental_state(
model.decoder,
incremental_state,
"cached_state",
(prev_hiddens, prev_cells, prev_input_feed),
)
# pass state inputs via incremental_state
num_states = model.decoder.get_num_states()
prev_states = inputs[next_state_input : next_state_input + num_states]
next_state_input += num_states
incremental_state = OrderedDict()
model.decoder.populate_incremental_state(incremental_state, prev_states)

decoder_output = model.decoder(
input_tokens,
Expand All @@ -345,13 +325,8 @@ def forward(self, input_tokens, prev_scores, timestep, *inputs):
log_probs_per_model.append(log_probs)
attn_weights_per_model.append(attn_scores)

(next_hiddens, next_cells, next_input_feed) = utils.get_incremental_state(
model.decoder, incremental_state, "cached_state"
)

for h, c in zip(next_hiddens, next_cells):
state_outputs.extend([h, c])
state_outputs.append(next_input_feed)
next_states = model.decoder.serialize_incremental_state(incremental_state)
state_outputs.extend(next_states)

average_log_probs = torch.mean(
torch.cat(log_probs_per_model, dim=1), dim=1, keepdim=True
Expand Down Expand Up @@ -735,15 +710,6 @@ def forward(self, input_token, target_token, timestep, *inputs):

for i, model in enumerate(self.models):
encoder_output = inputs[i]
prev_hiddens = []
prev_cells = []

for _ in range(len(model.decoder.layers)):
prev_hiddens.append(inputs[next_state_input])
prev_cells.append(inputs[next_state_input + 1])
next_state_input += 2
prev_input_feed = inputs[next_state_input].view(1, -1)
next_state_input += 1

# no batching, we only care about care about "max" length
src_length_int = int(encoder_output.size()[0])
Expand All @@ -755,8 +721,8 @@ def forward(self, input_token, target_token, timestep, *inputs):

encoder_out = (
encoder_output,
prev_hiddens,
prev_cells,
None,
None,
src_length,
src_tokens,
src_embeddings,
Expand All @@ -766,16 +732,12 @@ def forward(self, input_token, target_token, timestep, *inputs):
model.decoder._is_incremental_eval = True
model.eval()

# placeholder
incremental_state = {}

# cache previous state inputs
utils.set_incremental_state(
model.decoder,
incremental_state,
"cached_state",
(prev_hiddens, prev_cells, prev_input_feed),
)
# pass state inputs via incremental_state
num_states = model.decoder.get_num_states()
prev_states = inputs[next_state_input : next_state_input + num_states]
next_state_input += num_states
incremental_state = OrderedDict()
model.decoder.populate_incremental_state(incremental_state, prev_states)

decoder_output = model.decoder(
input_token.view(1, 1),
Expand All @@ -789,13 +751,8 @@ def forward(self, input_token, target_token, timestep, *inputs):

log_probs_per_model.append(log_probs)

(next_hiddens, next_cells, next_input_feed) = utils.get_incremental_state(
model.decoder, incremental_state, "cached_state"
)

for h, c in zip(next_hiddens, next_cells):
state_outputs.extend([h, c])
state_outputs.append(next_input_feed)
next_states = model.decoder.serialize_incremental_state(incremental_state)
state_outputs.extend(next_states)

average_log_probs = torch.mean(
torch.cat(log_probs_per_model, dim=0), dim=0, keepdim=True
Expand Down Expand Up @@ -1020,8 +977,8 @@ def forward(self, src_tokens, src_lengths, char_inds, word_lengths):
outputs.append(encoder_outputs)
output_names.append(f"encoder_output_{i}")

if hasattr(model.decoder, "_init_prev_states"):
states.extend(model.decoder._init_prev_states(encoder_out))
if hasattr(model.decoder, "get_init_prev_states"):
states.extend(model.decoder.get_init_prev_states(encoder_out))

# underlying assumption is each model has same vocab_reduction_module
vocab_reduction_module = self.models[0].decoder.vocab_reduction_module
Expand Down
45 changes: 42 additions & 3 deletions pytorch_translate/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,7 +1139,7 @@ def forward_unprojected(self, input_tokens, encoder_out, incremental_state=None)
prev_hiddens, prev_cells, input_feed = cached_state
else:
# first time step, initialize previous states
init_prev_states = self._init_prev_states(encoder_out)
init_prev_states = self.get_init_prev_states(encoder_out)
prev_hiddens = []
prev_cells = []

Expand Down Expand Up @@ -1247,7 +1247,13 @@ def max_positions(self):
"""Maximum output length supported by the decoder."""
return int(1e5) # an arbitrary large number

def _init_prev_states(self, encoder_out):
def get_num_states(self):
num_states = 2 * len(self.layers)
if self.attention.context_dim:
num_states += 1
return num_states

def get_init_prev_states(self, encoder_out):
(
encoder_output,
final_hiddens,
Expand All @@ -1274,10 +1280,43 @@ def _init_prev_states(self, encoder_out):
for h, c in zip(prev_hiddens, prev_cells):
prev_states.extend([h, c])
if self.attention.context_dim:
prev_states.append(self.initial_attn_context)
prev_states.append(self.initial_attn_context.view(1, -1))

return prev_states

def populate_incremental_state(self, incremental_state, states):
"""
From output of previous step outputs, for ONNX tracing.
"""
prev_hiddens = []
prev_cells = []

for i in range(len(self.layers)):
prev_hiddens.append(states[2 * i])
prev_cells.append(states[2 * i + 1])

input_feed = states[-1]

# cache previous state inputs
utils.set_incremental_state(
self,
incremental_state,
"cached_state",
(prev_hiddens, prev_cells, input_feed),
)

def serialize_incremental_state(self, incremental_state):
state_outputs = []
(hiddens, cells, input_feed) = utils.get_incremental_state(
self, incremental_state, "cached_state"
)

for h, c in zip(hiddens, cells):
state_outputs.extend([h, c])
state_outputs.append(input_feed)

return state_outputs


@register_model_architecture("rnn", "rnn")
def base_architecture(args):
Expand Down