Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add initial Python API decoder support #869

Merged
merged 10 commits into from
Jan 14, 2025
9 changes: 8 additions & 1 deletion onnxruntime_extensions/_cuops.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,15 @@ class SentencepieceDecoder(CustomOp):
@classmethod
def get_inputs(cls):
return [
cls.io_def("ids", onnx.TensorProto.INT64, [None])
cls.io_def("ids", onnx.TensorProto.INT64, [None]),
cls.io_def('fairseq', onnx_proto.TensorProto.BOOL, [None])
]

@classmethod
def input_default_values(cls):
return {
'fairseq': [False]
}

@classmethod
def get_outputs(cls):
Expand Down
74 changes: 56 additions & 18 deletions onnxruntime_extensions/_hf_cvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,43 +168,45 @@ def spm_decoder(self, **kwargs):
TokenOpParam = namedtuple("TokenOpParam",
["pre_op", "pre_attribute_cvt",
"post_op", "post_attribute_cvt",
"default_inputs"],
"default_encoder_inputs",
"default_decoder_inputs"],
defaults=(None, None, None, None, None))

# Some tokenizers can be added by this table
# https://github.com/huggingface/transformers/blob/main/src/transformers/convert_slow_tokenizer.py#L1252
# @formatter:off
_PROCESSOR_DICT = {
"BertTokenizer": TokenOpParam('BertTokenizer', HFTokenizerConverter.bert_tokenizer,
'BertDecoder', HFTokenizerConverter.bpe_decoder, None),
'BertDecoder', HFTokenizerConverter.bpe_decoder, None, None),
"DistilBertTokenizer": TokenOpParam('BertTokenizer', HFTokenizerConverter.bert_tokenizer,
'BertDecoder', HFTokenizerConverter.bpe_decoder, None),
'BertDecoder', HFTokenizerConverter.bpe_decoder, None, None),
"GPT2Tokenizer": TokenOpParam('GPT2Tokenizer', HFTokenizerConverter.bpe_tokenizer,
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None, None),
"CodeGenTokenizer": TokenOpParam('GPT2Tokenizer', HFTokenizerConverter.bpe_tokenizer,
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None, None),
"CLIPTokenizer": TokenOpParam('CLIPTokenizer', HFTokenizerConverter.clip_tokenizer,
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None, None),
"RobertaTokenizer": TokenOpParam('RobertaTokenizer', HFTokenizerConverter.roberta_tokenizer,
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None, None),
"BartTokenizer": TokenOpParam('RobertaTokenizer', HFTokenizerConverter.roberta_tokenizer,
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None, None),
"LayoutLMv3Tokenizer": TokenOpParam('RobertaTokenizer', HFTokenizerConverter.roberta_tokenizer,
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None, None),
"LongformerTokenizer": TokenOpParam('RobertaTokenizer', HFTokenizerConverter.roberta_tokenizer,
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None, None),
"LEDTokenizer": TokenOpParam('RobertaTokenizer', HFTokenizerConverter.roberta_tokenizer,
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None, None),
"MvpTokenizer": TokenOpParam('RobertaTokenizer', HFTokenizerConverter.roberta_tokenizer,
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None, None),
"T5Tokenizer": TokenOpParam('SentencepieceTokenizer', HFTokenizerConverter.spm_tokenizer,
'SentencepieceDecoder', HFTokenizerConverter.spm_decoder,
default_inputs={'add_eos': [True]}),
default_encoder_inputs={'add_eos': [True]}, default_decoder_inputs=None),
"LlamaTokenizer": TokenOpParam('SpmTokenizer', HFTokenizerConverter.bpe_tokenizer,
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None),
'BpeDecoder', HFTokenizerConverter.bpe_decoder, None, None),
"XLMRobertaTokenizer": TokenOpParam('SentencepieceTokenizer', HFTokenizerConverter.spm_tokenizer,
'SentencepieceDecoder', HFTokenizerConverter.spm_decoder,
default_inputs={'add_bos': [True], 'add_eos': [True], 'fairseq': [True]}),
default_encoder_inputs={'add_bos': [True], 'add_eos': [True], 'fairseq': [True]},
default_decoder_inputs={'fairseq': [True]}),
}
# @formatter:on

