Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【SCU】【Paddle TensorRT No.57】Add pd_op.temporal_shift converter #69848

Open
wants to merge 26 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

temporalshift的kernel已经不支持x.size()!=4的输入,并且加上这个条件过不了覆盖率单测,因此去掉这个条件

Original file line number Diff line number Diff line change
Expand Up @@ -2180,6 +2180,27 @@ class OneHotOpPattern
}
};

class TemporalShiftOpPattern
: public pir::OpRewritePattern<paddle::dialect::TemporalShiftOp> {
public:
using pir::OpRewritePattern<
paddle::dialect::TemporalShiftOp>::OpRewritePattern;

bool MatchAndRewrite(paddle::dialect::TemporalShiftOp op,
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op.attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
return false;
}
if (!op->HasAttribute("shift_ratio") || !op->HasAttribute("seg_num")) {
VLOG(3) << "temporal shift need attributes : shift_ratio and seg_num";
return false;
}
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;
}
};

class InstanceNormOpPattern
: public pir::OpRewritePattern<paddle::dialect::InstanceNormOp> {
public:
Expand Down Expand Up @@ -2386,6 +2407,7 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
ps.Add(std::make_unique<TanhOpPattern>(context));
ps.Add(std::make_unique<CeluOpPattern>(context));
ps.Add(std::make_unique<OneHotOpPattern>(context));
ps.Add(std::make_unique<TemporalShiftOpPattern>(context));
ps.Add(std::make_unique<InstanceNormOpPattern>(context));
ps.Add(std::make_unique<AffineChannelOpPattern>(context));
return ps;
Expand Down
107 changes: 107 additions & 0 deletions python/paddle/tensorrt/impls/others.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
trt_concat,
trt_prod,
trt_shape,
trt_sub,
trt_sum,
)
from paddle.tensorrt.register import converter_registry
Expand Down Expand Up @@ -303,6 +304,112 @@ def share_data_converter(network, paddle_op, inputs):
return identity_layer.get_output(0)


@converter_registry.register("pd_op.temporal_shift", trt_version="8.x")
def temporal_shift_converter(network, paddle_op, inputs):
input_tensor = inputs[0]
shift_ratio = paddle_op.attrs()["shift_ratio"]
T = paddle_op.attrs()["seg_num"]
data_format = paddle_op.attrs().get("data_format", "NCHW")

if data_format == "NHWC":
# Transpose input to [N, C, H, W]
transpose_layer = network.add_shuffle(input_tensor)
transpose_layer.first_transpose = trt.Permutation([0, 3, 1, 2])
input_tensor = transpose_layer.get_output(0)

input_dims = input_tensor.shape
C, H, W = input_dims[1], input_dims[2], input_dims[3]

# Reshape input to [N, T, C, H, W]
reshape_layer = network.add_shuffle(input_tensor)
reshape_layer.reshape_dims = trt.Dims([-1, T, C, H, W])
input_tensor = reshape_layer.get_output(0)

# Pad input to [N, T + 2, C, H, W]
pre_pad = add_1D_constant_layer(network, [0, 1, 0, 0, 0])
post_pad = add_1D_constant_layer(network, [0, 1, 0, 0, 0])
dims = 5
zeros = add_1D_constant_layer(network, [0] * dims)
start = trt_sub(network, zeros, pre_pad)
total_padding = trt_sum(network, pre_pad, post_pad)
input_shape = trt_shape(network, input_tensor)
size = trt_sum(network, input_shape, total_padding)
stride = [1] * dims
dummy = stride

slice_layer = network.add_slice(input_tensor, dummy, dummy, stride)
slice_layer.set_input(1, start)
slice_layer.set_input(2, size)

trt_version = trt.__version__.split('.')
if int(trt_version[0]) > 8 or (
int(trt_version[0]) == 8 and int(trt_version[1]) >= 5
):
slice_layer.mode = trt.SampleMode.FILL
else:
slice_layer.mode = trt.SliceMode.FILL

slice_c = int(C * shift_ratio)
slice_c2 = int(C * shift_ratio * 2)

