Skip to content

Commit

Permalink
Merge branch 'unslothai:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
dsingal0 authored Jul 6, 2024
2 parents 29ee45b + 4be284b commit 18a04c3
Show file tree
Hide file tree
Showing 8 changed files with 306 additions and 41 deletions.
27 changes: 26 additions & 1 deletion unsloth/chat_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,21 @@
CHAT_TEMPLATES["gemma_chatml"] = (gemma_chatml_template, gemma_chatml_eos_token, True, gemma_chatml_ollama,)
pass

# =========================================== Gemma 2
# Same as Gemma 1, but with sliding window attention!
# https://ollama.com/library/gemma2/blobs/6522ca797f47
gemma2_template = gemma_template
gemma2_ollama = gemma_ollama + "PARAMETER num_ctx 4096\n"
gemma2_eos_token = "<end_of_turn>"
CHAT_TEMPLATES["gemma2"] = (gemma2_template, gemma2_eos_token, True, gemma2_ollama,)

# =========================================== Gemma 2 with ChatML instead
gemma2_chatml_template = gemma_chatml_template
gemma2_chatml_ollama = gemma_chatml_ollama + "PARAMETER num_ctx 4096\n"
gemma2_chatml_eos_token = gemma_chatml_eos_token
CHAT_TEMPLATES["gemma2_chatml"] = (gemma2_chatml_template, gemma2_chatml_eos_token, True, gemma2_chatml_ollama,)
pass

# =========================================== Llama-3
# Weirdly \n\n is needed?
llama3_template = \
Expand Down Expand Up @@ -1014,7 +1029,17 @@ def get_ollama_eos_tokens(tokenizer, extra_eos_tokens = []):
pass
final_eos_tokens += extra_eos_tokens
final_eos_tokens += repeatted_tokens
return final_eos_tokens

# Remove new lines, spaces and HTML tags
filtered_eos_tokens = []
for token in final_eos_tokens:
if token.count("\n") == len(token): continue
elif token.count("▁") == len(token): continue
elif token.startswith("<") and len(token) <= 2: continue
elif token.startswith("</") and len(token) == 3: continue
filtered_eos_tokens.append(token)
pass
return filtered_eos_tokens
pass


Expand Down
211 changes: 178 additions & 33 deletions unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,43 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "2024.7"

__all__ = [
"prepare_model_for_kbit_training",
"xformers",
"xformers_attention",
"xformers_version",
"__version__",
"HAS_FLASH_ATTENTION",
"platform_system",
"patch_tokenizer",
"get_statistics",
"Unsloth_Offloaded_Gradient_Checkpointer",
"offload_to_disk",
"offload_input_embeddings",
"offload_output_embeddings",
"is_bfloat16_supported",
"unsloth_offloaded_gradient_checkpoint",
"torch_compile_options",
"patch_linear_scaling",
"create_boolean_mask",
]

import torch
from typing import Union, Optional, List, Any, Callable
from typing import Union, Optional, List, Any, Callable, Tuple
import warnings
from platform import system as platform_system
platform_system = platform_system()
import math
import numpy as np
import os
import psutil
import inspect
import re

# =============================================
# Disable some warnings which can get annoying
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "torch")
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "huggingface_hub")
warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "subprocess")
Expand All @@ -26,20 +60,42 @@
# Stop "Special tokens have been added in the vocabulary, ..."
import logging
logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.CRITICAL+1)
# =============================================

# =============================================
# Edits all Config files to enable RoPE Scaling for all models
from transformers import PretrainedConfig

model_architectures = ["llama", "mistral", "gemma", "gemma2", "qwen2",]

for model_name in model_architectures:
config_filepath = f"transformers.models.{model_name}.configuration_{model_name}"
model_filepath = f"transformers.models.{model_name}.modeling_{model_name}"
config_filename = f"{model_name.title()}Config"
exec(f"from {config_filepath} import {config_filename}", globals())

config = inspect.getsource(eval(config_filename))
if "rope_scaling" in config: continue
config = re.sub(
r"(\*\*kwargs)[\s]{0,}\,[\s]{0,}\)[\s]{0,}\:",
r"rope_scaling=None,"\
r"\n **kwargs):\n"\
r"\n self.rope_scaling = rope_scaling\n",
config,
)
exec(config, globals())

