From f61e148e537c2657fcd15b64198ea7b76eefe0e0 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 30 Jan 2024 13:05:31 +0400 Subject: [PATCH] slice tensor on gpu side --- serve/mlc_serve/model/model_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/serve/mlc_serve/model/model_common.py b/serve/mlc_serve/model/model_common.py index ba12e8db2c..e4ce53aee3 100644 --- a/serve/mlc_serve/model/model_common.py +++ b/serve/mlc_serve/model/model_common.py @@ -155,7 +155,7 @@ def _is_safe_to_sample(prob_like): torch.cuda.nvtx.range_pop() return None - res_random = torch.multinomial(probs, 1, True).cpu().numpy()[:, 0] + res_random = torch.multinomial(probs, 1, True)[:, 0].cpu().numpy() if logits_random.shape[0] == num_seq: torch.cuda.nvtx.range_pop()