diff --git a/configs/1B/H100_llama2_edu_no_feat.toml b/configs/1B/H100_llama2_edu_no_feat.toml deleted file mode 100644 index 0afd432f..00000000 --- a/configs/1B/H100_llama2_edu_no_feat.toml +++ /dev/null @@ -1,23 +0,0 @@ -name_model = "1B" -project = "debug_1B_zero_band" -type_model = "llama2" - -[train] -micro_bs = 4 -reshard_after_forward = true -attn_fn = "sdpa" -sequence_packing = false - -[data] -seq_length = 8192 -num_workers = 4 -dataset_name_or_paths = "/data/datasets/fineweb-edu" -reverse_data_files = true - -[optim] -batch_size = 256 -warmup_steps = 1000 -total_steps = 1_000_000_000_000 -sched_type = "wsd-sqrt" -lr = 2e-4 -z_loss = false diff --git a/scripts/export_dcp.py b/scripts/export_dcp.py index cf7460dd..82019e05 100644 --- a/scripts/export_dcp.py +++ b/scripts/export_dcp.py @@ -138,7 +138,6 @@ def main(config: ExportConfig): config.type_model, vocab_size=len(tokenizer), seq_length=config.data.seq_length, - attn_fn=config.train.attn_fn, ) # Convert ZeroBand config to HuggingFace config diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 7ab7cb8d..ae82469e 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -1,7 +1,6 @@ import os from typing import Literal import time -import warnings import psutil from pydantic import model_validator from multiprocessing.process import _children @@ -77,14 +76,6 @@ class TrainConfig(BaseConfig): memory_profiler: MemoryProfilerConfig | None = None sequence_packing: bool = True - attn_fn: Literal["flash", "sdpa"] | None = None - - @model_validator(mode="after") - def validate_attn_fn(self): - if self.attn_fn is not None: - warnings.warn("attn_fn argument is deprecated") - - return self class MonitorConfig(BaseConfig):