Skip to content

Commit

Permalink
Create FSM that stops generation when token found
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Nov 23, 2023
1 parent cafc075 commit 6982349
Showing 1 changed file with 33 additions and 14 deletions.
47 changes: 33 additions & 14 deletions outlines/index/index.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import NewType, Protocol, Union
from typing import List, NewType, Optional, Protocol

import torch

Expand All @@ -8,26 +8,45 @@

@dataclass(frozen=True)
class GenerateInstruction:
logits_mask: str
temperature: float
top_k: int
top_p: int


@dataclass(frozen=True)
class FillInstruction:
token_ids: int


FSMInstruction = Union[GenerateInstruction, FillInstruction]
tokens_to_mask: List[int]


class FSM(Protocol):
def next_instruction(self, state: FSMState) -> FSMInstruction:
def next_instruction(self, state: FSMState) -> GenerateInstruction:
...

def next_state(self, state: FSMState, token_id: torch.Tensor) -> FSMState:
...

def is_final_state(self, state: FSMState) -> bool:
...


class StopAtTokenFSM:
def __init__(self, stop_token_id: int, max_tokens: Optional[int] = None):
self.stop_token_id = stop_token_id
self.max_tokens = max_tokens
self.num_tokens_generated = 0

def next_instructions(self, _: FSMState) -> GenerateInstruction:
return GenerateInstruction([])

def next_state(self, state: FSMState, token_id: torch.Tensor) -> FSMState:
self.num_tokens_generated += 1

if token_id == self.stop_token_id:
return FSMState(1)
else:
return FSMState(0)

def is_final_state(self, state: FSMState) -> bool:
# Stop if the maximum number of tokens has been generated
# regardless of whether the stop token id has been found.
if self.max_tokens is not None:
if self.num_tokens_generated == self.max_tokens:
return True

if state == 1:
return True
else:
return False

0 comments on commit 6982349

Please sign in to comment.