diff --git a/pyvene/models/esm/__init__.py b/pyvene/models/esm/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pyvene/models/esm/modelings_intervenable_esm.py b/pyvene/models/esm/modelings_intervenable_esm.py new file mode 100644 index 00000000..5a5c21fa --- /dev/null +++ b/pyvene/models/esm/modelings_intervenable_esm.py @@ -0,0 +1,63 @@ +""" +Each modeling file in this library is a mapping between +abstract naming of intervention anchor points and actual +model module defined in the huggingface library. + +We also want to let the intervention library know how to +config the dimensions of intervention based on model config +defined in the huggingface library. +""" + +from ..constants import ( + CONST_INPUT_HOOK, + CONST_OUTPUT_HOOK +) + +"""esm base model""" +esm_type_to_module_mapping = dict( + block_input=("encoder.layer[%s]", CONST_INPUT_HOOK), + block_output=("encoder.layer[%s]", CONST_OUTPUT_HOOK), + + mlp_input=("encoder.layer[%s].intermediate", CONST_INPUT_HOOK), + mlp_activation=("encoder.layer[%s].intermediate", CONST_OUTPUT_HOOK), + mlp_output=("encoder.layer[%s].output", CONST_OUTPUT_HOOK), + + attention_value_output=("encoder.layer[%s].attention.output", CONST_INPUT_HOOK), + head_attention_value_output=("encoder.layer[%s].attention.output", CONST_INPUT_HOOK), + + attention_input=("encoder.layer[%s].attention", CONST_INPUT_HOOK), + attention_output=("encoder.layer[%s].attention", CONST_OUTPUT_HOOK), + + query_output=("encoder.layer[%s].attention.self.query", CONST_OUTPUT_HOOK), + head_query_output=("encoder.layer[%s].attention.self.query", CONST_OUTPUT_HOOK), + key_output=("encoder.layer[%s].attention.self.key", CONST_OUTPUT_HOOK), + head_key_output=("encoder.layer[%s].attention.self.key", CONST_OUTPUT_HOOK), + value_output=("encoder.layer[%s].attention.self.value", CONST_OUTPUT_HOOK), + head_value_output=("encoder.layer[%s].attention.self.value", CONST_OUTPUT_HOOK), + +) +esm_type_to_dimension_mapping = dict( + block_input=("hidden_size",), + block_output=("hidden_size"), + + mlp_input=("hidden_size",), + mlp_activation=("intermediate_size",), + mlp_output=("hidden_size",), + + attention_value_output=("hidden_size",), + head_attention_value_output=("hidden_size/num_attention_heads",), + + attention_input=("num_attention_heads",), + attention_output=("num_attention_heads/num_attention_heads",), + + query_output=("num_attention_heads",), + head_query_output=("num_attention_heads/num_attention_heads",), + key_output=("num_attention_heads",), + head_key_output=("num_attention_heads/num_attention_heads",), + value_output=("num_attention_heads",), + head_value_output=("num_attention_heads/num_attention_heads",), +) + +"""esm for mlm model""" +esm_mlm_type_to_module_mapping = {k: ("esm." + i, j) for k, (i, j) in esm_type_to_module_mapping.items()} +esm_mlm_type_to_dimension_mapping = esm_type_to_dimension_mapping.copy() diff --git a/pyvene/models/intervenable_modelcard.py b/pyvene/models/intervenable_modelcard.py index 1f7329f4..c500639b 100644 --- a/pyvene/models/intervenable_modelcard.py +++ b/pyvene/models/intervenable_modelcard.py @@ -13,6 +13,7 @@ from .backpack_gpt2.modelings_intervenable_backpack_gpt2 import * from .llava.modelings_intervenable_llava import * from .olmo.modelings_intervenable_olmo import * +from .esm.modelings_intervenable_esm import * ######################################################################### """ @@ -63,6 +64,8 @@ hf_models.gemma2.modeling_gemma2.Gemma2ForCausalLM: gemma2_lm_type_to_module_mapping, hf_models.olmo.modeling_olmo.OlmoModel: olmo_type_to_module_mapping, hf_models.olmo.modeling_olmo.OlmoForCausalLM: olmo_lm_type_to_module_mapping, + hf_models.esm.modeling_esm.EsmModel: esm_type_to_module_mapping, + hf_models.esm.modeling_esm.EsmForMaskedLM: esm_mlm_type_to_module_mapping, hf_models.blip.modeling_blip.BlipForQuestionAnswering: blip_type_to_module_mapping, hf_models.blip.modeling_blip.BlipForImageTextRetrieval: blip_itm_type_to_module_mapping, MLPModel: mlp_type_to_module_mapping, @@ -98,6 +101,8 @@ hf_models.gemma2.modeling_gemma2.Gemma2ForCausalLM: gemma2_lm_type_to_dimension_mapping, hf_models.olmo.modeling_olmo.OlmoModel: olmo_type_to_dimension_mapping, hf_models.olmo.modeling_olmo.OlmoForCausalLM: olmo_lm_type_to_dimension_mapping, + hf_models.esm.modeling_esm.EsmModel: esm_type_to_dimension_mapping, + hf_models.esm.modeling_esm.EsmForMaskedLM: esm_mlm_type_to_dimension_mapping, hf_models.blip.modeling_blip.BlipForQuestionAnswering: blip_type_to_dimension_mapping, hf_models.blip.modeling_blip.BlipForImageTextRetrieval: blip_itm_type_to_dimension_mapping, MLPModel: mlp_type_to_dimension_mapping,