From 55715812b64d6c07a2cace5868dd9ce87015db2c Mon Sep 17 00:00:00 2001
From: PGijsbers
Date: Wed, 27 Nov 2024 18:39:04 +0200
Subject: [PATCH] Separate out framework def as dataclass
---
amlb/benchmark.py | 11 +++-----
amlb/frameworks/definitions.py | 50 +++++++++++++++++++++++++++++++++-
2 files changed, 53 insertions(+), 8 deletions(-)
diff --git a/amlb/benchmark.py b/amlb/benchmark.py
index ff078a67a..90a010053 100644
--- a/amlb/benchmark.py
+++ b/amlb/benchmark.py
@@ -22,6 +22,7 @@
import pandas as pd
+from .frameworks.definitions import load_framework_definition
from .job import Job, JobError, SimpleJobRunner, MultiThreadingJobRunner
from .datasets import DataLoader, DataSourceType
from .data import DatasetType
@@ -119,13 +120,8 @@ def __init__(
Benchmark.data_loader = DataLoader(rconfig())
self._job_history = self._load_job_history(job_history=job_history)
-
- fsplits = framework_name.split(":", 1)
- framework_name = fsplits[0]
- tag = fsplits[1] if len(fsplits) > 1 else None
- self.framework_def, self.framework_name = rget().framework_definition(
- framework_name, tag
- )
+ framework = load_framework_definition(framework_name, rget())
+ self.framework_def, self.framework_name = framework, framework.name
log.debug("Using framework definition: %s.", self.framework_def)
self.constraint_def, self.constraint_name = rget().constraint_definition(
@@ -658,6 +654,7 @@ def handle_unfulfilled(message, on_auto="warn"):
class BenchmarkTask:
+
def __init__(self, benchmark: Benchmark, task_def, fold):
"""
diff --git a/amlb/frameworks/definitions.py b/amlb/frameworks/definitions.py
index e096876ac..bc3dcb446 100644
--- a/amlb/frameworks/definitions.py
+++ b/amlb/frameworks/definitions.py
@@ -1,11 +1,17 @@
+from __future__ import annotations
+
import copy
import itertools
import logging
import os
-from typing import List, Optional, Union
+from dataclasses import dataclass, field
+from typing import List, Optional, Union, TYPE_CHECKING
from amlb.utils import Namespace, config_load, str_sanitize
+if TYPE_CHECKING:
+ from amlb import Resources
+
log = logging.getLogger(__name__)
default_tag = "_"
@@ -233,3 +239,45 @@ def _remove_frameworks_with_unknown_parent(frameworks: Namespace):
"Removing framework %s as parent %s doesn't exist.", framework, parent
)
del frameworks[framework]
+
+
+@dataclass
+class Image:
+ author: str
+ image: str
+ tag: str
+
+
+@dataclass
+class Framework:
+ name: str
+ abstract: bool
+ module: str
+ version: str
+ # Image
+ image: Image
+ # Setup
+ _setup_cmd: str | None
+ setup_cmd: str | None
+ setup_script: str | None
+ setup_env: dict = field(default_factory=dict)
+ setup_args: list[str] = field(default_factory=list)
+ # more optionals
+ params: dict = field(default_factory=dict)
+ refs: list = field(default_factory=list)
+ description: str | None = None
+ project: str | None = None
+
+ def __post_init__(self):
+ if isinstance(self.image, dict):
+ self.image = Image(**self.image)
+
+
+def load_framework_definition(
+ framework_name: str, configuration: "Resources"
+) -> Framework:
+ tag = None
+ if ":" in framework_name:
+ framework_name, tag = framework_name.split(":", 1)
+ definition_ns, name = configuration.framework_definition(framework_name, tag)
+ return Framework(**Namespace.dict(definition_ns))