diff --git a/optimum/exporters/onnx/convert.py b/optimum/exporters/onnx/convert.py index 0cca6d129f..833e3e2461 100644 --- a/optimum/exporters/onnx/convert.py +++ b/optimum/exporters/onnx/convert.py @@ -16,6 +16,7 @@ import copy import gc +import importlib import multiprocessing as mp import os import traceback @@ -532,7 +533,11 @@ def export_pytorch( # Check that inputs match, and order them properly dummy_inputs = config.generate_dummy_inputs(framework="pt", **input_shapes) - device = torch.device(device) + if device == "dml" and importlib.util.find_spec("torch_directml"): + torch_directml = importlib.import_module("torch_directml") + device = torch_directml.device() + else: + device = torch.device(device) def remap(value): if isinstance(value, torch.Tensor): @@ -540,7 +545,7 @@ def remap(value): return value - if device.type == "cuda" and torch.cuda.is_available(): + if device.type == "cuda" and torch.cuda.is_available() or device.type == "privateuseone": model.to(device) dummy_inputs = tree_map(remap, dummy_inputs) diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 47a6ae08ca..d639b9f06e 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -2271,12 +2271,16 @@ def get_model_from_task( kwargs["torch_dtype"] = torch_dtype if isinstance(device, str): - device = torch.device(device) + if device == "dml" and importlib.util.find_spec("torch_directml"): + torch_directml = importlib.import_module("torch_directml") + device = torch_directml.device() + else: + device = torch.device(device) elif device is None: device = torch.device("cpu") # TODO : fix EulerDiscreteScheduler loading to enable for SD models - if version.parse(torch.__version__) >= version.parse("2.0") and library_name != "diffusers": + if version.parse(torch.__version__) >= version.parse("2.0") and library_name != "diffusers" and device.type != "privateuseone": with device: # Initialize directly in the requested device, to save allocation time. Especially useful for large # models to initialize on cuda device. diff --git a/optimum/onnxruntime/io_binding/io_binding_helper.py b/optimum/onnxruntime/io_binding/io_binding_helper.py index f32ecc56e6..30fe264dbc 100644 --- a/optimum/onnxruntime/io_binding/io_binding_helper.py +++ b/optimum/onnxruntime/io_binding/io_binding_helper.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import importlib import logging import traceback from typing import TYPE_CHECKING @@ -145,8 +146,12 @@ def to_pytorch_via_dlpack(ort_value: OrtValue) -> torch.Tensor: @staticmethod def get_device_index(device): if isinstance(device, str): - # could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0 - device = torch.device(device) + if device == "dml" and importlib.util.find_spec("torch_directml"): + torch_directml = importlib.import_module("torch_directml") + device = torch_directml.device() + else: + # could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0 + device = torch.device(device) elif isinstance(device, int): return device return 0 if device.index is None else device.index diff --git a/optimum/utils/import_utils.py b/optimum/utils/import_utils.py index 8da1df5fac..c3c0cc6aae 100644 --- a/optimum/utils/import_utils.py +++ b/optimum/utils/import_utils.py @@ -113,6 +113,8 @@ def _is_package_available( "onnxruntime-migraphx", "ort-migraphx-nightly", "ort-rocm-nightly", + # For DirectML + "onnxruntime-directml", ], ) _tf_available, _tf_version = _is_package_available( diff --git a/setup.py b/setup.py index d132975aa4..100d49a499 100644 --- a/setup.py +++ b/setup.py @@ -69,6 +69,15 @@ "protobuf>=3.20.1", "transformers>=4.36,<4.49.0", ], + "onnxruntime-directml": [ + "onnx", + "onnxruntime-directml>=1.11.0", + "datasets>=1.2.1", + "evaluate", + "protobuf>=3.20.1", + "accelerate", # ORTTrainer requires it. + "transformers>=4.36,<4.48.0", + ], "exporters": [ "onnx", "onnxruntime", @@ -81,6 +90,13 @@ "timm", "transformers>=4.36,<4.49.0", ], + "exporters-directml": [ + "torch-directml", + "onnx", + "onnxruntime-directml", + "timm", + "transformers>=4.36,<4.48.0", + ], "exporters-tf": [ "tensorflow>=2.4,<=2.12.1", "tf2onnx",