Skip to content

Commit

Permalink
Add support for Apples' Metal Performance Shaders (MPS) in pytorch (#…
Browse files Browse the repository at this point in the history
…2037)

* Add support for Apples' Metal Performance Shaders (MPS) in pytorch engine.
* Add Unit test

Co-authored-by: hrayrm <[email protected]>
Co-authored-by: Frank Liu <[email protected]>
  • Loading branch information
3 people authored Sep 28, 2022
1 parent 6f68a11 commit 1d0bc77
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ public static int toDeviceType(Device device) {
return 0;
} else if (Device.Type.GPU.equals(deviceType)) {
return 1;
} else if ("mps".equals(deviceType)) {
return 13;
} else {
throw new IllegalArgumentException("Unsupported device: " + device.toString());
throw new IllegalArgumentException("Unsupported device: " + device);
}
}

Expand All @@ -49,6 +51,8 @@ public static String fromDeviceType(int deviceType) {
return Device.Type.CPU;
case 1:
return Device.Type.GPU;
case 13:
return "mps";
default:
throw new IllegalArgumentException("Unsupported deviceType: " + deviceType);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.pytorch.integration;

import ai.djl.Device;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;

import org.testng.Assert;
import org.testng.SkipException;
import org.testng.annotations.Test;

public class MpsTest {

@Test
public void testMps() {
if (!"aarch64".equals(System.getProperty("os.arch"))
|| !System.getProperty("os.name").startsWith("Mac")) {
throw new SkipException("MPS test requires M1 macOS.");
}

Device device = Device.of("mps", -1);
try (NDManager manager = NDManager.newBaseManager(device)) {
NDArray array = manager.zeros(new Shape(1, 2));
Assert.assertEquals(array.getDevice().getDeviceType(), "mps");
}
}
}

0 comments on commit 1d0bc77

Please sign in to comment.