Skip to content

Commit

Permalink
Always fallback to the DEFAULT_MODEL_FAMILY
Browse files Browse the repository at this point in the history
The code attempted to fallback to this, but if no model_path was given
then we were not actually falling back to
DEFAULT_MODEL_FAMILY (merlinite). This ensures we always fallback to
it, even if no model_path is given.

Signed-off-by: Ben Browning <[email protected]>
  • Loading branch information
bbrowning committed Jan 6, 2025
1 parent b17b08d commit 8703b9f
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
9 changes: 7 additions & 2 deletions src/instructlab/sdg/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,10 @@ def get_model_family(model_family, model_path):
return model_family

# Try to guess the model family based on the model's filename
guess = re.match(r"^\w*", os.path.basename(model_path)).group(0).lower()
return guess if guess in registry else DEFAULT_MODEL_FAMILY
if model_path:
guess = re.match(r"^\w*", os.path.basename(model_path)).group(0).lower()
if guess in registry:
return guess

# Nothing was found, so just return the default
return DEFAULT_MODEL_FAMILY
5 changes: 5 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,8 @@ def test_unknown_model_family(self):
"foobar", "./models/mixtral-8x7b-instruct-v0.1.Q4_K_M.gguf"
)
assert "Unknown model family: foobar" in str(exc.value)

def test_none_args(self):
assert (
models.get_model_family(None, None) == "merlinite"
)

0 comments on commit 8703b9f

Please sign in to comment.