From de7494e32c0c76a928916c244d1cb985dc041685 Mon Sep 17 00:00:00 2001 From: Junjie Zhang <1356732652@qq.com> Date: Sat, 30 Nov 2024 20:08:33 +0800 Subject: [PATCH 01/17] add --- .../transforms/tensorrt/trt_op_marker_pass.cc | 32 ++++++ python/paddle/tensorrt/impls/others.py | 104 ++++++++++++++++++ test/tensorrt/test_converter_others.py | 85 ++++++++++++++ 3 files changed, 221 insertions(+) diff --git a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc index e6887de9618de5..32efb704b525e8 100644 --- a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc +++ b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc @@ -2094,6 +2094,37 @@ class AssignValueOpPattern } }; +class TemporalShiftOpPattern + : public pir::OpRewritePattern { + public: + using pir::OpRewritePattern::OpRewritePattern; + + bool MatchAndRewrite(paddle::dialect::TemporalShiftOp op, + pir::PatternRewriter &rewriter) const override { + if (op->HasAttribute(kCanRunTrtAttr) && + op.attribute(kCanRunTrtAttr).data()) { + return false; + } +#if IS_TRT_VERSION_LT(8200) + VLOG(3) << "temporal_shift is not supported when TensorRT < 8.2"; + return false; +#endif + if (!op->HasAttribute("shift_ratio") || !op->HasAttribute("seg_num")) { + VLOG(3) << "temporal shift need attributes : shift_ratio and seg_num"; + return false; + } + auto x = op.operand_source(0); + auto x_shape = pir::GetShapeFromValue(x); + if (x_shape.size() != 4) { + VLOG(3) << "The input and grid tensors must be shape tensors of rank 4 " + "when using TRT TemporalShift layer."; + return false; + } + op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true)); + return true; + } +}; + class TrtOpMarkerPass : public pir::PatternRewritePass { public: TrtOpMarkerPass() : pir::PatternRewritePass("trt_op_marker_pass", 2) {} @@ -2207,6 +2238,7 @@ class TrtOpMarkerPass : public pir::PatternRewritePass { ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); + ps.Add(std::make_unique(context)); return ps; } }; diff --git a/python/paddle/tensorrt/impls/others.py b/python/paddle/tensorrt/impls/others.py index 490709a6f06fa4..f4808d28b921cc 100644 --- a/python/paddle/tensorrt/impls/others.py +++ b/python/paddle/tensorrt/impls/others.py @@ -299,3 +299,107 @@ def share_data_converter(network, paddle_op, inputs): identity_layer = network.add_identity(x) return identity_layer.get_output(0) + + +@converter_registry.register("pd_op.temporal_shift", trt_version="8.x") +def temporal_shift_converter(network, paddle_op, inputs): + input_tensor = inputs[0] + shift_ratio = paddle_op.attrs().get("shift_ratio") + T = paddle_op.attrs().get("seg_num") + data_format = paddle_op.attrs().get("data_format", "NCHW") + + if data_format == "NHWC": + # Transpose input to [N, C, H, W] + transpose_layer = network.add_shuffle(input_tensor) + transpose_layer.first_transpose = trt.Permutation([0, 3, 1, 2]) + input_tensor = transpose_layer.get_output(0) + + input_dims = input_tensor.shape + C, H, W = input_dims[1], input_dims[2], input_dims[3] + + # Reshape input to [N, T, C, H, W] + reshape_layer = network.add_shuffle(input_tensor) + reshape_layer.reshape_dims = trt.Dims([-1, T, C, H, W]) + input_tensor = reshape_layer.get_output(0) + + # Pad input to [N, T + 2, C, H, W] + pre_pad = add_1D_constant_layer(network, [0, 1, 0, 0, 0]) + post_pad = add_1D_constant_layer(network, [0, 1, 0, 0, 0]) + dims = 5 + zeros = add_1D_constant_layer(network, [0] * dims) + start = network.add_elementwise( + zeros, pre_pad, trt.ElementWiseOperation.SUB + ).get_output(0) + total_padding = network.add_elementwise( + pre_pad, post_pad, trt.ElementWiseOperation.SUM + ).get_output(0) + input_shape = trt_shape(network, input_tensor) + size = network.add_elementwise( + input_shape, total_padding, trt.ElementWiseOperation.SUM + ).get_output(0) + stride = [1] * dims + dummy = stride + + slice_layer = network.add_slice(input_tensor, dummy, dummy, stride) + slice_layer.set_input(1, start) + slice_layer.set_input(2, size) + + trt_version = trt.__version__.split('.') + if int(trt_version[0]) > 8 or (int(trt_version[0]) == 8 and int(trt_version[1]) >= 5): + slice_layer.mode = trt.SampleMode.FILL + else: + slice_layer.mode = trt.SliceMode.FILL + + slice_c = int(C * shift_ratio) + slice_c2 = int(C * shift_ratio * 2) + + slice_start1 = zeros + slice_start2 = add_1D_constant_layer(network, [0, 2, slice_c, 0, 0]) + slice_start3 = add_1D_constant_layer(network, [0, 1, slice_c2, 0, 0]) + + slice_size_base = trt_shape(network, input_tensor) + sub_size1 = add_1D_constant_layer(network, [0, 0, C - slice_c, 0, 0]) + sub_size2 = add_1D_constant_layer(network, [0, 0, C + slice_c - slice_c2, 0, 0]) + sub_size3 = add_1D_constant_layer(network, [0, 0, slice_c2, 0, 0]) + + slice_size1 = network.add_elementwise( + slice_size_base, sub_size1, trt.ElementWiseOperation.SUB + ).get_output(0) + slice_size2 = network.add_elementwise( + slice_size_base, sub_size2, trt.ElementWiseOperation.SUB + ).get_output(0) + slice_size3 = network.add_elementwise( + slice_size_base, sub_size3, trt.ElementWiseOperation.SUB + ).get_output(0) + + slice1_layer = network.add_slice(slice_layer.get_output(0), start=dummy, shape=dummy, stride=stride) + slice1_layer.set_input(1, slice_start1) + slice1_layer.set_input(2, slice_size1) + slice2_layer = network.add_slice(slice_layer.get_output(0), start=dummy, shape=dummy, stride=stride) + slice2_layer.set_input(1, slice_start2) + slice2_layer.set_input(2, slice_size2) + slice3_layer = network.add_slice(slice_layer.get_output(0), start=dummy, shape=dummy, stride=stride) + slice3_layer.set_input(1, slice_start3) + slice3_layer.set_input(2, slice_size3) + + if slice_c == 0: + concat_inputs = [slice2_layer.get_output(0), slice3_layer.get_output(0)] + concat_layer = network.add_concatenation(concat_inputs) + concat_layer.axis = 2 + else: + concat_inputs = [slice1_layer.get_output(0), slice2_layer.get_output(0), slice3_layer.get_output(0)] + concat_layer = network.add_concatenation(concat_inputs) + concat_layer.axis = 2 + + # Reshape output to [N*T,C,H,W] + reshape_layer = network.add_shuffle(concat_layer.get_output(0)) + reshape_layer.reshape_dims = trt.Dims(inputs[0].shape) + + if data_format == "NHWC": + transpose_layer = network.add_shuffle(reshape_layer.get_output(0)) + transpose_layer.first_transpose = trt.Permutation([0, 2, 3, 1]) + output_tensor = transpose_layer.get_output(0) + else: + output_tensor = reshape_layer.get_output(0) + + return output_tensor diff --git a/test/tensorrt/test_converter_others.py b/test/tensorrt/test_converter_others.py index f9f7171e647609..b3fceb20158ca9 100644 --- a/test/tensorrt/test_converter_others.py +++ b/test/tensorrt/test_converter_others.py @@ -394,5 +394,90 @@ def test_trt_result(self): self.check_trt_result() +class TestTemporalShiftTRTPatternBasic(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.nn.functional.temporal_shift + self.api_args = { + "x": np.random.random([4, 9, 7, 7]).astype(np.float32), + "seg_num": 2, + "shift_ratio": 0.2, + "data_format": "NCHW" + } + self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [2, 9, 7, 7]} + self.max_shape = {"x": [8, 9, 7, 7]} + + def test_trt_result(self): + self.check_trt_result() + + +class TestTemporalShiftTRTPatternDifferentSegNum(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.nn.functional.temporal_shift + self.api_args = { + "x": np.random.random([4, 9, 7, 7]).astype(np.float32), + "seg_num": 4, + "shift_ratio": 0.2, + "data_format": "NCHW" + } + self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [4, 9, 7, 7]} + self.max_shape = {"x": [8, 9, 7, 7]} + + def test_trt_result(self): + self.check_trt_result() + + +class TestTemporalShiftTRTPatternDifferentShiftRatio(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.nn.functional.temporal_shift + self.api_args = { + "x": np.random.random([4, 9, 7, 7]).astype(np.float32), + "seg_num": 2, + "shift_ratio": 0.4, + "data_format": "NCHW" + } + self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [2, 9, 7, 7]} + self.max_shape = {"x": [8, 9, 7, 7]} + + def test_trt_result(self): + self.check_trt_result() + + +class TestTemporalShiftTRTPatternDifferentDataFormat(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.nn.functional.temporal_shift + self.api_args = { + "x": np.random.random([4, 9, 7, 7]).astype(np.float32), + "seg_num": 2, + "shift_ratio": 0.2, + "data_format": "NHWC" + } + self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [2, 9, 7, 7]} + self.max_shape = {"x": [8, 9, 7, 7]} + + def test_trt_result(self): + self.check_trt_result() + + +class TestTemporalShiftTRTPatternMinMaxShape(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.nn.functional.temporal_shift + self.api_args = { + "x": np.random.random([4, 9, 7, 7]).astype(np.float32), + "seg_num": 2, + "shift_ratio": 0.2, + "data_format": "NCHW", + } + self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [2, 9, 7, 7]} + self.max_shape = {"x": [10, 9, 7, 7]} + + def test_trt_result(self): + self.check_trt_result() + + if __name__ == '__main__': unittest.main() From f57496c51aed29f73130615bb679a08ae6e6eabd Mon Sep 17 00:00:00 2001 From: Junjie Zhang <1356732652@qq.com> Date: Sat, 30 Nov 2024 22:21:40 +0800 Subject: [PATCH 02/17] fix codestyle --- .../transforms/tensorrt/trt_op_marker_pass.cc | 7 ++--- python/paddle/tensorrt/impls/others.py | 26 ++++++++++++++----- test/tensorrt/test_converter_others.py | 8 +++--- 3 files changed, 28 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc index 32efb704b525e8..f170549ba00a47 100644 --- a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc +++ b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc @@ -2097,14 +2097,15 @@ class AssignValueOpPattern class TemporalShiftOpPattern : public pir::OpRewritePattern { public: - using pir::OpRewritePattern::OpRewritePattern; + using pir::OpRewritePattern< + paddle::dialect::TemporalShiftOp>::OpRewritePattern; bool MatchAndRewrite(paddle::dialect::TemporalShiftOp op, pir::PatternRewriter &rewriter) const override { if (op->HasAttribute(kCanRunTrtAttr) && op.attribute(kCanRunTrtAttr).data()) { - return false; - } + return false; + } #if IS_TRT_VERSION_LT(8200) VLOG(3) << "temporal_shift is not supported when TensorRT < 8.2"; return false; diff --git a/python/paddle/tensorrt/impls/others.py b/python/paddle/tensorrt/impls/others.py index f4808d28b921cc..2682a1c6bd978b 100644 --- a/python/paddle/tensorrt/impls/others.py +++ b/python/paddle/tensorrt/impls/others.py @@ -345,7 +345,9 @@ def temporal_shift_converter(network, paddle_op, inputs): slice_layer.set_input(2, size) trt_version = trt.__version__.split('.') - if int(trt_version[0]) > 8 or (int(trt_version[0]) == 8 and int(trt_version[1]) >= 5): + if int(trt_version[0]) > 8 or ( + int(trt_version[0]) == 8 and int(trt_version[1]) >= 5 + ): slice_layer.mode = trt.SampleMode.FILL else: slice_layer.mode = trt.SliceMode.FILL @@ -359,7 +361,9 @@ def temporal_shift_converter(network, paddle_op, inputs): slice_size_base = trt_shape(network, input_tensor) sub_size1 = add_1D_constant_layer(network, [0, 0, C - slice_c, 0, 0]) - sub_size2 = add_1D_constant_layer(network, [0, 0, C + slice_c - slice_c2, 0, 0]) + sub_size2 = add_1D_constant_layer( + network, [0, 0, C + slice_c - slice_c2, 0, 0] + ) sub_size3 = add_1D_constant_layer(network, [0, 0, slice_c2, 0, 0]) slice_size1 = network.add_elementwise( @@ -372,13 +376,19 @@ def temporal_shift_converter(network, paddle_op, inputs): slice_size_base, sub_size3, trt.ElementWiseOperation.SUB ).get_output(0) - slice1_layer = network.add_slice(slice_layer.get_output(0), start=dummy, shape=dummy, stride=stride) + slice1_layer = network.add_slice( + slice_layer.get_output(0), start=dummy, shape=dummy, stride=stride + ) slice1_layer.set_input(1, slice_start1) slice1_layer.set_input(2, slice_size1) - slice2_layer = network.add_slice(slice_layer.get_output(0), start=dummy, shape=dummy, stride=stride) + slice2_layer = network.add_slice( + slice_layer.get_output(0), start=dummy, shape=dummy, stride=stride + ) slice2_layer.set_input(1, slice_start2) slice2_layer.set_input(2, slice_size2) - slice3_layer = network.add_slice(slice_layer.get_output(0), start=dummy, shape=dummy, stride=stride) + slice3_layer = network.add_slice( + slice_layer.get_output(0), start=dummy, shape=dummy, stride=stride + ) slice3_layer.set_input(1, slice_start3) slice3_layer.set_input(2, slice_size3) @@ -387,7 +397,11 @@ def temporal_shift_converter(network, paddle_op, inputs): concat_layer = network.add_concatenation(concat_inputs) concat_layer.axis = 2 else: - concat_inputs = [slice1_layer.get_output(0), slice2_layer.get_output(0), slice3_layer.get_output(0)] + concat_inputs = [ + slice1_layer.get_output(0), + slice2_layer.get_output(0), + slice3_layer.get_output(0), + ] concat_layer = network.add_concatenation(concat_inputs) concat_layer.axis = 2 diff --git a/test/tensorrt/test_converter_others.py b/test/tensorrt/test_converter_others.py index b3fceb20158ca9..b5f756e258aa9f 100644 --- a/test/tensorrt/test_converter_others.py +++ b/test/tensorrt/test_converter_others.py @@ -401,7 +401,7 @@ def setUp(self): "x": np.random.random([4, 9, 7, 7]).astype(np.float32), "seg_num": 2, "shift_ratio": 0.2, - "data_format": "NCHW" + "data_format": "NCHW", } self.program_config = {"feed_list": ["x"]} self.min_shape = {"x": [2, 9, 7, 7]} @@ -418,7 +418,7 @@ def setUp(self): "x": np.random.random([4, 9, 7, 7]).astype(np.float32), "seg_num": 4, "shift_ratio": 0.2, - "data_format": "NCHW" + "data_format": "NCHW", } self.program_config = {"feed_list": ["x"]} self.min_shape = {"x": [4, 9, 7, 7]} @@ -435,7 +435,7 @@ def setUp(self): "x": np.random.random([4, 9, 7, 7]).astype(np.float32), "seg_num": 2, "shift_ratio": 0.4, - "data_format": "NCHW" + "data_format": "NCHW", } self.program_config = {"feed_list": ["x"]} self.min_shape = {"x": [2, 9, 7, 7]} @@ -452,7 +452,7 @@ def setUp(self): "x": np.random.random([4, 9, 7, 7]).astype(np.float32), "seg_num": 2, "shift_ratio": 0.2, - "data_format": "NHWC" + "data_format": "NHWC", } self.program_config = {"feed_list": ["x"]} self.min_shape = {"x": [2, 9, 7, 7]} From a70de787ccbfa80a2a971c1c86bc4fecbc2b55c2 Mon Sep 17 00:00:00 2001 From: Junjie Zhang <1356732652@qq.com> Date: Thu, 5 Dec 2024 18:22:23 +0800 Subject: [PATCH 03/17] update --- python/paddle/tensorrt/impls/others.py | 25 +++++++------------------ 1 file changed, 7 insertions(+), 18 deletions(-) diff --git a/python/paddle/tensorrt/impls/others.py b/python/paddle/tensorrt/impls/others.py index 2682a1c6bd978b..e81bb6f04f0bae 100644 --- a/python/paddle/tensorrt/impls/others.py +++ b/python/paddle/tensorrt/impls/others.py @@ -26,6 +26,7 @@ trt_concat, trt_prod, trt_shape, + trt_sub, trt_sum, ) from paddle.tensorrt.register import converter_registry @@ -327,16 +328,10 @@ def temporal_shift_converter(network, paddle_op, inputs): post_pad = add_1D_constant_layer(network, [0, 1, 0, 0, 0]) dims = 5 zeros = add_1D_constant_layer(network, [0] * dims) - start = network.add_elementwise( - zeros, pre_pad, trt.ElementWiseOperation.SUB - ).get_output(0) - total_padding = network.add_elementwise( - pre_pad, post_pad, trt.ElementWiseOperation.SUM - ).get_output(0) + start = trt_sum(network, zeros, pre_pad) + total_padding = trt_sum(network, pre_pad, post_pad) input_shape = trt_shape(network, input_tensor) - size = network.add_elementwise( - input_shape, total_padding, trt.ElementWiseOperation.SUM - ).get_output(0) + size = trt_sum(network, input_shape, total_padding) stride = [1] * dims dummy = stride @@ -366,15 +361,9 @@ def temporal_shift_converter(network, paddle_op, inputs): ) sub_size3 = add_1D_constant_layer(network, [0, 0, slice_c2, 0, 0]) - slice_size1 = network.add_elementwise( - slice_size_base, sub_size1, trt.ElementWiseOperation.SUB - ).get_output(0) - slice_size2 = network.add_elementwise( - slice_size_base, sub_size2, trt.ElementWiseOperation.SUB - ).get_output(0) - slice_size3 = network.add_elementwise( - slice_size_base, sub_size3, trt.ElementWiseOperation.SUB - ).get_output(0) + slice_size1 = trt_sub(network, slice_size_base, sub_size1) + slice_size2 = trt_sub(network, slice_size_base, sub_size2) + slice_size3 = trt_sub(network, slice_size_base, sub_size3) slice1_layer = network.add_slice( slice_layer.get_output(0), start=dummy, shape=dummy, stride=stride From 1187a2f4757d80f9128ae445af9a893c7e31c39f Mon Sep 17 00:00:00 2001 From: Junjie Zhang <1356732652@qq.com> Date: Thu, 12 Dec 2024 10:31:45 +0800 Subject: [PATCH 04/17] Update trt_op_marker_pass.cc --- paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc | 4 ---- 1 file changed, 4 deletions(-) diff --git a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc index f170549ba00a47..4ab30a2f374c20 100644 --- a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc +++ b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc @@ -2106,10 +2106,6 @@ class TemporalShiftOpPattern op.attribute(kCanRunTrtAttr).data()) { return false; } -#if IS_TRT_VERSION_LT(8200) - VLOG(3) << "temporal_shift is not supported when TensorRT < 8.2"; - return false; -#endif if (!op->HasAttribute("shift_ratio") || !op->HasAttribute("seg_num")) { VLOG(3) << "temporal shift need attributes : shift_ratio and seg_num"; return false; From cf49db556bdf30ae71f198db22a7a621ca91c5ab Mon Sep 17 00:00:00 2001 From: Junjie Zhang <1356732652@qq.com> Date: Mon, 16 Dec 2024 15:12:00 +0800 Subject: [PATCH 05/17] add_fp16 --- python/paddle/tensorrt/impls/others.py | 16 ++++++++-------- test/tensorrt/test_converter_others.py | 25 ++++++++++++++++++++----- 2 files changed, 28 insertions(+), 13 deletions(-) diff --git a/python/paddle/tensorrt/impls/others.py b/python/paddle/tensorrt/impls/others.py index e81bb6f04f0bae..e9fcc64e1881a8 100644 --- a/python/paddle/tensorrt/impls/others.py +++ b/python/paddle/tensorrt/impls/others.py @@ -328,7 +328,7 @@ def temporal_shift_converter(network, paddle_op, inputs): post_pad = add_1D_constant_layer(network, [0, 1, 0, 0, 0]) dims = 5 zeros = add_1D_constant_layer(network, [0] * dims) - start = trt_sum(network, zeros, pre_pad) + start = trt_sub(network, zeros, pre_pad) total_padding = trt_sum(network, pre_pad, post_pad) input_shape = trt_shape(network, input_tensor) size = trt_sum(network, input_shape, total_padding) @@ -381,8 +381,8 @@ def temporal_shift_converter(network, paddle_op, inputs): slice3_layer.set_input(1, slice_start3) slice3_layer.set_input(2, slice_size3) + concat_inputs = [slice2_layer.get_output(0), slice3_layer.get_output(0)] if slice_c == 0: - concat_inputs = [slice2_layer.get_output(0), slice3_layer.get_output(0)] concat_layer = network.add_concatenation(concat_inputs) concat_layer.axis = 2 else: @@ -395,14 +395,14 @@ def temporal_shift_converter(network, paddle_op, inputs): concat_layer.axis = 2 # Reshape output to [N*T,C,H,W] - reshape_layer = network.add_shuffle(concat_layer.get_output(0)) - reshape_layer.reshape_dims = trt.Dims(inputs[0].shape) + reshape_layer3 = network.add_shuffle(concat_layer.get_output(0)) + reshape_layer3.reshape_dims = trt.Dims(inputs[0].shape) if data_format == "NHWC": - transpose_layer = network.add_shuffle(reshape_layer.get_output(0)) - transpose_layer.first_transpose = trt.Permutation([0, 2, 3, 1]) - output_tensor = transpose_layer.get_output(0) + transpose_layer2 = network.add_shuffle(reshape_layer3.get_output(0)) + transpose_layer2.first_transpose = trt.Permutation([0, 2, 3, 1]) + output_tensor = transpose_layer2.get_output(0) else: - output_tensor = reshape_layer.get_output(0) + output_tensor = reshape_layer3.get_output(0) return output_tensor diff --git a/test/tensorrt/test_converter_others.py b/test/tensorrt/test_converter_others.py index b5f756e258aa9f..86d06a216cfaff 100644 --- a/test/tensorrt/test_converter_others.py +++ b/test/tensorrt/test_converter_others.py @@ -407,7 +407,10 @@ def setUp(self): self.min_shape = {"x": [2, 9, 7, 7]} self.max_shape = {"x": [8, 9, 7, 7]} - def test_trt_result(self): + def test_trt_result_fp16(self): + self.check_trt_result(precision_mode="fp16") + + def test_trt_result_fp32(self): self.check_trt_result() @@ -424,7 +427,10 @@ def setUp(self): self.min_shape = {"x": [4, 9, 7, 7]} self.max_shape = {"x": [8, 9, 7, 7]} - def test_trt_result(self): + def test_trt_result_fp16(self): + self.check_trt_result(precision_mode="fp16") + + def test_trt_result_fp32(self): self.check_trt_result() @@ -441,7 +447,10 @@ def setUp(self): self.min_shape = {"x": [2, 9, 7, 7]} self.max_shape = {"x": [8, 9, 7, 7]} - def test_trt_result(self): + def test_trt_result_fp16(self): + self.check_trt_result(precision_mode="fp16") + + def test_trt_result_fp32(self): self.check_trt_result() @@ -458,7 +467,10 @@ def setUp(self): self.min_shape = {"x": [2, 9, 7, 7]} self.max_shape = {"x": [8, 9, 7, 7]} - def test_trt_result(self): + def test_trt_result_fp16(self): + self.check_trt_result(precision_mode="fp16") + + def test_trt_result_fp32(self): self.check_trt_result() @@ -475,7 +487,10 @@ def setUp(self): self.min_shape = {"x": [2, 9, 7, 7]} self.max_shape = {"x": [10, 9, 7, 7]} - def test_trt_result(self): + def test_trt_result_fp16(self): + self.check_trt_result(precision_mode="fp16") + + def test_trt_result_fp32(self): self.check_trt_result() From 1f5b7ea92c14f6fb715eba4717a1a8bf7c2059ad Mon Sep 17 00:00:00 2001 From: Junjie Zhang <1356732652@qq.com> Date: Wed, 18 Dec 2024 22:13:57 +0800 Subject: [PATCH 06/17] Update trt_op_marker_pass.cc --- .../transforms/tensorrt/trt_op_marker_pass.cc | 65 ------------------- 1 file changed, 65 deletions(-) diff --git a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc index 06d02a19da2dd0..c65e6e2f721f3d 100644 --- a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc +++ b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc @@ -2078,69 +2078,6 @@ class OneHotOpPattern } }; -bool CheckStaticShape(const pir::Operation *op) { - std::vector vec_shape; - auto shape_attr = op->attribute("shape").dyn_cast(); - for (const auto &attr : shape_attr.AsVector()) { - vec_shape.push_back(attr.dyn_cast().data()); - } - for (int32_t dim : vec_shape) { - if (dim == -1) { - VLOG(3) << "pd_op.assign_value_ or pd_op.assign_value cannot support " - "dynamic shape"; - return false; - } - } - int shape_size = vec_shape.size(); - int values_count = - op->attribute("values").dyn_cast().size(); - if (shape_size != values_count) { - VLOG(3) << "pd_op.assign_value or pd_op.assign_value shape size is not " - "equal to the values size"; - return false; - } - return true; -} - -class AssignValue_OpPattern - : public pir::OpRewritePattern { - public: - using pir::OpRewritePattern< - paddle::dialect::AssignValue_Op>::OpRewritePattern; - bool MatchAndRewrite(paddle::dialect::AssignValue_Op op, - pir::PatternRewriter &rewriter) const override { - if (op->HasAttribute(kCanRunTrtAttr) && - op->attribute(kCanRunTrtAttr).data()) { - return false; - } - if (!CheckStaticShape(op)) { - return false; - } - - op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true)); - return true; - } -}; - -class AssignValueOpPattern - : public pir::OpRewritePattern { - public: - using pir::OpRewritePattern::OpRewritePattern; - bool MatchAndRewrite(paddle::dialect::AssignValueOp op, - pir::PatternRewriter &rewriter) const override { - if (op->HasAttribute(kCanRunTrtAttr) && - op->attribute(kCanRunTrtAttr).data()) { - return false; - } - if (!CheckStaticShape(op)) { - return false; - } - - op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true)); - return true; - } -}; - class TemporalShiftOpPattern : public pir::OpRewritePattern { public: @@ -2290,8 +2227,6 @@ class TrtOpMarkerPass : public pir::PatternRewritePass { ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); - ps.Add(std::make_unique(context)); - ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); return ps; } From b10c6a58f34ef052e639ba14e22c611bf2881574 Mon Sep 17 00:00:00 2001 From: Junjie Zhang <1356732652@qq.com> Date: Tue, 24 Dec 2024 23:43:16 +0800 Subject: [PATCH 07/17] update --- .../transforms/tensorrt/trt_op_marker_pass.cc | 11 ++-- test/tensorrt/test_converter_others.py | 54 +++++++++++++++++++ 2 files changed, 61 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc index 269fd082f2eaf0..746160d315488a 100644 --- a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc +++ b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc @@ -2084,15 +2084,18 @@ class TemporalShiftOpPattern op.attribute(kCanRunTrtAttr).data()) { return false; } - if (!op->HasAttribute("shift_ratio") || !op->HasAttribute("seg_num")) { - VLOG(3) << "temporal shift need attributes : shift_ratio and seg_num"; + if (!op->HasAttribute("shift_ratio")) { + VLOG(3) << "temporal shift need attributes : shift_ratio"; + return false; + } + if (!op->HasAttribute("seg_num")) { + VLOG(3) << "temporal shift need attributes : seg_num"; return false; } auto x = op.operand_source(0); auto x_shape = pir::GetShapeFromValue(x); if (x_shape.size() != 4) { - VLOG(3) << "The input and grid tensors must be shape tensors of rank 4 " - "when using TRT TemporalShift layer."; + VLOG(3) << "The input and grid tensors must be shape tensors of rank 4."; return false; } diff --git a/test/tensorrt/test_converter_others.py b/test/tensorrt/test_converter_others.py index f3138c053ca9d7..2ad77883980749 100644 --- a/test/tensorrt/test_converter_others.py +++ b/test/tensorrt/test_converter_others.py @@ -494,5 +494,59 @@ def test_trt_result_fp32(self): self.check_trt_result() +class TestTemporalShiftTRTPatternDifferentDataFormat1(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.nn.functional.temporal_shift + self.api_args = { + "x": np.random.random([4, 9, 7, 7]).astype(np.float32), + "seg_num": 2, + "shift_ratio": 0.2, + "data_format": "NHWC", + } + self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [2, 9, 7, 7]} + self.max_shape = {"x": [10, 9, 7, 7]} + + def test_trt_result_fp16(self): + self.check_trt_result(precision_mode="fp16") + + def test_trt_result_fp32(self): + self.check_trt_result() + + +class TestTemporalShiftTRTPatternError1(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.nn.functional.temporal_shift + self.api_args = { + "x": np.random.random([4, 9, 7, 7]).astype(np.float32), + "seg_num": 2, + "#": 0.2, + "data_format": "NHWC", + } + self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [2, 9, 7, 7]} + self.max_shape = {"x": [10, 9, 7, 7]} + + def test_trt_result(self): + self.check_marker(expected_result=False) + + +class TestTemporalShiftTRTPatternError2(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.nn.functional.temporal_shift + self.api_args = { + "x": np.random.random([4, 9, 7, 7]).astype(np.float32), + "#": 2, + "shift_ratio": 0.2, + "data_format": "NHWC", + } + self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [2, 9, 7, 7]} + self.max_shape = {"x": [10, 9, 7, 7]} + + def test_trt_result(self): + self.check_marker(expected_result=False) + + if __name__ == '__main__': unittest.main() From 57a1289440c8faa5ee558a18cda1af5b3892f525 Mon Sep 17 00:00:00 2001 From: Junjie Zhang <1356732652@qq.com> Date: Wed, 25 Dec 2024 09:52:58 +0800 Subject: [PATCH 08/17] Update trt_op_marker_pass.cc --- paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc index 746160d315488a..afa904178da472 100644 --- a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc +++ b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc @@ -2099,6 +2099,11 @@ class TemporalShiftOpPattern return false; } + op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true)); + return true; + } +}; + class InstanceNormOpPattern : public pir::OpRewritePattern { public: From f9b97506d3774a0a35edbc1ee32d6b4fca012cb1 Mon Sep 17 00:00:00 2001 From: Junjie Zhang <1356732652@qq.com> Date: Wed, 25 Dec 2024 15:38:03 +0800 Subject: [PATCH 09/17] fix --- python/paddle/tensorrt/impls/others.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/paddle/tensorrt/impls/others.py b/python/paddle/tensorrt/impls/others.py index e57a348f46fd7f..b9f101a8a0eaff 100644 --- a/python/paddle/tensorrt/impls/others.py +++ b/python/paddle/tensorrt/impls/others.py @@ -307,9 +307,9 @@ def share_data_converter(network, paddle_op, inputs): @converter_registry.register("pd_op.temporal_shift", trt_version="8.x") def temporal_shift_converter(network, paddle_op, inputs): input_tensor = inputs[0] - shift_ratio = paddle_op.attrs().get("shift_ratio") - T = paddle_op.attrs().get("seg_num") - data_format = paddle_op.attrs().get("data_format", "NCHW") + shift_ratio = paddle_op.attrs()["shift_ratio"] + T = paddle_op.attrs()["seg_num"] + data_format = paddle_op.attrs()["data_format"] if data_format == "NHWC": # Transpose input to [N, C, H, W] From e2183120b8e812d5a7a3e0403abe170552d12443 Mon Sep 17 00:00:00 2001 From: Junjie Zhang <1356732652@qq.com> Date: Fri, 3 Jan 2025 10:59:03 +0800 Subject: [PATCH 10/17] fix --- .../transforms/tensorrt/trt_op_marker_pass.cc | 8 +--- python/paddle/tensorrt/impls/others.py | 4 +- test/tensorrt/test_converter_others.py | 41 +++++-------------- 3 files changed, 15 insertions(+), 38 deletions(-) diff --git a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc index afa904178da472..e623441d70701d 100644 --- a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc +++ b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc @@ -2084,12 +2084,8 @@ class TemporalShiftOpPattern op.attribute(kCanRunTrtAttr).data()) { return false; } - if (!op->HasAttribute("shift_ratio")) { - VLOG(3) << "temporal shift need attributes : shift_ratio"; - return false; - } - if (!op->HasAttribute("seg_num")) { - VLOG(3) << "temporal shift need attributes : seg_num"; + if (!op->HasAttribute("shift_ratio") || !op->HasAttribute("seg_num")) { + VLOG(3) << "temporal shift need attributes : shift_ratio and seg_num"; return false; } auto x = op.operand_source(0); diff --git a/python/paddle/tensorrt/impls/others.py b/python/paddle/tensorrt/impls/others.py index b9f101a8a0eaff..74f4980735353d 100644 --- a/python/paddle/tensorrt/impls/others.py +++ b/python/paddle/tensorrt/impls/others.py @@ -309,7 +309,7 @@ def temporal_shift_converter(network, paddle_op, inputs): input_tensor = inputs[0] shift_ratio = paddle_op.attrs()["shift_ratio"] T = paddle_op.attrs()["seg_num"] - data_format = paddle_op.attrs()["data_format"] + data_format = paddle_op.attrs().get("data_format", "NCHW") if data_format == "NHWC": # Transpose input to [N, C, H, W] @@ -398,7 +398,7 @@ def temporal_shift_converter(network, paddle_op, inputs): # Reshape output to [N*T,C,H,W] reshape_layer3 = network.add_shuffle(concat_layer.get_output(0)) - reshape_layer3.reshape_dims = trt.Dims(inputs[0].shape) + reshape_layer3.reshape_dims = trt.Dims([-1, C, H, W]) if data_format == "NHWC": transpose_layer2 = network.add_shuffle(reshape_layer3.get_output(0)) diff --git a/test/tensorrt/test_converter_others.py b/test/tensorrt/test_converter_others.py index 2ad77883980749..77e4148c33e222 100644 --- a/test/tensorrt/test_converter_others.py +++ b/test/tensorrt/test_converter_others.py @@ -461,6 +461,7 @@ def setUp(self): "x": np.random.random([4, 9, 7, 7]).astype(np.float32), "seg_num": 2, "shift_ratio": 0.2, + "name": None, "data_format": "NHWC", } self.program_config = {"feed_list": ["x"]} @@ -494,58 +495,38 @@ def test_trt_result_fp32(self): self.check_trt_result() -class TestTemporalShiftTRTPatternDifferentDataFormat1(TensorRTBaseTest): - def setUp(self): - self.python_api = paddle.nn.functional.temporal_shift - self.api_args = { - "x": np.random.random([4, 9, 7, 7]).astype(np.float32), - "seg_num": 2, - "shift_ratio": 0.2, - "data_format": "NHWC", - } - self.program_config = {"feed_list": ["x"]} - self.min_shape = {"x": [2, 9, 7, 7]} - self.max_shape = {"x": [10, 9, 7, 7]} - - def test_trt_result_fp16(self): - self.check_trt_result(precision_mode="fp16") - - def test_trt_result_fp32(self): - self.check_trt_result() - - class TestTemporalShiftTRTPatternError1(TensorRTBaseTest): def setUp(self): self.python_api = paddle.nn.functional.temporal_shift self.api_args = { "x": np.random.random([4, 9, 7, 7]).astype(np.float32), - "seg_num": 2, - "#": 0.2, - "data_format": "NHWC", } self.program_config = {"feed_list": ["x"]} self.min_shape = {"x": [2, 9, 7, 7]} self.max_shape = {"x": [10, 9, 7, 7]} def test_trt_result(self): - self.check_marker(expected_result=False) + with self.assertRaises(TypeError) as context: + self.check_marker(expected_result=False) class TestTemporalShiftTRTPatternError2(TensorRTBaseTest): def setUp(self): self.python_api = paddle.nn.functional.temporal_shift self.api_args = { - "x": np.random.random([4, 9, 7, 7]).astype(np.float32), - "#": 2, + "x": np.random.random([4, 9, 7, 7, 7]).astype(np.float32), + "seg_num": 2, "shift_ratio": 0.2, - "data_format": "NHWC", + "name": None, + "data_format": "NCHW", } self.program_config = {"feed_list": ["x"]} - self.min_shape = {"x": [2, 9, 7, 7]} - self.max_shape = {"x": [10, 9, 7, 7]} + self.min_shape = {"x": [2, 9, 7, 7, 7]} + self.max_shape = {"x": [10, 9, 7, 7, 7]} def test_trt_result(self): - self.check_marker(expected_result=False) + with self.assertRaises(ValueError) as context: + self.check_marker(expected_result=False) if __name__ == '__main__': From 9591aa91415954ec16e3b7defc844b8b1c403212 Mon Sep 17 00:00:00 2001 From: Junjie Zhang <1356732652@qq.com> Date: Fri, 3 Jan 2025 19:24:48 +0800 Subject: [PATCH 11/17] fix --- test/tensorrt/test_converter_others.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/test/tensorrt/test_converter_others.py b/test/tensorrt/test_converter_others.py index 77e4148c33e222..a48148a4e31202 100644 --- a/test/tensorrt/test_converter_others.py +++ b/test/tensorrt/test_converter_others.py @@ -495,9 +495,17 @@ def test_trt_result_fp32(self): self.check_trt_result() +def wrapper_temporal_shift(x): + return paddle.nn.functional.temporal_shift(x=x, seg_num=2, shift_ratio=0.2) + + +def wrapper_temporal_shift_2(x, seg_num, shift_ratio): + return paddle.nn.functional.temporal_shift(x=paddle.randn([4, 9, 7, 7]), seg_num=seg_num, shift_ratio=shift_ratio) + + class TestTemporalShiftTRTPatternError1(TensorRTBaseTest): def setUp(self): - self.python_api = paddle.nn.functional.temporal_shift + self.python_api = wrapper_temporal_shift self.api_args = { "x": np.random.random([4, 9, 7, 7]).astype(np.float32), } @@ -506,27 +514,23 @@ def setUp(self): self.max_shape = {"x": [10, 9, 7, 7]} def test_trt_result(self): - with self.assertRaises(TypeError) as context: - self.check_marker(expected_result=False) + self.check_marker(expected_result=False) class TestTemporalShiftTRTPatternError2(TensorRTBaseTest): def setUp(self): - self.python_api = paddle.nn.functional.temporal_shift + self.python_api = wrapper_temporal_shift_2 self.api_args = { "x": np.random.random([4, 9, 7, 7, 7]).astype(np.float32), "seg_num": 2, "shift_ratio": 0.2, - "name": None, - "data_format": "NCHW", } self.program_config = {"feed_list": ["x"]} self.min_shape = {"x": [2, 9, 7, 7, 7]} self.max_shape = {"x": [10, 9, 7, 7, 7]} def test_trt_result(self): - with self.assertRaises(ValueError) as context: - self.check_marker(expected_result=False) + self.check_marker(expected_result=False) if __name__ == '__main__': From 227d6d74096df3e0ea68f757322fd263bc529496 Mon Sep 17 00:00:00 2001 From: Junjie Zhang <1356732652@qq.com> Date: Fri, 3 Jan 2025 19:28:07 +0800 Subject: [PATCH 12/17] fix codestyle --- test/tensorrt/test_converter_others.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/tensorrt/test_converter_others.py b/test/tensorrt/test_converter_others.py index a48148a4e31202..cd908626843442 100644 --- a/test/tensorrt/test_converter_others.py +++ b/test/tensorrt/test_converter_others.py @@ -500,7 +500,9 @@ def wrapper_temporal_shift(x): def wrapper_temporal_shift_2(x, seg_num, shift_ratio): - return paddle.nn.functional.temporal_shift(x=paddle.randn([4, 9, 7, 7]), seg_num=seg_num, shift_ratio=shift_ratio) + return paddle.nn.functional.temporal_shift( + x=paddle.randn([4, 9, 7, 7]), seg_num=seg_num, shift_ratio=shift_ratio + ) class TestTemporalShiftTRTPatternError1(TensorRTBaseTest): From ee27104be6294cf293eeab9d03eee7a654fce5cc Mon Sep 17 00:00:00 2001 From: Junjie Zhang <1356732652@qq.com> Date: Tue, 7 Jan 2025 11:46:20 +0800 Subject: [PATCH 13/17] add_test --- test/tensorrt/test_converter_others.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/test/tensorrt/test_converter_others.py b/test/tensorrt/test_converter_others.py index cd908626843442..ae40723d1b6afb 100644 --- a/test/tensorrt/test_converter_others.py +++ b/test/tensorrt/test_converter_others.py @@ -414,6 +414,26 @@ def test_trt_result_fp32(self): self.check_trt_result() +class TestTemporalShiftTRTPatternZeroSlice(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.nn.functional.temporal_shift + self.api_args = { + "x": np.random.random([4, 2, 7, 7]).astype(np.float32), + "seg_num": 2, + "shift_ratio": 0.2, + "data_format": "NCHW", + } + self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [2, 2, 7, 7]} + self.max_shape = {"x": [8, 2, 7, 7]} + + def test_trt_result_fp16(self): + self.check_trt_result(precision_mode="fp16") + + def test_trt_result_fp32(self): + self.check_trt_result() + + class TestTemporalShiftTRTPatternDifferentSegNum(TensorRTBaseTest): def setUp(self): self.python_api = paddle.nn.functional.temporal_shift From 6a90a0b96ab7f238f6d2c01d9fd1a6092447c969 Mon Sep 17 00:00:00 2001 From: Junjie Zhang <1356732652@qq.com> Date: Tue, 7 Jan 2025 17:42:37 +0800 Subject: [PATCH 14/17] add_optshape --- test/tensorrt/test_converter_others.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/tensorrt/test_converter_others.py b/test/tensorrt/test_converter_others.py index ae40723d1b6afb..75b620e5bff2cc 100644 --- a/test/tensorrt/test_converter_others.py +++ b/test/tensorrt/test_converter_others.py @@ -405,6 +405,7 @@ def setUp(self): } self.program_config = {"feed_list": ["x"]} self.min_shape = {"x": [2, 9, 7, 7]} + self.opt_shape = {"x": [2, 9, 7, 7]} self.max_shape = {"x": [8, 9, 7, 7]} def test_trt_result_fp16(self): @@ -425,6 +426,7 @@ def setUp(self): } self.program_config = {"feed_list": ["x"]} self.min_shape = {"x": [2, 2, 7, 7]} + self.opt_shape = {"x": [2, 2, 7, 7]} self.max_shape = {"x": [8, 2, 7, 7]} def test_trt_result_fp16(self): @@ -445,6 +447,7 @@ def setUp(self): } self.program_config = {"feed_list": ["x"]} self.min_shape = {"x": [4, 9, 7, 7]} + self.opt_shape = {"x": [4, 9, 7, 7]} self.max_shape = {"x": [8, 9, 7, 7]} def test_trt_result_fp16(self): @@ -465,6 +468,7 @@ def setUp(self): } self.program_config = {"feed_list": ["x"]} self.min_shape = {"x": [2, 9, 7, 7]} + self.opt_shape = {"x": [2, 9, 7, 7]} self.max_shape = {"x": [8, 9, 7, 7]} def test_trt_result_fp16(self): @@ -486,6 +490,7 @@ def setUp(self): } self.program_config = {"feed_list": ["x"]} self.min_shape = {"x": [2, 9, 7, 7]} + self.opt_shape = {"x": [2, 9, 7, 7]} self.max_shape = {"x": [8, 9, 7, 7]} def test_trt_result_fp16(self): @@ -506,6 +511,7 @@ def setUp(self): } self.program_config = {"feed_list": ["x"]} self.min_shape = {"x": [2, 9, 7, 7]} + self.opt_shape = {"x": [2, 9, 7, 7]} self.max_shape = {"x": [10, 9, 7, 7]} def test_trt_result_fp16(self): @@ -533,6 +539,7 @@ def setUp(self): } self.program_config = {"feed_list": ["x"]} self.min_shape = {"x": [2, 9, 7, 7]} + self.opt_shape = {"x": [2, 9, 7, 7]} self.max_shape = {"x": [10, 9, 7, 7]} def test_trt_result(self): @@ -549,6 +556,7 @@ def setUp(self): } self.program_config = {"feed_list": ["x"]} self.min_shape = {"x": [2, 9, 7, 7, 7]} + self.opt_shape = {"x": [2, 9, 7, 7, 7]} self.max_shape = {"x": [10, 9, 7, 7, 7]} def test_trt_result(self): From 9087b4c526f40e344842458187db15aa5d4cdc7a Mon Sep 17 00:00:00 2001 From: Junjie Zhang <1356732652@qq.com> Date: Wed, 8 Jan 2025 10:42:17 +0800 Subject: [PATCH 15/17] update --- paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc | 5 +++-- test/tensorrt/test_converter_others.py | 4 +--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc index c7593ddf5ff553..17102fb9b47a70 100644 --- a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc +++ b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc @@ -2174,8 +2174,9 @@ class TemporalShiftOpPattern VLOG(3) << "temporal shift need attributes : shift_ratio and seg_num"; return false; } - auto x = op.operand_source(0); - auto x_shape = pir::GetShapeFromValue(x); + pir::Value x = op.operand_source(0); + auto x_type = x.type().dyn_cast(); + auto x_shape = x_type.dims(); if (x_shape.size() != 4) { VLOG(3) << "The input and grid tensors must be shape tensors of rank 4."; return false; diff --git a/test/tensorrt/test_converter_others.py b/test/tensorrt/test_converter_others.py index 75b620e5bff2cc..125e0de9bad0e5 100644 --- a/test/tensorrt/test_converter_others.py +++ b/test/tensorrt/test_converter_others.py @@ -526,9 +526,7 @@ def wrapper_temporal_shift(x): def wrapper_temporal_shift_2(x, seg_num, shift_ratio): - return paddle.nn.functional.temporal_shift( - x=paddle.randn([4, 9, 7, 7]), seg_num=seg_num, shift_ratio=shift_ratio - ) + return paddle.nn.functional.temporal_shift(x=x, seg_num=seg_num, shift_ratio=shift_ratio) class TestTemporalShiftTRTPatternError1(TensorRTBaseTest): From f05046a1e557014ff01164d57a79c63b19d40dfe Mon Sep 17 00:00:00 2001 From: Junjie Zhang <1356732652@qq.com> Date: Wed, 8 Jan 2025 10:50:30 +0800 Subject: [PATCH 16/17] Update test_converter_others.py --- test/tensorrt/test_converter_others.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/test/tensorrt/test_converter_others.py b/test/tensorrt/test_converter_others.py index 125e0de9bad0e5..cdb5b20e5f53c2 100644 --- a/test/tensorrt/test_converter_others.py +++ b/test/tensorrt/test_converter_others.py @@ -526,7 +526,9 @@ def wrapper_temporal_shift(x): def wrapper_temporal_shift_2(x, seg_num, shift_ratio): - return paddle.nn.functional.temporal_shift(x=x, seg_num=seg_num, shift_ratio=shift_ratio) + return paddle.nn.functional.temporal_shift( + x=paddle.randn([4, 9, 7, 7]), seg_num=seg_num, shift_ratio=shift_ratio + ) class TestTemporalShiftTRTPatternError1(TensorRTBaseTest): @@ -548,14 +550,14 @@ class TestTemporalShiftTRTPatternError2(TensorRTBaseTest): def setUp(self): self.python_api = wrapper_temporal_shift_2 self.api_args = { - "x": np.random.random([4, 9, 7, 7, 7]).astype(np.float32), + "x": np.random.random([4, 9, 7]).astype(np.float32), "seg_num": 2, "shift_ratio": 0.2, } self.program_config = {"feed_list": ["x"]} - self.min_shape = {"x": [2, 9, 7, 7, 7]} - self.opt_shape = {"x": [2, 9, 7, 7, 7]} - self.max_shape = {"x": [10, 9, 7, 7, 7]} + self.min_shape = {"x": [2, 9, 7]} + self.opt_shape = {"x": [2, 9, 7]} + self.max_shape = {"x": [10, 9, 7]} def test_trt_result(self): self.check_marker(expected_result=False) From 280e7454dcd37c891d849c19bff0afae5382f0b1 Mon Sep 17 00:00:00 2001 From: Junjie Zhang <1356732652@qq.com> Date: Wed, 8 Jan 2025 11:41:35 +0800 Subject: [PATCH 17/17] delete size --- .../transforms/tensorrt/trt_op_marker_pass.cc | 8 ------- test/tensorrt/test_converter_others.py | 23 ------------------- 2 files changed, 31 deletions(-) diff --git a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc index 2968d24534107b..e36a1d8c0d9649 100644 --- a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc +++ b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc @@ -2196,14 +2196,6 @@ class TemporalShiftOpPattern VLOG(3) << "temporal shift need attributes : shift_ratio and seg_num"; return false; } - pir::Value x = op.operand_source(0); - auto x_type = x.type().dyn_cast(); - auto x_shape = x_type.dims(); - if (x_shape.size() != 4) { - VLOG(3) << "The input and grid tensors must be shape tensors of rank 4."; - return false; - } - op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true)); return true; } diff --git a/test/tensorrt/test_converter_others.py b/test/tensorrt/test_converter_others.py index d0910c0f0a5b98..caa4aa3fa2521f 100644 --- a/test/tensorrt/test_converter_others.py +++ b/test/tensorrt/test_converter_others.py @@ -537,12 +537,6 @@ def wrapper_temporal_shift(x): return paddle.nn.functional.temporal_shift(x=x, seg_num=2, shift_ratio=0.2) -def wrapper_temporal_shift_2(x, seg_num, shift_ratio): - return paddle.nn.functional.temporal_shift( - x=paddle.randn([4, 9, 7, 7]), seg_num=seg_num, shift_ratio=shift_ratio - ) - - class TestTemporalShiftTRTPatternError1(TensorRTBaseTest): def setUp(self): self.python_api = wrapper_temporal_shift @@ -558,23 +552,6 @@ def test_trt_result(self): self.check_marker(expected_result=False) -class TestTemporalShiftTRTPatternError2(TensorRTBaseTest): - def setUp(self): - self.python_api = wrapper_temporal_shift_2 - self.api_args = { - "x": np.random.random([4, 9, 7]).astype(np.float32), - "seg_num": 2, - "shift_ratio": 0.2, - } - self.program_config = {"feed_list": ["x"]} - self.min_shape = {"x": [2, 9, 7]} - self.opt_shape = {"x": [2, 9, 7]} - self.max_shape = {"x": [10, 9, 7]} - - def test_trt_result(self): - self.check_marker(expected_result=False) - - def affine_channel(x, scale_shape, bias_shape, layout): scale = paddle.static.create_parameter( shape=scale_shape, dtype='float32', name="scale"