From 2aa3aee26871281c63e6c72468802689e7d9059e Mon Sep 17 00:00:00 2001 From: Khaled Sulayman Date: Fri, 8 Nov 2024 16:27:36 -0500 Subject: [PATCH] Check for tokenizer in downloaded models directory Signed-off-by: Khaled Sulayman --- src/instructlab/sdg/utils/chunkers.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/src/instructlab/sdg/utils/chunkers.py b/src/instructlab/sdg/utils/chunkers.py index 59a9b570..393221bf 100644 --- a/src/instructlab/sdg/utils/chunkers.py +++ b/src/instructlab/sdg/utils/chunkers.py @@ -19,6 +19,7 @@ PdfFormatOption, ) from docling.pipeline.standard_pdf_pipeline import StandardPdfPipeline +from instructlab.model.backends.backends import is_model_gguf, is_model_safetensors from langchain_text_splitters import Language, RecursiveCharacterTextSplitter from tabulate import tabulate from transformers import AutoTokenizer @@ -186,18 +187,12 @@ def __init__( filepaths, output_dir: Path, chunk_word_count: int, - tokenizer_model_name="mistralai/Mixtral-8x7B-Instruct-v0.1", + tokenizer_model_name: str, ): self.document_paths = document_paths self.filepaths = filepaths self.output_dir = self._path_validator(output_dir) self.chunk_word_count = chunk_word_count - self.tokenizer_model_name = ( - tokenizer_model_name - if tokenizer_model_name is not None - else "mistralai/Mixtral-8x7B-Instruct-v0.1" - ) - self.tokenizer = self.create_tokenizer(tokenizer_model_name) def chunk_documents(self) -> List: @@ -305,12 +300,24 @@ def create_tokenizer(self, model_name: str): Returns: AutoTokenizer: The tokenizer instance. """ + # Third Party + import ipdb + + ipdb.set_trace() + model_path = Path(model_name) try: - tokenizer = AutoTokenizer.from_pretrained(model_name) - logger.info(f"Successfully loaded tokenizer from: {model_name}") + if is_model_safetensors(model_path): + tokenizer = AutoTokenizer.from_pretrained(model_path) + elif is_model_gguf(model_path): + tokenizer = AutoTokenizer.from_pretrained(model_path.parent, gguf_file=model_path.name) + logger.info(f"Successfully loaded tokenizer from: {model_path}") return tokenizer except Exception as e: - logger.error(f"Failed to load tokenizer from {model_name}: {str(e)}") + logger.error( + f"Failed to load tokenizer as model was not found at {model_path}." + "Please run `ilab model download {model_name} and try again\n" + "{str(e)}" + ) raise def get_token_count(self, text, tokenizer):