Skip to content

Commit

Permalink
update get_input_samples includes task_ids when n_tasks is given
Browse files Browse the repository at this point in the history
  • Loading branch information
Francois Ledoyen authored and Francois Ledoyen committed Feb 12, 2025
1 parent 9c969b4 commit f74a1a1
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 15 deletions.
26 changes: 17 additions & 9 deletions tests/test_methods/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@ class AbstractAdapterTestBase:
do_run_train_tests = True
num_labels = 2

def get_input_samples(self, shape=None, vocab_size=5000, config=None, **kwargs):
"""Creates a dummy batch of samples in the format required for the model."""
raise NotImplementedError("get_input_samples() must be implemented in the subclass.")

def add_head(self, model, name, **kwargs):
"""Adds a dummy head to the model."""
raise NotImplementedError("add_head() must be implemented in the subclass.")
Expand All @@ -42,6 +38,12 @@ def attach_labels(self, inputs):
"""Attaches labels to the input samples."""
raise NotImplementedError("attach_labels() with respective label shape must be implemented in the subclass.")

def get_input_samples(self, shape=None, vocab_size=5000, config=None, **kwargs):
in_data = {}
if "n_tasks" in kwargs:
in_data["task_ids"] = torch.randint(0, kwargs["n_tasks"], (shape[0],))
return in_data

def get_model(self):
"""Builds a model instance for testing based on the provied model configuration."""
if self.model_class == AutoAdapterModel:
Expand Down Expand Up @@ -91,20 +93,21 @@ class TextAdapterTestBase(AbstractAdapterTestBase):

def get_input_samples(self, shape=None, vocab_size=5000, config=None, **kwargs):
shape = shape or self.input_shape
in_data = super().get_input_samples(shape, vocab_size, config, **kwargs)
input_ids = self.build_rand_ids_tensor(shape, vocab_size=vocab_size)

# Ensures that only tha last token in each sample is the eos token (needed e.g. for BART)
if config and config.eos_token_id is not None and config.eos_token_id < vocab_size:
input_ids[input_ids == config.eos_token_id] = random.randint(0, config.eos_token_id - 1)
input_ids[:, -1] = config.eos_token_id
in_data = {"input_ids": input_ids}
in_data["input_ids"] = input_ids

# Add decoder input ids for models with a decoder
if config and config.is_encoder_decoder:
in_data["decoder_input_ids"] = input_ids.clone()

if "num_labels" in kwargs:
in_data["labels"] = self.build_rand_ids_tensor(shape[:-1], vocab_size=kwargs["num_labels"])

return in_data

def add_head(self, model, name, **kwargs):
Expand All @@ -118,7 +121,9 @@ def get_dataset(self, tokenizer=None):
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
data_args = GlueDataTrainingArguments(
task_name="mrpc", data_dir="./hf_transformers/tests/fixtures/tests_samples/MRPC", overwrite_cache=True
task_name="mrpc",
data_dir="./hf_transformers/tests/fixtures/tests_samples/MRPC",
overwrite_cache=True,
)
return GlueDataset(data_args, tokenizer=tokenizer, mode="train")

Expand All @@ -143,8 +148,10 @@ class VisionAdapterTestBase(AbstractAdapterTestBase):

def get_input_samples(self, shape=None, config=None, dtype=torch.float, **kwargs):
shape = shape or self.input_shape
in_data = super().get_input_samples(shape, config=config, **kwargs)
pixel_values = self.build_rand_tensor(shape, dtype=dtype)
return {"pixel_values": pixel_values}
in_data["pixel_values"] = pixel_values
return in_data

def add_head(self, model, name, **kwargs):
kwargs["num_labels"] = 10 if "num_labels" not in kwargs else kwargs["num_labels"]
Expand Down Expand Up @@ -198,7 +205,8 @@ def add_head(self, model, name, head_type="seq2seq_lm", **kwargs):

def get_input_samples(self, shape=None, config=None, **kwargs):
shape = shape or self.input_shape
in_data = {"input_features": self.build_rand_tensor(shape, dtype=torch.float)}
in_data = super().get_input_samples(shape, config=config, **kwargs)
in_data["input_features"] = self.build_rand_tensor(shape, dtype=torch.float)

# Add decoder input ids for models with a decoder
if config and config.is_encoder_decoder:
Expand Down
10 changes: 5 additions & 5 deletions tests/test_methods/method_test_impl/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,16 +157,16 @@ def run_get_test(self, model, adapter_config, num_expected_modules):

model.delete_adapter("first")

def run_forward_test(self, model, adapter_config, dtype=torch.float32, adapter_setup=None):
def run_forward_test(self, model, adapter_config, dtype=torch.float32, **kwargs):
model.eval()

name = adapter_config.__class__.__name__
adapter_setup = adapter_setup or name
adapter_setup = kwargs.get("adapter_setup") or name
if name not in model.adapters_config:
model.add_adapter(name, config=adapter_config)
model.to(torch_device).to(dtype)

input_data = self.get_input_samples(config=model.config, dtype=dtype)
input_data = self.get_input_samples(config=model.config, dtype=dtype, **kwargs)

# pass 1: set adapter via property
model.set_active_adapters(adapter_setup)
Expand All @@ -192,7 +192,7 @@ def run_forward_test(self, model, adapter_config, dtype=torch.float32, adapter_s
model.set_active_adapters(None)
model.delete_adapter(name)

def run_load_test(self, adapter_config):
def run_load_test(self, adapter_config, **kwargs):
model1, model2 = create_twin_models(self.model_class, self.config)

name = "dummy_adapter"
Expand Down Expand Up @@ -221,7 +221,7 @@ def run_load_test(self, adapter_config):
self.assertTrue(name in model2.adapters_config)

# check equal output
input_data = self.get_input_samples(config=model1.config)
input_data = self.get_input_samples(config=model1.config, **kwargs)
model1.to(torch_device)
model2.to(torch_device)
output1 = model1(**input_data)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_methods/test_on_clip/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class CLIPAdapterTestBase(TextAdapterTestBase):
def get_input_samples(self, vocab_size=5000, config=None, dtype=torch.float, **kwargs):
# text inputs
shape = self.default_text_input_samples_shape
in_data = super().get_input_samples(shape, vocab_size, config, **kwargs)
total_dims = 1
for dim in shape:
total_dims *= dim
Expand All @@ -47,7 +48,7 @@ def get_input_samples(self, vocab_size=5000, config=None, dtype=torch.float, **k
if config and config.eos_token_id is not None and config.eos_token_id < vocab_size:
input_ids[input_ids == config.eos_token_id] = random.randint(0, config.eos_token_id - 1)
input_ids[:, -1] = config.eos_token_id
in_data = {"input_ids": input_ids}
in_data["input_ids"] = input_ids

# vision inputs
shape = self.default_vision_input_samples_shape
Expand Down

0 comments on commit f74a1a1

Please sign in to comment.