Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Jan 10, 2024
1 parent e695ec8 commit 818bb7f
Showing 1 changed file with 130 additions and 0 deletions.
130 changes: 130 additions & 0 deletions examples/python/run_llama_batched_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,13 @@ class SequenceGenerationResponse:
token_id: int


@dataclass
class EvalQueryRequest:
request_id: int
past_token_ids: List[int]
query_token_ids: List[int]


def sample(logits):
logits = torch.from_dlpack(logits)
return torch.argmax(logits, -1).cpu().numpy()
Expand Down Expand Up @@ -241,6 +248,76 @@ def _pad_to_max(x: List[int], max_len: int) -> List[int]:
)


def _prepare_eval_queries(
requests: List[EvalQueryRequest],
all_slot_mappings,
sliding_window,
dev,
):
seq_lens = []
query_lens = []
input_ids = []
slot_mapping = []
past_slot_mapping = []
positions = []
permute_map = []

query_offset = sum([len(request.past_token_ids) for request in requests])
past_offset = 0

for request in requests:
num_past_tokens = len(request.past_token_ids)
num_queries = len(request.query_token_ids)
query_lens.append(num_queries)
request_id = request.request_id
input_ids += request.query_token_ids

positions += [num_past_tokens + i for i in range(num_queries)]

if sliding_window:
seq_lens.append(min(num_past_tokens + num_queries, sliding_window))
# TODO: verify this
past_slot_mapping += all_slot_mappings[request_id][
: min(num_past_tokens, sliding_window)
]
slot_mapping += all_slot_mappings[request_id][
min(num_past_tokens, sliding_window) : min(num_past_tokens, sliding_window)
+ num_queries
]
else:
seq_lens.append(num_past_tokens + num_queries)
past_slot_mapping += all_slot_mappings[request_id][:num_past_tokens]
slot_mapping += all_slot_mappings[request_id][
num_past_tokens : num_past_tokens + num_queries
]

permute_map += list(range(past_offset, past_offset + num_past_tokens)) + list(
range(query_offset, query_offset + num_queries)
)

query_offset += num_queries
past_offset += num_past_tokens

input_ids = tvm.nd.array(np.array(input_ids, dtype="int32"), dev)
positions = tvm.nd.array(np.array(positions, dtype="int32"), dev)
seq_lens = tvm.nd.array(np.array(seq_lens, dtype="int32"), dev)
slot_mapping = tvm.nd.array(np.array(slot_mapping, dtype="int32"), dev)

query_lens = tvm.nd.array(np.array(query_lens, dtype="int32"), dev)
past_slot_mapping = tvm.nd.array(np.array(past_slot_mapping, dtype="int32"), dev)
permute_map = tvm.nd.array(np.array(permute_map, dtype="int32"), dev)

return (
input_ids,
positions,
seq_lens,
slot_mapping,
query_lens,
past_slot_mapping,
permute_map,
)


class Model:
def __init__(
self, artifact_path, model_name, quant, vocab_size, num_shards, dev, sliding_window
Expand Down Expand Up @@ -443,6 +520,59 @@ def run(args):
for p, g in zip(prompts, generated):
print("Prompt = '{}', generated text = '{}'".format(p, g))

query_token_lens = [4, 3, 5, 2]

eval_query_requests = []

for request_id, query_token_len in zip(request_ids, query_token_lens):
queries_to_eval = requests[request_id].token_ids[-query_token_len:]
past_tokens = requests[request_id].token_ids[:-query_token_len]
eval_query_requests.append(EvalQueryRequest(request_id, past_tokens, queries_to_eval))

(
input_ids,
positions,
seq_lens,
slot_mapping,
query_lens,
past_slot_mapping,
permute_map,
) = _prepare_eval_queries(
eval_query_requests,
cache.slot_mappings,
None,
model.dev,
)

logits = model.mod["evaluate_multi_query"](
input_ids,
positions,
seq_lens,
cache.cache,
slot_mapping,
query_lens,
past_slot_mapping,
permute_map,
model.params,
)[0].numpy()

assert logits.shape[0] == sum(query_token_lens)

logits_offset = 0

for request_id, query_token_len in zip(request_ids, query_token_lens):
for i in range(query_token_len - 1):
# requests[request_id].token_ids[-query_token_len:] are the "ground truth" tokens.
# Doing argmax over multi-timestep logits computed in parallel should yield the same
# tokens at the corresponding positions.
past_tokens = requests[request_id].token_ids[:-query_token_len]
assert (
np.argmax(logits[logits_offset + i])
== requests[request_id].token_ids[len(past_tokens) + i + 1]
)

logits_offset += query_token_len


if __name__ == "__main__":
run(parse_args())

0 comments on commit 818bb7f

Please sign in to comment.