Skip to content
This repository has been archived by the owner on Jan 21, 2025. It is now read-only.

Commit

Permalink
Fix typo in VocabEmbedding argument.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 317381351
  • Loading branch information
sharannarang authored and Mesh TensorFlow Team committed Jun 19, 2020
1 parent e49af68 commit d353aa9
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions mesh_tensorflow/transformer/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1980,7 +1980,7 @@ class VocabEmbedding(object):
"""A class to go from vocab ids to model states and model states to logits."""

def __init__(self, mesh, vocab_dim, output_dim, variable_dtype, name,
ensemble_dim, scale_variable_like_classifer_weights=False):
ensemble_dim, scale_variable_like_classifier_weights=False):
"""Embedding for the vocabulary.
Most of the arguments get passed to `mtf.layers.embedding_weights`.
Expand All @@ -1992,12 +1992,12 @@ def __init__(self, mesh, vocab_dim, output_dim, variable_dtype, name,
variable_dtype: a mtf.VariableDType
name: a string
ensemble_dim: a mtf.Dimension
scale_variable_like_classifer_weights: a boolean
scale_variable_like_classifier_weights: a boolean
"""
self._vocab_dim = vocab_dim
self._output_dim = output_dim
self._scale_variable_like_classifier_weights = (
scale_variable_like_classifer_weights)
scale_variable_like_classifier_weights)
if self._scale_variable_like_classifier_weights:
initializer = tf.random_normal_initializer(
stddev=self._output_dim.size ** -0.5)
Expand Down

0 comments on commit d353aa9

Please sign in to comment.