Expand Down Expand Up @@ -246,8 +248,8 @@ def pre_processing(self, **kwargs):

# add default_inputs into initializers to simplify the model input
n_inputs = len(default_inputs)
if self.cvt_quadruple.default_inputs is not None:
default_inputs.update(self.cvt_quadruple.default_inputs)
if self.cvt_quadruple.default_encoder_inputs is not None:
default_inputs.update(self.cvt_quadruple.default_encoder_inputs)
if len(default_inputs) != n_inputs:
raise ValueError(
"Op: {} does not have the inputs from its TokenOpParam.".format(_cvt_op))
Expand Down Expand Up @@ -287,7 +289,43 @@ def pre_processing(self, **kwargs):
return g

def post_processing(self, **kwargs):
with_default_inputs = kwargs.pop("WITH_DEFAULT_INPUTS", True)

_cvt_op = self.cvt_quadruple.post_op
_cvt_func = self.cvt_quadruple.post_attribute_cvt
cvt = partial(_cvt_func, self.cvt_obj)
return SingleOpGraph.build_graph(_cvt_op, cvt=cvt, **kwargs)
g = SingleOpGraph.build_graph(_cvt_op, cvt=cvt, **kwargs)

default_inputs = {}
if with_default_inputs:
op_class = SingleOpGraph.get_op_class(_cvt_op)
sayanshaw24 marked this conversation as resolved.
Show resolved Hide resolved
default_inputs = op_class.input_default_values()
if default_inputs is None:
encoder_inputs = self.cvt_quadruple.default_encoder_inputs
if encoder_inputs is not None and encoder_inputs["fairseq"]:
default_inputs = {} # need to set to empty dict to call .update later
else:
return g

# add default_inputs into initializers to simplify the model input
if self.cvt_quadruple.default_decoder_inputs is not None:
default_inputs.update(self.cvt_quadruple.default_decoder_inputs)

new_initializers = []

for k, v in default_inputs.items():
input_value_info = next((i for i in g.input if i.name == k), None)
if input_value_info is None:
raise ValueError(
"The input {} is not found in the graph".format(k))

np_dtype = onnx.helper.tensor_dtype_to_np_dtype(
input_value_info.type.tensor_type.elem_type)
value = nparray(v, np_dtype)
new_initializers.append(onnx.numpy_helper.from_array(value, k))
g.initializer.extend(new_initializers)
new_inputs = [i for i in g.input if i.name not in default_inputs]
g.ClearField("input")
g.input.extend(new_inputs)

return g
8 changes: 5 additions & 3 deletions test/test_autotokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,12 @@ def test_xlm_roberta_tokenizer(self):
" add words that should not exist and be tokenized to , such as saoneuhaoesuth")
ids = tokenizer.encode(text, return_tensors="np")

ort_tok, _ = gen_processing_models(
tokenizer, pre_kwargs={"WITH_DEFAULT_INPUTS": True})
actual_ids, *_ = ort_inference(ort_tok, [text])
tok, detok = gen_processing_models(
tokenizer, pre_kwargs={"WITH_DEFAULT_INPUTS": True}, post_kwargs={"WITH_DEFAULT_INPUTS": True})
actual_ids, *_ = ort_inference(tok, [text])
det_text = ort_inference(detok, actual_ids)
np.testing.assert_array_equal(ids[0], actual_ids)
self.assertEqual(text, det_text)

def test_trie_tokenizer(self):
vocab_file = util.get_test_data_file(
Expand Down
2 changes: 1 addition & 1 deletion test/test_sentencepiece_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ def test_spm_decoder(self):
fullname = util.get_test_data_file('data', 'en.wiki.bpe.vs100000.model')
ofunc = OrtPyFunction.from_customop('SentencepieceDecoder', model=open(fullname, 'rb').read())

result = ofunc(np.array([1095, 4054, 26, 2022, 755, 99935], dtype=np.int64))
result = ofunc(np.array([1095, 4054, 26, 2022, 755, 99935], dtype=np.int64), np.array([False], dtype=np.bool_))
self.assertEqual(' '.join(result), 'best hotel in bay area.')


Expand Down
Loading