diff --git a/outlines/models/transformers.py b/outlines/models/transformers.py index 444492500..50d75558b 100644 --- a/outlines/models/transformers.py +++ b/outlines/models/transformers.py @@ -79,18 +79,6 @@ def __init__(self, tokenizer: "PreTrainedTokenizer", **kwargs): self.vocabulary = self.tokenizer.get_vocab() self.is_llama = isinstance(self.tokenizer, get_llama_tokenizer_types()) - def encode( - self, prompt: Union[str, List[str]], **kwargs - ) -> Tuple["torch.LongTensor", "torch.LongTensor"]: - kwargs["padding"] = True - kwargs["return_tensors"] = "pt" - output = self.tokenizer(prompt, **kwargs) - return output["input_ids"], output["attention_mask"] - - def decode(self, token_ids: "torch.LongTensor") -> List[str]: - text = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True) - return text - def convert_token_to_string(self, token: str) -> str: from transformers.file_utils import SPIECE_UNDERLINE @@ -137,62 +125,30 @@ def __init__( self.model = model self.tokenizer = TransformerTokenizer(tokenizer) - def forward( - self, - input_ids: "torch.LongTensor", - attention_mask: "torch.LongTensor", - past_key_values: Optional[Tuple] = None, - ) -> Tuple["torch.FloatTensor", Optional[KVCacheType]]: - """Compute a forward pass through the transformer model. - - Parameters - ---------- - input_ids - The input token ids. Must be one or two dimensional. - attention_mask - The attention mask. Must be one or two dimensional. - past_key_values - A tuple of tuples containing the cached key and value tensors for each - attention head. + def generate( + self, prompt: Union[str, List[str]], logits_processor, **inference_kwargs + ): + from transformers import LogitsProcessorList, GenerationConfig - Returns - ------- - The computed logits and the new cached key and value tensors. - """ - try: - import torch - except ImportError: - ImportError( - "The `torch` library needs to be installed to use `transformers` models." - ) - assert 0 < input_ids.ndim < 3 - - if past_key_values: - input_ids = input_ids[..., -1].unsqueeze(-1) - - with torch.inference_mode(): - output = self.model( - input_ids, - attention_mask=attention_mask, - return_dict=True, - output_attentions=False, - output_hidden_states=False, - past_key_values=past_key_values, - ) + if isinstance(prompts, str): + prompts = [prompts] - return output.logits, output.past_key_values + input_ids, attention_mask = self.tokenizer.encode([prompts]) - def __call__( - self, - input_ids: "torch.LongTensor", - attention_mask: "torch.LongTensor", - past_key_values: Optional[Tuple] = None, - ) -> "torch.FloatTensor": - logits, kv_cache = self.forward(input_ids, attention_mask, past_key_values) - next_token_logits = logits[..., -1, :] + inputs = { + "input_ids": input_ids.to(self.model.device), + "attention_mask": attention_mask.to(self.model.device), + } - return next_token_logits, kv_cache + if logits_processor is not None: + logits_processor_list = LogitsProcessorList([logits_processor]) + else: + logits_processor_list = None + + output_ids = self.model.generate( + **inputs, generation_config=generation_config + ) def generate( self, @@ -223,27 +179,15 @@ def generate( The generated text """ if isinstance(prompts, str): - # convert to 2d - input_ids, attention_mask = self.tokenizer.encode([prompts]) - else: - input_ids, attention_mask = self.tokenizer.encode(prompts) + prompts = [prompts] + + input_ids, attention_mask = self.tokenizer.encode([prompts]) inputs = { "input_ids": input_ids.to(self.model.device), "attention_mask": attention_mask.to(self.model.device), } - if ( - "attention_mask" - not in inspect.signature(self.model.forward).parameters.keys() - ): - del inputs["attention_mask"] - generation_kwargs = self._get_generation_kwargs( - prompts, - generation_parameters, - logits_processor, - sampling_parameters, - ) generated_ids = self._generate_output_seq(prompts, inputs, **generation_kwargs) # if single str input and single sample per input, convert to a 1D output @@ -296,53 +240,6 @@ def stream( output_group_ids = generated_ids.select(-1, i).unsqueeze(-1) yield self._decode_generation(output_group_ids) - def _get_generation_kwargs( - self, - prompts: Union[str, List[str]], - generation_parameters: GenerationParameters, - logits_processor: Optional["OutlinesLogitsProcessor"], - sampling_parameters: SamplingParameters, - ) -> dict: - """ - Conert outlines generation parameters into model.generate kwargs - """ - from transformers import GenerationConfig, LogitsProcessorList, set_seed - - max_new_tokens, stop_at, seed = dataclasses.astuple(generation_parameters) - sampler, num_samples, top_p, top_k, temperature = dataclasses.astuple( - sampling_parameters - ) - if max_new_tokens is None: - max_new_tokens = int(2**30) - - # global seed, not desirable - if seed is not None: - set_seed(seed) - - if logits_processor is not None: - logits_processor_list = LogitsProcessorList([logits_processor]) - else: - logits_processor_list = None - - generation_config = GenerationConfig( - max_new_tokens=max_new_tokens, - stop_strings=stop_at, - num_return_sequences=(num_samples or 1), - top_p=top_p, - top_k=top_k, - temperature=temperature, - do_sample=(sampler == "multinomial"), - num_beams=(num_samples if sampler == "beam_search" else 1), - eos_token_id=self.tokenizer.eos_token_id, - pad_token_id=self.tokenizer.pad_token_id, - ) - - return dict( - logits_processor=logits_processor_list, - generation_config=generation_config, - tokenizer=self.tokenizer.tokenizer, - ) - def _generate_output_seq( self, prompts, inputs, generation_config, **generation_kwargs ):