From d47bd6b2f19588e1e74e41e189b5f973d807d0f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Tue, 12 Mar 2024 17:09:56 +0100 Subject: [PATCH] Restore FSM interface for backward compatibility --- outlines/fsm/fsm.py | 69 +++++++++ tests/fsm/test_fsm.py | 347 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 416 insertions(+) create mode 100644 outlines/fsm/fsm.py create mode 100644 tests/fsm/test_fsm.py diff --git a/outlines/fsm/fsm.py b/outlines/fsm/fsm.py new file mode 100644 index 000000000..d0340a1ad --- /dev/null +++ b/outlines/fsm/fsm.py @@ -0,0 +1,69 @@ +import warnings +from typing import TYPE_CHECKING, List, NewType + +from outlines.fsm.guide import CFGGuide, RegexGuide, StopAtEOSGuide + +if TYPE_CHECKING: + from outlines.models.tokenizer import Tokenizer + +FSMState = NewType("FSMState", int) + + +class StopAtEosFSM(StopAtEOSGuide): + """FSM to generate text until EOS has been generated.""" + + def __init__(self, tokenizer: "Tokenizer"): + warnings.warn( + UserWarning( + "The `StopAtTokenFSM` interface is deprecated and will be removed on 2024-06-01. Please use `StopAtEOSGuide` instead." + ) + ) + super().__init__(tokenizer) + + def allowed_token_ids(self, state: FSMState) -> List[int]: + next_instruction = self.get_next_instruction(state) + return next_instruction.tokens + + def next_state(self, state: FSMState, token_id: int) -> FSMState: + return FSMState(self.get_next_state(state, token_id)) + + +class RegexFSM(RegexGuide): + """FSM to generate text that is in the language of a regular expression.""" + + def __init__(self, regex_string: str, tokenizer): + warnings.warn( + UserWarning( + "The `RegexFSM` interface is deprecated and will be removed on 2024-06-01. Please use `RegexGuide` instead." + ) + ) + super().__init__(regex_string, tokenizer) + + def allowed_token_ids(self, state: FSMState) -> List[int]: + next_instruction = self.get_next_instruction(state) + return next_instruction.tokens + + def next_state(self, state: FSMState, token_id: int) -> FSMState: + return FSMState(self.get_next_state(state, token_id)) + + +class CFGFSM(CFGGuide): + """FSM to generate text that is in the language of a context-free grammar.""" + + def __init__(self, cfg_string: str, tokenizer): + warnings.warn( + UserWarning( + "The `CFGFSM` interface is deprecated and will be removed on 2024-06-01. Please use `CFGGuide` instead." + ) + ) + super().__init__(cfg_string, tokenizer) + + def allowed_token_ids(self, state: FSMState) -> List[int]: + return self.get_next_instruction(state).tokens + + def next_state(self, state: FSMState, token_id: int) -> FSMState: + return FSMState(self.get_next_state(state, token_id)) + + def copy(self) -> "CFGFSM": + """Create a copy of the FSM.""" + return CFGFSM(self.cfg_string, self.tokenizer) diff --git a/tests/fsm/test_fsm.py b/tests/fsm/test_fsm.py new file mode 100644 index 000000000..8ce17c6eb --- /dev/null +++ b/tests/fsm/test_fsm.py @@ -0,0 +1,347 @@ +import pytest + +from outlines.fsm.fsm import CFGFSM, RegexFSM, StopAtEosFSM + + +def test_stop_at_eos(): + class MockTokenizer: + vocabulary = {"a": 1, "eos": 2} + eos_token_id = 2 + + with pytest.warns(UserWarning): + fsm = StopAtEosFSM(MockTokenizer()) + + assert fsm.allowed_token_ids(fsm.start_state) == [1, 2] + assert fsm.allowed_token_ids(fsm.final_state) == [2] + assert fsm.next_state(fsm.start_state, 2) == fsm.final_state + assert fsm.next_state(fsm.start_state, 1) == fsm.start_state + assert fsm.is_final_state(fsm.start_state) is False + assert fsm.is_final_state(fsm.final_state) is True + + +def test_regex_vocabulary_error(): + class MockTokenizer: + vocabulary = {"a": 1} + special_tokens = {"eos"} + eos_token_id = 3 + + def convert_token_to_string(self, token): + return token + + regex_str = "[1-9]" + + with pytest.raises(ValueError, match="The vocabulary"): + RegexFSM(regex_str, MockTokenizer()) + + +def test_regex(): + class MockTokenizer: + vocabulary = {"1": 1, "a": 2, "eos": 3} + special_tokens = {"eos"} + eos_token_id = 3 + + def convert_token_to_string(self, token): + return token + + regex_str = "[1-9]" + tokenizer = MockTokenizer() + + with pytest.warns(UserWarning): + fsm = RegexFSM(regex_str, tokenizer) + + assert fsm.states_to_token_maps == {0: {1: 1}} + assert fsm.allowed_token_ids(state=0) == [1] + assert fsm.next_state(state=0, token_id=1) == 1 + assert fsm.next_state(state=0, token_id=tokenizer.eos_token_id) == -1 + + assert fsm.is_final_state(0) is False + + for state in fsm.final_states: + assert fsm.is_final_state(state) is True + + +def test_regex_final_state(): + """Make sure that the FSM stays in the final state as we keep generating""" + + class MockTokenizer: + vocabulary = {"`": 101, ".": 102, "\n": 103, "eos": 104} + special_tokens = {"eos"} + eos_token_id = 104 + + def convert_token_to_string(self, token): + return token + + regex_str = r"`\n(\.\n)?`\n" + tokenizer = MockTokenizer() + + with pytest.warns(UserWarning): + fsm = RegexFSM(regex_str, tokenizer) + + state = fsm.next_state(state=4, token_id=103) + assert state == 5 + assert fsm.is_final_state(state) + + state = fsm.next_state(state=5, token_id=103) + assert state == 5 + + assert fsm.is_final_state(-1) + + +def test_cfg(): + class MockTokenizer: + vocabulary = {"{": 1, "}": 2, "[": 3, "]": 4, "eos": 5} + special_tokens = {"eos"} + eos_token = "eos" + eos_token_id = 5 + + def convert_token_to_string(self, token): + return token + + @property + def inverse_vocabulary(self): + return {v: k for k, v in self.vocabulary.items()} + + def decode(self, token_ids): + return [self.inverse_vocabulary[t] for t in token_ids] + + cfg_str = """ + start: expr + expr: "{" expr "}" | "[" expr "]" | + """ + tokenizer = MockTokenizer() + + with pytest.warns(UserWarning): + fsm = CFGFSM(cfg_str, tokenizer) + + assert set(fsm.allowed_token_ids(state=fsm.start_state)) == {1, 3, 5} + state = fsm.next_state(state=fsm.start_state, token_id=1) + assert fsm.generation == "{" + assert not fsm.is_final_state(state) + + assert set(fsm.allowed_token_ids(state=state)) == {1, 2, 3} + state = fsm.next_state(state=state, token_id=3) + assert fsm.generation == "{[" + assert not fsm.is_final_state(state) + + assert set(fsm.allowed_token_ids(state=state)) == {1, 3, 4} + state = fsm.next_state(state=state, token_id=4) + assert fsm.generation == "{[]" + assert not fsm.is_final_state(state) + + assert set(fsm.allowed_token_ids(state=state)) == {2} + state = fsm.next_state(state=state, token_id=2) + assert fsm.generation == "{[]}" + assert not fsm.is_final_state(state) + + assert set(fsm.allowed_token_ids(state=state)) == {5} + state = fsm.next_state(state=state, token_id=5) + assert fsm.generation == "{[]}" + assert fsm.is_final_state(state) + + +def test_cfg_early_termination(): + class MockTokenizer: + vocabulary = {"(": 1, ")": 2, "eos": 3} + special_tokens = {"eos"} + eos_token = "eos" + eos_token_id = 3 + + def convert_token_to_string(self, token): + return token + + @property + def inverse_vocabulary(self): + return {v: k for k, v in self.vocabulary.items()} + + def decode(self, token_ids): + return [self.inverse_vocabulary[t] for t in token_ids] + + cfg_str = """ + start: expr+ + expr: "(" subexpr ")" + subexpr: expr | + """ + tokenizer = MockTokenizer() + + with pytest.warns(UserWarning): + fsm = CFGFSM(cfg_str, tokenizer) + + assert set(fsm.allowed_token_ids(state=fsm.start_state)) == {1} + state = fsm.next_state(state=fsm.start_state, token_id=1) + assert fsm.generation == "(" + assert not fsm.is_final_state(state) + + assert set(fsm.allowed_token_ids(state=state)) == {1, 2} + state = fsm.next_state(state=state, token_id=2) + assert fsm.generation == "()" + assert not fsm.is_final_state(state) + + # possible to continue or terminate + assert set(fsm.allowed_token_ids(state=state)) == {1, 3} + state = fsm.next_state(state=state, token_id=3) # feed eos + assert fsm.generation == "()" + assert fsm.is_final_state(state) + + # once eos generated, can only terminate + assert set(fsm.allowed_token_ids(state=state)) == {3} + + +def test_cfg_ignore_directive(): + class MockTokenizer: + vocabulary = {"a": 1, " ": 2, "eos": 3} + special_tokens = {"eos"} + eos_token = "eos" + eos_token_id = 3 + + def convert_token_to_string(self, token): + return token + + @property + def inverse_vocabulary(self): + return {v: k for k, v in self.vocabulary.items()} + + def decode(self, token_ids): + return [self.inverse_vocabulary[t] for t in token_ids] + + cfg_str = """ + start: LETTER+ + LETTER: "a" + WS: " " + %ignore WS + """ + tokenizer = MockTokenizer() + + with pytest.warns(UserWarning): + fsm = CFGFSM(cfg_str, tokenizer) + + state = 0 + + assert set(fsm.allowed_token_ids(state=0)) == {1, 2} + state = fsm.next_state(state=0, token_id=2) + assert fsm.generation == " " + assert not fsm.is_final_state(state) + + assert set(fsm.allowed_token_ids(state=0)) == {1, 2} + state = fsm.next_state(state=0, token_id=1) + assert fsm.generation == " a" + assert not fsm.is_final_state(state) + + assert set(fsm.allowed_token_ids(state=state)) == {1, 2, 3} + state = fsm.next_state(state=state, token_id=2) + assert fsm.generation == " a " + assert not fsm.is_final_state(state) + + assert set(fsm.allowed_token_ids(state=state)) == {1, 2, 3} + state = fsm.next_state(state=state, token_id=2) + assert fsm.generation == " a " + assert not fsm.is_final_state(state) + + assert set(fsm.allowed_token_ids(state=state)) == {1, 2, 3} + state = fsm.next_state(state=state, token_id=1) + assert fsm.generation == " a a" + assert not fsm.is_final_state(state) + + assert set(fsm.allowed_token_ids(state=state)) == {1, 2, 3} + state = fsm.next_state(state=state, token_id=3) + assert fsm.generation == " a a" + assert fsm.is_final_state(state) + + # once eos generated, can only terminate + assert set(fsm.allowed_token_ids(state=state)) == {3} + + +def test_cfg_multitoken_terminal(): + class MockTokenizer: + vocabulary = {"a": 1, "b": 2, "eos": 3} + special_tokens = {"eos"} + eos_token = "eos" + eos_token_id = 3 + + def convert_token_to_string(self, token): + return token + + @property + def inverse_vocabulary(self): + return {v: k for k, v in self.vocabulary.items()} + + def decode(self, token_ids): + return [self.inverse_vocabulary[t] for t in token_ids] + + cfg_str = """ + start: S + S: "aa" | "bb" + """ + tokenizer = MockTokenizer() + + with pytest.warns(UserWarning): + fsm = CFGFSM(cfg_str, tokenizer) + + assert set(fsm.allowed_token_ids(state=fsm.start_state)) == {1, 2} + assert fsm.reset_state # starting new regex + state = fsm.next_state(state=fsm.start_state, token_id=1) + assert fsm.generation == "a" + assert not fsm.is_final_state(state) + + assert set(fsm.allowed_token_ids(state=state)) == {1} + assert not fsm.reset_state # continuing current regex + state = fsm.next_state(state=state, token_id=1) + assert fsm.generation == "aa" + assert not fsm.is_final_state(state) + + assert set(fsm.allowed_token_ids(state=state)) == {3} + assert not fsm.reset_state # completing current regex + state = fsm.next_state(state=state, token_id=3) + assert fsm.generation == "aa" + assert fsm.is_final_state(state) + + +def test_cfg_allow_both_extend_and_shift_terminal(): + class MockTokenizer: + vocabulary = {"(": 1, ")": 2, "a": 3, "eos": 4} + special_tokens = {"eos"} + eos_token = "eos" + eos_token_id = 4 + + def convert_token_to_string(self, token): + return token + + @property + def inverse_vocabulary(self): + return {v: k for k, v in self.vocabulary.items()} + + def decode(self, token_ids): + return [self.inverse_vocabulary[t] for t in token_ids] + + cfg_str = """ + start: s + s: "(" s ")" | /a+/ + """ + tokenizer = MockTokenizer() + + with pytest.warns(UserWarning): + fsm = CFGFSM(cfg_str, tokenizer) + + assert set(fsm.allowed_token_ids(state=fsm.start_state)) == {1, 3} + state = fsm.next_state(state=fsm.start_state, token_id=1) + assert fsm.generation == "(" + assert not fsm.is_final_state(state) + + assert set(fsm.allowed_token_ids(state=state)) == {1, 3} + state = fsm.next_state(state=state, token_id=3) + assert fsm.generation == "(a" + assert not fsm.is_final_state(state) + + assert set(fsm.allowed_token_ids(state=state)) == {2, 3} + state = fsm.next_state(state=state, token_id=3) + assert fsm.generation == "(aa" + assert not fsm.is_final_state(state) + + assert set(fsm.allowed_token_ids(state=state)) == {2, 3} + state = fsm.next_state(state=state, token_id=2) + assert fsm.generation == "(aa)" + assert not fsm.is_final_state(state) + + assert set(fsm.allowed_token_ids(state=state)) == {4} + state = fsm.next_state(state=state, token_id=4) + assert fsm.generation == "(aa)" + assert fsm.is_final_state(state)