Skip to content
This repository has been archived by the owner on Aug 1, 2023. It is now read-only.

Commit

Permalink
Change torch.jit.trace to no longer be a decorator (#11069)
Browse files Browse the repository at this point in the history
Summary:
This was done because it surprising for a decorator to run a function
rather than wrap it, and not simplify the syntax for tracing modules.
Pull Request resolved: pytorch/pytorch#11069

Reviewed By: jamesr66a

Differential Revision: D9583192

Pulled By: zdevito

fbshipit-source-id: b914b7ab4c73c255086465a6576eef3a22de1e13
  • Loading branch information
zdevito authored and facebook-github-bot committed Aug 30, 2018
1 parent 9055520 commit adb281f
Showing 1 changed file with 14 additions and 11 deletions.
25 changes: 14 additions & 11 deletions pytorch_translate/ensemble_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ def __init__(

encoder_ens = EncoderEnsemble(self.models)
example_encoder_outs = encoder_ens(src_tokens, src_lengths)
self.encoder_ens = torch.jit.trace(src_tokens, src_lengths)(encoder_ens)
self.encoder_ens = torch.jit.trace(encoder_ens, (src_tokens, src_lengths))
decoder_ens = DecoderBatchedStepEnsemble(
self.models,
tgt_dict,
Expand All @@ -536,14 +536,17 @@ def __init__(
prev_token, prev_scores, ts, *example_encoder_outs
)
self.decoder_ens_tile = torch.jit.trace(
prev_token, prev_scores, ts, *example_encoder_outs
)(decoder_ens_tile)
decoder_ens_tile, (prev_token, prev_scores, ts, *example_encoder_outs)
)
self.decoder_ens = torch.jit.trace(
prev_token.repeat(self.beam_size),
prev_scores.repeat(self.beam_size),
ts,
*tiled_states,
)(decoder_ens)
decoder_ens,
(
prev_token.repeat(self.beam_size),
prev_scores.repeat(self.beam_size),
ts,
*tiled_states,
),
)

self.input_names = [
"src_tokens",
Expand Down Expand Up @@ -858,7 +861,7 @@ def __init__(self, model_list, tgt_dict, word_reward=0, unk_reward=0):

encoder_ens = EncoderEnsemble(self.models)
example_encoder_outs = encoder_ens(source_tokens, source_length)
self.encoder_ens = torch.jit.trace(source_tokens, source_length)(encoder_ens)
self.encoder_ens = torch.jit.trace(encoder_ens, (source_tokens, source_length))
decoder_ens = KnownOutputDecoderStepEnsemble(
self.models, tgt_dict, word_reward, unk_reward
)
Expand All @@ -867,8 +870,8 @@ def __init__(self, model_list, tgt_dict, word_reward=0, unk_reward=0):
ts = torch.LongTensor([0])
_, *states = decoder_ens(prev_token, target_token, ts, *example_encoder_outs)
self.decoder_ens = torch.jit.trace(
prev_token, target_token, ts, *example_encoder_outs
)(decoder_ens)
decoder_ens, (prev_token, target_token, ts, *example_encoder_outs)
)

self.input_names = [
"source_tokens",
Expand Down

0 comments on commit adb281f

Please sign in to comment.