From 5677128cb892d17fe2281beae9e394fd6f89e455 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Thu, 18 Apr 2024 15:21:01 +0000 Subject: [PATCH] [MPS] Fix crash with binary_cross_entropy is invoked for half dtypes (#124258) By creating constants using input tensors dtype One line reproducer: ``` python -c "import torch; x=torch.arange(3, dtype=torch.float16,device='mps');print(torch.nn.functional.binary_cross_entropy(x, x))" ``` Before the change ``` loc("mps_subtract"("(mpsFileLoc): /AppleInternal/Library/BuildRoots/ce725a5f-c761-11ee-a4ec-b6ef2fd8d87b/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm":233:0)): error: input types 'tensor' and 'tensor<3xf16>' are not broadcast compatible LLVM ERROR: Failed to infer result type(s). ``` After ``` tensor(-33.7812, device='mps:0', dtype=torch.float16) ``` Fixes https://github.com/pytorch/pytorch/issues/124252 Pull Request resolved: https://github.com/pytorch/pytorch/pull/124258 Approved by: https://github.com/kulinseth --- aten/src/ATen/native/mps/operations/LossOps.mm | 14 ++++++++------ test/test_mps.py | 12 ++++-------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/aten/src/ATen/native/mps/operations/LossOps.mm b/aten/src/ATen/native/mps/operations/LossOps.mm index 7b10476106f6ae..77727cb197fa3a 100644 --- a/aten/src/ATen/native/mps/operations/LossOps.mm +++ b/aten/src/ATen/native/mps/operations/LossOps.mm @@ -76,7 +76,7 @@ static string reductionToString(int64_t reduction) { newCachedGraph->targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target); newCachedGraph->gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); - MPSGraphTensor* normTensor = [mpsGraph constantWithScalar:norm dataType:MPSDataTypeFloat32]; + MPSGraphTensor* normTensor = [mpsGraph constantWithScalar:norm dataType:[newCachedGraph->inputTensor dataType]]; MPSGraphTensor* diffTensor = [mpsGraph subtractionWithPrimaryTensor:newCachedGraph->inputTensor secondaryTensor:newCachedGraph->targetTensor name:nil]; @@ -116,11 +116,12 @@ static string reductionToString(int64_t reduction) { static MPSGraphTensor* bce_forward_mps(CachedGraph* bceGraph) { MPSGraph* mpsGraph = bceGraph->graph(); + const auto inputType = [bceGraph->inputTensor dataType]; // Forward BCE: L = -w (y ln(x) + (1-y) ln(1-x)) - MPSGraphTensor* one = [mpsGraph constantWithScalar:1.0 dataType:MPSDataTypeFloat32]; + MPSGraphTensor* one = [mpsGraph constantWithScalar:1.0 dataType:inputType]; // -100 is the hard limit value defined in BCELoss Spec. to clamp the log - MPSGraphTensor* neg100 = [mpsGraph constantWithScalar:-100.0 dataType:MPSDataTypeFloat32]; + MPSGraphTensor* neg100 = [mpsGraph constantWithScalar:-100.0 dataType:inputType]; // 1 - x MPSGraphTensor* one_Input = [mpsGraph subtractionWithPrimaryTensor:one secondaryTensor:bceGraph->inputTensor @@ -154,11 +155,12 @@ static string reductionToString(int64_t reduction) { static MPSGraphTensor* bce_backward_mps(CachedGraph* bceGraph) { MPSGraph* mpsGraph = bceGraph->graph(); + const auto inputType = [bceGraph->inputTensor dataType]; // Backward BCE: d(L)/d(x) = -w (y - x) / (x - x^2) - MPSGraphTensor* one = [mpsGraph constantWithScalar:1.0 dataType:MPSDataTypeFloat32]; + MPSGraphTensor* one = [mpsGraph constantWithScalar:1.0 dataType:inputType]; // epsilon used to clamp the grad input denominator - MPSGraphTensor* epsilon = [mpsGraph constantWithScalar:1e-12 dataType:MPSDataTypeFloat32]; + MPSGraphTensor* epsilon = [mpsGraph constantWithScalar:1e-12 dataType:inputType]; // 1 - x MPSGraphTensor* one_Input = [mpsGraph subtractionWithPrimaryTensor:one secondaryTensor:bceGraph->inputTensor @@ -238,7 +240,7 @@ static string reductionToString(int64_t reduction) { if (grad_output.defined()) { if (reduction == at::Reduction::Mean) { MPSGraphTensor* inputNumel = [mpsGraph constantWithScalar:static_cast(input.numel()) - dataType:MPSDataTypeFloat32]; + dataType:[bceLoss dataType]]; newCachedGraph->gradInputTensor = [mpsGraph divisionWithPrimaryTensor:bceLoss secondaryTensor:inputNumel name:nil]; diff --git a/test/test_mps.py b/test/test_mps.py index 3597ec8d124282..862bda96c729a4 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -67,6 +67,7 @@ def mps_ops_grad_modifier(ops): 'digamma': [torch.float32], 'special.polygammaspecial_polygamma_n_0': [torch.float16], 'polygammapolygamma_n_0': [torch.float16], + 'nn.functional.binary_cross_entropy': [torch.float16], # Unimplemented ops '__getitem__': [torch.float16], @@ -171,8 +172,6 @@ def mps_ops_grad_modifier(ops): 'nn.functional.conv_transpose1d': [torch.float16], 'nn.functional.conv_transpose2d': [torch.float16], 'nn.functional.conv_transpose3d': [torch.float16], - 'nn.functional.nll_loss': [torch.float16], - 'nn.functional.cross_entropy': [torch.float16], } MACOS_13_3_XFAILLIST_GRAD = { @@ -987,12 +986,6 @@ def mps_ops_modifier(ops): # Unsupported # input types 'tensor<1x3x9x9xf16>' and 'tensor<1xf32>' are not broadcast compatible 'nn.functional.avg_pool2d': [torch.float16], - # input types 'tensor' and 'tensor<1xf16>' are not broadcast compatible - # Refer to the issue please: https://github.com/pytorch/pytorch/issues/124252 - 'nn.functional.binary_cross_entropy': [torch.float16], - - 'nn.functional.nll_loss': [torch.float16], - 'nn.functional.cross_entropy': [torch.float16], } def addDecorator(op, d) -> None: @@ -11419,6 +11412,9 @@ class TestConsistency(TestCaseMPS): 'nn.functional.batch_norm', 'nn.functional.instance_norm', 'round', 'xlogy', 'addcmul', + 'nn.functional.cross_entropy', + 'nn.functional.binary_cross_entropy', + 'nn.functional.nll_loss', 'nn.functional.max_pool2d', 'nn.functional.gelu', 'nn.functional.glu',