Skip to content

Commit

Permalink
WIP - Create Regex FSM
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Nov 23, 2023
1 parent 6982349 commit be12c16
Showing 1 changed file with 50 additions and 1 deletion.
51 changes: 50 additions & 1 deletion outlines/index/index.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
from dataclasses import dataclass
from typing import List, NewType, Optional, Protocol
from typing import TYPE_CHECKING, List, NewType, Optional, Protocol

import interegular
import torch

from outlines.index.fsm import create_fsm_index_tokenizer, make_deterministic_fsm

if TYPE_CHECKING:
from outlines.models.tokenizer import Tokenizer

FSMState = NewType("FSMState", int)


Expand Down Expand Up @@ -50,3 +56,46 @@ def is_final_state(self, state: FSMState) -> bool:
return True
else:
return False


class RegexFSM:
def __init__(
self,
regex_string: str,
tokenizer: "Tokenizer",
max_tokens: Optional[int] = None,
):
regex_pattern = interegular.parse_pattern(regex_string)
regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
(
self.states_to_token_maps,
self.empty_token_ids,
) = create_fsm_index_tokenizer(regex_fsm, tokenizer)

if not any(
regex_fsm.finals.intersection(v.values())
for v in self.states_to_token_maps.values()
):
raise ValueError(
"The vocabulary does not allow us to build a sequence that matches the input regex"
)

self.max_tokens = max_tokens
self.num_tokens_generated = 0

def next_instructions(self, state: FSMState) -> GenerateInstruction:
return GenerateInstruction(list(self.states_to_token_maps[state].keys()))

def next_state(self, state: FSMState, token_id: torch.Tensor) -> FSMState:
self.num_tokens_generated += 1
last_token_to_end_state = self.states_to_token_maps[state]
return FSMState(last_token_to_end_state[token_id])

def is_final_state(self, state: FSMState) -> bool:
# Stop if the maximum number of tokens has been generated
# regardless of whether the stop token id has been found.
if self.max_tokens is not None:
if self.num_tokens_generated == self.max_tokens:
return True

return False

0 comments on commit be12c16

Please sign in to comment.