Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

skip unimplemented error; update workflow #10

Open
wants to merge 1 commit into
base: gma/cpu_support
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .github/workflows/cpu-inference.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,7 @@ jobs:
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
cd tests
TRANSFORMERS_CACHE=~/tmp/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest -m 'inference' unit/inference/test_inference_config.py
pytest -v -s unit/autotuning/ unit/checkpoint/ unit/comm/ unit/compression/ unit/elasticity/ unit/launcher/ unit/profiling/ unit/ops
TRANSFORMERS_CACHE=~/tmp/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest -m 'seq_inference' unit/
TRANSFORMERS_CACHE=~/tmp/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest -m 'inference_ops' unit/
TRANSFORMERS_CACHE=~/tmp/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest --forked -n 4 -m 'inference' unit/
4 changes: 4 additions & 0 deletions accelerator/abstract_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ def is_bf16_supported(self):
def is_fp16_supported(self):
...

@abc.abstractmethod
def supported_dtypes(self):
...

# Misc
@abc.abstractmethod
def amp(self):
Expand Down
5 changes: 4 additions & 1 deletion accelerator/cpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,10 @@ def is_bf16_supported(self):
return True

def is_fp16_supported(self):
return True
return False

def supported_dtypes(self):
return [torch.float, torch.bfloat16]

# Tensor operations

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def op_enabled(op_name):
for op_name, builder in ALL_OPS.items():
op_compatible = builder.is_compatible()
compatible_ops[op_name] = op_compatible
compatible_ops["deepspeed_not_implemented"] = False

# If op is requested but not available, throw an error.
if op_enabled(op_name) and not op_compatible:
Expand Down
5 changes: 5 additions & 0 deletions tests/unit/checkpoint/test_latest_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,15 @@

import deepspeed

import pytest
from unit.common import DistributedTest
from unit.simple_model import *

from unit.checkpoint.common import checkpoint_correctness_verification
from deepspeed.ops.op_builder import FusedAdamBuilder

if not deepspeed.ops.__compatible_ops__[FusedAdamBuilder.NAME]:
pytest.skip("This op had not been implemented on this system.", allow_module_level=True)


class TestLatestCheckpoint(DistributedTest):
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/inference/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
from deepspeed.model_implementations import DeepSpeedTransformerInference
from torch import nn
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder import InferenceBuilder

if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
pytest.skip("This op had not been implemented on this system.", allow_module_level=True)

rocm_version = OpBuilder.installed_rocm_version()
if rocm_version != (0, 0):
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/inference/test_model_profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
from transformers import pipeline
from unit.common import DistributedTest
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder import InferenceBuilder

if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
pytest.skip("This op had not been implemented on this system.", allow_module_level=True)


@pytest.fixture
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/ops/accelerators/test_accelerator_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#pytest.skip(
# "transformer kernels are temporarily disabled because of unexplained failures",
# allow_module_level=True)
if torch.half not in get_accelerator().supported_dtypes():
pytest.skip(f"fp16 not supported, valid dtype: {get_accelerator().supported_dtypes()}", allow_module_level=True)


def check_equal(first, second, atol=1e-2, verbose=False):
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/ops/accelerators/test_accelerator_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from deepspeed.accelerator import get_accelerator
from unit.common import DistributedTest

if torch.half not in get_accelerator().supported_dtypes():
pytest.skip(f"fp16 not supported, valid dtype: {get_accelerator().supported_dtypes()}", allow_module_level=True)


def check_equal(first, second, atol=1e-2, verbose=False):
if verbose:
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/ops/adam/test_adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
from deepspeed.ops.adam import DeepSpeedCPUAdam
from unit.common import DistributedTest
from unit.simple_model import SimpleModel
from deepspeed.accelerator import get_accelerator

if torch.half not in get_accelerator().supported_dtypes():
pytest.skip(f"fp16 not supported, valid dtype: {get_accelerator().supported_dtypes()}", allow_module_level=True)
# yapf: disable
#'optimizer, zero_offload, torch_adam, adam_w_mode, resulting_optimizer
adam_configs = [["AdamW", False, False, False, (FusedAdam, True)],
Expand Down
8 changes: 6 additions & 2 deletions tests/unit/ops/quantizer/test_fake_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@

import torch
import pytest
import deepspeed
from deepspeed.accelerator import get_accelerator
from deepspeed.ops import op_builder
from deepspeed.ops.op_builder import QuantizerBuilder

if not deepspeed.ops.__compatible_ops__[QuantizerBuilder.NAME]:
pytest.skip("Inference ops are not available on this system", allow_module_level=True)

quantizer_cuda_module = None

Expand Down Expand Up @@ -36,7 +40,7 @@ def run_quant_dequant(inputs, groups, bits):
global quantizer_cuda_module

if quantizer_cuda_module is None:
quantizer_cuda_module = op_builder.QuantizerBuilder().load()
quantizer_cuda_module = QuantizerBuilder().load()
return quantizer_cuda_module.ds_quantize_fp16(inputs, groups, bits)


Expand Down
10 changes: 7 additions & 3 deletions tests/unit/ops/quantizer/test_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,20 @@

import pytest
import torch
from deepspeed.ops import op_builder
import deepspeed
from deepspeed.ops.op_builder import QuantizerBuilder
from deepspeed.accelerator import get_accelerator

if not deepspeed.ops.__compatible_ops__[QuantizerBuilder.NAME]:
pytest.skip("Inference ops are not available on this system", allow_module_level=True)

inference_module = None


def run_quantize_ds(activations, num_groups, q_bits, is_symmetric_quant):
global inference_module
if inference_module is None:
inference_module = op_builder.QuantizerBuilder().load()
inference_module = QuantizerBuilder().load()

return inference_module.quantize(activations, num_groups, q_bits,
inference_module.Symmetric if is_symmetric_quant else inference_module.Asymmetric)
Expand All @@ -23,7 +27,7 @@ def run_quantize_ds(activations, num_groups, q_bits, is_symmetric_quant):
def run_dequantize_ds(activations, params, num_groups, q_bits, is_symmetric_quant):
global inference_module
if inference_module is None:
inference_module = op_builder.QuantizerBuilder().load()
inference_module = QuantizerBuilder().load()
return inference_module.dequantize(
activations,
params,
Expand Down
5 changes: 5 additions & 0 deletions tests/unit/ops/spatial/test_nhwc_bias_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@

import pytest
import torch
import deepspeed
from deepspeed.ops.op_builder import SpatialInferenceBuilder
from deepspeed.ops.transformer.inference.bias_add import nhwc_bias_add
from deepspeed.accelerator import get_accelerator

if not deepspeed.ops.__compatible_ops__[SpatialInferenceBuilder.NAME]:
pytest.skip("Inference ops are not available on this system", allow_module_level=True)


def allclose(x, y):
assert x.dtype == y.dtype
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/profiling/flops_profiler/test_flops_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
from unit.simple_model import SimpleModel, random_dataloader
from unit.common import DistributedTest
from unit.util import required_minimum_torch_version
from deepspeed.accelerator import get_accelerator

if torch.half not in get_accelerator().supported_dtypes():
pytest.skip(f"fp16 not supported, valid dtype: {get_accelerator().supported_dtypes()}", allow_module_level=True)

pytestmark = pytest.mark.skipif(not required_minimum_torch_version(major_version=1, minor_version=3),
reason='requires Pytorch version 1.3 or above')
Expand Down