From 28254487791d3e81516d64e30984c7193850febb Mon Sep 17 00:00:00 2001 From: KalaivaniMCW Date: Thu, 30 Jan 2025 22:55:06 +0000 Subject: [PATCH] #13856: update gelu_bw ops on bf8b limitations --- .../binary_backward/binary_backward_pybind.hpp | 11 ++++++++--- .../eltwise/unary_backward/unary_backward_pybind.hpp | 4 +++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp index 5a5aa125598..9dac2ad46fb 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp @@ -312,7 +312,8 @@ void bind_binary_backward_bias_gelu( const std::string& parameter_b_doc, string parameter_b_value, const std::string& description, - const std::string& supported_dtype = "BFLOAT16") { + const std::string& supported_dtype = "BFLOAT16", + const std::string_view note = "") { auto doc = fmt::format( R"doc( @@ -345,6 +346,8 @@ void bind_binary_backward_bias_gelu( bfloat8_b/bfloat4_b is only supported on TILE_LAYOUT + {9} + Example: >>> grad_tensor = ttnn.from_torch(torch.tensor([[1, 2], [3, 4]], dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT, device=device) @@ -361,7 +364,8 @@ void bind_binary_backward_bias_gelu( parameter_b_doc, parameter_b_value, description, - supported_dtype); + supported_dtype, + note); bind_registered_operation( module, @@ -1317,7 +1321,8 @@ void py_module(py::module& module) { "none", R"doc(Performs backward operations for bias_gelu on :attr:`input_tensor_a` and :attr:`input_tensor_b` or :attr:`input_tensor` and :attr:`bias`, with given :attr:`grad_tensor` using given :attr:`approximate` mode. :attr:`approximate` mode can be 'none', 'tanh'.)doc", - R"doc(BFLOAT16, BFLOAT8_B)doc"); + R"doc(BFLOAT16)doc", + R"doc(For more details about BFLOAT8_B, refer to the `BFLOAT8_B limitations <../tensor.html#limitation-of-bfloat8-b>`_.)doc"); } } // namespace binary_backward diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp index d8206cf4eb9..d488b5b8f28 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp @@ -1376,7 +1376,9 @@ void py_module(py::module& module) { "Approximation type", "none", R"doc(Performs backward operations for gelu on :attr:`input_tensor`, with given :attr:`grad_tensor` using given :attr:`approximate` mode. - :attr:`approximate` mode can be 'none', 'tanh'.)doc"); + :attr:`approximate` mode can be 'none', 'tanh'.)doc", + R"doc(BFLOAT16)doc", + R"doc(For more details about BFLOAT8_B, refer to the `BFLOAT8_B limitations <../tensor.html#limitation-of-bfloat8-b>`_.)doc"); detail::bind_unary_backward_unary_optional_float( module,