Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add CLAP onnx export for zero-shot-audio-classification #1552

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,14 @@ class OnnxConfig(ExportConfig, ABC):
"text-generation": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}),
"token-classification": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}),
"visual-question-answering": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}),
"zero-shot-audio-classification": OrderedDict(
{
"logits_per_audio": {0: "audio_batch_size", 1: "text_batch_size"},
"logits_per_text": {0: "text_batch_size", 1: "audio_batch_size"},
"text_embeds": {0: "text_batch_size"},
"audio_embeds": {0: "audio_batch_size"},
}
),
"zero-shot-image-classification": OrderedDict(
{
"logits_per_image": {0: "image_batch_size", 1: "text_batch_size"},
Expand Down
7 changes: 7 additions & 0 deletions optimum/exporters/onnx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,13 @@ class AudioOnnxConfig(OnnxConfig):
def inputs(self) -> Dict[str, Dict[int, str]]:
return {"input_values": {0: "batch_size", 1: "sequence_length"}}

class TextAndAudioOnnxConfig(OnnxConfig):
"""
Handles multi-modal text and audio architectures.
"""

DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, DummyAudioInputGenerator)


class AudioToTextOnnxConfig(OnnxSeq2SeqConfigWithPast):
DUMMY_INPUT_GENERATOR_CLASSES = (
Expand Down
88 changes: 88 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
NormalizedEncoderDecoderConfig,
NormalizedSeq2SeqConfig,
NormalizedTextAndVisionConfig,
NormalizedTextAndAudioConfig,
NormalizedTextConfig,
NormalizedVisionConfig,
logging,
Expand All @@ -54,6 +55,7 @@
AudioToTextOnnxConfig,
EncoderDecoderBaseOnnxConfig,
TextAndVisionOnnxConfig,
TextAndAudioOnnxConfig,
TextDecoderOnnxConfig,
TextDecoderWithPositionIdsOnnxConfig,
TextEncoderOnnxConfig,
Expand Down Expand Up @@ -876,6 +878,92 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
return dummy_inputs


class CLAPNormalizedConfig(NormalizedTextAndAudioConfig):
TEXT_CONFIG = "text_config"
AUDIO_CONFIG = "audio_config"


class CLAPOnnxConfig(TextAndAudioOnnxConfig):
NORMALIZED_CONFIG_CLASS = CLAPNormalizedConfig
DEFAULT_ONNX_OPSET = 14

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
i = {
"input_ids": {0: "text_batch_size", 1: "sequence_length"},
"attention_mask": {0: "text_batch_size", 1: "sequence_length"},
"input_features": {0: "audio_batch_size", 1: "num_channels", 2: "height", 3: "width"}, # As described in modeling_clap.py
}
# print('self._normalized_config', self._normalized_config.audio_config)
# print('self._normalized_config', self._normalized_config)
if(self._normalized_config.audio_config.enable_fusion):
i["is_longer"] = {0: "audio_batch_size"}

return i

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
return {
"logits_per_audio": {0: "audio_batch_size", 1: "text_batch_size"},
"logits_per_text": {0: "text_batch_size", 1: "audio_batch_size"},
"text_embeds": {0: "text_batch_size"},
"audio_embeds": {0: "audio_batch_size"},
}


class CLAPTextWithProjectionOnnxConfig(TextEncoderOnnxConfig):
ATOL_FOR_VALIDATION = 1e-3
# The ONNX export of this architecture needs the Trilu operator support, available since opset 14
DEFAULT_ONNX_OPSET = 14

NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(
vocab_size="vocab_size",
sequence_length="max_position_embeddings",
num_layers="num_hidden_layers",
allow_new=True,
)

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
return {
"input_ids": {0: "batch_size", 1: "sequence_length"},
}

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
common_outputs = {
"text_embeds": {0: "batch_size", 1: "sequence_length"},
"last_hidden_state": {0: "batch_size", 1: "sequence_length"},
}
if self._normalized_config.output_hidden_states:
for i in range(self._normalized_config.num_layers + 1):
common_outputs[f"hidden_states.{i}"] = {0: "batch_size", 1: "sequence_length"}

return common_outputs


class CLAPTextOnnxConfig(CLAPTextWithProjectionOnnxConfig):
@property
def outputs(self) -> Dict[str, Dict[int, str]]:
common_outputs = {
"last_hidden_state": {0: "batch_size", 1: "sequence_length"},
"pooler_output": {0: "batch_size"},
}
if self._normalized_config.output_hidden_states:
for i in range(self._normalized_config.num_layers + 1):
common_outputs[f"hidden_states.{i}"] = {0: "batch_size", 1: "sequence_length"}

return common_outputs

def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs)

if framework == "pt":
import torch
dummy_inputs["input_ids"] = dummy_inputs["input_ids"].to(dtype=torch.int32)
return dummy_inputs


class UNetOnnxConfig(VisionOnnxConfig):
ATOL_FOR_VALIDATION = 1e-3
# The ONNX export of a CLIPText architecture, an other Stable Diffusion component, needs the Trilu
Expand Down
6 changes: 6 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ class TasksManager:
"text2text-generation": "AutoModelForSeq2SeqLM",
"text-classification": "AutoModelForSequenceClassification",
"token-classification": "AutoModelForTokenClassification",
"zero-shot-audio-classification": "AutoModel",
"zero-shot-image-classification": "AutoModelForZeroShotImageClassification",
"zero-shot-object-detection": "AutoModelForZeroShotObjectDetection",
}
Expand Down Expand Up @@ -389,6 +390,11 @@ class TasksManager:
onnx="CamembertOnnxConfig",
tflite="CamembertTFLiteConfig",
),
"clap": supported_tasks_mapping(
"feature-extraction",
"zero-shot-audio-classification",
onnx="CLAPOnnxConfig",
),
"clip": supported_tasks_mapping(
"feature-extraction",
"zero-shot-image-classification",
Expand Down
2 changes: 2 additions & 0 deletions optimum/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@
NormalizedEncoderDecoderConfig,
NormalizedSeq2SeqConfig,
NormalizedTextAndVisionConfig,
NormalizedTextAndAudioConfig,
NormalizedTextConfig,
NormalizedVisionConfig,
NormalizedAudioConfig,
)
28 changes: 26 additions & 2 deletions optimum/utils/input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def wrapper(*args, **kwargs):
"feature_size": 80,
"nb_max_frames": 3000,
"audio_sequence_length": 16000,
"num_mel_bins": 64,
}


