Skip to content

Commit

Permalink
speculation and test
Browse files Browse the repository at this point in the history
  • Loading branch information
xtinkt committed Jul 23, 2024
1 parent c0a4d2e commit 140ee9b
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 18 deletions.
36 changes: 21 additions & 15 deletions src/petals/client/inference_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,24 @@ async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[
if not next_input_message.uid and not next_input_message.tensors:
break # this message means "done sending"

@property
def position(self):
return self._position

@position.setter
def position(self, start_from_position: int):
assert start_from_position <= self._position
self._position = start_from_position
if self.history is not None and self.history.shape[1] >= start_from_position:
self.history = self.history[:, :start_from_position, :] if start_from_position > 0 else None

def step(
self,
inputs: torch.Tensor,
prompts: torch.Tensor,
hypo_ids: torch.LongTensor,
*,
step_id: str,
start_from_position: int,
) -> torch.Tensor:
"""
Inference step: send a chunk of input tensors and receive a chunk of outputs
Expand All @@ -100,12 +110,6 @@ def step(
if self.closed:
raise Exception("Session is closed, cannot perform step")

if start_from_position is not None:
assert start_from_position <= self._position
self._position = start_from_position
if self.history is not None and self.history.shape[1] >= start_from_position:
self.history = self.history[:, :start_from_position, :] if start_from_position > 0 else None

n_input_tokens = inputs.shape[1]
if self.history is None:
self.history = inputs
Expand All @@ -127,8 +131,8 @@ def step(
request_metadata = dict(session_id=self.session_id, step_id=step_id)
if not self.stepped:
request_metadata.update(self.session_metadata)
if start_from_position is not None:
request_metadata["start_from_position"] = start_from_position
if self._position is not None:
request_metadata["start_from_position"] = self._position
elif self.config.use_server_to_server:
next_servers = self._collect_next_servers()
if next_servers:
Expand Down Expand Up @@ -235,6 +239,13 @@ def num_blocks(self) -> int:
def position(self) -> int:
return self._position

@position.setter
def position(self, start_from_position: int) -> None:
self._position = start_from_position
for session in self._server_sessions:
assert isinstance(session, _ServerInferenceSession)
session.position = start_from_position

def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_ServerInferenceSession]:
server_sessions = []
try:
Expand Down Expand Up @@ -275,12 +286,7 @@ def step(
inputs: torch.Tensor,
prompts: Optional[torch.Tensor] = None,
hypo_ids: Optional[torch.Tensor] = None,
start_from_position: Optional[int] = None,
) -> torch.Tensor:

if start_from_position is not None:
self._position = start_from_position

assert not self._closed
if torch.is_grad_enabled():
logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
Expand Down Expand Up @@ -324,12 +330,12 @@ def step(
self._update_sequence(server_idx, block_idx, attempt_no)

server_session = self._server_sessions[server_idx]
assert server_session.position == self.position, f"{server_session.position} and {self.position}"
inputs = server_session.step(
inputs,
prompts[server_session.span.start : server_session.span.end],
hypo_ids,
step_id=step_id,
start_from_position=start_from_position,
)

server_idx += 1
Expand Down
3 changes: 3 additions & 0 deletions src/petals/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@
from petals.models.llama.config import DistributedLlamaConfig
from petals.models.llama.model import (
DistributedLlamaForCausalLM,
DistributedLlamaForSpeculativeGeneration,
DistributedLlamaForSequenceClassification,
DistributedLlamaModel,
DistributedLlamaForSpeculativeGeneration,
)
from petals.utils.auto_config import register_model_classes

register_model_classes(
config=DistributedLlamaConfig,
model=DistributedLlamaModel,
model_for_causal_lm=DistributedLlamaForCausalLM,
model_for_speculative=DistributedLlamaForSpeculativeGeneration,
model_for_sequence_classification=DistributedLlamaForSequenceClassification,
)
102 changes: 101 additions & 1 deletion src/petals/models/llama/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import Optional
from typing import Optional, Union

Check failure on line 1 in src/petals/models/llama/model.py

View workflow job for this annotation

GitHub Actions / isort

Imports are incorrectly sorted and/or formatted.

import hivemind
import torch
import torch.nn as nn
from hivemind.utils.logging import get_logger
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaPreTrainedModel
from transformers.generation import LogitsProcessorList, StoppingCriteriaList, GenerationConfig
from transformers.generation.utils import GenerateNonBeamOutput, GenerationMixin

from petals.client.from_pretrained import FromPretrainedMixin
from petals.client.lm_head import LMHead
Expand Down Expand Up @@ -153,6 +155,104 @@ def transformer(self) -> DistributedLlamaModel: # For compatibility with Remote
return self.model


class DistributedLlamaForSpeculativeGeneration(DistributedLlamaForCausalLM, GenerationMixin):
def __init__(self, config: DistributedLlamaConfig, small_model: LlamaForCausalLM):
DistributedLlamaForCausalLM.__init__(self, config)
self.small_model = small_model

def _sample(
self,
input_ids: torch.LongTensor,
logits_processor: LogitsProcessorList,
stopping_criteria: StoppingCriteriaList,
generation_config: GenerationConfig,
synced_gpus: bool,
streamer: Optional["BaseStreamer"],
logits_warper: Optional[LogitsProcessorList],
speculative_batch_size: int = 10,
**model_kwargs,
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
print(model_kwargs)

pad_token_id = generation_config.pad_token_id
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)

assert not generation_config.do_sample, "sample is not working for speculative generation now"

# keep track of which sequences are already finished
batch_size = input_ids.shape[0]
this_peer_finished = False
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
firsts = True

while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
speculative_batch_size = min(speculative_batch_size, self.active_session._max_length - input_ids.shape[1])
with torch.no_grad():
speculative_outputs = self.small_model.generate(
input_ids,
max_new_tokens=speculative_batch_size,
do_sample=False,
use_cache=False
)
speculative_tokens = speculative_outputs[:, -speculative_batch_size:]

full_sequence = torch.cat([input_ids, speculative_tokens], dim=-1)
assert input_ids.shape[1] + speculative_batch_size == full_sequence.shape[1]

with torch.no_grad():
real_input = full_sequence
if not firsts:
self.active_session.position = input_ids.shape[1] - 1
real_input = real_input[:, -speculative_batch_size - 1:]
else:
firsts = False
real_input = real_input[:, :-1]

precise_model_outputs = self(real_input, return_dict=True)
full_token_logits = precise_model_outputs.logits[:, -speculative_batch_size:, :].clone()

all_valid_tokens = []

first_token = None
for i in range(speculative_batch_size):
token_logits = full_token_logits[:, i, :]
valid_token = torch.argmax(token_logits, dim=-1)

if first_token is None:
first_token = valid_token

# Проверка правильности токена)
if valid_token.item() == speculative_tokens[:, i].item():
all_valid_tokens.append(valid_token.unsqueeze(-1))
else:
break

if not all_valid_tokens and first_token is not None:
all_valid_tokens.append(first_token.unsqueeze(-1))
all_valid_tokens = torch.cat(all_valid_tokens, dim=-1)

# finished sentences should have their next token be a padding token
if has_eos_stopping_criteria:
all_valid_tokens = all_valid_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, all_valid_tokens], dim=-1)

unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, None)
this_peer_finished = unfinished_sequences.max() == 0

del precise_model_outputs

return input_ids

def get_output_embeddings(self):
return self.lm_head

@property
def transformer(self) -> DistributedLlamaModel: # For compatibility with RemoteGenerationMixin
return self.model


class DistributedLlamaForSequenceClassification(FromPretrainedMixin, LlamaForSequenceClassification):
_keys_to_ignore_on_load_missing = DistributedLlamaModel._keys_to_ignore_on_load_missing
_keys_to_ignore_on_load_unexpected = DistributedLlamaModel._keys_to_ignore_on_load_unexpected
Expand Down
1 change: 1 addition & 0 deletions src/petals/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
AutoDistributedConfig,
AutoDistributedModel,
AutoDistributedModelForCausalLM,
AutoDistributedSpeculativeModel,
AutoDistributedModelForSequenceClassification,
)
from petals.utils.dht import declare_active_modules, get_remote_module_infos
5 changes: 5 additions & 0 deletions src/petals/utils/auto_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class _ModelClasses:
config: Type[PretrainedConfig]
model: Optional[Type[PreTrainedModel]] = None
model_for_causal_lm: Optional[Type[PreTrainedModel]] = None
model_for_speculative: Optional[Type[PreTrainedModel]] = None
model_for_sequence_classification: Optional[Type[PreTrainedModel]] = None


Expand Down Expand Up @@ -90,5 +91,9 @@ class AutoDistributedModelForCausalLM(DefaultRevisionMixin, _AutoDistributedBase
_mapping_field = "model_for_causal_lm"


class AutoDistributedSpeculativeModel(DefaultRevisionMixin, _AutoDistributedBase):
_mapping_field = "model_for_speculative"


class AutoDistributedModelForSequenceClassification(DefaultRevisionMixin, _AutoDistributedBase):
_mapping_field = "model_for_sequence_classification"
42 changes: 40 additions & 2 deletions tests/test_speculative_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import pytest
import torch

from petals import AutoDistributedConfig, RemoteSequential
import transformers

from petals import AutoDistributedConfig, RemoteSequential, DistributedLlamaForSpeculativeGeneration, AutoDistributedSpeculativeModel
from petals.server.block_functions import MAX_SHORT_INFERENCE_TOKENS
from petals.server.from_pretrained import load_pretrained_block
from test_utils import *
Expand All @@ -26,10 +28,46 @@ def test_remote_block_with_cache_invalidation_exact_match(atol_forward=1e-4, ato
with torch.inference_mode():
with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
initial_outputs_inference = sess.step(inputs)
secondary_outputs_inference = sess.step(short_inputs[:, 2:, :], start_from_position=2)
sess.position = 2
secondary_outputs_inference = sess.step(short_inputs[:, 2:, :])
result = torch.cat([initial_outputs_inference[:, :2, :], secondary_outputs_inference], dim=1)

ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
(outputs_local,) = ref_block(short_inputs)

assert torch.allclose(outputs_local, result, rtol=0, atol=atol_inference)

@pytest.fixture
def noisy_model():
noisy_model = transformers.AutoModelForCausalLM.from_pretrained(
REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
)
lm_head = noisy_model.get_output_embeddings()
assert isinstance(lm_head, torch.nn.Linear)
with torch.no_grad():
lm_head.weight += torch.randn_like(lm_head.weight) * 0.02
return noisy_model

@pytest.fixture
def model():
return transformers.AutoModelForCausalLM.from_pretrained(
MODEL_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
)

@pytest.fixture
def tokenizer():
# We set use_fast=False since LlamaTokenizerFast is slow on load
return transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)

@pytest.mark.forked
def test_remote_speculative_generation(tokenizer, model, noisy_model, atol_inference=1e-3):
speculated_distributed_model = AutoDistributedSpeculativeModel.from_pretrained(
MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32, small_model=noisy_model
)

inputs_single = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]

generated_spec = speculated_distributed_model.generate(inputs_single, max_new_tokens=100, do_sample=False)
generated_local = model.generate(inputs_single, max_new_tokens=100, do_sample=False)

assert torch.allclose(generated_spec, generated_local, rtol=0, atol=atol_inference)

0 comments on commit 140ee9b

Please sign in to comment.