diff --git a/build_variables.bzl b/build_variables.bzl index 20afef7172017..0f2ee809b5869 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -337,6 +337,7 @@ core_sources_full_mobile_no_backend_interface_xplat = [ "torch/csrc/jit/passes/quantization/fusion_passes.cpp", "torch/csrc/jit/passes/quantization/register_packed_params.cpp", "torch/csrc/jit/python/update_graph_executor_opt.cpp", + "torch/csrc/jit/python/utf8_decoding_ignore.cpp", "torch/csrc/jit/runtime/argument_spec.cpp", "torch/csrc/jit/runtime/autodiff.cpp", "torch/csrc/jit/runtime/graph_executor.cpp", diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 1ec6a444e8c0e..15c3e8e3ad387 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -85,6 +85,7 @@ #include #include #include +#include #include #include #include @@ -993,7 +994,7 @@ void initJITBindings(PyObject* module) { #ifdef TORCH_ENABLE_LLVM return true; #else - return false; + return false; #endif }) .def( @@ -1124,27 +1125,30 @@ void initJITBindings(PyObject* module) { return retval; }) .def("_jit_pass_batch_mm", BatchMM) - .def("_jit_decay_packed_param_input_types", [](Graph& g) { - for (Value* i : g.inputs()) { - if (i->type() == - getCustomClass( - "__torch__.torch.classes.quantized.Conv2dPackedParamsBase") || - i->type() == - getCustomClass( - "__torch__.torch.classes.quantized.Conv3dPackedParamsBase") || - i->type() == - getCustomClass( - "__torch__.torch.classes.quantized.LinearPackedParamsBase")) { - // Dummy CompleteTensorType to appease ONNX validator. - i->setType(TensorType::create( - at::kQInt8, - c10::kCPU, - std::vector{1}, - std::vector{1}, - c10::nullopt)); - } - } - }); + .def( + "_jit_decay_packed_param_input_types", + [](Graph& g) { + for (Value* i : g.inputs()) { + if (i->type() == + getCustomClass( + "__torch__.torch.classes.quantized.Conv2dPackedParamsBase") || + i->type() == + getCustomClass( + "__torch__.torch.classes.quantized.Conv3dPackedParamsBase") || + i->type() == + getCustomClass( + "__torch__.torch.classes.quantized.LinearPackedParamsBase")) { + // Dummy CompleteTensorType to appease ONNX validator. + i->setType(TensorType::create( + at::kQInt8, + c10::kCPU, + std::vector{1}, + std::vector{1}, + c10::nullopt)); + } + } + }) + .def("_jit_set_utf8_decoding_ignore", &setUTF8DecodingIgnore); // NB: This isn't actually used for regular PyTorch symbolic tracing; // XLA is what needs this diff --git a/torch/csrc/jit/python/pybind_utils.cpp b/torch/csrc/jit/python/pybind_utils.cpp index 221753ddc3f8f..8a83030e291b6 100644 --- a/torch/csrc/jit/python/pybind_utils.cpp +++ b/torch/csrc/jit/python/pybind_utils.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include @@ -555,7 +556,13 @@ py::object toPyObject(IValue ivalue) { } else if (ivalue.isBool()) { return py::cast(std::move(ivalue).toBool()); } else if (ivalue.isString()) { - return py::cast(std::move(ivalue).toStringRef()); + if (getUTF8DecodingIgnore()) { + std::string s = std::move(ivalue).toStringRef(); + PyObject* pyObj = PyUnicode_DecodeUTF8(s.data(), s.length(), "ignore"); + return py::reinterpret_steal(pyObj); + } else { + return py::cast(std::move(ivalue).toStringRef()); + } } else if (ivalue.isList()) { auto list = std::move(ivalue).toList(); py::list t{list.size()}; diff --git a/torch/csrc/jit/python/utf8_decoding_ignore.cpp b/torch/csrc/jit/python/utf8_decoding_ignore.cpp new file mode 100644 index 0000000000000..406731c425f9f --- /dev/null +++ b/torch/csrc/jit/python/utf8_decoding_ignore.cpp @@ -0,0 +1,16 @@ +#include + +namespace torch::jit { + +namespace { +thread_local bool kIgnore = false; +} + +void setUTF8DecodingIgnore(bool o) { + kIgnore = o; +} +bool getUTF8DecodingIgnore() { + return kIgnore; +} + +} // namespace torch::jit diff --git a/torch/csrc/jit/python/utf8_decoding_ignore.h b/torch/csrc/jit/python/utf8_decoding_ignore.h new file mode 100644 index 0000000000000..43ecfbcf0e208 --- /dev/null +++ b/torch/csrc/jit/python/utf8_decoding_ignore.h @@ -0,0 +1,8 @@ +#pragma once +#include +namespace torch { +namespace jit { +TORCH_API void setUTF8DecodingIgnore(bool o); +TORCH_API bool getUTF8DecodingIgnore(); +} // namespace jit +} // namespace torch