Skip to content

Commit

Permalink
inductor(CPU): support dynamic shape for onednn fusion path (pytorch#…
Browse files Browse the repository at this point in the history
  • Loading branch information
XiaobingSuper authored and pytorchmergebot committed Apr 7, 2023
1 parent 77d9742 commit d643a00
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 75 deletions.
14 changes: 4 additions & 10 deletions aten/src/ATen/native/mkldnn/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,23 +225,17 @@ Tensor mkldnn_linear_pointwise(
}

if (mkldnn_bias.has_value()) {
ideep::inner_product_forward::compute(
ideep::inner_product_forward::compute</*reorder_src=*/false, /*reorder_weight=*/false>(
mkldnn_input,
w,
mkldnn_bias.value(),
mkldnn_output,
ideep::scale_t(),
ideep::scale_t(),
ideep::scale_t(),
op_attr);
} else {
ideep::inner_product_forward::compute(
ideep::inner_product_forward::compute</*reorder_src=*/false, /*reorder_weight=*/false>(
mkldnn_input,
w,
mkldnn_output,
ideep::scale_t(),
ideep::scale_t(),
ideep::scale_t(),
op_attr);
}

Expand Down Expand Up @@ -308,15 +302,15 @@ Tensor mkldnn_linear_pointwise_binary(
auto op_attr = ideep::attr_t::fuse_binary(it_binary->second, other_desc);

if (mkldnn_bias.has_value()) {
ideep::inner_product_forward::compute_binary(
ideep::inner_product_forward::compute_binary</*reorder_src=*/false, /*reorder_weight=*/false>(
mkldnn_input,
mkldnn_other,
w,
mkldnn_bias.value(),
mkldnn_output,
op_attr);
} else {
ideep::inner_product_forward::compute_binary(
ideep::inner_product_forward::compute_binary</*reorder_src=*/false, /*reorder_weight=*/false>(
mkldnn_input, mkldnn_other, w, mkldnn_output, op_attr);
}

Expand Down
1 change: 1 addition & 0 deletions test/inductor/test_torchinductor_codegen_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def run(*ex, **kwargs):
"test_views3_dynamic_shapes": TestFailure(("cpu",)),
"test_views4_dynamic_shapes": TestFailure(("cpu",)),
"test_zeros_dynamic_shapes": TestFailure(("cpu",)),
"test_upsample_cat_conv_dynamic_shapes": TestFailure(("cpu",), is_skip=True),
#
# Failed to find for loop/triton kernel:
#
Expand Down
2 changes: 2 additions & 0 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -3185,6 +3185,8 @@ def _original_deconv_weight_size(
dilation = tuple(dilation_)
assert isinstance(groups, int)
output_padding = tuple(output_padding_) if output_padding_ else (0, 0)
x.realize()
weight.realize()
with V.graph.fake_mode:
x_fake = ir_node_to_tensor(x, guard_shape=True)
weight_fake = ir_node_to_tensor(weight, guard_shape=True)
Expand Down
154 changes: 97 additions & 57 deletions torch/_inductor/mkldnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Optional

import torch
import torch._dynamo.config as dynamo_config
import torch.nn as nn
import torch.nn.functional as F

Expand All @@ -13,7 +14,6 @@
matches_module_pattern,
replace_node_module,
)
from torch.fx.experimental.symbolic_shapes import guard_int
from torch.fx.passes.shape_prop import ShapeProp
from torch.nn.modules.utils import _pair
from . import config
Expand Down Expand Up @@ -98,7 +98,7 @@ def __init__(
self,
conv: nn.Module,
unary: Optional[nn.Module],
input_size: list,
input_size: Optional[list],
):
super().__init__(
conv.in_channels,
Expand Down Expand Up @@ -131,8 +131,10 @@ def _update_module_params(self, conv, unary, input_size):
self.stride,
self.dilation,
self.groups,
tuple(guard_int(x) for x in input_size),
),
input_size,
)
if input_size is not None
else self.weight.to_mkldnn(),
requires_grad=self.weight.requires_grad,
)

Expand Down Expand Up @@ -174,7 +176,7 @@ def __init__(
self,
conv: nn.Module,
binary_op_name: str,
input_size: list,
input_size: Optional[list],
):
super().__init__(
conv.in_channels,
Expand Down Expand Up @@ -205,8 +207,10 @@ def _update_module_params(self, conv, binary_op_name, input_size):
self.stride,
self.dilation,
self.groups,
tuple(guard_int(x) for x in input_size),
),
input_size,
)
if input_size is not None
else self.weight.to_mkldnn(),
requires_grad=self.weight.requires_grad,
)

Expand Down Expand Up @@ -283,7 +287,9 @@ def forward(self, input):


class LinearUnary(nn.Linear):
def __init__(self, linear: nn.Module, unary: Optional[nn.Module], input_size: list):
def __init__(
self, linear: nn.Module, unary: Optional[nn.Module], input_size: Optional[list]
):
super().__init__(
linear.in_features,
linear.out_features,
Expand All @@ -302,10 +308,15 @@ def _update_module_params(self, linear, unary, input_size):
self.attr, self.scalars, self.algorithm = unary_modules_map[
unary.__class__
](unary)
self.batch_size = reduce(lambda x, y: x * y, input_size[:-1])
self.batch_size = (
reduce(lambda x, y: x * y, input_size[:-1])
if input_size is not None
else None
)
self.packed_weight = torch.nn.Parameter(
torch.ops.mkldnn._reorder_linear_weight(
self.weight.to_mkldnn(), self.batch_size
self.weight.to_mkldnn(),
self.batch_size,
),
requires_grad=self.weight.requires_grad,
)
Expand All @@ -327,7 +338,7 @@ def __init__(
self,
linear: nn.Module,
binary_op_name: str,
input_size: list,
input_size: Optional[list],
):
super().__init__(
linear.in_features,
Expand All @@ -341,7 +352,11 @@ def __init__(
def _update_module_params(self, linear, binary_op_name, input_size):
self.__dict__ = copy.deepcopy(linear.__dict__)
self.attr = binary_op_name
self.batch_size = reduce(lambda x, y: x * y, input_size[:-1])
self.batch_size = (
reduce(lambda x, y: x * y, input_size[:-1])
if input_size is not None
else None
)
self.packed_weight = torch.nn.Parameter(
torch.ops.mkldnn._reorder_linear_weight(
self.weight.to_mkldnn(), self.batch_size
Expand All @@ -361,7 +376,7 @@ def __init__(
self,
conv_transpose: nn.Module,
unary: Optional[nn.Module],
input_size: list,
input_size: Optional[list],
):
super().__init__(
conv_transpose.in_channels,
Expand All @@ -384,14 +399,18 @@ def _update_module_params(self, conv_transpose, unary, input_size):
self.attr, self.scalars, self.algorithm = (
unary_modules_map[unary.__class__](unary) if unary else ("none", [], "")
)
packed_weight = torch.ops.mkldnn._reorder_convolution_transpose_weight(
self.weight.to_mkldnn(),
self.padding,
self.output_padding,
self.stride,
self.dilation,
self.groups,
input_size,
packed_weight = (
torch.ops.mkldnn._reorder_convolution_transpose_weight(
self.weight.to_mkldnn(),
self.padding,
self.output_padding,
self.stride,
self.dilation,
self.groups,
input_size,
)
if input_size is not None
else self.weight.transpose(0, 1).to_mkldnn()
)
self.weight = torch.nn.Parameter(
packed_weight,
Expand Down Expand Up @@ -433,7 +452,7 @@ def forward(self, input):
return self._conv_transpose_forward(input, self.weight, self.bias)


def packed_conv_eval(conv: nn.Module, input_size: list):
def packed_conv_eval(conv: nn.Module, input_size: Optional[list]):
assert not (conv.training), "Fusion only for eval!"
return ConvUnary2d(
conv,
Expand All @@ -442,7 +461,7 @@ def packed_conv_eval(conv: nn.Module, input_size: list):
)


def packed_conv_transpose_eval(conv_transpose: nn.Module, input_size: list):
def packed_conv_transpose_eval(conv_transpose: nn.Module, input_size: Optional[list]):
assert not (conv_transpose.training), "Fusion only for eval!"
return ConvTransposeUnary2d(
conv_transpose,
Expand All @@ -451,7 +470,9 @@ def packed_conv_transpose_eval(conv_transpose: nn.Module, input_size: list):
)


def fused_conv_unary_eval(conv: nn.Module, unary: nn.Module, input_size: list):
def fused_conv_unary_eval(
conv: nn.Module, unary: nn.Module, input_size: Optional[list]
):
assert not (conv.training), "Fusion only for eval!"
return ConvUnary2d(
conv,
Expand All @@ -460,7 +481,9 @@ def fused_conv_unary_eval(conv: nn.Module, unary: nn.Module, input_size: list):
)


def fused_conv_binary_eval(conv: nn.Module, binary_op_name: str, input_size: list):
def fused_conv_binary_eval(
conv: nn.Module, binary_op_name: str, input_size: Optional[list]
):
assert not (conv.training), "Fusion only for eval!"
return ConvBinary2d(
conv,
Expand All @@ -470,34 +493,36 @@ def fused_conv_binary_eval(conv: nn.Module, binary_op_name: str, input_size: lis


def fused_conv_binary_unary_eval(
conv_binary: nn.Module, unary: nn.Module, input_size: list
conv_binary: nn.Module, unary: nn.Module, input_size: Optional[list]
):
assert not (conv_binary.training), "Fusion only for eval!"
# reuse origin conv module, and just update its' unary attr.
conv_binary._update_unary_params(unary)
return conv_binary


def packed_linear_eval(linear: nn.Module, input_size: list):
def packed_linear_eval(linear: nn.Module, input_size: Optional[list]):
assert not (linear.training), "Fusion only for eval!"
if linear.weight.dtype == torch.bfloat16:
return LinearUnary(linear, None, input_size)
return PackedLinear(linear, input_size)


def fused_linear_unary_eval(linear: nn.Module, unary: nn.Module, input_size: list):
def fused_linear_unary_eval(
linear: nn.Module, unary: nn.Module, input_size: Optional[list]
):
assert not (linear.training), "Fusion only for eval!"
return LinearUnary(linear, unary, input_size)


def fused_linear_binary_eval(linear: nn.Module, attr: str, input_size: list):
def fused_linear_binary_eval(linear: nn.Module, attr: str, input_size: Optional[list]):
assert not (linear.training), "Fusion only for eval!"
linear_binary = LinearBinary(linear, attr, input_size)
return linear_binary


def fused_conv_transpose_unary_eval(
conv_transpose: nn.Module, unary: nn.Module, input_size: list
conv_transpose: nn.Module, unary: nn.Module, input_size: Optional[list]
):
assert not (conv_transpose.training), "Fusion only for eval!"
return ConvTransposeUnary2d(
Expand All @@ -521,16 +546,16 @@ def mkldnn_fuse_fx(gm: torch.fx.GraphModule, example_inputs):
return gm
if not is_cpu:
return gm
# For binary fusion, we need to check inputs info to make sure
# the binary inputs have same tensor info(device, dtype, and layout).

fake_mode = detect_fake_mode(example_inputs)
ShapeProp(gm, fake_mode=fake_mode).propagate(*example_inputs)
gm = fuse_unary(gm)
gm = fuse_binary(gm)
# why re-run fuse_unary? we want to enable conv+binary+unary fusion,
# such as conv+add+relu for vision model.
if not dynamo_config.dynamic_shapes:
ShapeProp(gm, fake_mode=fake_mode).propagate(*example_inputs)
gm = fuse_unary(gm)
if not dynamo_config.dynamic_shapes:
gm = fuse_binary(gm)
# why re-run fuse_unary? we want to enable conv+binary+unary fusion,
# such as conv+add+relu for vision model.
gm = fuse_unary(gm)
# if config.cpp.weight_prepack and not dynamo_config.dynamic_shapes:
if config.cpp.weight_prepack:
gm = pack_module(gm)
return gm
Expand Down Expand Up @@ -625,11 +650,17 @@ def fuse_unary(gm: torch.fx.GraphModule):
):
continue
# TODO: remove this when group depthwise ConvTranspose is supported
if is_group_depthwise_conv_transpose(computation_node):
if is_group_depthwise_conv_transpose(computation_node) or (
type(computation_node) in [nn.ConvTranspose2d]
and dynamo_config.dynamic_shapes
):
continue
computation_node_input_size = (
node.args[0].args[0].meta.get("tensor_meta").shape
)
if dynamo_config.dynamic_shapes:
computation_node_input_size = None
else:
computation_node_input_size = (
node.args[0].args[0].meta.get("tensor_meta").shape
)
fused_module = fuse_func(
computation_node, unary_node, computation_node_input_size
)
Expand All @@ -648,6 +679,8 @@ def replace_and_fuse_for_binary(
computation_node_input_size = (
node.args[index_node].args[0].meta.get("tensor_meta").shape
)
if dynamo_config.dynamic_shapes:
computation_node_input_size = None
fused_module = fuse_func(computation_node, attr, computation_node_input_size)
replace_node_module(node.args[index_node], modules, fused_module)
node.args[index_node].args = node.args[index_node].args + (
Expand Down Expand Up @@ -761,26 +794,33 @@ def pack_module(gm: torch.fx.GraphModule):
if type(cur_module) in computation_op_packed_map:
if cur_module.training:
continue
computation_node_input_meta = node.args[0].meta.get("tensor_meta")
# for fp32 linear, only packed when has mkl
if (
computation_node_input_meta.dtype == torch.float32
and type(cur_module) in [torch.nn.Linear]
and not torch._C.has_mkl
):
continue
computation_node_input_size = computation_node_input_meta.shape
if (
type(cur_module) in [torch.nn.Linear]
and len(computation_node_input_size) < 2
):
continue
if dynamo_config.dynamic_shapes:
computation_node_input_meta = None
computation_node_input_size = None
if (
type(cur_module) in [torch.nn.Linear]
and cur_module.weight.dtype == torch.float32
):
continue
else:
computation_node_input_meta = node.args[0].meta.get("tensor_meta")
computation_node_input_size = computation_node_input_meta.shape
if type(cur_module) in [torch.nn.Linear]:
# for fp32 linear, only packed when has mkl.
if (
cur_module.weight.dtype == torch.float32
and (not torch._C.has_mkl)
) or len(computation_node_input_size) < 2:
continue
if type(cur_module) in [nn.Conv2d] and isinstance(
cur_module.padding, str
):
continue
# TODO: remove this when group depthwise ConvTranspose is supported
if is_group_depthwise_conv_transpose(cur_module):
if type(cur_module) in [nn.ConvTranspose2d] and (
is_group_depthwise_conv_transpose(cur_module)
or dynamo_config.dynamic_shapes
):
continue
new_module = computation_op_packed_map[type(cur_module)](
cur_module, computation_node_input_size
Expand Down
Loading

0 comments on commit d643a00

Please sign in to comment.