exec(f"import {config_filepath}", globals())
exec(f"{config_filepath}.{config_filename} = {config_filename}", globals())
pass
# =============================================

# =============================================
# Get Flash Attention v2 if Ampere (RTX 30xx, A100)
import bitsandbytes as bnb
from transformers.models.llama.modeling_llama import logger
from transformers import AutoTokenizer
from platform import system as platform_system
platform_system = platform_system()
import math
import numpy as np
import os
import psutil

__version__ = "2024.7"

# Get Flash Attention v2 if Ampere (RTX 30xx, A100)
major_version, minor_version = torch.cuda.get_device_capability()
SUPPORTS_BFLOAT16 = False

Expand Down Expand Up @@ -69,25 +125,10 @@
import xformers.ops.fmha as xformers
xformers_attention = xformers.memory_efficient_attention
from xformers import __version__ as xformers_version
# =============================================

__all__ = [
"prepare_model_for_kbit_training",
"xformers",
"xformers_attention",
"xformers_version",
"__version__",
"HAS_FLASH_ATTENTION",
"platform_system",
"patch_tokenizer",
"get_statistics",
"Unsloth_Offloaded_Gradient_Checkpointer",
"offload_to_disk",
"offload_input_embeddings",
"offload_output_embeddings",
"is_bfloat16_supported",
"unsloth_offloaded_gradient_checkpoint",
"torch_compile_options",
]
# =============================================
# Torch compile settings

# Just remove max_autotune_gemm warning
import functools
Expand Down Expand Up @@ -128,7 +169,7 @@ def is_big_gpu(index):
"trace.enabled" : False, # Output Triton kernel outputs!
"triton.cudagraphs" : False,
}

# =============================================

def prepare_model_for_kbit_training(
model : Any,
Expand Down Expand Up @@ -266,6 +307,7 @@ def patch_tokenizer(model, tokenizer):
pass


# =============================================
# Weirdly LoraLayer.update_layer downcasts PEFT layers to float16??
# For mixed precision, we need it to be in float32 not float16.
from peft.tuners.lora.layer import LoraLayer
Expand Down Expand Up @@ -295,6 +337,7 @@ def patch_tokenizer(model, tokenizer):
"Luckily, your training run will still work in the meantime!"
)
pass
# =============================================


