Skip to content

Commit

Permalink
[MPS] Fixes GELU, LeakyRELU and MISH on non-contiguous tensors (pytor…
Browse files Browse the repository at this point in the history
…ch#123049)

Fixes GELU, LeakyRELU and MISH activation functions on non-contiguous tensors (for instance, when a transpose operation was applied on the tensors prior to the MPS operator), forward and backward passes.

I also extended tests on the 3 activation functions to check: full-precision and half-precision, contiguous and non-contiguous, and several dims of tensors: scalars, 1D, empty, 2D, > 3D.

I had issues with Mish and GELU activations when asserting the gradients vs. CPU with sum() on some cases, so I reverted to the previous setup by setting a gradient parameter on .backwards().
This PR also fixes an issue with LeakyRELU on empty tensors.

Fixes pytorch#98212 huggingface/transformers#22468 huggingface/transformers#19353
Pull Request resolved: pytorch#123049
Approved by: https://github.com/kulinseth
  • Loading branch information
jtang98 authored and pytorchmergebot committed Apr 21, 2024
1 parent 98f3e02 commit a6a3f2e
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 25 deletions.
61 changes: 52 additions & 9 deletions aten/src/ATen/native/mps/operations/Activation.mm
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,17 @@ Tensor relu_mps(const Tensor& self) {
using CachedGraph = MPSUnaryCachedGraph;
TORCH_CHECK(output.is_mps());

if (self.numel() == 0) {
return;
}

MPSStream* stream = getCurrentMPSStream();

bool executeGatherOp =
!(self.is_contiguous(MemoryFormat::Contiguous) || self.is_contiguous(MemoryFormat::ChannelsLast) ||
self.is_contiguous(MemoryFormat::ChannelsLast3d));
Tensor output_ = at::empty_like(self, executeGatherOp ? MemoryFormat::Contiguous : MemoryFormat::Preserve);

@autoreleasepool {
string key = "leaky_relu" + getTensorsStringKey({self}) + ":" + to_string(negative_slope.to<double>());
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
Expand All @@ -152,13 +161,17 @@ Tensor relu_mps(const Tensor& self) {
newCachedGraph->outputTensor_ = outputTensor;
});

Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, nil, executeGatherOp);
Placeholder outputPlaceholder =
Placeholder(cachedGraph->outputTensor_, executeGatherOp ? output_ : output, nil, false);

// Create dictionary of inputs and outputs
auto feeds = dictionaryFromPlaceholders(selfPlaceholder);
runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
}
if (executeGatherOp) {
output.copy_(output_);
}
}

TORCH_IMPL_FUNC(leaky_relu_backward_out_mps)
Expand All @@ -171,8 +184,14 @@ Tensor relu_mps(const Tensor& self) {
using CachedGraph = MPSUnaryGradCachedGraph;
TORCH_CHECK(output.is_mps());

if (self.numel() == 0) {
return;
}

MPSStream* stream = getCurrentMPSStream();

Tensor output_ = at::empty_like(self, self.suggest_memory_format());

@autoreleasepool {
string key =
"leaky_relu_backward" + getTensorsStringKey({self, grad_output}) + ":" + to_string(negative_slope.to<double>());
Expand Down Expand Up @@ -202,12 +221,13 @@ Tensor relu_mps(const Tensor& self) {

Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output);
Placeholder outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, output);
Placeholder outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, output_);

// Create dictionary of inputs and outputs
auto feeds = dictionaryFromPlaceholders(gradOutputPlaceholder, selfPlaceholder);
runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
}
output.copy_(output_);
}

TORCH_IMPL_FUNC(log_softmax_mps_out)
Expand Down Expand Up @@ -656,6 +676,11 @@ Tensor log_sigmoid_backward_mps(const Tensor& grad_output, const Tensor& self, c
auto approximate_type = get_gelutype_enum(approximate);
MPSStream* stream = getCurrentMPSStream();

bool executeGatherOp =
!(self.is_contiguous(MemoryFormat::Contiguous) || self.is_contiguous(MemoryFormat::ChannelsLast) ||
self.is_contiguous(MemoryFormat::ChannelsLast3d));
Tensor output_ = at::empty_like(self, executeGatherOp ? MemoryFormat::Contiguous : MemoryFormat::Preserve);

@autoreleasepool {
const auto key = "gelu_out_mps" + getTensorsStringKey({self}) + ":" + gelutype_to_string(approximate_type);
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
Expand All @@ -672,12 +697,17 @@ Tensor log_sigmoid_backward_mps(const Tensor& grad_output, const Tensor& self, c
newCachedGraph->outputTensor_ = outputTensor;
});

Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, nil, executeGatherOp);
Placeholder outputPlaceholder =
Placeholder(cachedGraph->outputTensor_, executeGatherOp ? output_ : output, nil, false);

