You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability=random.uniform(0, 1)
ifnotdeterministicand (dropout_probability<self.layerdrop): # skip the layer
layer_outputs= (None, None)
else:
layer_outputs=encoder_layer(
hidden_states,
attention_mask,
output_attentions,
deterministic,
)
hidden_states=layer_outputs[0]
Fixing that I also noticed dropout_probability = random.uniform(0, 1) is only run during tracing so looping a compiled training step will always drop the same layers.
The text was updated successfully, but these errors were encountered:
System Info
transformers
version: 4.44.2Who can help?
@sanchit-gandhi
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Train a FlaxWhisperForConditionalGeneration model with encoder/decoder layerdrop activated.
Expected behavior
I'm using FlaxWhisperForConditionalGeneration but I see the same code is in a bunch of models.
Here hidden_states is set to None if the layer is dropped causing the error.
transformers/src/transformers/models/whisper/modeling_flax_whisper.py
Lines 442 to 453 in d5aebc6
Fixing that I also noticed dropout_probability = random.uniform(0, 1) is only run during tracing so looping a compiled training step will always drop the same layers.
The text was updated successfully, but these errors were encountered: