Skip to content

Commit

Permalink
[draft] prepare outlines for outlines-core v0.2
Browse files Browse the repository at this point in the history
  • Loading branch information
yvan-sraka committed Jan 22, 2025
1 parent 063291d commit ae2a8a9
Showing 1 changed file with 75 additions and 45 deletions.
120 changes: 75 additions & 45 deletions outlines/fsm/guide.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,11 @@
import collections
import copy
import warnings
from typing import TYPE_CHECKING, Any, Generator, Union
from typing import TYPE_CHECKING, Any, Generator, Optional, Union

import torch
from lark.indenter import DedentError
from lark.lexer import UnexpectedCharacters, UnexpectedToken
from outlines_core.fsm.guide import Generate
from outlines_core.fsm.guide import Guide as CoreGuide
from outlines_core.fsm.guide import RegexGuide as CoreRegexGuide
from outlines_core.fsm.guide import Write
from outlines_core.fsm.guide import (
create_states_mapping as uncached_create_states_mapping,
)
from outlines_core.json_schema import build_regex_from_schema

from outlines import grammars
from outlines.fsm.parsing import PartialLark, PartialParserState
Expand All @@ -21,9 +14,6 @@
from outlines.models.tokenizer import Tokenizer


Instruction = Union[Write, Generate]


class Guide(CoreGuide):
"""Base definition of a generation guide.
Expand Down Expand Up @@ -54,10 +44,10 @@ def __init__(self, tokenizer: "Tokenizer"):
self.eos_token_id = tokenizer.eos_token_id
self.vocabulary = tokenizer.vocabulary.values()

def get_next_instruction(self, state: int) -> Instruction:
def get_next_instruction(self, state: int) -> Union[None, int]:
if self.is_final_state(state):
return Write([self.eos_token_id])
return Generate(None)
return self.eos_token_id
return None

def get_next_state(self, state: int, token_id: int) -> int:
if token_id == self.eos_token_id or state == self.final_state:
Expand All @@ -72,32 +62,12 @@ def copy(self):
return self


def cached_create_states_mapping(regex_string, tokenizer, *args, **kwargs):
return uncached_create_states_mapping(regex_string, tokenizer, *args, **kwargs)


class RegexGuide(CoreRegexGuide):
"""
Guide to generate text in the language of a regular expression.
CoreRegexGuide with outlines cache
"""

@classmethod
def from_regex(
cls,
regex_string: str,
tokenizer,
**kwargs,
class CFGState:
def __init__(
self, parser_state: Optional[PartialParserState], prev_token: Union[None, int]
):
return super().from_regex(
regex_string,
tokenizer,
_create_states_mapping=cached_create_states_mapping,
**kwargs,
)


CFGState = collections.namedtuple("CFGState", ["parser_state", "prev_token"])
self.parser_state = parser_state
self.prev_token = prev_token


class CFGGuide(Guide):
Expand All @@ -124,7 +94,7 @@ def __init__(self, cfg_string: str, tokenizer):
parser_state=self.parser.parse(""), prev_token=None
)

def get_next_instruction(self, state: CFGState) -> Instruction:
def get_next_instruction(self, state: CFGState) -> Union[None, int]:
"""Return the next instruction for guided generation.
Current lazy approach:
Expand All @@ -147,14 +117,14 @@ def get_next_instruction(self, state: CFGState) -> Instruction:
"""

if state.parser_state is None:
return Write(torch.tensor([self.eos_token_id]))
return self.eos_token_id

valid_tokens = list(
self.iter_valid_token_ids(state, self.tokenizer.vocabulary.values())
)
if len(valid_tokens) == 1:
return Write(torch.tensor(valid_tokens))
return Generate(torch.tensor(valid_tokens))
return valid_tokens[0]
return None

def iter_valid_token_ids(
self, state: CFGState, candidate_token_ids: list
Expand Down Expand Up @@ -220,7 +190,7 @@ def get_next_state(self, state: CFGState, token_id: int) -> CFGState:

def _get_parser_state_token_applied(
self, state: CFGState, token_id: int
) -> PartialParserState:
) -> Optional[PartialParserState]:
"""
Don't mutate `parser_state`, copy to protect
Expand All @@ -230,6 +200,9 @@ def _get_parser_state_token_applied(
Don't allow empty ("") tokens, raise ValueError
"""
if state.parser_state is None:
return None

parser_state = copy.copy(state.parser_state) # prevent side effects

# normalize
Expand Down Expand Up @@ -274,3 +247,60 @@ def must_terminate_state(self, state: CFGState) -> bool:
def copy(self) -> "CFGGuide":
"""Create a copy of the Guide."""
return CFGGuide(self.cfg_string, self.tokenizer)


class Write:
def __init__(self, text: str):
self.text = text


class Generate:
def __init__(self, model, allowed_token_ids):
self.model = model
self.allowed_token_ids = allowed_token_ids


def create_states_mapping(fsm, tokenizer):
states_mapping = {}
for state in fsm.states:
states_mapping[state] = {
"allowed_token_ids": list(fsm.transitions[state].keys()),
"is_final": fsm.is_final(state),
}
return states_mapping


def create_states_mapping_from_fsm(fsm, tokenizer):
states_mapping = {}
for state in fsm.states:
states_mapping[state] = {
"allowed_token_ids": list(fsm.transitions[state].keys()),
"is_final": fsm.is_final(state),
}
return states_mapping


class RegexGuide(Guide):
def __init__(self, regex, tokenizer):
self.regex = regex
self.tokenizer = tokenizer
self.eos_token_id = tokenizer.eos_token_id
self.fsm = build_regex_from_schema(regex)
self.states_mapping = create_states_mapping_from_fsm(self.fsm, tokenizer)
self.initial_state = 0

def get_next_instruction(self, state):
if self.is_final_state(state):
return self.eos_token_id
return None

def get_next_state(self, state, token_id):
if token_id == self.eos_token_id or self.is_final_state(state):
return self.fsm.final_state
return self.fsm.transitions[state].get(token_id, self.initial_state)

def is_final_state(self, state):
return self.fsm.is_final(state)

def copy(self):
return RegexGuide(self.regex, self.tokenizer)

0 comments on commit ae2a8a9

Please sign in to comment.