diff --git a/outlines/index/index.py b/outlines/index/index.py index 525c164b3..9c61bb84b 100644 --- a/outlines/index/index.py +++ b/outlines/index/index.py @@ -1,8 +1,14 @@ from dataclasses import dataclass -from typing import List, NewType, Optional, Protocol +from typing import TYPE_CHECKING, List, NewType, Optional, Protocol +import interegular import torch +from outlines.index.fsm import create_fsm_index_tokenizer, make_deterministic_fsm + +if TYPE_CHECKING: + from outlines.models.tokenizer import Tokenizer + FSMState = NewType("FSMState", int) @@ -50,3 +56,46 @@ def is_final_state(self, state: FSMState) -> bool: return True else: return False + + +class RegexFSM: + def __init__( + self, + regex_string: str, + tokenizer: "Tokenizer", + max_tokens: Optional[int] = None, + ): + regex_pattern = interegular.parse_pattern(regex_string) + regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) + ( + self.states_to_token_maps, + self.empty_token_ids, + ) = create_fsm_index_tokenizer(regex_fsm, tokenizer) + + if not any( + regex_fsm.finals.intersection(v.values()) + for v in self.states_to_token_maps.values() + ): + raise ValueError( + "The vocabulary does not allow us to build a sequence that matches the input regex" + ) + + self.max_tokens = max_tokens + self.num_tokens_generated = 0 + + def next_instructions(self, state: FSMState) -> GenerateInstruction: + return GenerateInstruction(list(self.states_to_token_maps[state].keys())) + + def next_state(self, state: FSMState, token_id: torch.Tensor) -> FSMState: + self.num_tokens_generated += 1 + last_token_to_end_state = self.states_to_token_maps[state] + return FSMState(last_token_to_end_state[token_id]) + + def is_final_state(self, state: FSMState) -> bool: + # Stop if the maximum number of tokens has been generated + # regardless of whether the stop token id has been found. + if self.max_tokens is not None: + if self.num_tokens_generated == self.max_tokens: + return True + + return False