auto feeds = dictionaryFromPlaceholders(selfPlaceholder);
runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
}

if (executeGatherOp) {
output.copy_(output_);
}
}

TORCH_IMPL_FUNC(gelu_backward_out_mps)
Expand All @@ -686,8 +716,11 @@ Tensor log_sigmoid_backward_mps(const Tensor& grad_output, const Tensor& self, c
using CachedGraph = MPSUnaryGradCachedGraph;

// Empty output
if (grad_input.numel() == 0)
if (self.numel() == 0) {
return;
}

Tensor grad_input_ = at::empty_like(self, self.suggest_memory_format());

auto approximate_type = get_gelutype_enum(approximate);
MPSStream* stream = getCurrentMPSStream();
Expand Down Expand Up @@ -761,11 +794,12 @@ Tensor log_sigmoid_backward_mps(const Tensor& grad_output, const Tensor& self, c

Placeholder gradPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad);
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
Placeholder outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, grad_input);
Placeholder outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, grad_input_);

auto feeds = dictionaryFromPlaceholders(gradPlaceholder, selfPlaceholder);
runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
}
grad_input.copy_(grad_input_);
}

static void elu_variants_out_mps(const Tensor& self,
Expand Down Expand Up @@ -1241,6 +1275,11 @@ Tensor glu_backward_mps(const Tensor& grad_output, const Tensor& self, const int

MPSStream* stream = getCurrentMPSStream();

bool executeGatherOp =
!(self.is_contiguous(MemoryFormat::Contiguous) || self.is_contiguous(MemoryFormat::ChannelsLast) ||
self.is_contiguous(MemoryFormat::ChannelsLast3d));
Tensor result_ = at::empty_like(self, executeGatherOp ? MemoryFormat::Contiguous : MemoryFormat::Preserve);

@autoreleasepool {
string key = "mish_out_mps:" + getTensorsStringKey({self});

Expand All @@ -1257,12 +1296,16 @@ Tensor glu_backward_mps(const Tensor& grad_output, const Tensor& self, const int
newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->outputTensor_ = outputTensor;
});
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result);
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, nil, executeGatherOp);
Placeholder outputPlaceholder =
Placeholder(cachedGraph->outputTensor_, executeGatherOp ? result_ : result, nil, false);

auto feeds = dictionaryFromPlaceholders(selfPlaceholder);
runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
}
if (executeGatherOp) {
result.copy_(result_);
}
}

Tensor mish_backward_mps(const Tensor& grad_output, const Tensor& self) {
Expand Down
95 changes: 79 additions & 16 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -1470,29 +1470,44 @@ def testNpLeakyRelu(self):
0.9]]),
negative_slope=0.1))

def _testLeakyRelu(self, np_features, negative_slope, device):
cpu_x = torch.from_numpy(np_features).requires_grad_()
mps_x = torch.from_numpy(np_features).to('mps').requires_grad_()
def _testLeakyRelu(self, shape, dtype, negative_slope, contiguous):
cpu_x = torch.randn(shape, device='cpu', dtype=dtype)
mps_x = cpu_x.detach().clone().to('mps')

if not contiguous and not (0 in shape or len(shape) < 2):
# Tranposing will make the tensor non-contiguous
cpu_x = cpu_x.transpose(0, 1)
mps_x = mps_x.transpose(0, 1)
assert not mps_x.is_contiguous()

cpu_x.requires_grad_()
mps_x.requires_grad_()

relu_op = torch.nn.LeakyReLU(negative_slope)

cpu_leaky_relu = relu_op(cpu_x)
mps_leaky_relu = relu_op(mps_x)
torch.testing.assert_close(cpu_leaky_relu, mps_leaky_relu.to('cpu'))

# test backward pass

cpu_grad = torch.ones_like(cpu_leaky_relu)
mps_grad = cpu_grad.to('mps')
cpu_leaky_relu.backward(gradient=cpu_grad)

