Skip to content

Commit

Permalink
Remove unused FSM code
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard authored and rlouf committed Dec 7, 2023
1 parent 93fd91a commit 430ee4d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 245 deletions.
132 changes: 1 addition & 131 deletions outlines/text/fsm.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from collections import namedtuple
from functools import lru_cache
from itertools import chain
from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Set, Tuple
from typing import TYPE_CHECKING, Dict, Generator, List, Sequence, Set, Tuple

import numba
import numpy as np
from interegular.fsm import FSM, Alphabet, OblivionError, anything_else
from joblib import Parallel, delayed
from numba.typed.typedobjectutils import _nonoptional

if TYPE_CHECKING:
Expand Down Expand Up @@ -149,17 +147,6 @@ def create_fsm_info(
],
)

spec = [
numba.int64,
numba.types.Set(numba.int64),
numba.types.DictType(numba.types.UniTuple(numba.int64, 2), numba.int64),
numba.types.DictType(numba.int64, numba.types.ListType(numba.int64)),
numba.optional(numba.int64),
numba.types.DictType(numba.types.string, numba.int64),
]

FSMInfoNumbaType = numba.types.NamedTuple(spec, FSMInfo)


def make_deterministic_fsm(fsm: FSM) -> Tuple[BetterFSM, Dict[int, int]]:
"""Construct an equivalent FSM with deterministic state labels."""
Expand Down Expand Up @@ -314,123 +301,6 @@ def walk_fsm(
return accepted_states


# TODO FIXME: Can't cache this due to https://github.com/numba/numba/issues/9177
@numba.njit(nogil=True)
def find_partial_matches(
fsm_info: FSMInfo,
input_string: str,
full_match: bool = True,
) -> Generator[Tuple[int, List[int]], None, None]:
"""Find the states in the finite state machine `fsm_info` that accept `input_string`.
This will consider all possible states in the finite state machine (FSM)
that accept the beginning of `input_string` as starting points, unless a
specific `start_state` is provided.
Parameters
----------
fsm_info
The finite state machine.
input_string
The string for which we generate partial matches.
full_match
Matches must cover the entire string.
Returns
-------
A set of tuples corresponding to each valid starting state in the FSM. The
first element of each tuple contains an integer indicating the position in
`input_string` at which the FSM stopped. The second element is the tuple
of states visited during execution of the FSM plus the next, unvisited
transition state.
"""

if len(input_string) == 0:
return

trans_key = fsm_info.alphabet_symbol_mapping.get(
input_string[0], fsm_info.alphabet_anything_value
)

for state in fsm_info.trans_key_to_states.get(
trans_key, numba.typed.List.empty_list(numba.int64) # type: ignore
):
path = _walk_fsm(
fsm_info.transitions,
fsm_info.alphabet_symbol_mapping,
fsm_info.alphabet_anything_value,
fsm_info.initial,
fsm_info.finals,
input_string,
state,
full_match=full_match,
)
if path:
path.insert(0, state)
res = (len(path) - 2, path)
yield res


@numba.njit(nogil=True, cache=True)
def process_token_string(
fsm_info: FSMInfo,
token: str,
token_idx: int,
final_state_string: Optional[str] = None,
) -> Set[Tuple[int, int]]:
res = set()
vocab_string_len = len(token)

for end_idx, state_seq in find_partial_matches(fsm_info, token, full_match=False):
if end_idx is not None and end_idx < vocab_string_len - 1:
continue

res.add((state_seq[0], token_idx))

if token == final_state_string:
# Allow transitions to EOS from all terminals FSM states
for state in fsm_info.finals:
res.add((state, token_idx))

return res


def create_fsm_index(
fsm_info: FSMInfo,
vocabulary: Dict[str, int],
final_state_string: Optional[str] = None,
n_jobs=-1,
) -> Dict[int, Set[int]]:
"""Construct a map from FSM states to subsets of `vocabulary`.
The subsets of `vocabulary` consist of elements that are accepted by--or
transition to--the corresponding partial parse states.
Parameters
----------
fsm
The finite-state machine.
vocabulary
The vocabulary composed of token strings mapped to token IDs.
final_state_string
A string from `vocabulary` that is to be added to all the final states
in the FSM (e.g. ``"<EOS>"``).
"""

results = Parallel(backend="threading", n_jobs=n_jobs, return_as="generator")(
delayed(process_token_string)(fsm_info, token, token_idx, final_state_string)
for token, token_idx in vocabulary.items()
)

states_to_token_subsets: Dict[int, Set[int]] = {}

for fsm_state, token_idx in chain.from_iterable(results):
states_to_token_subsets.setdefault(fsm_state, set()).add(token_idx)

return states_to_token_subsets


def fsm_union(
fsms: Sequence[FSM],
) -> Tuple[FSM, Dict[int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]]]:
Expand Down
120 changes: 6 additions & 114 deletions tests/text/test_fsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@
from outlines.models.transformers import TransformerTokenizer
from outlines.text.fsm import (
_walk_fsm,
create_fsm_index,
create_fsm_index_end_to_end,
create_fsm_index_tokenizer,
find_partial_matches,
fsm_union,
get_sub_fsms_from_seq,
make_deterministic_fsm,
Expand Down Expand Up @@ -84,112 +82,6 @@ def test_walk_fsm(function):
assert res == tuple()


def test_partial_match():
name_pattern = interegular.parse_pattern(r"[^\W\d]\w*")
name_fsm, _ = make_deterministic_fsm(name_pattern.to_fsm().reduce())
assert name_fsm.initial == 0

name_fsm = name_fsm.fsm_info

def_pattern = interegular.parse_pattern("def")
def_fsm, _ = make_deterministic_fsm(def_pattern.to_fsm().reduce())
assert def_fsm.initial == 0

def_fsm = def_fsm.fsm_info

def to_python(res):
return {(x, tuple(y)) for x, y in res}

res = to_python(find_partial_matches(def_fsm, "def"))
assert res == {(2, (0, 1, 2, 3))}
res = to_python(find_partial_matches(def_fsm, "de", full_match=False))
assert res == {(1, (0, 1, 2))}
res = to_python(find_partial_matches(def_fsm, "d", full_match=False))
assert res == {(0, (0, 1))}
res = to_python(find_partial_matches(def_fsm, ""))
assert res == set()
res = to_python(find_partial_matches(def_fsm, "df"))
assert res == set()
res = to_python(find_partial_matches(def_fsm, "ef", full_match=False))
assert res == {(1, (1, 2, 3))}
res = to_python(find_partial_matches(def_fsm, "e", full_match=False))
assert res == {(0, (1, 2))}
res = to_python(find_partial_matches(def_fsm, "f", full_match=False))
assert res == {(0, (2, 3))}
res = to_python(find_partial_matches(def_fsm, "ef foo", full_match=False))
assert res == {(1, (1, 2, 3))}

# This string has a `DEF` token in it, but should ultimately not lex one
res = to_python(find_partial_matches(def_fsm, "defb", full_match=False))
assert res == {(2, (0, 1, 2, 3))}

# `NAME` can have multiple start states for this input
res = to_python(find_partial_matches(name_fsm, "d", full_match=False))
assert res == {(0, (0, 1)), (0, (1, 1))}
# Not this case
res = to_python(find_partial_matches(name_fsm, "1d"))
assert res == {(1, (1, 1, 1))}

res = to_python(find_partial_matches(name_fsm, "blah"))
assert res == {
(3, (0, 1, 1, 1, 1)),
(3, (1, 1, 1, 1, 1)),
}

float_pattern = interegular.parse_pattern(
r"([+-]?((0|[1-9]+)([.][0-9]*)?)|([.][0-9]+))"
)
float_fsm, _ = make_deterministic_fsm(float_pattern.to_fsm().reduce())
assert 5 in float_fsm.finals
assert 2 not in float_fsm.finals

float_fsm = float_fsm.fsm_info

res = to_python(find_partial_matches(float_fsm, ".", full_match=False))
assert res == {(0, (3, 5)), (0, (4, 5)), (0, (0, 2))}

joins_fsm, _ = make_deterministic_fsm(
interegular.parse_pattern(r"(JOIN LEFT|JOIN)").to_fsm().reduce()
)

joins_fsm = joins_fsm.fsm_info

res = to_python(find_partial_matches(joins_fsm, "JOIN BLAH", full_match=False))
assert res == {(3, (0, 1, 2, 3, 4))}

res = to_python(find_partial_matches(joins_fsm, "JOIN L", full_match=False))
assert res == {(5, (0, 1, 2, 3, 4, 5, 6))}

res = to_python(find_partial_matches(joins_fsm, "JOI", full_match=False))
assert res == {(2, (0, 1, 2, 3))}

regex_pattern = interegular.parse_pattern("0|[1-9][2-9]*")
regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())

