Skip to content

Commit

Permalink
fix for local files
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Feb 7, 2025
1 parent b8c0dc2 commit f52e576
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 7 deletions.
15 changes: 14 additions & 1 deletion optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Classes handling causal-lm related architectures in ONNX Runtime."""

import logging
import os
import re
from pathlib import Path
from tempfile import TemporaryDirectory
Expand Down Expand Up @@ -418,15 +419,27 @@ def _from_pretrained(
use_merged = False

if file_name is None:
if local_files_only:
object_id = str(model_id).replace("/", "--")
cached_model_dir = os.path.join(cache_dir, f"models--{object_id}")
refs_file = os.path.join(os.path.join(cached_model_dir, "refs"), revision or "main")
with open(refs_file) as f:
revision = f.read()
model_dir = os.path.join(cached_model_dir, "snapshots", revision)
else:
model_dir = str(model_id)

onnx_files = find_files_matching_pattern(
model_id,
model_dir,
ONNX_FILE_PATTERN,
glob_pattern="**/*.onnx",
subfolder=subfolder,
token=token,
revision=revision,
)

model_path = Path(model_dir)

if len(onnx_files) == 0:
raise FileNotFoundError(f"Could not find any ONNX model file in {model_path}")

Expand Down
19 changes: 15 additions & 4 deletions optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,17 +487,28 @@ def _from_pretrained(
model_path = Path(model_id)

if file_name is None:
if local_files_only:
object_id = str(model_id).replace("/", "--")
cached_model_dir = os.path.join(cache_dir, f"models--{object_id}")
refs_file = os.path.join(os.path.join(cached_model_dir, "refs"), revision or "main")
with open(refs_file) as f:
_revision = f.read()
model_dir = os.path.join(cached_model_dir, "snapshots", _revision)
else:
model_dir = str(model_id)

onnx_files = find_files_matching_pattern(
model_id,
model_dir,
ONNX_FILE_PATTERN,
glob_pattern="**/*.onnx",
subfolder=subfolder,
token=token,
revision=revision,
)

model_path = Path(model_dir)
if len(onnx_files) == 0:
raise FileNotFoundError(f"Could not find any ONNX model file in {model_path}")
raise FileNotFoundError(f"Could not find any ONNX model file in {model_dir}")

file_name = onnx_files[0].name
subfolder = onnx_files[0].parent
Expand Down Expand Up @@ -706,8 +717,8 @@ def from_pretrained(
cached_model_dir = os.path.join(cache_dir, f"models--{object_id}")
refs_file = os.path.join(os.path.join(cached_model_dir, "refs"), revision or "main")
with open(refs_file) as f:
revision = f.read()
model_dir = os.path.join(cached_model_dir, "snapshots", revision)
_revision = f.read()
model_dir = os.path.join(cached_model_dir, "snapshots", _revision)
else:
model_dir = model_id

Expand Down
1 change: 0 additions & 1 deletion tests/onnxruntime/test_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,6 @@ def test_ort_pipeline_class_dispatch(self, model_arch: str):
@parameterized.expand(SUPPORTED_ARCHITECTURES)
@require_diffusers
def test_num_images_per_prompt(self, model_arch: str):

pipeline = self.ORTMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch])

for batch_size in [1, 3]:
Expand Down
3 changes: 2 additions & 1 deletion tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ def test_load_model_from_hub(self):
def test_load_model_from_hub_subfolder(self):
# does not pass with ORTModel as it does not have export_feature attribute
model = ORTModelForSequenceClassification.from_pretrained(
"fxmarty/tiny-bert-sst2-distilled-subfolder", subfolder="my_subfolder",
"fxmarty/tiny-bert-sst2-distilled-subfolder",
subfolder="my_subfolder",
)
self.assertIsInstance(model.model, onnxruntime.InferenceSession)
self.assertIsInstance(model.config, PretrainedConfig)
Expand Down

0 comments on commit f52e576

Please sign in to comment.