From 0f9365a14f290066d1ee1be30a4b6e2875279bbc Mon Sep 17 00:00:00 2001 From: philheller Date: Fri, 6 Sep 2024 12:45:22 +0200 Subject: [PATCH] Cleanup, comment out some print statements --- src/transformers/generation/utils.py | 29 ++++++++++------------------ 1 file changed, 10 insertions(+), 19 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 19535b23d950..723a3feba754 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -276,6 +276,7 @@ class GenerateBeamDecoderOnlyOutput(ModelOutput): # for continuation attention_mask: Optional[torch.LongTensor] = None last_beam_scores: Optional[torch.FloatTensor] = None + # group beam search next_input_ids: Optional[torch.LongTensor] = None @@ -341,6 +342,7 @@ class GenerateBeamEncoderDecoderOutput(ModelOutput): # for continuation attention_mask: Optional[torch.LongTensor] = None last_beam_scores: Optional[torch.FloatTensor] = None + # group beam search next_input_ids: Optional[torch.LongTensor] = None @@ -1727,9 +1729,9 @@ def generate( tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs) # todo:phil remove in prod - print(generation_config, - {k: v for k, v in model_kwargs.items() if k != 'past_key_values'} - ) + # print(generation_config, + # {k: v for k, v in model_kwargs.items() if k != 'past_key_values'} + # ) self._validate_model_kwargs(model_kwargs.copy()) self._validate_assistant(assistant_model) @@ -1935,7 +1937,7 @@ def generate( # todo:phil figure out for every mode, what continuation would look like # ? input for every configuration step would have to be the output from the previous step # ! leaving out speculative decoding for now - print("GENERATION_MODE", generation_mode) + # print("GENERATION_MODE", generation_mode) if generation_mode == GenerationMode.ASSISTED_GENERATION: if generation_config.num_return_sequences > 1: @@ -2190,9 +2192,9 @@ def generate( ) elif generation_mode == GenerationMode.CONSTRAINED_BEAM_SEARCH: - if generation_config.resume_generation: + # if generation_config.resume_generation: # todo:phil implement - print("Resuming generation") + # print("Resuming generation") final_constraints = [] if generation_config.constraints is not None: final_constraints = generation_config.constraints @@ -2993,12 +2995,8 @@ def _sample( generation_config: GenerationConfig, synced_gpus: bool, streamer: Optional["BaseStreamer"], -<<<<<<< HEAD - logits_warper: Optional[LogitsProcessorList] = None, - last_scores: Optional[Tuple[torch.FloatTensor]] = None, -======= logits_warper: Optional[LogitsProcessorList], ->>>>>>> 7789317cb66e3040bb99a9359259d750e30a276f + last_scores: Optional[Tuple[torch.FloatTensor]] = None, **model_kwargs, ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: r""" @@ -3204,7 +3202,6 @@ def _temporary_reorder_cache(self, past_key_values, beam_idx): past_key_values.reorder_cache(beam_idx) return past_key_values -<<<<<<< HEAD def _reorder_attention_mask(self, beam_indices: Tuple[Tuple[int]], **model_kwargs): last_tokens = [inner_tuple[-1] for inner_tuple in beam_indices] last_indices = torch.tensor([token.item() for token in last_tokens]) @@ -3219,8 +3216,6 @@ def _reorder_attention_mask(self, beam_indices: Tuple[Tuple[int]], **model_kwarg return model_kwargs # TODO (joao, v4.42): remove default for `logits_warper` -======= ->>>>>>> 7789317cb66e3040bb99a9359259d750e30a276f def _beam_search( self, input_ids: torch.LongTensor, @@ -3229,14 +3224,10 @@ def _beam_search( stopping_criteria: StoppingCriteriaList, generation_config: GenerationConfig, synced_gpus: bool, -<<<<<<< HEAD - logits_warper: Optional[LogitsProcessorList] = None, + logits_warper: Optional[LogitsProcessorList], last_beam_scores: Optional[torch.FloatTensor] = None, last_scores: Optional[Tuple[torch.FloatTensor]] = None, original_prompt_len: int = None, -======= - logits_warper: Optional[LogitsProcessorList], ->>>>>>> 7789317cb66e3040bb99a9359259d750e30a276f **model_kwargs, ) -> Union[GenerateBeamOutput, torch.LongTensor]: r"""