From 3895d239dc69c82352c159cf543502f4ff078c63 Mon Sep 17 00:00:00 2001 From: wenkelf Date: Thu, 12 Oct 2023 04:38:59 +0000 Subject: [PATCH] Reformatting... --- graphium/cli/finetune_utils.py | 38 ++++++++++--------- graphium/cli/parameters.py | 9 +++-- graphium/cli/train_finetune_test.py | 6 +-- graphium/config/_loader.py | 4 +- graphium/finetuning/finetuning.py | 2 +- .../finetuning/finetuning_architecture.py | 2 +- 6 files changed, 33 insertions(+), 28 deletions(-) diff --git a/graphium/cli/finetune_utils.py b/graphium/cli/finetune_utils.py index 847b991f4..5566e7961 100644 --- a/graphium/cli/finetune_utils.py +++ b/graphium/cli/finetune_utils.py @@ -135,36 +135,38 @@ def get_fingerprints_from_model( logger.info(f"Saving fingerprints to {path}") torch.save(fps, path) - -def get_tdc_task_specific(task: str, output: Literal['name', 'mode', 'last_activation']): - if output == 'last_activation': - config_arch_path="expts/hydra-configs/tasks/task_heads/admet.yaml" - with open(config_arch_path, 'r') as yaml_file: +def get_tdc_task_specific(task: str, output: Literal["name", "mode", "last_activation"]): + if output == "last_activation": + config_arch_path = "expts/hydra-configs/tasks/task_heads/admet.yaml" + with open(config_arch_path, "r") as yaml_file: config_tdc_arch = yaml.load(yaml_file, Loader=yaml.FullLoader) - return config_tdc_arch['architecture']['task_heads'][task]['last_activation'] + return config_tdc_arch["architecture"]["task_heads"][task]["last_activation"] else: - config_metrics_path="expts/hydra-configs/tasks/loss_metrics_datamodule/admet.yaml" - with open(config_metrics_path, 'r') as yaml_file: + config_metrics_path = "expts/hydra-configs/tasks/loss_metrics_datamodule/admet.yaml" + with open(config_metrics_path, "r") as yaml_file: config_tdc_task_metric = yaml.load(yaml_file, Loader=yaml.FullLoader) - metric = config_tdc_task_metric['predictor']['metrics_on_progress_bar'][task][0] + metric = config_tdc_task_metric["predictor"]["metrics_on_progress_bar"][task][0] metric_mode_map = { - 'mae': 'min', - 'auroc': 'max', - 'auprc': 'max', - 'spearman': 'max', + "mae": "min", + "auroc": "max", + "auprc": "max", + "spearman": "max", } - if output == 'name': + if output == "name": return metric - elif output == 'mode': + elif output == "mode": return metric_mode_map[metric] -OmegaConf.register_new_resolver("get_metric_name", lambda x: get_tdc_task_specific(x, output='name')) -OmegaConf.register_new_resolver("get_metric_mode", lambda x: get_tdc_task_specific(x, output='mode')) -OmegaConf.register_new_resolver("get_last_activation", lambda x: get_tdc_task_specific(x, output='last_activation')) + +OmegaConf.register_new_resolver("get_metric_name", lambda x: get_tdc_task_specific(x, output="name")) +OmegaConf.register_new_resolver("get_metric_mode", lambda x: get_tdc_task_specific(x, output="mode")) +OmegaConf.register_new_resolver( + "get_last_activation", lambda x: get_tdc_task_specific(x, output="last_activation") +) OmegaConf.register_new_resolver("eval", lambda x: eval(x, {"np": np})) diff --git a/graphium/cli/parameters.py b/graphium/cli/parameters.py index e0b31ac91..389ac286c 100644 --- a/graphium/cli/parameters.py +++ b/graphium/cli/parameters.py @@ -21,6 +21,7 @@ param_app = typer.Typer(help="Parameter counts.") app.add_typer(param_app, name="params") + @param_app.command(name="infer", help="Infer parameter count.") def infer_parameter_count(overrides: List[str] = []) -> int: with initialize(version_base=None, config_path="../../expts/hydra-configs"): @@ -52,8 +53,11 @@ def infer_parameter_count(overrides: List[str] = []) -> int: return num_params + @param_app.command(name="balance", help="Balance parameter count.") -def balance_parameter_count(overrides: List[str], target_param_count: int, max_rel_diff: float, rel_step: float, old_dim: int) -> None: +def balance_parameter_count( + overrides: List[str], target_param_count: int, max_rel_diff: float, rel_step: float, old_dim: int +) -> None: with initialize(version_base=None, config_path="../../expts/hydra-configs"): cfg = compose( config_name="main", @@ -88,11 +92,10 @@ def balance_parameter_count(overrides: List[str], target_param_count: int, max_r logger.info(f"Hidden edge dim unchanged: {tmp_edge_dim}.") print(tmp_dim, tmp_edge_dim, rel_step, "true") return - + # Reduce step size when overshooting if np.sign(old_dim - tmp_dim) != np.sign(tmp_dim - new_dim) and old_dim > 0: rel_step /= 2 logger.info(f"Relative step changed: {2*rel_step} -> {rel_step}.") print(new_dim, new_edge_dim, rel_step, "false") - diff --git a/graphium/cli/train_finetune_test.py b/graphium/cli/train_finetune_test.py index 4e7fe2961..5d9a29b97 100644 --- a/graphium/cli/train_finetune_test.py +++ b/graphium/cli/train_finetune_test.py @@ -159,11 +159,11 @@ def run_training_finetuning_testing(cfg: DictConfig) -> None: # When checkpoints are logged during training, we can, e.g., use the best or last checkpoint for testing test_ckpt_path = None - test_ckpt_name = cfg['trainer'].pop('test_from_checkpoint', None) - test_ckpt_dir = cfg['trainer']['model_checkpoint'].pop('dirpath', None) + test_ckpt_name = cfg["trainer"].pop("test_from_checkpoint", None) + test_ckpt_dir = cfg["trainer"]["model_checkpoint"].pop("dirpath", None) if test_ckpt_name is not None and test_ckpt_dir is not None: test_ckpt_path = os.path.join(test_ckpt_dir, test_ckpt_name) - + # Run the model testing with SafeRun(name="TESTING", raise_error=cfg["constants"]["raise_train_error"], verbose=True): trainer.test(model=predictor, datamodule=datamodule, ckpt_path=test_ckpt_path) diff --git a/graphium/config/_loader.py b/graphium/config/_loader.py index 000cd59f1..48d7e9078 100644 --- a/graphium/config/_loader.py +++ b/graphium/config/_loader.py @@ -189,7 +189,7 @@ def load_architecture( architecture: The datamodule used to process and load the data """ - if isinstance(config, dict) and 'finetuning' not in config: + if isinstance(config, dict) and "finetuning" not in config: config = omegaconf.OmegaConf.create(config) cfg_arch = config["architecture"] @@ -249,7 +249,7 @@ def load_architecture( gnn_kwargs.setdefault("in_dim", edge_in_dim) # Set the parameters for the full network - if 'finetuning' not in config: + if "finetuning" not in config: task_heads_kwargs = omegaconf.OmegaConf.to_object(task_heads_kwargs) # Set all the input arguments for the model diff --git a/graphium/finetuning/finetuning.py b/graphium/finetuning/finetuning.py index 3829db9e4..63927850e 100644 --- a/graphium/finetuning/finetuning.py +++ b/graphium/finetuning/finetuning.py @@ -66,7 +66,7 @@ def freeze_module(self, pl_module, module_name: str, module_map: Dict[str, Union module_map: Dictionary mapping from module_name to corresponding module(s) """ modules = module_map[module_name] - + if module_name == "pe_encoders": for param in pl_module.model.pretrained_model.net.encoder_manager.parameters(): param.requires_grad = False diff --git a/graphium/finetuning/finetuning_architecture.py b/graphium/finetuning/finetuning_architecture.py index 55718fed0..a4c2b34c1 100644 --- a/graphium/finetuning/finetuning_architecture.py +++ b/graphium/finetuning/finetuning_architecture.py @@ -199,7 +199,7 @@ def __init__( super().__init__() # Load pretrained model - pretrained_model = PredictorModule.load_pretrained_model(pretrained_model, device='cpu').model + pretrained_model = PredictorModule.load_pretrained_model(pretrained_model, device="cpu").model pretrained_model.create_module_map() # Initialize new model with architecture after desired modifications to architecture.