From d422ad4e38786f13810458b6e529380a0376b9f0 Mon Sep 17 00:00:00 2001 From: Abe Estrada Date: Fri, 2 Feb 2024 09:48:52 -0700 Subject: [PATCH] Add MPS backend support --- web_demo_mm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/web_demo_mm.py b/web_demo_mm.py index 753df09..e1a978c 100644 --- a/web_demo_mm.py +++ b/web_demo_mm.py @@ -14,6 +14,7 @@ import re import secrets import tempfile +import torch from modelscope import ( snapshot_download, AutoModelForCausalLM, AutoTokenizer, GenerationConfig ) @@ -50,7 +51,7 @@ def _load_model_tokenizer(args): if args.cpu_only: device_map = "cpu" else: - device_map = "cuda" + device_map = torch.device("mps" if torch.backends.mps.is_available() else "cuda") model = AutoModelForCausalLM.from_pretrained( args.checkpoint_path,