slice_start1 = zeros
slice_start2 = add_1D_constant_layer(network, [0, 2, slice_c, 0, 0])
slice_start3 = add_1D_constant_layer(network, [0, 1, slice_c2, 0, 0])

slice_size_base = trt_shape(network, input_tensor)
sub_size1 = add_1D_constant_layer(network, [0, 0, C - slice_c, 0, 0])
sub_size2 = add_1D_constant_layer(
network, [0, 0, C + slice_c - slice_c2, 0, 0]
)
sub_size3 = add_1D_constant_layer(network, [0, 0, slice_c2, 0, 0])

slice_size1 = trt_sub(network, slice_size_base, sub_size1)
slice_size2 = trt_sub(network, slice_size_base, sub_size2)
slice_size3 = trt_sub(network, slice_size_base, sub_size3)

slice1_layer = network.add_slice(
slice_layer.get_output(0), start=dummy, shape=dummy, stride=stride
)
slice1_layer.set_input(1, slice_start1)
slice1_layer.set_input(2, slice_size1)
slice2_layer = network.add_slice(
slice_layer.get_output(0), start=dummy, shape=dummy, stride=stride
)
slice2_layer.set_input(1, slice_start2)
slice2_layer.set_input(2, slice_size2)
slice3_layer = network.add_slice(
slice_layer.get_output(0), start=dummy, shape=dummy, stride=stride
)
slice3_layer.set_input(1, slice_start3)
slice3_layer.set_input(2, slice_size3)

concat_inputs = [slice2_layer.get_output(0), slice3_layer.get_output(0)]
if slice_c == 0:
concat_layer = network.add_concatenation(concat_inputs)
concat_layer.axis = 2
else:
concat_inputs = [
slice1_layer.get_output(0),
slice2_layer.get_output(0),
slice3_layer.get_output(0),
]
concat_layer = network.add_concatenation(concat_inputs)
concat_layer.axis = 2

# Reshape output to [N*T,C,H,W]
reshape_layer3 = network.add_shuffle(concat_layer.get_output(0))
reshape_layer3.reshape_dims = trt.Dims([-1, C, H, W])

if data_format == "NHWC":
transpose_layer2 = network.add_shuffle(reshape_layer3.get_output(0))
transpose_layer2.first_transpose = trt.Permutation([0, 2, 3, 1])
output_tensor = transpose_layer2.get_output(0)
else:
output_tensor = reshape_layer3.get_output(0)

return output_tensor


@converter_registry.register("pd_op.affine_channel", trt_version="8.x")
def affine_channel_converter(network, paddle_op, inputs):
x, scale_weights, bias_weights = inputs
Expand Down
146 changes: 146 additions & 0 deletions test/tensorrt/test_converter_others.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,152 @@ def test_trt_result(self):
self.check_trt_result()


class TestTemporalShiftTRTPatternBasic(TensorRTBaseTest):
def setUp(self):
self.python_api = paddle.nn.functional.temporal_shift
self.api_args = {
"x": np.random.random([4, 9, 7, 7]).astype(np.float32),
"seg_num": 2,
"shift_ratio": 0.2,
"data_format": "NCHW",
}
self.program_config = {"feed_list": ["x"]}
self.min_shape = {"x": [2, 9, 7, 7]}
self.opt_shape = {"x": [2, 9, 7, 7]}
self.max_shape = {"x": [8, 9, 7, 7]}

def test_trt_result_fp16(self):
self.check_trt_result(precision_mode="fp16")

def test_trt_result_fp32(self):
self.check_trt_result()


class TestTemporalShiftTRTPatternZeroSlice(TensorRTBaseTest):
def setUp(self):
self.python_api = paddle.nn.functional.temporal_shift
self.api_args = {
"x": np.random.random([4, 2, 7, 7]).astype(np.float32),
"seg_num": 2,
"shift_ratio": 0.2,
"data_format": "NCHW",
}
self.program_config = {"feed_list": ["x"]}
self.min_shape = {"x": [2, 2, 7, 7]}
self.opt_shape = {"x": [2, 2, 7, 7]}
self.max_shape = {"x": [8, 2, 7, 7]}

