Skip to content

Commit

Permalink
[MPS] Fix crash with binary_cross_entropy is invoked for half dtypes (p…
Browse files Browse the repository at this point in the history
…ytorch#124258)

By creating constants using input tensors dtype

One line reproducer:
```
python -c "import torch; x=torch.arange(3, dtype=torch.float16,device='mps');print(torch.nn.functional.binary_cross_entropy(x, x))"
```

Before the change
```
loc("mps_subtract"("(mpsFileLoc): /AppleInternal/Library/BuildRoots/ce725a5f-c761-11ee-a4ec-b6ef2fd8d87b/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm":233:0)): error: input types 'tensor<f32>' and 'tensor<3xf16>' are not broadcast compatible
LLVM ERROR: Failed to infer result type(s).
```
After
```
tensor(-33.7812, device='mps:0', dtype=torch.float16)
```

Fixes pytorch#124252

Pull Request resolved: pytorch#124258
Approved by: https://github.com/kulinseth
  • Loading branch information
malfet authored and pytorchmergebot committed Apr 18, 2024
1 parent ef93402 commit 5677128
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 14 deletions.
14 changes: 8 additions & 6 deletions aten/src/ATen/native/mps/operations/LossOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ static string reductionToString(int64_t reduction) {
newCachedGraph->targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target);
newCachedGraph->gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);

MPSGraphTensor* normTensor = [mpsGraph constantWithScalar:norm dataType:MPSDataTypeFloat32];
MPSGraphTensor* normTensor = [mpsGraph constantWithScalar:norm dataType:[newCachedGraph->inputTensor dataType]];
MPSGraphTensor* diffTensor = [mpsGraph subtractionWithPrimaryTensor:newCachedGraph->inputTensor
secondaryTensor:newCachedGraph->targetTensor
name:nil];
Expand Down Expand Up @@ -116,11 +116,12 @@ static string reductionToString(int64_t reduction) {

static MPSGraphTensor* bce_forward_mps(CachedGraph* bceGraph) {
MPSGraph* mpsGraph = bceGraph->graph();
const auto inputType = [bceGraph->inputTensor dataType];

// Forward BCE: L = -w (y ln(x) + (1-y) ln(1-x))
MPSGraphTensor* one = [mpsGraph constantWithScalar:1.0 dataType:MPSDataTypeFloat32];
MPSGraphTensor* one = [mpsGraph constantWithScalar:1.0 dataType:inputType];
// -100 is the hard limit value defined in BCELoss Spec. to clamp the log
MPSGraphTensor* neg100 = [mpsGraph constantWithScalar:-100.0 dataType:MPSDataTypeFloat32];
MPSGraphTensor* neg100 = [mpsGraph constantWithScalar:-100.0 dataType:inputType];
// 1 - x
MPSGraphTensor* one_Input = [mpsGraph subtractionWithPrimaryTensor:one
secondaryTensor:bceGraph->inputTensor
Expand Down Expand Up @@ -154,11 +155,12 @@ static string reductionToString(int64_t reduction) {

static MPSGraphTensor* bce_backward_mps(CachedGraph* bceGraph) {
MPSGraph* mpsGraph = bceGraph->graph();
const auto inputType = [bceGraph->inputTensor dataType];

// Backward BCE: d(L)/d(x) = -w (y - x) / (x - x^2)
MPSGraphTensor* one = [mpsGraph constantWithScalar:1.0 dataType:MPSDataTypeFloat32];
MPSGraphTensor* one = [mpsGraph constantWithScalar:1.0 dataType:inputType];
// epsilon used to clamp the grad input denominator
MPSGraphTensor* epsilon = [mpsGraph constantWithScalar:1e-12 dataType:MPSDataTypeFloat32];
MPSGraphTensor* epsilon = [mpsGraph constantWithScalar:1e-12 dataType:inputType];
// 1 - x
MPSGraphTensor* one_Input = [mpsGraph subtractionWithPrimaryTensor:one
secondaryTensor:bceGraph->inputTensor
Expand Down Expand Up @@ -238,7 +240,7 @@ static string reductionToString(int64_t reduction) {
if (grad_output.defined()) {
if (reduction == at::Reduction::Mean) {
MPSGraphTensor* inputNumel = [mpsGraph constantWithScalar:static_cast<double>(input.numel())
dataType:MPSDataTypeFloat32];
dataType:[bceLoss dataType]];
newCachedGraph->gradInputTensor = [mpsGraph divisionWithPrimaryTensor:bceLoss
secondaryTensor:inputNumel
name:nil];
Expand Down
12 changes: 4 additions & 8 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def mps_ops_grad_modifier(ops):
'digamma': [torch.float32],
'special.polygammaspecial_polygamma_n_0': [torch.float16],
'polygammapolygamma_n_0': [torch.float16],
'nn.functional.binary_cross_entropy': [torch.float16],

# Unimplemented ops
'__getitem__': [torch.float16],
Expand Down Expand Up @@ -171,8 +172,6 @@ def mps_ops_grad_modifier(ops):
'nn.functional.conv_transpose1d': [torch.float16],
'nn.functional.conv_transpose2d': [torch.float16],
'nn.functional.conv_transpose3d': [torch.float16],
'nn.functional.nll_loss': [torch.float16],
'nn.functional.cross_entropy': [torch.float16],
}

MACOS_13_3_XFAILLIST_GRAD = {
Expand Down Expand Up @@ -987,12 +986,6 @@ def mps_ops_modifier(ops):
# Unsupported
# input types 'tensor<1x3x9x9xf16>' and 'tensor<1xf32>' are not broadcast compatible
'nn.functional.avg_pool2d': [torch.float16],
# input types 'tensor<f32>' and 'tensor<1xf16>' are not broadcast compatible
# Refer to the issue please: https://github.com/pytorch/pytorch/issues/124252
'nn.functional.binary_cross_entropy': [torch.float16],

'nn.functional.nll_loss': [torch.float16],
'nn.functional.cross_entropy': [torch.float16],
}

def addDecorator(op, d) -> None:
Expand Down Expand Up @@ -11419,6 +11412,9 @@ class TestConsistency(TestCaseMPS):
'nn.functional.batch_norm',
'nn.functional.instance_norm',
'round', 'xlogy', 'addcmul',
'nn.functional.cross_entropy',
'nn.functional.binary_cross_entropy',
'nn.functional.nll_loss',
'nn.functional.max_pool2d',
'nn.functional.gelu',
'nn.functional.glu',
Expand Down

0 comments on commit 5677128

Please sign in to comment.