Skip to content

Commit

Permalink
Add initial Python API decoder support (#869)
Browse files Browse the repository at this point in the history
* add initial python decoder api support

* fix subscripting error

* test xlmroberta decoding with python api

* fix attribute error

* update spm decoder graph

* update spm decoder test

---------

Co-authored-by: Sayan Shaw <[email protected]>
  • Loading branch information
sayanshaw24 and Sayan Shaw authored Jan 14, 2025
1 parent c8bb35d commit e8bf5a9
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 23 deletions.
9 changes: 8 additions & 1 deletion onnxruntime_extensions/_cuops.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,15 @@ class SentencepieceDecoder(CustomOp):
@classmethod
def get_inputs(cls):
return [
cls.io_def("ids", onnx.TensorProto.INT64, [None])
cls.io_def("ids", onnx.TensorProto.INT64, [None]),
cls.io_def('fairseq', onnx_proto.TensorProto.BOOL, [None])
]

@classmethod
def input_default_values(cls):
return {
'fairseq': [False]
}

@classmethod
def get_outputs(cls):
Expand Down
74 changes: 56 additions & 18 deletions onnxruntime_extensions/_hf_cvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,43 +168,45 @@ def spm_decoder(self, **kwargs):
TokenOpParam = namedtuple("TokenOpParam",
["pre_op", "pre_attribute_cvt",
"post_op", "post_attribute_cvt",
"default_inputs"],
"default_encoder_inputs",
"default_decoder_inputs"],
defaults=(None, None, None, None, None))

# Some tokenizers can be added by this table
# https://github.com/huggingface/transformers/blob/main/src/transformers/convert_slow_tokenizer.py#L1252
# @formatter:off
_PROCESSOR_DICT = {
"BertTokenizer": TokenOpParam('BertTokenizer', HFTokenizerConverter.bert_tokenizer,
'BertDecoder', HFTokenizerConverter.bpe_decoder, None),
'BertDecoder', HFTokenizerConverter.bpe_decoder, None, None),
"DistilBertTokenizer": TokenOpParam('BertTokenizer', HFTokenizerConverter.bert_tokenizer,
'BertDecoder', HFTokenizerConverter.bpe_decoder, None),
'BertDecoder', HFTokenizerConverter.bpe_decoder, None, None),
"GPT2Tokenizer": TokenOpParam('GPT2Tokenizer', HFTokenizerConverter.bpe_tokenizer,
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None, None),
"CodeGenTokenizer": TokenOpParam('GPT2Tokenizer', HFTokenizerConverter.bpe_tokenizer,
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None, None),
"CLIPTokenizer": TokenOpParam('CLIPTokenizer', HFTokenizerConverter.clip_tokenizer,
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None, None),
"RobertaTokenizer": TokenOpParam('RobertaTokenizer', HFTokenizerConverter.roberta_tokenizer,
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None, None),
"BartTokenizer": TokenOpParam('RobertaTokenizer', HFTokenizerConverter.roberta_tokenizer,
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None, None),
"LayoutLMv3Tokenizer": TokenOpParam('RobertaTokenizer', HFTokenizerConverter.roberta_tokenizer,
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None, None),
"LongformerTokenizer": TokenOpParam('RobertaTokenizer', HFTokenizerConverter.roberta_tokenizer,
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None, None),
"LEDTokenizer": TokenOpParam('RobertaTokenizer', HFTokenizerConverter.roberta_tokenizer,
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None, None),
"MvpTokenizer": TokenOpParam('RobertaTokenizer', HFTokenizerConverter.roberta_tokenizer,
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None, None),
"T5Tokenizer": TokenOpParam('SentencepieceTokenizer', HFTokenizerConverter.spm_tokenizer,
'SentencepieceDecoder', HFTokenizerConverter.spm_decoder,
default_inputs={'add_eos': [True]}),
default_encoder_inputs={'add_eos': [True]}, default_decoder_inputs=None),
"LlamaTokenizer": TokenOpParam('SpmTokenizer', HFTokenizerConverter.bpe_tokenizer,
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None, None),
"XLMRobertaTokenizer": TokenOpParam('SentencepieceTokenizer', HFTokenizerConverter.spm_tokenizer,
'SentencepieceDecoder', HFTokenizerConverter.spm_decoder,
default_inputs={'add_bos': [True], 'add_eos': [True], 'fairseq': [True]}),
default_encoder_inputs={'add_bos': [True], 'add_eos': [True], 'fairseq': [True]},
default_decoder_inputs={'fairseq': [True]}),
}
# @formatter:on

