From 3cdb93ef847fc41bd60e82581583d12bf43b85ef Mon Sep 17 00:00:00 2001 From: Mustafa Eyceoz Date: Wed, 30 Oct 2024 16:09:49 -0400 Subject: [PATCH] Remove excess register stuff Signed-off-by: Mustafa Eyceoz --- .../dolomite/hf_models/register_hf.py | 27 ------------------- 1 file changed, 27 deletions(-) diff --git a/src/instructlab/dolomite/hf_models/register_hf.py b/src/instructlab/dolomite/hf_models/register_hf.py index 9d930fb..e92e456 100644 --- a/src/instructlab/dolomite/hf_models/register_hf.py +++ b/src/instructlab/dolomite/hf_models/register_hf.py @@ -1,29 +1,15 @@ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM from .models import ( - GPTCrossLayerConfig, - GPTCrossLayerForCausalLM, - GPTCrossLayerModel, GPTDolomiteConfig, GPTDolomiteForCausalLM, - GPTDolomiteForCausalLM_TP, GPTDolomiteModel, - MoEDolomiteConfig, - MoEDolomiteForCausalLM, - MoEDolomiteForCausalLM_TP, - MoEDolomiteModel, - RNNDolomiteConfig, - RNNDolomiteForCausalLM, - RNNDolomiteModel, ) # (AutoConfig, AutoModel, AutoModelForCausalLM) _CUSTOM_MODEL_REGISTRY = [ (GPTDolomiteConfig, GPTDolomiteModel, GPTDolomiteForCausalLM), - (MoEDolomiteConfig, MoEDolomiteModel, MoEDolomiteForCausalLM), - (GPTCrossLayerConfig, GPTCrossLayerModel, GPTCrossLayerForCausalLM), - (RNNDolomiteConfig, RNNDolomiteModel, RNNDolomiteForCausalLM), ] _CUSTOM_MODEL_TYPES = [] _CUSTOM_MODEL_CLASSES = [] @@ -43,16 +29,3 @@ def register_model_classes() -> None: def is_custom_model(model_class: type[AutoModelForCausalLM] | type[AutoModelForSeq2SeqLM], model_type: str) -> bool: return model_class.__name__ in _CUSTOM_MODEL_CLASSES or model_type in _CUSTOM_MODEL_TYPES - - -_TENSOR_PARALLEL_CLASS_MAPPING = { - GPTDolomiteConfig.model_type: GPTDolomiteForCausalLM_TP, - MoEDolomiteConfig.model_type: MoEDolomiteForCausalLM_TP, -} - - -def get_tensor_parallel_class(model_type: str) -> AutoModelForCausalLM: - if model_type in _TENSOR_PARALLEL_CLASS_MAPPING: - return _TENSOR_PARALLEL_CLASS_MAPPING[model_type] - - raise ValueError(f"tensor parallel is not supported with `model_type` ({model_type})")