Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Your Name committed Nov 24, 2024
1 parent faaeb1b commit ad7819f
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 5 deletions.
10 changes: 6 additions & 4 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
Conv1D,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
id_tensor_storage,
#id_tensor_storage,
is_torch_greater_or_equal_than_1_13,
prune_conv1d_layer,
prune_layer,
Expand Down Expand Up @@ -4440,11 +4440,13 @@ def _fix_key(key):
if device_map is None and not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
ptrs = collections.defaultdict(list)
for name, tensor in model.state_dict().items():
id_tensor = id_tensor_storage(tensor)
ptrs[id_tensor].append(name)
#id_tensor = id_tensor_storage(tensor)
#ptrs[id_tensor].append(name)
pass

# These are all the pointers of shared tensors.
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
# tied_params = [names for _, names in ptrs.items() if len(names) > 1]
tied_params = []
else:
# id function doesn't work for meta tensor so we need this function
tied_params = find_tied_parameters(model)
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
from packaging import version
from torch import nn

#breakpoint()
#from safetensors.torch import storage_ptr, storage_size
#breakpoint()

from .utils import is_torch_greater_or_equal, is_torch_xla_available, logging


Expand Down
7 changes: 6 additions & 1 deletion temp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
breakpoint()

ckpt = "google/gemma-2b"
tokenizer = AutoTokenizer.from_pretrained(ckpt)
#tokenizer = AutoTokenizer.from_pretrained(ckpt)

# sentencepiece is not ok with py13 and GIL is reenabled
from transformers import GemmaTokenizer
# tokenizer = GemmaTokenizer.from_pretrained(ckpt)


model = model.to(device)
transformers.generation.utils.my_model = model
Expand Down

0 comments on commit ad7819f

Please sign in to comment.