Skip to content

Commit

Permalink
[ATen] Add CPU fp16 support for nll_loss and cross_entropy_loss (pyto…
Browse files Browse the repository at this point in the history
…rch#123256)

Add CPU FP16 support for nll_loss and cross_entropy_loss.
Resolve issue pytorch#123328.

Pull Request resolved: pytorch#123256
Approved by: https://github.com/jgong5, https://github.com/EikanWang, https://github.com/malfet
  • Loading branch information
etaf authored and pytorchmergebot committed Apr 18, 2024
1 parent d59f1da commit 6fcbeb3
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 20 deletions.
11 changes: 8 additions & 3 deletions aten/src/ATen/native/LossNLL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,12 @@ void nll_loss_forward_out_cpu_template(
const Tensor& weight,
int64_t reduction,
int64_t ignore_index) {
AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::BFloat16, input.scalar_type(), "nll_loss_out_frame", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::BFloat16,
ScalarType::Half,
input.scalar_type(),
"nll_loss_out_frame",
[&] {
if (target.scalar_type() == kByte) {
nll_loss_out_frame<scalar_t, uint8_t>(
output,
Expand Down Expand Up @@ -415,8 +419,9 @@ void nll_loss_backward_out_cpu_template(
const Tensor& total_weight) {
grad_input.zero_();

AT_DISPATCH_FLOATING_TYPES_AND(
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::BFloat16,
ScalarType::Half,
input.scalar_type(),
"nll_loss_backward_out_frame",
[&] {
Expand Down
6 changes: 4 additions & 2 deletions aten/src/ATen/native/LossNLL2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,9 @@ void nll_loss2d_forward_out_cpu_template(
check_inputs_nll_loss2d(input, target, weight);
total_weight.resize_({});

AT_DISPATCH_FLOATING_TYPES_AND(
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::BFloat16,
ScalarType::Half,
input.scalar_type(),
"nll_loss2d_forward_out_frame",
[&] {
Expand Down Expand Up @@ -383,8 +384,9 @@ void nll_loss2d_backward_out_cpu_template(
total_weight.numel(),
" elements)");

AT_DISPATCH_FLOATING_TYPES_AND(
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::BFloat16,
ScalarType::Half,
input.scalar_type(),
"nll_loss2d_backward_out_frame",
[&] {
Expand Down
1 change: 1 addition & 0 deletions test/onnx/test_fx_op_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -2001,6 +2001,7 @@ class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime):
"nn.functional.multilabel_soft_margin_loss": [4e-2, 5e-3],
"nn.functional.local_response_norm": [1e-2, 5e-3],
"nn.functional.poisson_nll_loss": [3e-2, 1e-3],
"nn.functional.nll_loss": [3e-2, 1e-3],
"native_batch_norm": [3e-2, 1e-3],
"dot": [3e-2, 1e-3],
"logit": [3e-2, 1e-3],
Expand Down
7 changes: 6 additions & 1 deletion test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ def mps_ops_grad_modifier(ops):
'nn.functional.conv_transpose1d': [torch.float16],
'nn.functional.conv_transpose2d': [torch.float16],
'nn.functional.conv_transpose3d': [torch.float16],
'nn.functional.nll_loss': [torch.float16],
'nn.functional.cross_entropy': [torch.float16],
}

MACOS_13_3_XFAILLIST_GRAD = {
Expand Down Expand Up @@ -987,7 +989,10 @@ def mps_ops_modifier(ops):
'nn.functional.avg_pool2d': [torch.float16],
# input types 'tensor<f32>' and 'tensor<1xf16>' are not broadcast compatible
# Refer to the issue please: https://github.com/pytorch/pytorch/issues/124252
'nn.functional.binary_cross_entropy': [torch.float16]
'nn.functional.binary_cross_entropy': [torch.float16],

'nn.functional.nll_loss': [torch.float16],
'nn.functional.cross_entropy': [torch.float16],
}

def addDecorator(op, d) -> None:
Expand Down
12 changes: 8 additions & 4 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13009,8 +13009,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
supports_out=False),
OpInfo(
"nn.functional.cross_entropy",
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
dtypes=floating_types_and(torch.float16, torch.bfloat16),
sample_inputs_func=sample_inputs_cross_entropy,
supports_out=False,
supports_forward_ad=True,
Expand All @@ -13033,6 +13032,9 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
"test_variant_consistency_jit",
device_type="cuda",
),
DecorateInfo(unittest.skip("FP16 corss_entropy cases have not been enabled on MPS yet"),
dtypes=(torch.half,), device_type="mps"),

)
),
OpInfo('nn.functional.normalize',
Expand Down Expand Up @@ -19427,8 +19429,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
),
OpInfo(
"nn.functional.nll_loss",
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
dtypes=floating_types_and(torch.float16, torch.bfloat16),
supports_out=False,
sample_inputs_func=sample_inputs_nll_loss,
supports_forward_ad=True,
Expand All @@ -19449,6 +19450,9 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
"test_cow_input",
device_type='cuda',
),
DecorateInfo(unittest.skip("FP16 nll_loss cases have not been enabled on MPS yet"),
dtypes=(torch.half,), device_type="mps"),

),
),
OpInfo(
Expand Down
12 changes: 2 additions & 10 deletions torch/testing/_internal/common_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -4013,16 +4013,8 @@ def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad
decorators=(
# No channels_last support for loss functions.
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_memory_format'),
# Expect failures for tests that rely on torch.half implementation on CPU
DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward", dtypes=[torch.float16], device_type='cpu'),
DecorateInfo(unittest.expectedFailure, "TestModule", "test_if_train_and_eval_modes_differ",
dtypes=[torch.float16], device_type='cpu'),
DecorateInfo(unittest.expectedFailure, "TestModule", "test_save_load", dtypes=[torch.float16],
device_type='cpu'),
DecorateInfo(unittest.expectedFailure, "TestModule", "test_non_contiguous_tensors", dtypes=[torch.float16],
device_type='cpu'),
DecorateInfo(unittest.expectedFailure, "TestModule", "test_multiple_device_transfer", dtypes=[torch.float16],
device_type='cuda'),
DecorateInfo(toleranceOverride({torch.float16: tol(atol=3e-2, rtol=1e-3)}), "TestModule",
"test_forward", dtypes=[torch.float16], device_type='cpu'),
DecorateInfo(unittest.expectedFailure, "TestModule", "test_cpu_gpu_parity", dtypes=[torch.float16],
device_type='cuda'),),
),
Expand Down

0 comments on commit 6fcbeb3

Please sign in to comment.