Skip to content

Commit

Permalink
fixed types
Browse files Browse the repository at this point in the history
  • Loading branch information
vegaluisjose committed Jan 26, 2024
1 parent 4d9d51f commit 3a65684
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions serve/mlc_serve/engine/constrained_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ def __init__(self, regex_string, tokenizer):
self.fsm_state: DefaultDict[SequenceId, int] = defaultdict(int)

def __call__(
self, seq_id: int, input_ids: List[int], scores: torch.Tensor
self, seq_id: SequenceId, input_ids: List[int], scores: torch.Tensor
) -> torch.Tensor:
"""Use the FSM to bias the logits before sampling the next token."""

if len(input_ids) == 0: # Initialize the fsm states
self.fsm_state: DefaultDict[SequenceId, int] = defaultdict(int)
self.fsm_state = defaultdict(int)
else:
last_token = input_ids[-1]
self.fsm_state[seq_id] = self.fsm.next_state(
Expand Down

0 comments on commit 3a65684

Please sign in to comment.