# State `1` has no transitions
assert not regex_fsm.map[1]

res = to_python(find_partial_matches(regex_fsm.fsm_info, "0", numba.int64(1)))
assert res == {(0, (0, 1))}


def test_create_fsm_index():
regex_str = "0|[1-9][0-9]*"

regex_pattern = interegular.parse_pattern(regex_str)
regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())

vocabulary = {"blah": 0, "1a": 1, "2": 2, "0": 3, "<EOS>": 4}

res = create_fsm_index(regex_fsm.fsm_info, vocabulary)

assert res == {0: {2, 3}, 2: {2, 3}}

res = create_fsm_index(regex_fsm.fsm_info, vocabulary, "<EOS>")

assert res == {0: {2, 3}, 1: {4}, 2: {2, 3, 4}}


def test_get_sub_fsms_from_seq():
name_pattern = interegular.parse_pattern(r"[^\W\d]\w*")
name_fsm, _ = make_deterministic_fsm(name_pattern.to_fsm().reduce())
Expand Down Expand Up @@ -329,18 +221,18 @@ def test_get_sub_fsms_from_seq():
]
fsm, fsms_to_trans_finals = fsm_union(join_fsms)

((_, state_seq),) = find_partial_matches(fsm.fsm_info, "OI", full_match=False)

# Matching "OI"
state_seq = [1, 2, 3]
res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals))
assert res == [(0, True, False), (1, True, False)]

((_, state_seq),) = find_partial_matches(fsm.fsm_info, "N", full_match=False)

# Matching "N"
state_seq = [3, 4]
res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals))
assert res == [(0, False, True), (1, True, False)]

((_, state_seq),) = find_partial_matches(fsm.fsm_info, " ", full_match=False)

# Matching " "
state_seq = [4, 5]
res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals))
assert res == [(1, True, False)]

Expand Down

0 comments on commit 430ee4d

Please sign in to comment.