Skip to content

Commit

Permalink
feat(learner): added function to set per_trail ressources
Browse files Browse the repository at this point in the history
  • Loading branch information
mathysgrapotte committed Jan 24, 2025
1 parent 353a685 commit 71d2b94
Showing 1 changed file with 36 additions and 24 deletions.
60 changes: 36 additions & 24 deletions src/stimulus/learner/raytune_learner_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,7 @@ class TuneWrapper:
"ray_worker_seed": int,
"data_path": str,
"tune_run_path": str,
Args:
config_path (str): path to the configuration file
model_class (nn.Module): A pytorch model.
data_path (str): path to the data to use.
encode_loader (EncodeLoader): An EncoderLoader object use by TorchDataset
max_gpus (int): Maximum number of GPUs to use.
max_cpus (int): Maximum number of CPUs to use.
max_mem (int): Maximum memory to use.
max_object_store_mem (float, optional): Maximum object store memory. Defaults to None.
ray_result_dir (str, optional): Directory to store the results. Defaults to None.
tune_run_name (str, optional): Name to give to the ray tune runs. Defaults to None.
}.
"""

def __init__(
Expand Down Expand Up @@ -87,22 +75,20 @@ def __init__(
# Set all general seeds: python, numpy and pytorch
set_general_seeds(self.config["seed"])

self.config["model"]: nn.Module = model_class
self.config["EncoderLoader"]: EncoderLoader = encode_loader
self.config["model"] = model_class
self.config["EncoderLoader"] = encode_loader
# add the ray method for number generation to the config so it can be passed to the trainable class, that will in turn set per worker seeds in a reproducible mnanner.
self.config["ray_worker_seed"] = ray.tune.randint(0, 1000)
self.config["data_path"]: str = check_path(data_path)
self.config["data_path"] = check_path(data_path)

# Set the tune run name and dir
if tune_run_name is None:
tune_run_name = "TuneModel_" + datetime.datetime.now(
tz=datetime.timezone.utc,
).strftime("%Y-%m-%d-%H-%M-%S")
if ray_result_dir is None:
ray_result_dir = os.environ.get(
"HOME",
) # If none ray puts it under home so we do to
self.config["tune_run_path"]: str = os.path.join(ray_result_dir, tune_run_name)
ray_result_dir = os.environ.get("HOME")
self.config["tune_run_path"] = os.path.join(ray_result_dir, tune_run_name)

# Create the tune configuration
scheduler_params: dict = self.config["tune"]["scheduler"]
Expand All @@ -117,16 +103,16 @@ def __init__(

# Set the hardware ressources
# TODO: if there's a check for these params, check it here during init
self.max_gpus: int = check_not_none(max_gpus, "max_gpus")
self.max_gpus: int = check_not_none(max_gpus, "max_gpus") # TODO: put check_not_none in this file to avoid abstractions
self.max_cpus: int = check_not_none(max_cpus, "max_cpus")
self.max_object_store_mem: int = max_object_store_mem # this is a special subset of the total usable memory that ray need for his internal work, by default is set to 30% of total memory usable
self.max_object_store_mem: int = max_object_store_mem # memory alocated to ray object store (on head node), set to 30% of total memory by default.
self.max_mem: int = max_mem

# TODO: implement checkpoiting
self.checkpoint_config: ray.train.CheckpointConfig = ray.train.CheckpointConfig(
checkpoint_at_end=True,
)
self.run_config: ray.train.RunConfig(
self.run_config: ray.train.RunConfig = ray.train.RunConfig(
name=tune_run_name,
storage_path=ray_result_dir,
checkpoint_config=self.checkpoint_config,
Expand All @@ -136,7 +122,8 @@ def __init__(

def tuner_initiazilation(self) -> ray.tune.Tuner:
"""Prepare the tuner with the configs."""
# in ray 3.0.0 the following issue is fixed. Sometimes it sees that ray is already initialized, so in that case shut it off and start anew. TODO update to ray 3.0.0
# in ray 3.0.0 the following issue is fixed. Sometimes it sees that ray is already initialized, so in that case shut it off and start anew. TODO update to ray 3.0.0
# NOTE: updating to ray 3.0.0 requires python 3.12 (so would break support for python 3.10 and 3.11)
if ray.is_initialized():
ray.shutdown()

Expand All @@ -152,3 +139,28 @@ def tuner_initiazilation(self) -> ray.tune.Tuner:
logging.info(f"CLUSTER ressources\t->\t{cluster_ressources}")

self.gpu_per_trial = self._set_per_trial_ressources(cluster_ressources, "gpu")
self.cpu_per_trial = self._set_per_trial_ressources(cluster_ressources, "cpu")

logging.info(f"PER_TRIAL resources -> GPU: {self.gpu_per_trial} CPU: {self.cpu_per_trial}")

def _set_per_trial_ressources(self, cluster_ressources: dict, resource_type: str) -> float:
"""Set the per trial ressources.
Args:
cluster_ressources (dict): The cluster ressources.
resource_type (str): The type of resource to set.
Returns:
float: The per trial ressources.
Raises:
ValueError: If the requested resource per trial is greater than the cluster resource.
"""
if self.config[f"{resource_type}_per_trial"] > cluster_ressources[resource_type]:
raise ValueError(f"The requested {resource_type} ({self.config[f'{resource_type}_per_trial']}) per trial is greater than the cluster {resource_type} ({cluster_ressources[resource_type]}).")
return self.config[f"{resource_type}_per_trial"]





0 comments on commit 71d2b94

Please sign in to comment.