Skip to content

Commit

Permalink
🧪 Add num_samples parameter to test_trainer.py
Browse files Browse the repository at this point in the history
  • Loading branch information
arxyzan committed Feb 9, 2024
1 parent 2ecef57 commit bd24ac6
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"text_classification": {
"dataset": {
"path": "hezarai/sentiment-dksf",
"num_samples": 4,
"config": {
"tokenizer_path": "hezarai/bert-base-fa",
}
Expand All @@ -40,6 +41,7 @@
"sequence_labeling": {
"dataset": {
"path": "hezarai/lscp-pos-500k",
"num_samples": 4,
"config": {
"tokenizer_path": "hezarai/bert-base-fa",
}
Expand All @@ -55,6 +57,7 @@
"text_summarization": {
"dataset": {
"path": "hezarai/xlsum-fa",
"num_samples": 4,
"config": {
"tokenizer_path": "hezarai/t5-base-fa",
"max_length": 32
Expand All @@ -71,6 +74,7 @@
"ocr": {
"dataset": {
"path": "hezarai/persian-license-plate-v1",
"num_samples": 4,
"config": {
"max_length": 8,
"reverse_digits": True,
Expand All @@ -87,6 +91,7 @@
"image-captioning": {
"dataset": {
"path": "hezarai/flickr30k-fa",
"num_samples": 2,
"config": {
"max_length": 16,
"tokenizer_path": "hezarai/vit-roberta-fa-base"
Expand All @@ -104,6 +109,7 @@
"speech-recognition": {
"dataset": {
"path": "hezarai/common-voice-13-fa",
"num_samples": 2,
"config": {
"labels_max_length": 16,
"tokenizer_path": "hezarai/whisper-small",
Expand All @@ -117,7 +123,7 @@
"task": "speech_recognition",
"mixed_precision": "fp16",
"metrics": ["wer", "cer"]
}
},
}
}

Expand All @@ -132,10 +138,11 @@
@pytest.mark.parametrize("task", tasks_setups.keys())
def test_trainer(task):
setup = tasks_setups[task]
num_samples = setup["dataset"]["num_samples"]

# Datasets
train_dataset = Dataset.load(setup["dataset"]["path"], split="train[:4]", **setup["dataset"]["config"])
eval_dataset = Dataset.load(setup["dataset"]["path"], split="test[:4]", **setup["dataset"]["config"])
train_dataset = Dataset.load(setup["dataset"]["path"], split=f"train[:{num_samples}]", **setup["dataset"]["config"])
eval_dataset = Dataset.load(setup["dataset"]["path"], split=f"test[:{num_samples}]", **setup["dataset"]["config"])

# Model & Preprocessor
model_config = ModelConfig.load(setup["model"]["path"])
Expand Down

0 comments on commit bd24ac6

Please sign in to comment.