Skip to content

Commit

Permalink
Minor fixes of finetuning examples
Browse files Browse the repository at this point in the history
  • Loading branch information
kasnerz committed Jun 1, 2023
1 parent fc6b0a4 commit 33755f8
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 30 deletions.
13 changes: 4 additions & 9 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

You can load the datasets in a unified way for training the models.


## Requirements
Please, first install the additional requirements listed in `./requirements.txt`.

## Fine-tuning using `transformers` library
`finetuning_transformers.py` is an example of fine-tuning a seq2seq model and running inference using `transformers` library. It can be run from the command line, e.g. :
```
Expand All @@ -23,12 +27,3 @@ The example `multitasking.py` is almost equivalent to `finetuning_transformers.p
Dataset-specific task description is prepended to each input item before training. <br>
In this example, custom linearization functions are implemented for E2E and WebNLG datasets.


## Requirements

Requirements to run the examples (in addition to `tabgenie`):
* `numpy`
* `evaluate`
* `transformers==4.25.1`
* `torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113`
* `sacrebleu`
14 changes: 3 additions & 11 deletions examples/finetuning_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,6 @@
from tabgenie import load_dataset


# extra requirements:
# numpy
# evaluate
# transformers==4.25.1
# torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
# sacrebleu


SEED = 42
random.seed(SEED)
np.random.seed(SEED)
Expand All @@ -36,7 +28,7 @@

MAX_LENGTH = 512
LABEL_PAD_TOKEN_ID = -100
PATIENCE = 5
PATIENCE = 3

BLEU_METRIC = evaluate.load("sacrebleu")

Expand Down Expand Up @@ -77,8 +69,8 @@ def compute_bleu(eval_preds, tokenizer):
@click.command()
@click.option("--dataset", "-d", required=True, type=str, help="Dataset to train on")
@click.option("--base-model", "-m", default="t5-small", type=str, help="Base model to finetune")
@click.option("--epochs", "-e", default=30, type=int, help="Maximum number of epochs")
@click.option("--batch-size", "-b", default=16, type=int, help="Path to the output directory")
@click.option("--epochs", "-e", default=10, type=int, help="Maximum number of epochs")
@click.option("--batch-size", "-b", default=16, type=int, help="Number of examples in a batch")
@click.option("--ckpt-dir", "-c", default=os.path.join(ROOT_DIR, "checkpoints"), type=str, help="Directory to store checkpoints")
@click.option("--output-dir", "-o", default=os.path.join(ROOT_DIR, "models"), type=str, help="Directory to store models and their outputs")
def main(dataset, base_model, epochs, batch_size, ckpt_dir, output_dir):
Expand Down
12 changes: 2 additions & 10 deletions examples/multitasking.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,6 @@
from tabgenie import load_dataset


# extra requirements:
# numpy
# evaluate
# transformers==4.25.1
# torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
# sacrebleu


SEED = 42
random.seed(SEED)
np.random.seed(SEED)
Expand All @@ -37,7 +29,7 @@

MAX_LENGTH = 512
LABEL_PAD_TOKEN_ID = -100
PATIENCE = 5
PATIENCE = 3

BLEU_METRIC = evaluate.load("sacrebleu")

Expand Down Expand Up @@ -92,7 +84,7 @@ def compute_bleu(eval_preds, tokenizer):
@click.command()
@click.option("--datasets", "-d", required=True, type=str, help="Datasets to train on")
@click.option("--base-model", "-m", default="t5-small", type=str, help="Base model to finetune")
@click.option("--epochs", "-e", default=30, type=int, help="Maximum number of epochs")
@click.option("--epochs", "-e", default=10, type=int, help="Maximum number of epochs")
@click.option("--batch-size", "-b", default=16, type=int, help="Path to the output directory")
@click.option("--ckpt-dir", "-c", default=os.path.join(ROOT_DIR, "checkpoints"), type=str, help="Directory to store checkpoints")
@click.option("--output-dir", "-o", default=os.path.join(ROOT_DIR, "models"), type=str, help="Directory to store models and their outputs")
Expand Down

0 comments on commit 33755f8

Please sign in to comment.