Skip to content

Commit

Permalink
Remove excess register stuff
Browse files Browse the repository at this point in the history
Signed-off-by: Mustafa Eyceoz <[email protected]>
  • Loading branch information
Maxusmusti committed Oct 30, 2024
1 parent b4b7b4b commit 3cdb93e
Showing 1 changed file with 0 additions and 27 deletions.
27 changes: 0 additions & 27 deletions src/instructlab/dolomite/hf_models/register_hf.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,15 @@
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM

from .models import (
GPTCrossLayerConfig,
GPTCrossLayerForCausalLM,
GPTCrossLayerModel,
GPTDolomiteConfig,
GPTDolomiteForCausalLM,
GPTDolomiteForCausalLM_TP,
GPTDolomiteModel,
MoEDolomiteConfig,
MoEDolomiteForCausalLM,
MoEDolomiteForCausalLM_TP,
MoEDolomiteModel,
RNNDolomiteConfig,
RNNDolomiteForCausalLM,
RNNDolomiteModel,
)


# (AutoConfig, AutoModel, AutoModelForCausalLM)
_CUSTOM_MODEL_REGISTRY = [
(GPTDolomiteConfig, GPTDolomiteModel, GPTDolomiteForCausalLM),
(MoEDolomiteConfig, MoEDolomiteModel, MoEDolomiteForCausalLM),
(GPTCrossLayerConfig, GPTCrossLayerModel, GPTCrossLayerForCausalLM),
(RNNDolomiteConfig, RNNDolomiteModel, RNNDolomiteForCausalLM),
]
_CUSTOM_MODEL_TYPES = []
_CUSTOM_MODEL_CLASSES = []
Expand All @@ -43,16 +29,3 @@ def register_model_classes() -> None:

def is_custom_model(model_class: type[AutoModelForCausalLM] | type[AutoModelForSeq2SeqLM], model_type: str) -> bool:
return model_class.__name__ in _CUSTOM_MODEL_CLASSES or model_type in _CUSTOM_MODEL_TYPES


_TENSOR_PARALLEL_CLASS_MAPPING = {
GPTDolomiteConfig.model_type: GPTDolomiteForCausalLM_TP,
MoEDolomiteConfig.model_type: MoEDolomiteForCausalLM_TP,
}


def get_tensor_parallel_class(model_type: str) -> AutoModelForCausalLM:
if model_type in _TENSOR_PARALLEL_CLASS_MAPPING:
return _TENSOR_PARALLEL_CLASS_MAPPING[model_type]

raise ValueError(f"tensor parallel is not supported with `model_type` ({model_type})")

0 comments on commit 3cdb93e

Please sign in to comment.