diff --git a/outlines/text/fsm.py b/outlines/text/fsm.py index f78f33e49..a96fa8e24 100644 --- a/outlines/text/fsm.py +++ b/outlines/text/fsm.py @@ -1,12 +1,10 @@ from collections import namedtuple from functools import lru_cache -from itertools import chain -from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Set, Tuple +from typing import TYPE_CHECKING, Dict, Generator, List, Sequence, Set, Tuple import numba import numpy as np from interegular.fsm import FSM, Alphabet, OblivionError, anything_else -from joblib import Parallel, delayed from numba.typed.typedobjectutils import _nonoptional if TYPE_CHECKING: @@ -149,17 +147,6 @@ def create_fsm_info( ], ) -spec = [ - numba.int64, - numba.types.Set(numba.int64), - numba.types.DictType(numba.types.UniTuple(numba.int64, 2), numba.int64), - numba.types.DictType(numba.int64, numba.types.ListType(numba.int64)), - numba.optional(numba.int64), - numba.types.DictType(numba.types.string, numba.int64), -] - -FSMInfoNumbaType = numba.types.NamedTuple(spec, FSMInfo) - def make_deterministic_fsm(fsm: FSM) -> Tuple[BetterFSM, Dict[int, int]]: """Construct an equivalent FSM with deterministic state labels.""" @@ -314,123 +301,6 @@ def walk_fsm( return accepted_states -# TODO FIXME: Can't cache this due to https://github.com/numba/numba/issues/9177 -@numba.njit(nogil=True) -def find_partial_matches( - fsm_info: FSMInfo, - input_string: str, - full_match: bool = True, -) -> Generator[Tuple[int, List[int]], None, None]: - """Find the states in the finite state machine `fsm_info` that accept `input_string`. - - This will consider all possible states in the finite state machine (FSM) - that accept the beginning of `input_string` as starting points, unless a - specific `start_state` is provided. - - Parameters - ---------- - fsm_info - The finite state machine. - input_string - The string for which we generate partial matches. - full_match - Matches must cover the entire string. - - Returns - ------- - A set of tuples corresponding to each valid starting state in the FSM. The - first element of each tuple contains an integer indicating the position in - `input_string` at which the FSM stopped. The second element is the tuple - of states visited during execution of the FSM plus the next, unvisited - transition state. - - """ - - if len(input_string) == 0: - return - - trans_key = fsm_info.alphabet_symbol_mapping.get( - input_string[0], fsm_info.alphabet_anything_value - ) - - for state in fsm_info.trans_key_to_states.get( - trans_key, numba.typed.List.empty_list(numba.int64) # type: ignore - ): - path = _walk_fsm( - fsm_info.transitions, - fsm_info.alphabet_symbol_mapping, - fsm_info.alphabet_anything_value, - fsm_info.initial, - fsm_info.finals, - input_string, - state, - full_match=full_match, - ) - if path: - path.insert(0, state) - res = (len(path) - 2, path) - yield res - - -@numba.njit(nogil=True, cache=True) -def process_token_string( - fsm_info: FSMInfo, - token: str, - token_idx: int, - final_state_string: Optional[str] = None, -) -> Set[Tuple[int, int]]: - res = set() - vocab_string_len = len(token) - - for end_idx, state_seq in find_partial_matches(fsm_info, token, full_match=False): - if end_idx is not None and end_idx < vocab_string_len - 1: - continue - - res.add((state_seq[0], token_idx)) - - if token == final_state_string: - # Allow transitions to EOS from all terminals FSM states - for state in fsm_info.finals: - res.add((state, token_idx)) - - return res - - -def create_fsm_index( - fsm_info: FSMInfo, - vocabulary: Dict[str, int], - final_state_string: Optional[str] = None, - n_jobs=-1, -) -> Dict[int, Set[int]]: - """Construct a map from FSM states to subsets of `vocabulary`. - - The subsets of `vocabulary` consist of elements that are accepted by--or - transition to--the corresponding partial parse states. - - Parameters - ---------- - fsm - The finite-state machine. - vocabulary - The vocabulary composed of token strings mapped to token IDs. - final_state_string - A string from `vocabulary` that is to be added to all the final states - in the FSM (e.g. ``""``). - """ - - results = Parallel(backend="threading", n_jobs=n_jobs, return_as="generator")( - delayed(process_token_string)(fsm_info, token, token_idx, final_state_string) - for token, token_idx in vocabulary.items() - ) - - states_to_token_subsets: Dict[int, Set[int]] = {} - - for fsm_state, token_idx in chain.from_iterable(results): - states_to_token_subsets.setdefault(fsm_state, set()).add(token_idx) - - return states_to_token_subsets - - def fsm_union( fsms: Sequence[FSM], ) -> Tuple[FSM, Dict[int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]]]: diff --git a/tests/text/test_fsm.py b/tests/text/test_fsm.py index 10c18eecc..7091fd0b9 100644 --- a/tests/text/test_fsm.py +++ b/tests/text/test_fsm.py @@ -5,10 +5,8 @@ from outlines.models.transformers import TransformerTokenizer from outlines.text.fsm import ( _walk_fsm, - create_fsm_index, create_fsm_index_end_to_end, create_fsm_index_tokenizer, - find_partial_matches, fsm_union, get_sub_fsms_from_seq, make_deterministic_fsm, @@ -84,112 +82,6 @@ def test_walk_fsm(function): assert res == tuple() -def test_partial_match(): - name_pattern = interegular.parse_pattern(r"[^\W\d]\w*") - name_fsm, _ = make_deterministic_fsm(name_pattern.to_fsm().reduce()) - assert name_fsm.initial == 0 - - name_fsm = name_fsm.fsm_info - - def_pattern = interegular.parse_pattern("def") - def_fsm, _ = make_deterministic_fsm(def_pattern.to_fsm().reduce()) - assert def_fsm.initial == 0 - - def_fsm = def_fsm.fsm_info - - def to_python(res): - return {(x, tuple(y)) for x, y in res} - - res = to_python(find_partial_matches(def_fsm, "def")) - assert res == {(2, (0, 1, 2, 3))} - res = to_python(find_partial_matches(def_fsm, "de", full_match=False)) - assert res == {(1, (0, 1, 2))} - res = to_python(find_partial_matches(def_fsm, "d", full_match=False)) - assert res == {(0, (0, 1))} - res = to_python(find_partial_matches(def_fsm, "")) - assert res == set() - res = to_python(find_partial_matches(def_fsm, "df")) - assert res == set() - res = to_python(find_partial_matches(def_fsm, "ef", full_match=False)) - assert res == {(1, (1, 2, 3))} - res = to_python(find_partial_matches(def_fsm, "e", full_match=False)) - assert res == {(0, (1, 2))} - res = to_python(find_partial_matches(def_fsm, "f", full_match=False)) - assert res == {(0, (2, 3))} - res = to_python(find_partial_matches(def_fsm, "ef foo", full_match=False)) - assert res == {(1, (1, 2, 3))} - - # This string has a `DEF` token in it, but should ultimately not lex one - res = to_python(find_partial_matches(def_fsm, "defb", full_match=False)) - assert res == {(2, (0, 1, 2, 3))} - - # `NAME` can have multiple start states for this input - res = to_python(find_partial_matches(name_fsm, "d", full_match=False)) - assert res == {(0, (0, 1)), (0, (1, 1))} - # Not this case - res = to_python(find_partial_matches(name_fsm, "1d")) - assert res == {(1, (1, 1, 1))} - - res = to_python(find_partial_matches(name_fsm, "blah")) - assert res == { - (3, (0, 1, 1, 1, 1)), - (3, (1, 1, 1, 1, 1)), - } - - float_pattern = interegular.parse_pattern( - r"([+-]?((0|[1-9]+)([.][0-9]*)?)|([.][0-9]+))" - ) - float_fsm, _ = make_deterministic_fsm(float_pattern.to_fsm().reduce()) - assert 5 in float_fsm.finals - assert 2 not in float_fsm.finals - - float_fsm = float_fsm.fsm_info - - res = to_python(find_partial_matches(float_fsm, ".", full_match=False)) - assert res == {(0, (3, 5)), (0, (4, 5)), (0, (0, 2))} - - joins_fsm, _ = make_deterministic_fsm( - interegular.parse_pattern(r"(JOIN LEFT|JOIN)").to_fsm().reduce() - ) - - joins_fsm = joins_fsm.fsm_info - - res = to_python(find_partial_matches(joins_fsm, "JOIN BLAH", full_match=False)) - assert res == {(3, (0, 1, 2, 3, 4))} - - res = to_python(find_partial_matches(joins_fsm, "JOIN L", full_match=False)) - assert res == {(5, (0, 1, 2, 3, 4, 5, 6))} - - res = to_python(find_partial_matches(joins_fsm, "JOI", full_match=False)) - assert res == {(2, (0, 1, 2, 3))} - - regex_pattern = interegular.parse_pattern("0|[1-9][2-9]*") - regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) - - # State `1` has no transitions - assert not regex_fsm.map[1] - - res = to_python(find_partial_matches(regex_fsm.fsm_info, "0", numba.int64(1))) - assert res == {(0, (0, 1))} - - -def test_create_fsm_index(): - regex_str = "0|[1-9][0-9]*" - - regex_pattern = interegular.parse_pattern(regex_str) - regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) - - vocabulary = {"blah": 0, "1a": 1, "2": 2, "0": 3, "": 4} - - res = create_fsm_index(regex_fsm.fsm_info, vocabulary) - - assert res == {0: {2, 3}, 2: {2, 3}} - - res = create_fsm_index(regex_fsm.fsm_info, vocabulary, "") - - assert res == {0: {2, 3}, 1: {4}, 2: {2, 3, 4}} - - def test_get_sub_fsms_from_seq(): name_pattern = interegular.parse_pattern(r"[^\W\d]\w*") name_fsm, _ = make_deterministic_fsm(name_pattern.to_fsm().reduce()) @@ -329,18 +221,18 @@ def test_get_sub_fsms_from_seq(): ] fsm, fsms_to_trans_finals = fsm_union(join_fsms) - ((_, state_seq),) = find_partial_matches(fsm.fsm_info, "OI", full_match=False) - + # Matching "OI" + state_seq = [1, 2, 3] res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) assert res == [(0, True, False), (1, True, False)] - ((_, state_seq),) = find_partial_matches(fsm.fsm_info, "N", full_match=False) - + # Matching "N" + state_seq = [3, 4] res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) assert res == [(0, False, True), (1, True, False)] - ((_, state_seq),) = find_partial_matches(fsm.fsm_info, " ", full_match=False) - + # Matching " " + state_seq = [4, 5] res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) assert res == [(1, True, False)]