Skip to content

Commit

Permalink
fix_tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
danielhanchen committed Jan 9, 2024
1 parent b522781 commit 82e6fec
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions unsloth/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ def from_pretrained(
token = None,
device_map = "sequential",
rope_scaling = None,
check_tokenizer = True,
fix_tokenizer = True,
):
SUPPORTS_BFLOAT16 = torch.cuda.is_bf16_supported()
gpu_stats = torch.cuda.get_device_properties(0)
Expand Down Expand Up @@ -704,7 +704,7 @@ def from_pretrained(
internal_model.max_seq_length = max_position_embeddings

# We check the tokenizer first for errors
if check_tokenizer:
if fix_tokenizer:
tokenizer = check_tokenizer(
model = model,
tokenizer = tokenizer,
Expand Down
4 changes: 2 additions & 2 deletions unsloth/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def from_pretrained(
token = None,
device_map = "sequential",
rope_scaling = None,
check_tokenizer = True,
fix_tokenizer = True,
*args, **kwargs,
):
if not SUPPORTS_FOURBIT and model_name in FOURBIT_MAPPER:
Expand Down Expand Up @@ -84,7 +84,7 @@ def from_pretrained(
token = token,
device_map = device_map,
rope_scaling = rope_scaling,
check_tokenizer = check_tokenizer,
fix_tokenizer = fix_tokenizer,
*args, **kwargs,
)
pass
Expand Down
4 changes: 2 additions & 2 deletions unsloth/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def from_pretrained(
token = None,
device_map = "sequential",
rope_scaling = None, # Mistral does not support RoPE scaling
check_tokenizer = True,
fix_tokenizer = True,
):
if rope_scaling is not None:
logger.warning_once("Unsloth: Mistral models do not support RoPE scaling.")
Expand Down Expand Up @@ -333,7 +333,7 @@ def from_pretrained(
internal_model.max_seq_length = max_position_embeddings

# We check the tokenizer first for errors
if check_tokenizer:
if fix_tokenizer:
tokenizer = check_tokenizer(
model = model,
tokenizer = tokenizer,
Expand Down

0 comments on commit 82e6fec

Please sign in to comment.