From 55715812b64d6c07a2cace5868dd9ce87015db2c Mon Sep 17 00:00:00 2001 From: PGijsbers Date: Wed, 27 Nov 2024 18:39:04 +0200 Subject: [PATCH] Separate out framework def as dataclass --- amlb/benchmark.py | 11 +++----- amlb/frameworks/definitions.py | 50 +++++++++++++++++++++++++++++++++- 2 files changed, 53 insertions(+), 8 deletions(-) diff --git a/amlb/benchmark.py b/amlb/benchmark.py index ff078a67a..90a010053 100644 --- a/amlb/benchmark.py +++ b/amlb/benchmark.py @@ -22,6 +22,7 @@ import pandas as pd +from .frameworks.definitions import load_framework_definition from .job import Job, JobError, SimpleJobRunner, MultiThreadingJobRunner from .datasets import DataLoader, DataSourceType from .data import DatasetType @@ -119,13 +120,8 @@ def __init__( Benchmark.data_loader = DataLoader(rconfig()) self._job_history = self._load_job_history(job_history=job_history) - - fsplits = framework_name.split(":", 1) - framework_name = fsplits[0] - tag = fsplits[1] if len(fsplits) > 1 else None - self.framework_def, self.framework_name = rget().framework_definition( - framework_name, tag - ) + framework = load_framework_definition(framework_name, rget()) + self.framework_def, self.framework_name = framework, framework.name log.debug("Using framework definition: %s.", self.framework_def) self.constraint_def, self.constraint_name = rget().constraint_definition( @@ -658,6 +654,7 @@ def handle_unfulfilled(message, on_auto="warn"): class BenchmarkTask: + def __init__(self, benchmark: Benchmark, task_def, fold): """ diff --git a/amlb/frameworks/definitions.py b/amlb/frameworks/definitions.py index e096876ac..bc3dcb446 100644 --- a/amlb/frameworks/definitions.py +++ b/amlb/frameworks/definitions.py @@ -1,11 +1,17 @@ +from __future__ import annotations + import copy import itertools import logging import os -from typing import List, Optional, Union +from dataclasses import dataclass, field +from typing import List, Optional, Union, TYPE_CHECKING from amlb.utils import Namespace, config_load, str_sanitize +if TYPE_CHECKING: + from amlb import Resources + log = logging.getLogger(__name__) default_tag = "_" @@ -233,3 +239,45 @@ def _remove_frameworks_with_unknown_parent(frameworks: Namespace): "Removing framework %s as parent %s doesn't exist.", framework, parent ) del frameworks[framework] + + +@dataclass +class Image: + author: str + image: str + tag: str + + +@dataclass +class Framework: + name: str + abstract: bool + module: str + version: str + # Image + image: Image + # Setup + _setup_cmd: str | None + setup_cmd: str | None + setup_script: str | None + setup_env: dict = field(default_factory=dict) + setup_args: list[str] = field(default_factory=list) + # more optionals + params: dict = field(default_factory=dict) + refs: list = field(default_factory=list) + description: str | None = None + project: str | None = None + + def __post_init__(self): + if isinstance(self.image, dict): + self.image = Image(**self.image) + + +def load_framework_definition( + framework_name: str, configuration: "Resources" +) -> Framework: + tag = None + if ":" in framework_name: + framework_name, tag = framework_name.split(":", 1) + definition_ns, name = configuration.framework_definition(framework_name, tag) + return Framework(**Namespace.dict(definition_ns))