Expand Down Expand Up @@ -636,7 +637,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int


class DummyAudioInputGenerator(DummyInputGenerator):
SUPPORTED_INPUT_NAMES = ("input_features", "input_values")
SUPPORTED_INPUT_NAMES = ("input_features", "input_values", "is_longer")

def __init__(
self,
Expand All @@ -646,6 +647,7 @@ def __init__(
feature_size: int = DEFAULT_DUMMY_SHAPES["feature_size"],
nb_max_frames: int = DEFAULT_DUMMY_SHAPES["nb_max_frames"],
audio_sequence_length: int = DEFAULT_DUMMY_SHAPES["audio_sequence_length"],
num_mel_bins: int = DEFAULT_DUMMY_SHAPES["num_mel_bins"],
**kwargs,
):
self.task = task
Expand All @@ -658,8 +660,19 @@ def __init__(
self.nb_max_frames = nb_max_frames
self.batch_size = batch_size
self.sequence_length = audio_sequence_length
if hasattr(self.normalized_config, "num_mel_bins"):
self.num_mel_bins = self.normalized_config.num_mel_bins
else:
self.num_mel_bins = num_mel_bins
if hasattr(self.normalized_config, "enable_fusion"):
self.enable_fusion = self.normalized_config.enable_fusion
else:
self.enable_fusion = False

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
print('generate dummy', input_name)
print(f'{self.normalized_config=}')
print(f'{self.normalized_config.model_type=}')
if input_name == "input_values": # raw waveform
return self.random_float_tensor(
shape=[self.batch_size, self.sequence_length],
Expand All @@ -668,9 +681,20 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
framework=framework,
dtype=float_dtype,
)
elif input_name == 'is_longer':
return self.constant_tensor(shape=[self.batch_size, 1], value=self.enable_fusion, framework=framework)

else:
if self.normalized_config.model_type == 'clap':
# TODO figure out what this value is for?
# https://huggingface.co/laion/clap-htsat-fused uses 4
num_channels = 1
shape = [self.batch_size, num_channels, self.feature_size, self.num_mel_bins]
Comment on lines +688 to +692
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update: When fusion is enabled, num channels becomes 4 due to stacking of the fbank features.
Also, self.feature_size is incorrect for this. Should be self.nb_max_frames + 1. e.g., [1, 1, 1001, 64] is a valid shape.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if self.normalized_config.model_type in ['clap', 'clap_audio_model']:
    shape = [self.batch_size, 1, 1001, 64]
else:
    shape = [self.batch_size, self.feature_size, self.nb_max_frames]

else:
shape = [self.batch_size, self.feature_size, self.nb_max_frames]

return self.random_float_tensor(
shape=[self.batch_size, self.feature_size, self.nb_max_frames],
shape=shape,
min_value=-1,
max_value=1,
framework=framework,
Expand Down
16 changes: 16 additions & 0 deletions optimum/utils/normalized_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ class NormalizedVisionConfig(NormalizedConfig):
NUM_CHANNELS = "num_channels"


class NormalizedAudioConfig(NormalizedConfig):
NUM_MEL_BINS = "num_mel_bins"


class NormalizedTextAndVisionConfig(NormalizedTextConfig, NormalizedVisionConfig):
TEXT_CONFIG = None
VISION_CONFIG = None
Expand All @@ -118,6 +122,18 @@ def __getattr__(self, attr_name):
)


class NormalizedTextAndAudioConfig(NormalizedTextConfig, NormalizedAudioConfig):
TEXT_CONFIG = None
AUDIO_CONFIG = None

def __getattr__(self, attr_name):
if self.TEXT_CONFIG is not None and attr_name.upper() in dir(NormalizedTextConfig):
attr_name = f"{self.TEXT_CONFIG}.{attr_name}"
elif self.AUDIO_CONFIG is not None and attr_name.upper() in dir(NormalizedAudioConfig):
attr_name = f"{self.AUDIO_CONFIG}.{attr_name}"
return super().__getattr__(attr_name)


class NormalizedEncoderDecoderConfig(NormalizedConfig):
ENCODER_NORMALIZED_CONFIG_CLASS = None
DECODER_NORMALIZED_CONFIG_CLASS = None
Expand Down
Loading