Expand Down Expand Up @@ -246,8 +248,8 @@ def pre_processing(self, **kwargs):

# add default_inputs into initializers to simplify the model input
n_inputs = len(default_inputs)
if self.cvt_quadruple.default_inputs is not None:
default_inputs.update(self.cvt_quadruple.default_inputs)
if self.cvt_quadruple.default_encoder_inputs is not None:
default_inputs.update(self.cvt_quadruple.default_encoder_inputs)
if len(default_inputs) != n_inputs:
raise ValueError(
"Op: {} does not have the inputs from its TokenOpParam.".format(_cvt_op))
Expand Down Expand Up @@ -287,7 +289,43 @@ def pre_processing(self, **kwargs):
return g

def post_processing(self, **kwargs):
with_default_inputs = kwargs.pop("WITH_DEFAULT_INPUTS", True)

_cvt_op = self.cvt_quadruple.post_op
_cvt_func = self.cvt_quadruple.post_attribute_cvt
cvt = partial(_cvt_func, self.cvt_obj)
return SingleOpGraph.build_graph(_cvt_op, cvt=cvt, **kwargs)
g = SingleOpGraph.build_graph(_cvt_op, cvt=cvt, **kwargs)

default_inputs = {}
if with_default_inputs:
op_class = SingleOpGraph.get_op_class(_cvt_op)
default_inputs = op_class.input_default_values()
if default_inputs is None:
encoder_inputs = self.cvt_quadruple.default_encoder_inputs
if encoder_inputs is not None and encoder_inputs["fairseq"]:
default_inputs = {} # need to set to empty dict to call .update later
else:
return g

# add default_inputs into initializers to simplify the model input
if self.cvt_quadruple.default_decoder_inputs is not None:
default_inputs.update(self.cvt_quadruple.default_decoder_inputs)

new_initializers = []

for k, v in default_inputs.items():
input_value_info = next((i for i in g.input if i.name == k), None)
if input_value_info is None:
raise ValueError(
"The input {} is not found in the graph".format(k))

np_dtype = onnx.helper.tensor_dtype_to_np_dtype(
input_value_info.type.tensor_type.elem_type)
value = nparray(v, np_dtype)
new_initializers.append(onnx.numpy_helper.from_array(value, k))
g.initializer.extend(new_initializers)
new_inputs = [i for i in g.input if i.name not in default_inputs]
g.ClearField("input")
g.input.extend(new_inputs)

return g
8 changes: 5 additions & 3 deletions test/test_autotokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,12 @@ def test_xlm_roberta_tokenizer(self):
" add words that should not exist and be tokenized to , such as saoneuhaoesuth")
ids = tokenizer.encode(text, return_tensors="np")

ort_tok, _ = gen_processing_models(
tokenizer, pre_kwargs={"WITH_DEFAULT_INPUTS": True})
actual_ids, *_ = ort_inference(ort_tok, [text])
tok, detok = gen_processing_models(
tokenizer, pre_kwargs={"WITH_DEFAULT_INPUTS": True}, post_kwargs={"WITH_DEFAULT_INPUTS": True})
actual_ids, *_ = ort_inference(tok, [text])
det_text = ort_inference(detok, actual_ids)
np.testing.assert_array_equal(ids[0], actual_ids)
self.assertEqual(text, det_text)

def test_trie_tokenizer(self):
vocab_file = util.get_test_data_file(
Expand Down
2 changes: 1 addition & 1 deletion test/test_sentencepiece_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ def test_spm_decoder(self):
fullname = util.get_test_data_file('data', 'en.wiki.bpe.vs100000.model')
ofunc = OrtPyFunction.from_customop('SentencepieceDecoder', model=open(fullname, 'rb').read())

result = ofunc(np.array([1095, 4054, 26, 2022, 755, 99935], dtype=np.int64))
result = ofunc(np.array([1095, 4054, 26, 2022, 755, 99935], dtype=np.int64), np.array([False], dtype=np.bool_))
self.assertEqual(' '.join(result), 'best hotel in bay area.')


Expand Down

0 comments on commit e8bf5a9

Please sign in to comment.