Skip to content

Commit

Permalink
HF integration (#291)
Browse files Browse the repository at this point in the history
* add hf

* import fix

* black reformat

* get_attn_func in hf loader

* typo

* black reformat

* black reformat

* black reformat

* reqs

* reqs

* reqs

* reqs

* indexing bug
  • Loading branch information
sedrick-keh-tri authored Jul 17, 2024
1 parent 69c9235 commit c0f1319
Show file tree
Hide file tree
Showing 5 changed files with 240 additions and 1 deletion.
2 changes: 1 addition & 1 deletion open_lm/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def torch_attn(queries, keys, values, is_causal, attention_mask=None):
if attention_mask is None:
bias = None
# If we only have one query, assume we don't need to be in causal mode (can attend to all keys).
if queries.shape == 1:
if queries.shape[1] == 1:
is_causal = False
else:
if not is_causal:
Expand Down
3 changes: 3 additions & 0 deletions open_lm/hf/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .configuration_openlm import OpenLMConfig
from .modeling_openlm import OpenLMForCausalLM
from .tokenization_openlm import OpenLMTokenizerFast
24 changes: 24 additions & 0 deletions open_lm/hf/configuration_openlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Follows OLMo's HF template

"""
OpenLM configuration
"""

from transformers import AutoConfig, PretrainedConfig
from transformers.utils import logging

from open_lm.model import Params

logger = logging.get_logger(__name__)


class OpenLMConfig(PretrainedConfig):
model_type = "openlm"

def __init__(self, **kwargs):
kwargs["architectures"] = ["OpenLMForCausalLM"]
super().__init__(**kwargs)


# Register the config class so that it is available for transformer pipelines, auto-loading etc.
AutoConfig.register("openlm", OpenLMConfig)
194 changes: 194 additions & 0 deletions open_lm/hf/modeling_openlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# Follows OLMo's HF template

import logging
from dataclasses import fields
from typing import List, Optional, Tuple, Union

import torch
from transformers import PreTrainedModel
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.auto import AutoModelForCausalLM

from open_lm.model import Params, Transformer
from open_lm.norms import get_norm_class
from open_lm.attention import get_attn_func

from .configuration_openlm import OpenLMConfig

log = logging.getLogger(__name__)


def create_model_config_from_pretrained_config(config: OpenLMConfig):
"""
Utility function
"""

kwargs = {}
for field in fields(Params):
if hasattr(config, field.name):
kwargs[field.name] = getattr(config, field.name)

model_config = Params(**kwargs)

if hasattr(config, "norm_type"):
model_config.norm_type = get_norm_class(config.norm_type)

if hasattr(config, "attn_name"):
model_config.attn_func = get_attn_func(config.attn_name)

return model_config


class OpenLMForCausalLM(PreTrainedModel):
"""
Extremely barebones HF model wrapper.
"""

config_class = OpenLMConfig
base_model_prefix = "model"

def __init__(self, config: OpenLMConfig, model: Optional[Transformer] = None):
super().__init__(config)

if not model:
self.model_config = create_model_config_from_pretrained_config(config)
# Initialize model (always on CPU to start with so we don't run out of GPU memory).
self.model_config.init_device = "cpu"
self.model = Transformer(self.model_config)

else:
self.model = model

def forward(
self,
input_ids: torch.LongTensor = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
attention_bias: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[
Cache
] = None, # This is a hack mitigation of an issue in transformers `4.39.x` https://github.com/huggingface/transformers/issues/29426
) -> Union[Tuple, CausalLMOutputWithPast]:
if inputs_embeds is not None:
log.warning("inputs_embeds is set but OpenLM does not support it yet")
if attention_bias is not None:
log.warning("attention_bias is et but OpenLM does not support it yet")
if use_cache is None:
use_cache = True
if output_attentions:
raise ValueError("output_attentions is not yet supported in OpenLM")
if output_hidden_states:
raise ValueError("output_hidden_states is not yet supported in OpenLM")

# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
# print("outer past_key_values: ", type(past_key_values))
# if past_key_values is not None:
# print(len(past_key_values), type(past_key_values[0]))
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
)

logits = outputs[0]
past_key_values = outputs[2]
hidden_states = None

loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = torch.nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.model_config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)

return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=past_key_values,
hidden_states=hidden_states,
)

def can_generate(self) -> bool:
return True

def prepare_inputs_for_generation(
self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs
):
if past_key_values is not None:
if isinstance(past_key_values[0][1], int):
# This assumes that the second item of past key values is the length of the past (this is the case for linear attention)
past_length = past_key_values[0][1]
else:
# This assumes that the first item of past key values is a list of all the past keys, thus the
# shape 1 is the length of the past (this is the case for attention without window)
past_length = past_key_values[0][0].shape[1]

# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1

input_ids = input_ids[:, remove_prefix_length:]

model_inputs = {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.pop("use_cache", True),
}
return model_inputs

def get_input_embeddings(self) -> torch.nn.Module:
return self.model.tok_embeddings

def set_input_embeddings(self, value: torch.nn.Module):
self.model.tok_embeddings = value

def get_output_embeddings(self):
if self.model_config.weight_tying:
return self.model.tok_embeddings
else:
return self.model.output

def set_output_embeddings(self, value: torch.nn.Module):
if self.model_config.weight_tying:
self.model.tok_embeddings = value
else:
self.model.output = value

def tie_weights(self):
"""
Copied from OLMo (description below). I removed it and the results just became garbage, so this pass is needed.
This function is intentionally left as a no-op.
Weight tying is handled as follows:
- When the model is initialized, the `ff_out` layer is conditionally defined based on the `weight_tying` configuration.
See: `if not config.weight_tying: self.transformer.update(...)` in `olmo/model.py`.
- When computing logits, the `wte` weights are used directly if `weight_tying` is enabled.
See: `if self.config.weight_tying: logits = F.linear(x, self.transformer.wte.weight, None)` in the `forward` method.
Therefore, there is no need to explicitly tie the weights in this function.
"""
pass

def resize_token_embeddings(
self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
) -> torch.nn.Embedding:
raise NotImplementedError


# Register the model so that it is available for transformer pipelines, auto-loading, etc.
AutoModelForCausalLM.register(OpenLMConfig, OpenLMForCausalLM)
18 changes: 18 additions & 0 deletions open_lm/hf/tokenization_openlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Follows OLMo's HF template

from transformers import AutoTokenizer, PreTrainedTokenizerFast

from open_lm.hf.configuration_openlm import OpenLMConfig


class OpenLMTokenizerFast(PreTrainedTokenizerFast):
# Note: OpenLM's tokenizer is already a wrapper around huggingface. This is potentially unnecessary.
pass

# def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
# # This is required to make the implementation complete.
# pass


# Register the tokenizer class so that it is available for transformer pipelines, auto-loading etc.
AutoTokenizer.register(OpenLMConfig, fast_tokenizer_class=OpenLMTokenizerFast)

0 comments on commit c0f1319

Please sign in to comment.