From 4be284bd79d2c4ffab378b93d7282b54f96647e9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 5 Jul 2024 23:48:42 -0700 Subject: [PATCH] Gemma 2 bug fixes + All RoPE Scaling Support (#736) * Update gemma2.py * Update llama.py * Update llama.py * Update gemma2.py * init * Update gemma2.py * Update gemma2.py * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * All RoPE Scaling support * cleanup * Update llama.py * Update llama.py * Update _utils.py * Update _utils.py * exec * exec * Attention_Module * attention_module * imports * exec * Update llama.py * Update llama.py * boolean mask * revert masking * Update llama.py * Update save.py * Update llama.py * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update utils.py * retry * Update gemma2.py * Update gemma2.py * Update gemma2.py * Update _utils.py * Update _utils.py * Update gemma2.py * Update chat_templates.py * Gemma 2 Ollama support * Update llama.py * Update llama.py --- unsloth/chat_templates.py | 27 ++++- unsloth/models/_utils.py | 211 ++++++++++++++++++++++++++++++++------ unsloth/models/gemma.py | 43 ++++++++ unsloth/models/gemma2.py | 18 +++- unsloth/models/llama.py | 15 ++- unsloth/models/mistral.py | 13 ++- unsloth/models/qwen2.py | 13 ++- unsloth/save.py | 7 +- 8 files changed, 306 insertions(+), 41 deletions(-) diff --git a/unsloth/chat_templates.py b/unsloth/chat_templates.py index 5f5b4e16c..596548df3 100644 --- a/unsloth/chat_templates.py +++ b/unsloth/chat_templates.py @@ -416,6 +416,21 @@ CHAT_TEMPLATES["gemma_chatml"] = (gemma_chatml_template, gemma_chatml_eos_token, True, gemma_chatml_ollama,) pass +# =========================================== Gemma 2 +# Same as Gemma 1, but with sliding window attention! +# https://ollama.com/library/gemma2/blobs/6522ca797f47 +gemma2_template = gemma_template +gemma2_ollama = gemma_ollama + "PARAMETER num_ctx 4096\n" +gemma2_eos_token = "" +CHAT_TEMPLATES["gemma2"] = (gemma2_template, gemma2_eos_token, True, gemma2_ollama,) + +# =========================================== Gemma 2 with ChatML instead +gemma2_chatml_template = gemma_chatml_template +gemma2_chatml_ollama = gemma_chatml_ollama + "PARAMETER num_ctx 4096\n" +gemma2_chatml_eos_token = gemma_chatml_eos_token +CHAT_TEMPLATES["gemma2_chatml"] = (gemma2_chatml_template, gemma2_chatml_eos_token, True, gemma2_chatml_ollama,) +pass + # =========================================== Llama-3 # Weirdly \n\n is needed? llama3_template = \ @@ -1014,7 +1029,17 @@ def get_ollama_eos_tokens(tokenizer, extra_eos_tokens = []): pass final_eos_tokens += extra_eos_tokens final_eos_tokens += repeatted_tokens - return final_eos_tokens + + # Remove new lines, spaces and HTML tags + filtered_eos_tokens = [] + for token in final_eos_tokens: + if token.count("\n") == len(token): continue + elif token.count("▁") == len(token): continue + elif token.startswith("<") and len(token) <= 2: continue + elif token.startswith("