You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Originally posted by san-deep-reddy April 14, 2024
I have tried many models and adapter types, custom head types and configs but I always end up with the same error -
File "testing\adapter2_testing.py", line 173, in <module>
trainer.train()
File "C:\Users\dlais\AppData\Local\Programs\Python\Python39\lib\site-packages\transformers\trainer.py", line 1537, in train
return inner_training_loop(
File "C:\Users\dlais\AppData\Local\Programs\Python\Python39\lib\site-packages\transformers\trainer.py", line 1772, in _inner_training_loop
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
File "C:\Users\dlais\AppData\Local\Programs\Python\Python39\lib\site-packages\transformers\trainer_callback.py", line 370, in on_train_begin
return self.call_event("on_train_begin", args, state, control)
File "C:\Users\dlais\AppData\Local\Programs\Python\Python39\lib\site-packages\transformers\trainer_callback.py", line 414, in call_event
result = getattr(callback, event)(
File "C:\Users\dlais\AppData\Local\Programs\Python\Python39\lib\site-packages\transformers\integrations\integration_utils.py", line 635, in on_train_begin
model_config_json = model.config.to_json_string()
File "C:\Users\dlais\AppData\Local\Programs\Python\Python39\lib\site-packages\transformers\configuration_utils.py", line 951, in to_json_string
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
File "C:\Users\dlais\AppData\Local\Programs\Python\Python39\lib\json\__init__.py", line 234, in dumps
return cls(
File "C:\Users\dlais\AppData\Local\Programs\Python\Python39\lib\json\encoder.py", line 201, in encode
chunks = list(chunks)
File "C:\Users\dlais\AppData\Local\Programs\Python\Python39\lib\json\encoder.py", line 431, in _iterencode
yield from _iterencode_dict(o, _current_indent_level)
File "C:\Users\dlais\AppData\Local\Programs\Python\Python39\lib\json\encoder.py", line 405, in _iterencode_dict
yield from chunks
File "C:\Users\dlais\AppData\Local\Programs\Python\Python39\lib\json\encoder.py", line 405, in _iterencode_dict
yield from chunks
File "C:\Users\dlais\AppData\Local\Programs\Python\Python39\lib\json\encoder.py", line 438, in _iterencode
o = _default(o)
File "C:\Users\dlais\AppData\Local\Programs\Python\Python39\lib\json\encoder.py", line 179, in default
raise TypeError(f'Object of type {o.__class__.__name__} '
TypeError: Object of type type is not JSON serializable
Here is my simple code
import adapters, torch
import torch.nn as nn
import numpy as np
from adapters.heads import PredictionHead
from adapters import AutoAdapterModel, AdapterTrainer, SeqBnConfig
from transformers import AutoTokenizer, TrainingArguments, EvalPrediction
from datasets import load_dataset
model_path = "bert-base-uncased"
model = AutoAdapterModel.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
class CustomHead(PredictionHead):
""" Same as ClassificationHead """
def __init__(self, model, head_name, **config):
super().__init__(head_name)
self.config = config
self.build(model=model)
def forward(self, outputs, cls_output=None, attention_mask=None, return_dict=False, **kwargs):
if cls_output is None:
cls_output = self._get_cls_output(outputs, **kwargs)
logits = super().forward(cls_output)
loss = None
loss_fct = nn.CrossEntropyLoss()
labels = kwargs.pop("labels", None)
loss = loss_fct(logits.view(-1, self.config["num_labels"]), labels.view(-1))
outputs = (logits,) + outputs[1:]
if labels is not None:
outputs = (loss,) + outputs
return outputs
seq_config = SeqBnConfig(reduction_factor=16, use_gating=True)
model.add_adapter("adapter2", config=seq_config)
model.delete_head('default')
model.register_custom_head("my_custom_head", CustomHead)
config = {"num_labels": 2, "layers": 1, "activation_function": "tanh"}
model.add_custom_head("my_custom_head", "adapter2", **config)
model.train_adapter(['adapter2'])
model.set_active_adapters(['adapter2']) #This line is redundant
def encode_batch(batch):
"""Encodes a batch of input data using the model tokenizer."""
return tokenizer(batch["text"], max_length=512, padding=True, truncation=True, return_tensors="pt")
dataset = load_dataset("rotten_tomatoes")
dataset = dataset.map(encode_batch, batched=True)
dataset = dataset.rename_column(original_column_name="label", new_column_name="labels")
dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=1,
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
learning_rate=2e-5,
evaluation_strategy="epoch",
report_to="tensorboard"
)
def compute_accuracy(p: EvalPrediction):
preds = np.argmax(p.predictions, axis=1)
return {"acc": (preds == p.label_ids).mean()}
trainer = AdapterTrainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["validation"],
compute_metrics=compute_accuracy,
)
trainer.train()
model.add_classification_head works fine though.
The text was updated successfully, but these errors were encountered:
To make the model_config serializable and prevent the error mentioned in
#680 move the costum_heads dictionary out of the config and make it a
separate attribute of the model class.
To make the model_config serializable and prevent the error mentioned in
adapter-hub#680 move the costum_heads dictionary out of the config and make it a
separate attribute of the model class.
Discussed in #679
Originally posted by san-deep-reddy April 14, 2024
I have tried many models and adapter types, custom head types and configs but I always end up with the same error -
Here is my simple code
model.add_classification_head works fine though.
The text was updated successfully, but these errors were encountered: