From ac54946786768a64351ab25bf8b070b124c88449 Mon Sep 17 00:00:00 2001 From: pkm99 Date: Mon, 2 Dec 2024 17:01:48 +0530 Subject: [PATCH 1/2] Filter out input ids not being used in inference. Fixes issue #1294 --- outlines/processors/structured.py | 1 + 1 file changed, 1 insertion(+) diff --git a/outlines/processors/structured.py b/outlines/processors/structured.py index d2bc15f77..1b51a39af 100644 --- a/outlines/processors/structured.py +++ b/outlines/processors/structured.py @@ -110,6 +110,7 @@ def process_logits( allowed_tokens = self.guide.get_next_instruction(guide_state).tokens.to( mask.device, non_blocking=True ) + allowed_tokens = allowed_tokens[allowed_tokens < mask.shape[-1]] # filter out input ids exceeding the mask length allowed_tokens_batch.append(allowed_tokens) batch_indices.append( torch.full_like(allowed_tokens, i) From a85055867fea4c7688f240a43d484c6595eccd14 Mon Sep 17 00:00:00 2001 From: pkm99 Date: Wed, 4 Dec 2024 10:25:04 +0530 Subject: [PATCH 2/2] black formatted code --- outlines/processors/structured.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/outlines/processors/structured.py b/outlines/processors/structured.py index 1b51a39af..24ac6d7e4 100644 --- a/outlines/processors/structured.py +++ b/outlines/processors/structured.py @@ -23,6 +23,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import math from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union @@ -110,7 +111,9 @@ def process_logits( allowed_tokens = self.guide.get_next_instruction(guide_state).tokens.to( mask.device, non_blocking=True ) - allowed_tokens = allowed_tokens[allowed_tokens < mask.shape[-1]] # filter out input ids exceeding the mask length + allowed_tokens = allowed_tokens[ + allowed_tokens < mask.shape[-1] + ] # filter out input ids exceeding the mask length allowed_tokens_batch.append(allowed_tokens) batch_indices.append( torch.full_like(allowed_tokens, i)