Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
xtinkt committed Jul 23, 2024
1 parent c142727 commit 8beaf65
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 103 deletions.
3 changes: 1 addition & 2 deletions src/petals/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
from petals.models.llama.config import DistributedLlamaConfig
from petals.models.llama.model import (
DistributedLlamaForCausalLM,
DistributedLlamaForSpeculativeGeneration,
DistributedLlamaForSequenceClassification,
DistributedLlamaModel,
DistributedLlamaForSpeculativeGeneration,
)
from petals.models.llama.speculative_model import DistributedLlamaForSpeculativeGeneration
from petals.utils.auto_config import register_model_classes

register_model_classes(
Expand Down
99 changes: 0 additions & 99 deletions src/petals/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from hivemind.utils.logging import get_logger
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaPreTrainedModel
from transformers.generation import LogitsProcessorList, StoppingCriteriaList, GenerationConfig
from transformers.generation.utils import GenerateNonBeamOutput, GenerationMixin

from petals.client.from_pretrained import FromPretrainedMixin
from petals.client.lm_head import LMHead
Expand Down Expand Up @@ -155,103 +153,6 @@ def transformer(self) -> DistributedLlamaModel: # For compatibility with Remote
return self.model


class DistributedLlamaForSpeculativeGeneration(DistributedLlamaForCausalLM, GenerationMixin):
def __init__(self, config: DistributedLlamaConfig, small_model: LlamaForCausalLM):
DistributedLlamaForCausalLM.__init__(self, config)
self.small_model = small_model

def _sample(
self,
input_ids: torch.LongTensor,
logits_processor: LogitsProcessorList,
stopping_criteria: StoppingCriteriaList,
generation_config: GenerationConfig,
synced_gpus: bool,
streamer: Optional["BaseStreamer"],
logits_warper: Optional[LogitsProcessorList],
speculative_batch_size: int = 10,
**model_kwargs,
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
print(model_kwargs)

pad_token_id = generation_config.pad_token_id
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)

assert not generation_config.do_sample, "sample is not working for speculative generation now"

# keep track of which sequences are already finished
batch_size = input_ids.shape[0]
this_peer_finished = False
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
firsts = True

while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
speculative_batch_size = min(speculative_batch_size, self.active_session._max_length - input_ids.shape[1])
with torch.no_grad():
speculative_outputs = self.small_model.generate(
input_ids,
max_new_tokens=speculative_batch_size,
do_sample=False,
use_cache=False
)
speculative_tokens = speculative_outputs[:, -speculative_batch_size:]

full_sequence = torch.cat([input_ids, speculative_tokens], dim=-1)
assert input_ids.shape[1] + speculative_batch_size == full_sequence.shape[1]

with torch.no_grad():
real_input = full_sequence
if not firsts:
self.active_session.position = input_ids.shape[1] - 1
real_input = real_input[:, -speculative_batch_size - 1:]
else:
firsts = False
real_input = real_input[:, :-1]

precise_model_outputs = self(real_input, return_dict=True)
full_token_logits = precise_model_outputs.logits[:, -speculative_batch_size:, :].clone()

all_valid_tokens = []

first_token = None
for i in range(speculative_batch_size):
token_logits = full_token_logits[:, i, :]
valid_token = torch.argmax(token_logits, dim=-1)

if first_token is None:
first_token = valid_token

if valid_token.item() == speculative_tokens[:, i].item():
all_valid_tokens.append(valid_token.unsqueeze(-1))
else:
break

if not all_valid_tokens and first_token is not None:
all_valid_tokens.append(first_token.unsqueeze(-1))
all_valid_tokens = torch.cat(all_valid_tokens, dim=-1)

# finished sentences should have their next token be a padding token
if has_eos_stopping_criteria:
all_valid_tokens = all_valid_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, all_valid_tokens], dim=-1)

unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, None)
this_peer_finished = unfinished_sequences.max() == 0

del precise_model_outputs

return input_ids

def get_output_embeddings(self):
return self.lm_head

@property
def transformer(self) -> DistributedLlamaModel: # For compatibility with RemoteGenerationMixin
return self.model


