Skip to content

Commit

Permalink
Separate out framework def as dataclass
Browse files Browse the repository at this point in the history
  • Loading branch information
PGijsbers committed Dec 8, 2024
1 parent d493e92 commit 5571581
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 8 deletions.
11 changes: 4 additions & 7 deletions amlb/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import pandas as pd

from .frameworks.definitions import load_framework_definition
from .job import Job, JobError, SimpleJobRunner, MultiThreadingJobRunner
from .datasets import DataLoader, DataSourceType
from .data import DatasetType
Expand Down Expand Up @@ -119,13 +120,8 @@ def __init__(
Benchmark.data_loader = DataLoader(rconfig())

self._job_history = self._load_job_history(job_history=job_history)

fsplits = framework_name.split(":", 1)
framework_name = fsplits[0]
tag = fsplits[1] if len(fsplits) > 1 else None
self.framework_def, self.framework_name = rget().framework_definition(
framework_name, tag
)
framework = load_framework_definition(framework_name, rget())
self.framework_def, self.framework_name = framework, framework.name
log.debug("Using framework definition: %s.", self.framework_def)

self.constraint_def, self.constraint_name = rget().constraint_definition(
Expand Down Expand Up @@ -658,6 +654,7 @@ def handle_unfulfilled(message, on_auto="warn"):


class BenchmarkTask:

def __init__(self, benchmark: Benchmark, task_def, fold):
"""
Expand Down
50 changes: 49 additions & 1 deletion amlb/frameworks/definitions.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from __future__ import annotations

import copy
import itertools
import logging
import os
from typing import List, Optional, Union
from dataclasses import dataclass, field
from typing import List, Optional, Union, TYPE_CHECKING

from amlb.utils import Namespace, config_load, str_sanitize

if TYPE_CHECKING:
from amlb import Resources

log = logging.getLogger(__name__)

default_tag = "_"
Expand Down Expand Up @@ -233,3 +239,45 @@ def _remove_frameworks_with_unknown_parent(frameworks: Namespace):
"Removing framework %s as parent %s doesn't exist.", framework, parent
)
del frameworks[framework]


@dataclass
class Image:
author: str
image: str
tag: str


@dataclass
class Framework:
name: str
abstract: bool
module: str
version: str
# Image
image: Image
# Setup
_setup_cmd: str | None
setup_cmd: str | None
setup_script: str | None
setup_env: dict = field(default_factory=dict)
setup_args: list[str] = field(default_factory=list)
# more optionals
params: dict = field(default_factory=dict)
refs: list = field(default_factory=list)
description: str | None = None
project: str | None = None

def __post_init__(self):
if isinstance(self.image, dict):
self.image = Image(**self.image)


def load_framework_definition(
framework_name: str, configuration: "Resources"
) -> Framework:
tag = None
if ":" in framework_name:
framework_name, tag = framework_name.split(":", 1)
definition_ns, name = configuration.framework_definition(framework_name, tag)
return Framework(**Namespace.dict(definition_ns))

0 comments on commit 5571581

Please sign in to comment.