def test_trt_result_fp16(self):
self.check_trt_result(precision_mode="fp16")

def test_trt_result_fp32(self):
self.check_trt_result()


class TestTemporalShiftTRTPatternDifferentSegNum(TensorRTBaseTest):
def setUp(self):
self.python_api = paddle.nn.functional.temporal_shift
self.api_args = {
"x": np.random.random([4, 9, 7, 7]).astype(np.float32),
"seg_num": 4,
"shift_ratio": 0.2,
"data_format": "NCHW",
}
self.program_config = {"feed_list": ["x"]}
self.min_shape = {"x": [4, 9, 7, 7]}
self.opt_shape = {"x": [4, 9, 7, 7]}
self.max_shape = {"x": [8, 9, 7, 7]}

def test_trt_result_fp16(self):
self.check_trt_result(precision_mode="fp16")

def test_trt_result_fp32(self):
self.check_trt_result()


class TestTemporalShiftTRTPatternDifferentShiftRatio(TensorRTBaseTest):
def setUp(self):
self.python_api = paddle.nn.functional.temporal_shift
self.api_args = {
"x": np.random.random([4, 9, 7, 7]).astype(np.float32),
"seg_num": 2,
"shift_ratio": 0.4,
"data_format": "NCHW",
}
self.program_config = {"feed_list": ["x"]}
self.min_shape = {"x": [2, 9, 7, 7]}
self.opt_shape = {"x": [2, 9, 7, 7]}
self.max_shape = {"x": [8, 9, 7, 7]}

def test_trt_result_fp16(self):
self.check_trt_result(precision_mode="fp16")

def test_trt_result_fp32(self):
self.check_trt_result()


class TestTemporalShiftTRTPatternDifferentDataFormat(TensorRTBaseTest):
def setUp(self):
self.python_api = paddle.nn.functional.temporal_shift
self.api_args = {
"x": np.random.random([4, 9, 7, 7]).astype(np.float32),
"seg_num": 2,
"shift_ratio": 0.2,
"name": None,
"data_format": "NHWC",
}
self.program_config = {"feed_list": ["x"]}
self.min_shape = {"x": [2, 9, 7, 7]}
self.opt_shape = {"x": [2, 9, 7, 7]}
self.max_shape = {"x": [8, 9, 7, 7]}

def test_trt_result_fp16(self):
self.check_trt_result(precision_mode="fp16")

def test_trt_result_fp32(self):
self.check_trt_result()


class TestTemporalShiftTRTPatternMinMaxShape(TensorRTBaseTest):
def setUp(self):
self.python_api = paddle.nn.functional.temporal_shift
self.api_args = {
"x": np.random.random([4, 9, 7, 7]).astype(np.float32),
"seg_num": 2,
"shift_ratio": 0.2,
"data_format": "NCHW",
}
self.program_config = {"feed_list": ["x"]}
self.min_shape = {"x": [2, 9, 7, 7]}
self.opt_shape = {"x": [2, 9, 7, 7]}
self.max_shape = {"x": [10, 9, 7, 7]}

def test_trt_result_fp16(self):
self.check_trt_result(precision_mode="fp16")

def test_trt_result_fp32(self):
self.check_trt_result()


def wrapper_temporal_shift(x):
return paddle.nn.functional.temporal_shift(x=x, seg_num=2, shift_ratio=0.2)


class TestTemporalShiftTRTPatternError1(TensorRTBaseTest):
def setUp(self):
self.python_api = wrapper_temporal_shift
self.api_args = {
"x": np.random.random([4, 9, 7, 7]).astype(np.float32),
}
self.program_config = {"feed_list": ["x"]}
self.min_shape = {"x": [2, 9, 7, 7]}
self.opt_shape = {"x": [2, 9, 7, 7]}
self.max_shape = {"x": [10, 9, 7, 7]}

def test_trt_result(self):
self.check_marker(expected_result=False)


def affine_channel(x, scale_shape, bias_shape, layout):
scale = paddle.static.create_parameter(
shape=scale_shape, dtype='float32', name="scale"
Expand Down
Loading