diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index 8bf93feb29..21bb6bd0ec 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -179,6 +179,14 @@ class OnnxConfig(ExportConfig, ABC): "text-generation": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), "token-classification": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), "visual-question-answering": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), + "zero-shot-audio-classification": OrderedDict( + { + "logits_per_audio": {0: "audio_batch_size", 1: "text_batch_size"}, + "logits_per_text": {0: "text_batch_size", 1: "audio_batch_size"}, + "text_embeds": {0: "text_batch_size"}, + "audio_embeds": {0: "audio_batch_size"}, + } + ), "zero-shot-image-classification": OrderedDict( { "logits_per_image": {0: "image_batch_size", 1: "text_batch_size"}, diff --git a/optimum/exporters/onnx/config.py b/optimum/exporters/onnx/config.py index 2eaa78d85e..ed5902f529 100644 --- a/optimum/exporters/onnx/config.py +++ b/optimum/exporters/onnx/config.py @@ -271,6 +271,13 @@ class AudioOnnxConfig(OnnxConfig): def inputs(self) -> Dict[str, Dict[int, str]]: return {"input_values": {0: "batch_size", 1: "sequence_length"}} +class TextAndAudioOnnxConfig(OnnxConfig): + """ + Handles multi-modal text and audio architectures. + """ + + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, DummyAudioInputGenerator) + class AudioToTextOnnxConfig(OnnxSeq2SeqConfigWithPast): DUMMY_INPUT_GENERATOR_CLASSES = ( diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index f4d50ad58d..e742efcd8b 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -43,6 +43,7 @@ NormalizedEncoderDecoderConfig, NormalizedSeq2SeqConfig, NormalizedTextAndVisionConfig, + NormalizedTextAndAudioConfig, NormalizedTextConfig, NormalizedVisionConfig, logging, @@ -54,6 +55,7 @@ AudioToTextOnnxConfig, EncoderDecoderBaseOnnxConfig, TextAndVisionOnnxConfig, + TextAndAudioOnnxConfig, TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig, TextEncoderOnnxConfig, @@ -876,6 +878,92 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs): return dummy_inputs +class CLAPNormalizedConfig(NormalizedTextAndAudioConfig): + TEXT_CONFIG = "text_config" + AUDIO_CONFIG = "audio_config" + + +class CLAPOnnxConfig(TextAndAudioOnnxConfig): + NORMALIZED_CONFIG_CLASS = CLAPNormalizedConfig + DEFAULT_ONNX_OPSET = 14 + + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + i = { + "input_ids": {0: "text_batch_size", 1: "sequence_length"}, + "attention_mask": {0: "text_batch_size", 1: "sequence_length"}, + "input_features": {0: "audio_batch_size", 1: "num_channels", 2: "height", 3: "width"}, # As described in modeling_clap.py + } + # print('self._normalized_config', self._normalized_config.audio_config) + # print('self._normalized_config', self._normalized_config) + if(self._normalized_config.audio_config.enable_fusion): + i["is_longer"] = {0: "audio_batch_size"} + + return i + + @property + def outputs(self) -> Dict[str, Dict[int, str]]: + return { + "logits_per_audio": {0: "audio_batch_size", 1: "text_batch_size"}, + "logits_per_text": {0: "text_batch_size", 1: "audio_batch_size"}, + "text_embeds": {0: "text_batch_size"}, + "audio_embeds": {0: "audio_batch_size"}, + } + + +class CLAPTextWithProjectionOnnxConfig(TextEncoderOnnxConfig): + ATOL_FOR_VALIDATION = 1e-3 + # The ONNX export of this architecture needs the Trilu operator support, available since opset 14 + DEFAULT_ONNX_OPSET = 14 + + NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( + vocab_size="vocab_size", + sequence_length="max_position_embeddings", + num_layers="num_hidden_layers", + allow_new=True, + ) + + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + return { + "input_ids": {0: "batch_size", 1: "sequence_length"}, + } + + @property + def outputs(self) -> Dict[str, Dict[int, str]]: + common_outputs = { + "text_embeds": {0: "batch_size", 1: "sequence_length"}, + "last_hidden_state": {0: "batch_size", 1: "sequence_length"}, + } + if self._normalized_config.output_hidden_states: + for i in range(self._normalized_config.num_layers + 1): + common_outputs[f"hidden_states.{i}"] = {0: "batch_size", 1: "sequence_length"} + + return common_outputs + + +class CLAPTextOnnxConfig(CLAPTextWithProjectionOnnxConfig): + @property + def outputs(self) -> Dict[str, Dict[int, str]]: + common_outputs = { + "last_hidden_state": {0: "batch_size", 1: "sequence_length"}, + "pooler_output": {0: "batch_size"}, + } + if self._normalized_config.output_hidden_states: + for i in range(self._normalized_config.num_layers + 1): + common_outputs[f"hidden_states.{i}"] = {0: "batch_size", 1: "sequence_length"} + + return common_outputs + + def generate_dummy_inputs(self, framework: str = "pt", **kwargs): + dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs) + + if framework == "pt": + import torch + dummy_inputs["input_ids"] = dummy_inputs["input_ids"].to(dtype=torch.int32) + return dummy_inputs + + class UNetOnnxConfig(VisionOnnxConfig): ATOL_FOR_VALIDATION = 1e-3 # The ONNX export of a CLIPText architecture, an other Stable Diffusion component, needs the Trilu diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 7545c72d6c..181f2bbadb 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -182,6 +182,7 @@ class TasksManager: "text2text-generation": "AutoModelForSeq2SeqLM", "text-classification": "AutoModelForSequenceClassification", "token-classification": "AutoModelForTokenClassification", + "zero-shot-audio-classification": "AutoModel", "zero-shot-image-classification": "AutoModelForZeroShotImageClassification", "zero-shot-object-detection": "AutoModelForZeroShotObjectDetection", } @@ -389,6 +390,11 @@ class TasksManager: onnx="CamembertOnnxConfig", tflite="CamembertTFLiteConfig", ), + "clap": supported_tasks_mapping( + "feature-extraction", + "zero-shot-audio-classification", + onnx="CLAPOnnxConfig", + ), "clip": supported_tasks_mapping( "feature-extraction", "zero-shot-image-classification", diff --git a/optimum/utils/__init__.py b/optimum/utils/__init__.py index 73eb86bdae..fb8d6711ef 100644 --- a/optimum/utils/__init__.py +++ b/optimum/utils/__init__.py @@ -73,6 +73,8 @@ NormalizedEncoderDecoderConfig, NormalizedSeq2SeqConfig, NormalizedTextAndVisionConfig, + NormalizedTextAndAudioConfig, NormalizedTextConfig, NormalizedVisionConfig, + NormalizedAudioConfig, ) diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index aa1f785309..84b6189845 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -66,6 +66,7 @@ def wrapper(*args, **kwargs): "feature_size": 80, "nb_max_frames": 3000, "audio_sequence_length": 16000, + "num_mel_bins": 64, } @@ -636,7 +637,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int class DummyAudioInputGenerator(DummyInputGenerator): - SUPPORTED_INPUT_NAMES = ("input_features", "input_values") + SUPPORTED_INPUT_NAMES = ("input_features", "input_values", "is_longer") def __init__( self, @@ -646,6 +647,7 @@ def __init__( feature_size: int = DEFAULT_DUMMY_SHAPES["feature_size"], nb_max_frames: int = DEFAULT_DUMMY_SHAPES["nb_max_frames"], audio_sequence_length: int = DEFAULT_DUMMY_SHAPES["audio_sequence_length"], + num_mel_bins: int = DEFAULT_DUMMY_SHAPES["num_mel_bins"], **kwargs, ): self.task = task @@ -658,8 +660,19 @@ def __init__( self.nb_max_frames = nb_max_frames self.batch_size = batch_size self.sequence_length = audio_sequence_length + if hasattr(self.normalized_config, "num_mel_bins"): + self.num_mel_bins = self.normalized_config.num_mel_bins + else: + self.num_mel_bins = num_mel_bins + if hasattr(self.normalized_config, "enable_fusion"): + self.enable_fusion = self.normalized_config.enable_fusion + else: + self.enable_fusion = False def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + print('generate dummy', input_name) + print(f'{self.normalized_config=}') + print(f'{self.normalized_config.model_type=}') if input_name == "input_values": # raw waveform return self.random_float_tensor( shape=[self.batch_size, self.sequence_length], @@ -668,9 +681,20 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int framework=framework, dtype=float_dtype, ) + elif input_name == 'is_longer': + return self.constant_tensor(shape=[self.batch_size, 1], value=self.enable_fusion, framework=framework) + else: + if self.normalized_config.model_type == 'clap': + # TODO figure out what this value is for? + # https://huggingface.co/laion/clap-htsat-fused uses 4 + num_channels = 1 + shape = [self.batch_size, num_channels, self.feature_size, self.num_mel_bins] + else: + shape = [self.batch_size, self.feature_size, self.nb_max_frames] + return self.random_float_tensor( - shape=[self.batch_size, self.feature_size, self.nb_max_frames], + shape=shape, min_value=-1, max_value=1, framework=framework, diff --git a/optimum/utils/normalized_config.py b/optimum/utils/normalized_config.py index 7a0af9a1a4..e364b7d8ef 100644 --- a/optimum/utils/normalized_config.py +++ b/optimum/utils/normalized_config.py @@ -101,6 +101,10 @@ class NormalizedVisionConfig(NormalizedConfig): NUM_CHANNELS = "num_channels" +class NormalizedAudioConfig(NormalizedConfig): + NUM_MEL_BINS = "num_mel_bins" + + class NormalizedTextAndVisionConfig(NormalizedTextConfig, NormalizedVisionConfig): TEXT_CONFIG = None VISION_CONFIG = None @@ -118,6 +122,18 @@ def __getattr__(self, attr_name): ) +class NormalizedTextAndAudioConfig(NormalizedTextConfig, NormalizedAudioConfig): + TEXT_CONFIG = None + AUDIO_CONFIG = None + + def __getattr__(self, attr_name): + if self.TEXT_CONFIG is not None and attr_name.upper() in dir(NormalizedTextConfig): + attr_name = f"{self.TEXT_CONFIG}.{attr_name}" + elif self.AUDIO_CONFIG is not None and attr_name.upper() in dir(NormalizedAudioConfig): + attr_name = f"{self.AUDIO_CONFIG}.{attr_name}" + return super().__getattr__(attr_name) + + class NormalizedEncoderDecoderConfig(NormalizedConfig): ENCODER_NORMALIZED_CONFIG_CLASS = None DECODER_NORMALIZED_CONFIG_CLASS = None