def get_statistics():
Expand Down Expand Up @@ -456,9 +499,8 @@ def unsloth_offloaded_gradient_checkpoint(function, *args, use_reentrant = None,
pass


"""
Remove warnings about missing kwargs and patch stuff
"""
# =============================================
# Fixes Bitsandbytes to remove missing warnings
from transformers.utils.quantization_config import BitsAndBytesConfig, QuantizationMethod
from inspect import getsource
from accelerate.utils.dataclasses import DistributedType
Expand Down Expand Up @@ -501,7 +543,7 @@ def _prepare_backend(

import transformers.utils.quantization_config
transformers.utils.quantization_config.BitsAndBytesConfig.__init__ = _BitsAndBytesConfig__init__

# =============================================

# Offloading to disk for modules (lm_head, embed_tokens)
import pickle
Expand Down Expand Up @@ -549,3 +591,106 @@ def offload_output_embeddings(model, temporary_location : str = "_unsloth_tempor
def is_bfloat16_supported():
return SUPPORTS_BFLOAT16
pass


# Patches models to add RoPE Scaling
def patch_linear_scaling(
model_name = "gemma2",
rope_module = None,
scaled_rope_module = None,
attention_module = None,
):
assert(rope_module is not None and scaled_rope_module is not None)
assert(attention_module is not None)

rope_name = rope_module.__name__
scaled_rope_name = scaled_rope_module.__name__
model_filepath = f"transformers.models.{model_name}.modeling_{model_name}"
exec_code = \
f"import torch.nn as nn\n"\
f"from typing import Union, Optional, List, Any, Callable, Tuple\n"\
f"from {model_filepath} import logger, "\
f"{model_name.title()}Attention, {model_name.title()}Config"

function = inspect.getsource(attention_module.__init__)
where = function.find("def")
function = function.split("\n")
function = "\n".join(x[where:] for x in function)
init_name = f"{model_name.title()}Attention__init__"
function = function.replace("def __init__", f"def {init_name}")
function = function.replace(
"super().__init__()",
f"super({model_name.title()}Attention, self).__init__()",
)
fix_rope_function = """
if getattr(self.config, "rope_scaling", None) is None:
self.rotary_emb = {rope_function}(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = {scaled_rope_function}(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
else:
raise ValueError(f"Unknown RoPE scaling type {{scaling_type}}")
pass
"""
fix_rope_function = fix_rope_function.format(
rope_function = rope_module.__name__,
scaled_rope_function = scaled_rope_module.__name__,
)
rotary_emb = re.findall(
"self.rotary_emb = .+?\)", function,
flags = re.DOTALL | re.MULTILINE,
)
if len(rotary_emb) == 0: return
rotary_emb = rotary_emb[0]
function = function.replace(rotary_emb, fix_rope_function, 1)
function = exec_code + "\n\n" + function
return init_name, function
pass


def create_boolean_mask(n = 4096, sliding_window = 2048):
# Creates a boolean mask for attention
mask = torch.ones(n, n, dtype = torch.bool)
if sliding_window == 0:
return torch.triu(mask, diagonal = 1, out = mask)
pass
torch.triu(mask, diagonal = 0, out = mask)
torch.triu(mask.T, diagonal = -sliding_window, out = mask.T)
mask = mask.T
torch.logical_not(mask, out = mask)
return mask
pass


def test_mask_creation():
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
for n in range(2, 23):
for s in range(1, 23):
correct_mask = AttentionMaskConverter(
is_causal = True,
sliding_window = s,
).to_causal_4d(1, n, n, dtype = torch.float16,).squeeze(0).squeeze(0)
correct_mask = (correct_mask == correct_mask.min())
our_mask = create_boolean_mask(n = n, sliding_window = s)
assert(torch.all(correct_mask == our_mask))
pass
correct_mask = AttentionMaskConverter(
is_causal = True,
sliding_window = None,
).to_causal_4d(1, n, n, dtype = torch.float16,).squeeze(0).squeeze(0)
correct_mask = (correct_mask == correct_mask.min())
our_mask = create_boolean_mask(n = n, sliding_window = 0)
assert(torch.all(correct_mask == our_mask))
pass
pass
43 changes: 43 additions & 0 deletions unsloth/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,10 +236,53 @@ def forward(self, x, position_ids=None, seq_len=None):
pass


class GemmaFixedLinearScalingRotaryEmbedding(GemmaFixedRotaryEmbedding):
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
# Fixes https://github.com/huggingface/transformers/pull/28837
# https://github.com/microsoft/DeepSpeed/issues/4932
# The precision of RoPE buffers is not correct, so we cast to int64.
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
pass

def _set_cos_sin_cache(self, seq_len, device, dtype):
# Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and
# in FP32. They are applied (multiplied) in FP32 as well.
self.max_seq_len_cached = seq_len

# The difference is we do division explicity instead of t * (1/x) ie we do t/x.
freq_exponents = (2.0 / self.dim) * (
torch.arange(self.dim // 2, dtype = torch.int64, device = "cpu").float()
)
timescale = self.base**freq_exponents
positions = torch.arange(self.max_seq_len_cached, device = "cpu", dtype = torch.int64).float()
positions = positions / self.scaling_factor
radians_new = positions[..., None] / timescale[None, None, :]
radians_new = radians_new.squeeze(0)

emb = torch.cat((radians_new, radians_new), dim = -1)
# We must do RoPE in float32!
cos = emb.cos().to(device = "cuda:0", non_blocking = True)#, dtype = dtype)
sin = emb.sin().to(device = "cuda:0", non_blocking = True)#, dtype = dtype)
self.register_buffer("cos_cached", cos, persistent = False)
self.register_buffer("sin_cached", sin, persistent = False)
pass
pass


class FastGemmaModel(FastLlamaModel):

@staticmethod
def pre_patch():
init_name, function = patch_linear_scaling(
model_name = "gemma",
rope_module = GemmaFixedRotaryEmbedding,
scaled_rope_module = GemmaFixedLinearScalingRotaryEmbedding,
attention_module = GemmaAttention,
)
exec(function, globals())
GemmaAttention.__init__ = eval(init_name)
GemmaAttention .forward = LlamaAttention_fast_forward
GemmaSdpaAttention .forward = LlamaAttention_fast_forward
GemmaFlashAttention2.forward = LlamaAttention_fast_forward
Expand Down
Loading

0 comments on commit 18a04c3

Please sign in to comment.