mps_leaky_relu.backward(gradient=mps_grad)
torch.testing.assert_close(cpu_x.grad, mps_x.grad.to('cpu'))
cpu_leaky_relu.backward(gradient=cpu_grad)

def testNumbersCPU(self):
for t in [np.float32]:
self._testLeakyRelu(
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
negative_slope=0.2,
device="cpu")
assert cpu_x.grad is not None # Check that the grad is well-populated
self.assertEqual(cpu_x.grad, mps_x.grad)

def testNumbersCPU(self):
for t in [torch.float, torch.half]:
for shape in [[], (0,), (0, 3), (4,), (4, 3), (5, 4, 3)]:
for contiguous in [True, False]:
self._testLeakyRelu(shape,
dtype=t,
negative_slope=0.2,
contiguous=contiguous)

class TestAvgPool(TestCaseMPS):
def _sum_pool2d(self, x, kernel_size):
Expand Down Expand Up @@ -6631,9 +6646,18 @@ def helper(input_shape, out_shape, return_indices, dtype, channels_last=False):
helper((2, 16, 16), (4, 4), return_indices, dtype)

def test_gelu_simple(self):
def helper(shape, dtype=torch.float):
cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
def helper(shape, dtype=torch.float, contiguous=True):
cpu_x = torch.randn(shape, device='cpu', dtype=dtype)
x = cpu_x.detach().clone().to('mps')

if not contiguous and (0 not in shape and len(shape) >= 2):
# Tranposing will make the tensor non-contiguous
cpu_x = cpu_x.transpose(0, 1)
x = x.transpose(0, 1)
assert not x.is_contiguous()

cpu_x.requires_grad_()
x.requires_grad_()

gelu_result = torch.nn.GELU()(x)
# GELU is not supported on CPU, so cast it to float
Expand All @@ -6648,16 +6672,55 @@ def helper(shape, dtype=torch.float):
atol = 1e-5 if dtype == torch.float else 1e-2
rtol = 1e-3 if dtype == torch.float else 1e-2
self.assertEqual(gelu_result, gelu_result_cpu.to(dtype), atol=atol, rtol=rtol)

assert x.grad is not None # Check that the grad is well-populated
self.assertEqual(x.grad, cpu_x.grad, atol=atol, rtol=rtol)

# Test empty shape too
for dtype in [torch.float, torch.half]:
for shape in [(0, 3), [], (2, 3), (2, 8, 4, 5)]:
helper(shape, dtype)
for shape in [[], (0,), (0, 3), (4,), (4, 3), (5, 4, 3)]:
for contiguous in [True, False]:
helper(shape, dtype, contiguous)
# Test that gelu would raise an assert for integral types
for dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
self.assertRaises(RuntimeError, lambda: torch.nn.GELU()(torch.randint(100, (2,), dtype=dtype, device="mps")))

def test_mish_simple(self):
def helper(shape, dtype=torch.float, contiguous=True):
cpu_x = torch.randn(shape, device='cpu', dtype=dtype)
x = cpu_x.detach().clone().to('mps')

if not contiguous and (0 not in shape and len(shape) >= 2):
# Tranposing will make the tensor non-contiguous
cpu_x = cpu_x.transpose(0, 1)
x = x.transpose(0, 1)
assert not x.is_contiguous()

cpu_x.requires_grad_()
x.requires_grad_()

mish_result = torch.nn.Mish()(x)
mish_result_cpu = torch.nn.Mish()(cpu_x)

cpu_grad = torch.ones_like(mish_result_cpu)
grad = cpu_grad.to('mps')

mish_result.backward(gradient=grad)
mish_result_cpu.backward(gradient=cpu_grad)

atol = 1e-5 if dtype == torch.float else 1e-2
rtol = 1e-3 if dtype == torch.float else 1e-2
self.assertEqual(mish_result, mish_result_cpu.to(dtype), atol=atol, rtol=rtol)

assert x.grad is not None # Check that the grad is well-populated
self.assertEqual(x.grad, cpu_x.grad, atol=atol, rtol=rtol)

# Test empty shape too
for dtype in [torch.float, torch.half]:
for shape in [[], (0,), (0, 3), (4,), (4, 3), (5, 4, 3)]:
for contiguous in [True, False]:
helper(shape, dtype, contiguous)

def test_gelu(self):
def _test_gelu(n, m, dtype, contiguous, atol=None, rtol=None):
numpy_dtype = {
Expand Down

0 comments on commit a6a3f2e

Please sign in to comment.