diff --git a/open_lm/attention.py b/open_lm/attention.py index e0e8aba5..7f2e2f4c 100644 --- a/open_lm/attention.py +++ b/open_lm/attention.py @@ -111,7 +111,7 @@ def torch_attn(queries, keys, values, is_causal, attention_mask=None): if attention_mask is None: bias = None # If we only have one query, assume we don't need to be in causal mode (can attend to all keys). - if queries.shape == 1: + if queries.shape[1] == 1: is_causal = False else: if not is_causal: diff --git a/open_lm/hf/__init__.py b/open_lm/hf/__init__.py new file mode 100644 index 00000000..84931689 --- /dev/null +++ b/open_lm/hf/__init__.py @@ -0,0 +1,3 @@ +from .configuration_openlm import OpenLMConfig +from .modeling_openlm import OpenLMForCausalLM +from .tokenization_openlm import OpenLMTokenizerFast diff --git a/open_lm/hf/configuration_openlm.py b/open_lm/hf/configuration_openlm.py new file mode 100644 index 00000000..75663962 --- /dev/null +++ b/open_lm/hf/configuration_openlm.py @@ -0,0 +1,24 @@ +# Follows OLMo's HF template + +""" +OpenLM configuration +""" + +from transformers import AutoConfig, PretrainedConfig +from transformers.utils import logging + +from open_lm.model import Params + +logger = logging.get_logger(__name__) + + +class OpenLMConfig(PretrainedConfig): + model_type = "openlm" + + def __init__(self, **kwargs): + kwargs["architectures"] = ["OpenLMForCausalLM"] + super().__init__(**kwargs) + + +# Register the config class so that it is available for transformer pipelines, auto-loading etc. +AutoConfig.register("openlm", OpenLMConfig) diff --git a/open_lm/hf/modeling_openlm.py b/open_lm/hf/modeling_openlm.py new file mode 100644 index 00000000..67ee1e4f --- /dev/null +++ b/open_lm/hf/modeling_openlm.py @@ -0,0 +1,194 @@ +# Follows OLMo's HF template + +import logging +from dataclasses import fields +from typing import List, Optional, Tuple, Union + +import torch +from transformers import PreTrainedModel +from transformers.cache_utils import Cache +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.models.auto import AutoModelForCausalLM + +from open_lm.model import Params, Transformer +from open_lm.norms import get_norm_class +from open_lm.attention import get_attn_func + +from .configuration_openlm import OpenLMConfig + +log = logging.getLogger(__name__) + + +def create_model_config_from_pretrained_config(config: OpenLMConfig): + """ + Utility function + """ + + kwargs = {} + for field in fields(Params): + if hasattr(config, field.name): + kwargs[field.name] = getattr(config, field.name) + + model_config = Params(**kwargs) + + if hasattr(config, "norm_type"): + model_config.norm_type = get_norm_class(config.norm_type) + + if hasattr(config, "attn_name"): + model_config.attn_func = get_attn_func(config.attn_name) + + return model_config + + +class OpenLMForCausalLM(PreTrainedModel): + """ + Extremely barebones HF model wrapper. + """ + + config_class = OpenLMConfig + base_model_prefix = "model" + + def __init__(self, config: OpenLMConfig, model: Optional[Transformer] = None): + super().__init__(config) + + if not model: + self.model_config = create_model_config_from_pretrained_config(config) + # Initialize model (always on CPU to start with so we don't run out of GPU memory). + self.model_config.init_device = "cpu" + self.model = Transformer(self.model_config) + + else: + self.model = model + + def forward( + self, + input_ids: torch.LongTensor = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + attention_bias: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[ + Cache + ] = None, # This is a hack mitigation of an issue in transformers `4.39.x` https://github.com/huggingface/transformers/issues/29426 + ) -> Union[Tuple, CausalLMOutputWithPast]: + if inputs_embeds is not None: + log.warning("inputs_embeds is set but OpenLM does not support it yet") + if attention_bias is not None: + log.warning("attention_bias is et but OpenLM does not support it yet") + if use_cache is None: + use_cache = True + if output_attentions: + raise ValueError("output_attentions is not yet supported in OpenLM") + if output_hidden_states: + raise ValueError("output_hidden_states is not yet supported in OpenLM") + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + # print("outer past_key_values: ", type(past_key_values)) + # if past_key_values is not None: + # print(len(past_key_values), type(past_key_values[0])) + outputs = self.model.forward( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + ) + + logits = outputs[0] + past_key_values = outputs[2] + hidden_states = None + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = torch.nn.CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.model_config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=past_key_values, + hidden_states=hidden_states, + ) + + def can_generate(self) -> bool: + return True + + def prepare_inputs_for_generation( + self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs + ): + if past_key_values is not None: + if isinstance(past_key_values[0][1], int): + # This assumes that the second item of past key values is the length of the past (this is the case for linear attention) + past_length = past_key_values[0][1] + else: + # This assumes that the first item of past key values is a list of all the past keys, thus the + # shape 1 is the length of the past (this is the case for attention without window) + past_length = past_key_values[0][0].shape[1] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + + model_inputs = { + "input_ids": input_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.pop("use_cache", True), + } + return model_inputs + + def get_input_embeddings(self) -> torch.nn.Module: + return self.model.tok_embeddings + + def set_input_embeddings(self, value: torch.nn.Module): + self.model.tok_embeddings = value + + def get_output_embeddings(self): + if self.model_config.weight_tying: + return self.model.tok_embeddings + else: + return self.model.output + + def set_output_embeddings(self, value: torch.nn.Module): + if self.model_config.weight_tying: + self.model.tok_embeddings = value + else: + self.model.output = value + + def tie_weights(self): + """ + Copied from OLMo (description below). I removed it and the results just became garbage, so this pass is needed. + This function is intentionally left as a no-op. + Weight tying is handled as follows: + - When the model is initialized, the `ff_out` layer is conditionally defined based on the `weight_tying` configuration. + See: `if not config.weight_tying: self.transformer.update(...)` in `olmo/model.py`. + - When computing logits, the `wte` weights are used directly if `weight_tying` is enabled. + See: `if self.config.weight_tying: logits = F.linear(x, self.transformer.wte.weight, None)` in the `forward` method. + Therefore, there is no need to explicitly tie the weights in this function. + """ + pass + + def resize_token_embeddings( + self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None + ) -> torch.nn.Embedding: + raise NotImplementedError + + +# Register the model so that it is available for transformer pipelines, auto-loading, etc. +AutoModelForCausalLM.register(OpenLMConfig, OpenLMForCausalLM) diff --git a/open_lm/hf/tokenization_openlm.py b/open_lm/hf/tokenization_openlm.py new file mode 100644 index 00000000..e8abdd69 --- /dev/null +++ b/open_lm/hf/tokenization_openlm.py @@ -0,0 +1,18 @@ +# Follows OLMo's HF template + +from transformers import AutoTokenizer, PreTrainedTokenizerFast + +from open_lm.hf.configuration_openlm import OpenLMConfig + + +class OpenLMTokenizerFast(PreTrainedTokenizerFast): + # Note: OpenLM's tokenizer is already a wrapper around huggingface. This is potentially unnecessary. + pass + + # def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + # # This is required to make the implementation complete. + # pass + + +# Register the tokenizer class so that it is available for transformer pipelines, auto-loading etc. +AutoTokenizer.register(OpenLMConfig, fast_tokenizer_class=OpenLMTokenizerFast)