class DistributedLlamaForSequenceClassification(FromPretrainedMixin, LlamaForSequenceClassification):
_keys_to_ignore_on_load_missing = DistributedLlamaModel._keys_to_ignore_on_load_missing
_keys_to_ignore_on_load_unexpected = DistributedLlamaModel._keys_to_ignore_on_load_unexpected
Expand Down
101 changes: 101 additions & 0 deletions src/petals/models/llama/speculative_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from typing import Optional, Union

Check failure on line 1 in src/petals/models/llama/speculative_model.py

View workflow job for this annotation

GitHub Actions / isort

Imports are incorrectly sorted and/or formatted.

import torch

from transformers.generation import LogitsProcessorList, StoppingCriteriaList, GenerationConfig
from transformers.generation.utils import GenerateNonBeamOutput, GenerationMixin
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama import LlamaForCausalLM

from petals.models.llama.config import DistributedLlamaConfig
from petals.models.llama.model import DistributedLlamaForCausalLM


class DistributedLlamaForSpeculativeGeneration(DistributedLlamaForCausalLM, GenerationMixin):
def __init__(self, config: DistributedLlamaConfig, small_model: LlamaForCausalLM):
DistributedLlamaForCausalLM.__init__(self, config)
self.small_model = small_model

def _sample(
self,
input_ids: torch.LongTensor,
logits_processor: LogitsProcessorList,
stopping_criteria: StoppingCriteriaList,
generation_config: GenerationConfig,
synced_gpus: bool,
streamer: Optional["BaseStreamer"],
logits_warper: Optional[LogitsProcessorList],
speculative_batch_size: int = 10,
**model_kwargs,
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
print(model_kwargs)

pad_token_id = generation_config.pad_token_id
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)

assert not generation_config.do_sample, "sample is not working for speculative generation now"

# keep track of which sequences are already finished
batch_size = input_ids.shape[0]
this_peer_finished = False
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
firsts = True

while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
speculative_batch_size = min(speculative_batch_size, self.active_session._max_length - input_ids.shape[1])
with torch.no_grad():
speculative_outputs = self.small_model.generate(
input_ids,
max_new_tokens=speculative_batch_size,
do_sample=False,
use_cache=False
)
speculative_tokens = speculative_outputs[:, -speculative_batch_size:]

full_sequence = torch.cat([input_ids, speculative_tokens], dim=-1)
assert input_ids.shape[1] + speculative_batch_size == full_sequence.shape[1]

with torch.no_grad():
real_input = full_sequence
if not firsts:
self.active_session.position = input_ids.shape[1] - 1
real_input = real_input[:, -speculative_batch_size - 1:]
else:
firsts = False
real_input = real_input[:, :-1]

precise_model_outputs = self(real_input, return_dict=True)
full_token_logits = precise_model_outputs.logits[:, -speculative_batch_size:, :].clone()

all_valid_tokens = []

first_token = None
for i in range(speculative_batch_size):
token_logits = full_token_logits[:, i, :]
valid_token = torch.argmax(token_logits, dim=-1)

if first_token is None:
first_token = valid_token

if valid_token.item() == speculative_tokens[:, i].item():
all_valid_tokens.append(valid_token.unsqueeze(-1))
else:
break

if not all_valid_tokens and first_token is not None:
all_valid_tokens.append(first_token.unsqueeze(-1))
all_valid_tokens = torch.cat(all_valid_tokens, dim=-1)

# finished sentences should have their next token be a padding token
if has_eos_stopping_criteria:
all_valid_tokens = all_valid_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, all_valid_tokens], dim=-1)

unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, None)
this_peer_finished = unfinished_sequences.max() == 0

del precise_model_outputs

return input_ids
2 changes: 1 addition & 1 deletion src/petals/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
AutoDistributedConfig,
AutoDistributedModel,
AutoDistributedModelForCausalLM,
AutoDistributedSpeculativeModel,
AutoDistributedModelForSequenceClassification,
AutoDistributedSpeculativeModel,
)
from petals.utils.dht import declare_active_modules, get_remote_module_infos
1 change: 0 additions & 1 deletion tests/test_speculative_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import pytest
import torch

import transformers

from petals import AutoDistributedConfig, RemoteSequential, DistributedLlamaForSpeculativeGeneration, AutoDistributedSpeculativeModel
Expand Down

0 comments on commit 8beaf65

Please sign in to comment.