diff --git a/outlines/index/index.py b/outlines/index/index.py index aa251ed80..525c164b3 100644 --- a/outlines/index/index.py +++ b/outlines/index/index.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import NewType, Protocol, Union +from typing import List, NewType, Optional, Protocol import torch @@ -8,22 +8,11 @@ @dataclass(frozen=True) class GenerateInstruction: - logits_mask: str - temperature: float - top_k: int - top_p: int - - -@dataclass(frozen=True) -class FillInstruction: - token_ids: int - - -FSMInstruction = Union[GenerateInstruction, FillInstruction] + tokens_to_mask: List[int] class FSM(Protocol): - def next_instruction(self, state: FSMState) -> FSMInstruction: + def next_instruction(self, state: FSMState) -> GenerateInstruction: ... def next_state(self, state: FSMState, token_id: torch.Tensor) -> FSMState: @@ -31,3 +20,33 @@ def next_state(self, state: FSMState, token_id: torch.Tensor) -> FSMState: def is_final_state(self, state: FSMState) -> bool: ... + + +class StopAtTokenFSM: + def __init__(self, stop_token_id: int, max_tokens: Optional[int] = None): + self.stop_token_id = stop_token_id + self.max_tokens = max_tokens + self.num_tokens_generated = 0 + + def next_instructions(self, _: FSMState) -> GenerateInstruction: + return GenerateInstruction([]) + + def next_state(self, state: FSMState, token_id: torch.Tensor) -> FSMState: + self.num_tokens_generated += 1 + + if token_id == self.stop_token_id: + return FSMState(1) + else: + return FSMState(0) + + def is_final_state(self, state: FSMState) -> bool: + # Stop if the maximum number of tokens has been generated + # regardless of whether the stop token id has been found. + if self.max_tokens is not None: + if self.num_tokens_generated == self.max_tokens: + return True + + if state == 1: + return True + else: + return False