Skip to content

Commit

Permalink
Speed up seqio.Vocabulary.decode_tf for empty tensors.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 590854176
  • Loading branch information
SeqIO Team authored and SeqIO committed Dec 14, 2023
1 parent 8012bf6 commit 5ba7194
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 1 deletion.
6 changes: 6 additions & 0 deletions seqio/vocabularies.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,12 @@ def _decode_tf(self, ids: tf.Tensor) -> tf.Tensor:

def decode_tf(self, ids: tf.Tensor) -> tf.Tensor:
"""Detokenizes int32 batched Tensor through first EOS."""
# The empty tensor is an important special case that can come up often. The
# call otherwise takes time proportional to the size of the vocabulary, so
# this can be a very significant speedup.
if ids.shape == (0,):
return tf.constant(b"", dtype=tf.string)

clean_ids = ids

if self.unk_id is not None:
Expand Down
30 changes: 29 additions & 1 deletion seqio/vocabularies_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Tests for seqio.vocabularies."""

from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from seqio import test_utils
from seqio import vocabularies
Expand Down Expand Up @@ -250,7 +251,7 @@ def test_decode_converts_ints_to_unigrams_correctly(self):



class SentencepieceVocabularyTest(absltest.TestCase):
class SentencepieceVocabularyTest(parameterized.TestCase):
TEST_STRING = "this is a test"
TEST_TOKENS = (11, 8, 6, 3, 8, 6, 3, 5, 10)
UNK_STRING = " ⁇ "
Expand All @@ -275,6 +276,33 @@ def test_decode_tf(self):
exp = [expected_str] * 2
self.assertEqual(exp, res)

@parameterized.named_parameters(
("1-dim empty", [], b""),
("1-dim nonempty", [11], b"th"),
("2-dim 1 empty", [[]], [b""]),
("2-dim 1 nonempty", [[11]], [b"th"]),
("2-dim 2 empty", [[], []], [b"", b""]),
("2-dim 1 empty 1 nonempty", [[], [11]], [b"", b"th"]),
("2-dim 1 nonempty 1 empty", [[11], []], [b"th", b""]),
("2-dim 2 nonempty", [[11], [8]], [b"th", b"i"]),
)
def test_decode_tf_small_examples(self, arg_elems, expected_elems):
vocab = test_utils.sentencepiece_vocab()

actual = vocab.decode_tf(tf.ragged.constant(arg_elems, dtype=tf.int32))

expected = tf.ragged.constant(expected_elems, dtype=tf.string)
self.assertEqual(expected.shape, actual.shape)
self.assertIs(expected.dtype, actual.dtype)
eq = tf.equal(actual, expected)
if not eq.numpy().all():
err = []
err.append("Mismatched tensors:")
err.append(f" {expected=}")
err.append(f" {actual=}")
err.append(f" tf.equal result: {eq}")
self.fail("\n".join(err))

def test_vocab(self):
vocab = test_utils.sentencepiece_vocab()
self.assertEqual(26, vocab.vocab_size)
Expand Down

0 comments on commit 5ba7194

Please sign in to comment.