diff --git a/src/stimulus/learner/raytune_learner.py b/src/stimulus/learner/raytune_learner.py index e147641..d689661 100644 --- a/src/stimulus/learner/raytune_learner.py +++ b/src/stimulus/learner/raytune_learner.py @@ -42,6 +42,7 @@ def __init__( tune_run_name: Optional[str] = None, *, debug: bool = False, + autoscaler: bool = False, ) -> None: """Initialize the TuneWrapper with the paths to the config, model, and data.""" self.config = config.model_dump() @@ -94,17 +95,27 @@ def __init__( self.config["encoder_loader"] = encoder_loader self.config["ray_worker_seed"] = tune.randint(0, 1000) - self.tuner = self.tuner_initialization() + self.gpu_per_trial = config.tune.tune_params.gpu_per_trial + self.cpu_per_trial = config.tune.tune_params.cpu_per_trial - def tuner_initialization(self) -> tune.Tuner: + self.tuner = self.tuner_initialization(autoscaler=autoscaler) + + def tuner_initialization(self, *, autoscaler: bool = False) -> tune.Tuner: """Prepare the tuner with the configs.""" # Get available resources from Ray cluster cluster_res = cluster_resources() logging.info(f"CLUSTER resources -> {cluster_res}") # Check per-trial resources - self.gpu_per_trial = self._check_per_trial_resources("gpu_per_trial", cluster_res, "GPU") - self.cpu_per_trial = self._check_per_trial_resources("cpu_per_trial", cluster_res, "CPU") + if self.gpu_per_trial > cluster_res["GPU"] and not autoscaler: + raise ValueError( + "GPU per trial is more than what is available in the cluster, set autoscaler to True to allow for autoscaler to be used.", + ) + + if self.cpu_per_trial > cluster_res["CPU"] and not autoscaler: + raise ValueError( + "CPU per trial is more than what is available in the cluster, set autoscaler to True to allow for autoscaler to be used.", + ) logging.info(f"PER_TRIAL resources -> GPU: {self.gpu_per_trial} CPU: {self.cpu_per_trial}") @@ -122,70 +133,6 @@ def tune(self) -> None: """Run the tuning process.""" self.tuner.fit() - def _check_per_trial_resources( - self, - resource_key: str, - cluster_max_resources: dict[str, float], - resource_type: str, - ) -> float: - """Check requested per-trial resources against available cluster resources. - - This function validates and adjusts the requested per-trial resource allocation based on the - available cluster resources. It handles three cases: - 1. Resource request is within cluster limits - uses requested amount - 2. Resource request exceeds cluster limits - warns and uses maximum available - 3. No resource request specified - calculates reasonable default based on cluster capacity - - Args: - resource_key: Key in config for the resource (e.g. "gpu_per_trial") - cluster_max_resources: Dictionary of maximum available cluster resources - resource_type: Type of resource being checked ("GPU" or "CPU") - - Returns: - float: Number of resources to allocate per trial - - Note: - For GPUs, returns 0.0 if no GPUs are available in the cluster. - """ - if resource_type == "GPU" and resource_type not in cluster_max_resources: - return 0.0 - - per_trial_resource: float = 0.0 - - # Check if resource is specified in config and within limits - if ( - resource_key in self.config["tune"] - and self.config["tune"][resource_key] <= cluster_max_resources[resource_type] - ): - per_trial_resource = float(self.config["tune"][resource_key]) - - # Warn if requested resources exceed available - elif ( - resource_key in self.config["tune"] - and self.config["tune"][resource_key] > cluster_max_resources[resource_type] - ): - logging.warning( - f"\n\n#### WARNING - {resource_type} per trial are more than what is available. " - f"{resource_type} per trial: {self.config['tune'][resource_key]} " - f"available: {cluster_max_resources[resource_type]} " - "overwriting value to max available", - ) - per_trial_resource = float(cluster_max_resources[resource_type]) - - # Set default if not specified - elif resource_key not in self.config["tune"]: - if cluster_max_resources[resource_type] == 0.0: - per_trial_resource = 0.0 - else: - per_trial_resource = float( - max( - 1, - (cluster_max_resources[resource_type] // self.config["tune"]["tune_params"]["num_samples"]), - ), - ) - - return per_trial_resource - class TuneModel(Trainable): """Trainable model class for Ray Tune."""