From b45880c53732aff1aa14a077e3116b9fbd0a5346 Mon Sep 17 00:00:00 2001 From: Shuming Hu Date: Thu, 23 Mar 2023 01:19:08 +0000 Subject: [PATCH] Optionally ignore utf-8 decoding error when converting std::string to python str. (#97282) Summary: When language models use c++ tokenizer, outputs are a c++ strings that are not necessarily valid utf-8 encodings. Default pybind11 casting uses strict utf-8 decoding. We relax the decoding using 'ignore' argument. Test Plan: https://www.internalfb.com/intern/testinfra/testrun/6473924609918070 Reviewed By: Nayef211 Differential Revision: D43970697 Pull Request resolved: https://github.com/pytorch/pytorch/pull/97282 Approved by: https://github.com/davidberard98 --- build_variables.bzl | 1 + torch/csrc/jit/python/init.cpp | 48 ++++++++++--------- torch/csrc/jit/python/pybind_utils.cpp | 9 +++- .../csrc/jit/python/utf8_decoding_ignore.cpp | 16 +++++++ torch/csrc/jit/python/utf8_decoding_ignore.h | 8 ++++ 5 files changed, 59 insertions(+), 23 deletions(-) create mode 100644 torch/csrc/jit/python/utf8_decoding_ignore.cpp create mode 100644 torch/csrc/jit/python/utf8_decoding_ignore.h diff --git a/build_variables.bzl b/build_variables.bzl index 20afef71720171..0f2ee809b58696 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -337,6 +337,7 @@ core_sources_full_mobile_no_backend_interface_xplat = [ "torch/csrc/jit/passes/quantization/fusion_passes.cpp", "torch/csrc/jit/passes/quantization/register_packed_params.cpp", "torch/csrc/jit/python/update_graph_executor_opt.cpp", + "torch/csrc/jit/python/utf8_decoding_ignore.cpp", "torch/csrc/jit/runtime/argument_spec.cpp", "torch/csrc/jit/runtime/autodiff.cpp", "torch/csrc/jit/runtime/graph_executor.cpp", diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 1ec6a444e8c0e8..15c3e8e3ad387b 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -85,6 +85,7 @@ #include #include #include +#include #include #include #include @@ -993,7 +994,7 @@ void initJITBindings(PyObject* module) { #ifdef TORCH_ENABLE_LLVM return true; #else - return false; + return false; #endif }) .def( @@ -1124,27 +1125,30 @@ void initJITBindings(PyObject* module) { return retval; }) .def("_jit_pass_batch_mm", BatchMM) - .def("_jit_decay_packed_param_input_types", [](Graph& g) { - for (Value* i : g.inputs()) { - if (i->type() == - getCustomClass( - "__torch__.torch.classes.quantized.Conv2dPackedParamsBase") || - i->type() == - getCustomClass( - "__torch__.torch.classes.quantized.Conv3dPackedParamsBase") || - i->type() == - getCustomClass( - "__torch__.torch.classes.quantized.LinearPackedParamsBase")) { - // Dummy CompleteTensorType to appease ONNX validator. - i->setType(TensorType::create( - at::kQInt8, - c10::kCPU, - std::vector{1}, - std::vector{1}, - c10::nullopt)); - } - } - }); + .def( + "_jit_decay_packed_param_input_types", + [](Graph& g) { + for (Value* i : g.inputs()) { + if (i->type() == + getCustomClass( + "__torch__.torch.classes.quantized.Conv2dPackedParamsBase") || + i->type() == + getCustomClass( + "__torch__.torch.classes.quantized.Conv3dPackedParamsBase") || + i->type() == + getCustomClass( + "__torch__.torch.classes.quantized.LinearPackedParamsBase")) { + // Dummy CompleteTensorType to appease ONNX validator. + i->setType(TensorType::create( + at::kQInt8, + c10::kCPU, + std::vector{1}, + std::vector{1}, + c10::nullopt)); + } + } + }) + .def("_jit_set_utf8_decoding_ignore", &setUTF8DecodingIgnore); // NB: This isn't actually used for regular PyTorch symbolic tracing; // XLA is what needs this diff --git a/torch/csrc/jit/python/pybind_utils.cpp b/torch/csrc/jit/python/pybind_utils.cpp index 221753ddc3f8f0..8a83030e291b68 100644 --- a/torch/csrc/jit/python/pybind_utils.cpp +++ b/torch/csrc/jit/python/pybind_utils.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include @@ -555,7 +556,13 @@ py::object toPyObject(IValue ivalue) { } else if (ivalue.isBool()) { return py::cast(std::move(ivalue).toBool()); } else if (ivalue.isString()) { - return py::cast(std::move(ivalue).toStringRef()); + if (getUTF8DecodingIgnore()) { + std::string s = std::move(ivalue).toStringRef(); + PyObject* pyObj = PyUnicode_DecodeUTF8(s.data(), s.length(), "ignore"); + return py::reinterpret_steal(pyObj); + } else { + return py::cast(std::move(ivalue).toStringRef()); + } } else if (ivalue.isList()) { auto list = std::move(ivalue).toList(); py::list t{list.size()}; diff --git a/torch/csrc/jit/python/utf8_decoding_ignore.cpp b/torch/csrc/jit/python/utf8_decoding_ignore.cpp new file mode 100644 index 00000000000000..406731c425f9fe --- /dev/null +++ b/torch/csrc/jit/python/utf8_decoding_ignore.cpp @@ -0,0 +1,16 @@ +#include + +namespace torch::jit { + +namespace { +thread_local bool kIgnore = false; +} + +void setUTF8DecodingIgnore(bool o) { + kIgnore = o; +} +bool getUTF8DecodingIgnore() { + return kIgnore; +} + +} // namespace torch::jit diff --git a/torch/csrc/jit/python/utf8_decoding_ignore.h b/torch/csrc/jit/python/utf8_decoding_ignore.h new file mode 100644 index 00000000000000..43ecfbcf0e208d --- /dev/null +++ b/torch/csrc/jit/python/utf8_decoding_ignore.h @@ -0,0 +1,8 @@ +#pragma once +#include +namespace torch { +namespace jit { +TORCH_API void setUTF8DecodingIgnore(bool o); +TORCH_API bool getUTF8DecodingIgnore(); +} // namespace jit +} // namespace torch