Skip to content

Commit

Permalink
Merge pull request #69 from mathysgrapotte/learner-tuner-init-refactor
Browse files Browse the repository at this point in the history
refactor(raytune_learner): removed check_ressources function
  • Loading branch information
mathysgrapotte authored Jan 30, 2025
2 parents 324d40d + 088f85c commit 3acd852
Showing 1 changed file with 15 additions and 68 deletions.
83 changes: 15 additions & 68 deletions src/stimulus/learner/raytune_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
tune_run_name: Optional[str] = None,
*,
debug: bool = False,
autoscaler: bool = False,
) -> None:
"""Initialize the TuneWrapper with the paths to the config, model, and data."""
self.config = config.model_dump()
Expand Down Expand Up @@ -94,17 +95,27 @@ def __init__(
self.config["encoder_loader"] = encoder_loader
self.config["ray_worker_seed"] = tune.randint(0, 1000)

self.tuner = self.tuner_initialization()
self.gpu_per_trial = config.tune.tune_params.gpu_per_trial
self.cpu_per_trial = config.tune.tune_params.cpu_per_trial

def tuner_initialization(self) -> tune.Tuner:
self.tuner = self.tuner_initialization(autoscaler=autoscaler)

def tuner_initialization(self, *, autoscaler: bool = False) -> tune.Tuner:
"""Prepare the tuner with the configs."""
# Get available resources from Ray cluster
cluster_res = cluster_resources()
logging.info(f"CLUSTER resources -> {cluster_res}")

# Check per-trial resources
self.gpu_per_trial = self._check_per_trial_resources("gpu_per_trial", cluster_res, "GPU")
self.cpu_per_trial = self._check_per_trial_resources("cpu_per_trial", cluster_res, "CPU")
if self.gpu_per_trial > cluster_res["GPU"] and not autoscaler:
raise ValueError(
"GPU per trial is more than what is available in the cluster, set autoscaler to True to allow for autoscaler to be used.",
)

if self.cpu_per_trial > cluster_res["CPU"] and not autoscaler:
raise ValueError(
"CPU per trial is more than what is available in the cluster, set autoscaler to True to allow for autoscaler to be used.",
)

logging.info(f"PER_TRIAL resources -> GPU: {self.gpu_per_trial} CPU: {self.cpu_per_trial}")

Expand All @@ -122,70 +133,6 @@ def tune(self) -> None:
"""Run the tuning process."""
self.tuner.fit()

def _check_per_trial_resources(
self,
resource_key: str,
cluster_max_resources: dict[str, float],
resource_type: str,
) -> float:
"""Check requested per-trial resources against available cluster resources.
This function validates and adjusts the requested per-trial resource allocation based on the
available cluster resources. It handles three cases:
1. Resource request is within cluster limits - uses requested amount
2. Resource request exceeds cluster limits - warns and uses maximum available
3. No resource request specified - calculates reasonable default based on cluster capacity
Args:
resource_key: Key in config for the resource (e.g. "gpu_per_trial")
cluster_max_resources: Dictionary of maximum available cluster resources
resource_type: Type of resource being checked ("GPU" or "CPU")
Returns:
float: Number of resources to allocate per trial
Note:
For GPUs, returns 0.0 if no GPUs are available in the cluster.
"""
if resource_type == "GPU" and resource_type not in cluster_max_resources:
return 0.0

per_trial_resource: float = 0.0

# Check if resource is specified in config and within limits
if (
resource_key in self.config["tune"]
and self.config["tune"][resource_key] <= cluster_max_resources[resource_type]
):
per_trial_resource = float(self.config["tune"][resource_key])

# Warn if requested resources exceed available
elif (
resource_key in self.config["tune"]
and self.config["tune"][resource_key] > cluster_max_resources[resource_type]
):
logging.warning(
f"\n\n#### WARNING - {resource_type} per trial are more than what is available. "
f"{resource_type} per trial: {self.config['tune'][resource_key]} "
f"available: {cluster_max_resources[resource_type]} "
"overwriting value to max available",
)
per_trial_resource = float(cluster_max_resources[resource_type])

# Set default if not specified
elif resource_key not in self.config["tune"]:
if cluster_max_resources[resource_type] == 0.0:
per_trial_resource = 0.0
else:
per_trial_resource = float(
max(
1,
(cluster_max_resources[resource_type] // self.config["tune"]["tune_params"]["num_samples"]),
),
)

return per_trial_resource


class TuneModel(Trainable):
"""Trainable model class for Ray Tune."""
Expand Down

0 comments on commit 3acd852

Please sign in to comment.