Skip to content

Commit

Permalink
fix cache and throughput
Browse files Browse the repository at this point in the history
  • Loading branch information
artek0chumak committed Apr 8, 2024
1 parent f06cfd2 commit 204855c
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 3 deletions.
6 changes: 6 additions & 0 deletions src/petals/models/bloom/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel, build_alibi_tensor

from petals.utils.misc import is_dummy


class WrappedBloomBlock(BloomBlock):
def forward(
Expand All @@ -22,6 +24,10 @@ def forward(
):
assert attention_mask is None, "Non-causal attention masks are not supported yet"
batch_size, seq_length = hidden_states.shape[:2]
if layer_past is not None and is_dummy(layer_past[0]):
# Bloom cannot use cache if it was misconsctructed(e.g. Dummy tensors)
# In this case, fallback to the old code:
layer_past = None
past_length = 0 if layer_past is None else layer_past[0].shape[-1]
seq_length_with_past = seq_length + past_length
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
Expand Down
13 changes: 10 additions & 3 deletions src/petals/server/throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from petals.server.block_utils import resolve_block_dtype, get_model_block
from petals.utils.convert_block import QuantType, convert_block
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
from petals.utils.misc import DUMMY_KEY_PAST

logger = get_logger(__name__)

Expand Down Expand Up @@ -205,15 +206,21 @@ def measure_compute_rps(
block = block.to(dtype)
block = convert_block(block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)

cache = None
cache = (DUMMY_KEY_PAST, DUMMY_KEY_PAST)
elapsed = 0
dummy_input = torch.randn(1, n_tokens, config.hidden_size, device=device, dtype=dtype)
_, cache = block.forward(dummy_input, use_cache=True) # Skip the 1st step to exclude the initialization time

def step(cache_):
outputs = block.forward(dummy_input, use_cache=inference, layer_past=cache_ if inference else None)
return outputs[1] if inference else None

cache = step(cache)
# Skip the 1st step to exclude the initialization time
synchronize(device)

start_time = time.perf_counter()
for _ in range(n_steps):
_, cache = block.forward(dummy_input, use_cache=True, layer_past=cache if inference else None)
cache = step(cache)
synchronize(device)
elapsed = time.perf_counter() - start_time
device_rps = n_steps * n_tokens / elapsed
Expand Down
2 changes: 2 additions & 0 deletions src/petals/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

DUMMY_INT64 = torch.empty(0, dtype=torch.int64)

DUMMY_KEY_PAST = torch.empty((0, 0, 0))


def is_dummy(tensor: torch.Tensor) -> bool:
return tensor.numel() == 0
Expand Down

0 comments on commit 204855c

Please sign in to comment.