Skip to content

Commit

Permalink
small fixes, improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
djsaunde committed Jan 24, 2025
1 parent ef38f10 commit 0e9bfa6
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 41 deletions.
4 changes: 2 additions & 2 deletions src/axolotl/cli/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
LOG = logging.getLogger(__name__)


def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> dict[str, float]:
"""
Evaluates a `transformers` model by first loading the dataset(s) specified in the
`axolotl` config, and then calling `axolotl.evaluate.evaluate`, which computes
Expand All @@ -39,7 +39,7 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

evaluate(cfg=cfg, dataset_meta=dataset_meta)
return evaluate(cfg=cfg, dataset_meta=dataset_meta)


def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
Expand Down
6 changes: 3 additions & 3 deletions src/axolotl/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import sys
from pathlib import Path
from typing import Dict, Optional
from typing import Optional

import torch
from accelerate.logging import get_logger
Expand All @@ -26,7 +26,7 @@

def evaluate_dataset(
trainer, dataset, dataset_type: str, flash_optimum: bool = False
) -> Optional[Dict[str, float]]:
) -> Optional[dict[str, float]]:
"""Helper function to evaluate a single dataset safely.
Args:
Expand Down Expand Up @@ -61,7 +61,7 @@ def evaluate_dataset(
return metrics


def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> dict[str, float]:
"""
Evaluate a model on training and validation datasets
Expand Down
46 changes: 10 additions & 36 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,45 +709,19 @@ def set_attention_config(self) -> None:
if self.cfg.flash_attention:
if not self.cfg.sample_packing and self.cfg.s2_attention:
pass

if self.cfg.diff_attention:
self.model_kwargs[
"attn_implementation"
] = "differential_flash_attention_2"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"differential_flash_attention_2"
)
else:
self.model_kwargs["attn_implementation"] = "flash_attention_2"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
self.model_kwargs["attn_implementation"] = "flash_attention_2"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
elif self.cfg.sdp_attention:
if self.cfg.diff_attention:
self.model_kwargs["attn_implementation"] = "differential_sdpa"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"differential_sdpa"
)
else:
self.model_kwargs["attn_implementation"] = "sdpa"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"sdpa"
)
self.model_kwargs["attn_implementation"] = "sdpa"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"sdpa"
)
elif self.cfg.eager_attention:
if self.cfg.diff_attention:
self.model_kwargs["attn_implementation"] = "differential_eager"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"differential_eager"
)
else:
self.model_kwargs["attn_implementation"] = "eager"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"eager"
)
elif self.cfg.diff_attention:
self.model_kwargs["attn_implementation"] = "differential_eager"
self.model_kwargs["attn_implementation"] = "eager"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"differential_eager"
"eager"
)

if self.cfg.low_cpu_mem_usage:
Expand Down

0 comments on commit 0e9bfa6

Please sign in to comment.