Skip to content

Commit

Permalink
Cleanup, comment out some print statements
Browse files Browse the repository at this point in the history
  • Loading branch information
philheller committed Sep 6, 2024
1 parent df84617 commit 0f9365a
Showing 1 changed file with 10 additions and 19 deletions.
29 changes: 10 additions & 19 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ class GenerateBeamDecoderOnlyOutput(ModelOutput):
# for continuation
attention_mask: Optional[torch.LongTensor] = None
last_beam_scores: Optional[torch.FloatTensor] = None
# group beam search
next_input_ids: Optional[torch.LongTensor] = None


Expand Down Expand Up @@ -341,6 +342,7 @@ class GenerateBeamEncoderDecoderOutput(ModelOutput):
# for continuation
attention_mask: Optional[torch.LongTensor] = None
last_beam_scores: Optional[torch.FloatTensor] = None
# group beam search
next_input_ids: Optional[torch.LongTensor] = None


Expand Down Expand Up @@ -1727,9 +1729,9 @@ def generate(
tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
# todo:phil remove in prod
print(generation_config,
{k: v for k, v in model_kwargs.items() if k != 'past_key_values'}
)
# print(generation_config,
# {k: v for k, v in model_kwargs.items() if k != 'past_key_values'}
# )
self._validate_model_kwargs(model_kwargs.copy())
self._validate_assistant(assistant_model)

Expand Down Expand Up @@ -1935,7 +1937,7 @@ def generate(
# todo:phil figure out for every mode, what continuation would look like
# ? input for every configuration step would have to be the output from the previous step
# ! leaving out speculative decoding for now
print("GENERATION_MODE", generation_mode)
# print("GENERATION_MODE", generation_mode)

if generation_mode == GenerationMode.ASSISTED_GENERATION:
if generation_config.num_return_sequences > 1:
Expand Down Expand Up @@ -2190,9 +2192,9 @@ def generate(
)

elif generation_mode == GenerationMode.CONSTRAINED_BEAM_SEARCH:
if generation_config.resume_generation:
# if generation_config.resume_generation:
# todo:phil implement
print("Resuming generation")
# print("Resuming generation")
final_constraints = []
if generation_config.constraints is not None:
final_constraints = generation_config.constraints
Expand Down Expand Up @@ -2993,12 +2995,8 @@ def _sample(
generation_config: GenerationConfig,
synced_gpus: bool,
streamer: Optional["BaseStreamer"],
<<<<<<< HEAD
logits_warper: Optional[LogitsProcessorList] = None,
last_scores: Optional[Tuple[torch.FloatTensor]] = None,
=======
logits_warper: Optional[LogitsProcessorList],
>>>>>>> 7789317cb66e3040bb99a9359259d750e30a276f
last_scores: Optional[Tuple[torch.FloatTensor]] = None,
**model_kwargs,
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
r"""
Expand Down Expand Up @@ -3204,7 +3202,6 @@ def _temporary_reorder_cache(self, past_key_values, beam_idx):
past_key_values.reorder_cache(beam_idx)
return past_key_values

<<<<<<< HEAD
def _reorder_attention_mask(self, beam_indices: Tuple[Tuple[int]], **model_kwargs):
last_tokens = [inner_tuple[-1] for inner_tuple in beam_indices]
last_indices = torch.tensor([token.item() for token in last_tokens])
Expand All @@ -3219,8 +3216,6 @@ def _reorder_attention_mask(self, beam_indices: Tuple[Tuple[int]], **model_kwarg
return model_kwargs

# TODO (joao, v4.42): remove default for `logits_warper`
=======
>>>>>>> 7789317cb66e3040bb99a9359259d750e30a276f
def _beam_search(
self,
input_ids: torch.LongTensor,
Expand All @@ -3229,14 +3224,10 @@ def _beam_search(
stopping_criteria: StoppingCriteriaList,
generation_config: GenerationConfig,
synced_gpus: bool,
<<<<<<< HEAD
logits_warper: Optional[LogitsProcessorList] = None,
logits_warper: Optional[LogitsProcessorList],
last_beam_scores: Optional[torch.FloatTensor] = None,
last_scores: Optional[Tuple[torch.FloatTensor]] = None,
original_prompt_len: int = None,
=======
logits_warper: Optional[LogitsProcessorList],
>>>>>>> 7789317cb66e3040bb99a9359259d750e30a276f
**model_kwargs,
) -> Union[GenerateBeamOutput, torch.LongTensor]:
r"""
Expand Down

0 comments on commit 0f9365a

Please sign in to comment.