Skip to content

Commit

Permalink
Reformatting...
Browse files Browse the repository at this point in the history
  • Loading branch information
WenkelF committed Oct 12, 2023
1 parent cae7c0a commit 3895d23
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 28 deletions.
38 changes: 20 additions & 18 deletions graphium/cli/finetune_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,36 +135,38 @@ def get_fingerprints_from_model(
logger.info(f"Saving fingerprints to {path}")
torch.save(fps, path)


def get_tdc_task_specific(task: str, output: Literal['name', 'mode', 'last_activation']):

if output == 'last_activation':
config_arch_path="expts/hydra-configs/tasks/task_heads/admet.yaml"
with open(config_arch_path, 'r') as yaml_file:
def get_tdc_task_specific(task: str, output: Literal["name", "mode", "last_activation"]):
if output == "last_activation":
config_arch_path = "expts/hydra-configs/tasks/task_heads/admet.yaml"
with open(config_arch_path, "r") as yaml_file:
config_tdc_arch = yaml.load(yaml_file, Loader=yaml.FullLoader)

return config_tdc_arch['architecture']['task_heads'][task]['last_activation']
return config_tdc_arch["architecture"]["task_heads"][task]["last_activation"]

else:
config_metrics_path="expts/hydra-configs/tasks/loss_metrics_datamodule/admet.yaml"
with open(config_metrics_path, 'r') as yaml_file:
config_metrics_path = "expts/hydra-configs/tasks/loss_metrics_datamodule/admet.yaml"
with open(config_metrics_path, "r") as yaml_file:
config_tdc_task_metric = yaml.load(yaml_file, Loader=yaml.FullLoader)

metric = config_tdc_task_metric['predictor']['metrics_on_progress_bar'][task][0]
metric = config_tdc_task_metric["predictor"]["metrics_on_progress_bar"][task][0]

metric_mode_map = {
'mae': 'min',
'auroc': 'max',
'auprc': 'max',
'spearman': 'max',
"mae": "min",
"auroc": "max",
"auprc": "max",
"spearman": "max",
}

if output == 'name':
if output == "name":
return metric
elif output == 'mode':
elif output == "mode":
return metric_mode_map[metric]

OmegaConf.register_new_resolver("get_metric_name", lambda x: get_tdc_task_specific(x, output='name'))
OmegaConf.register_new_resolver("get_metric_mode", lambda x: get_tdc_task_specific(x, output='mode'))
OmegaConf.register_new_resolver("get_last_activation", lambda x: get_tdc_task_specific(x, output='last_activation'))

OmegaConf.register_new_resolver("get_metric_name", lambda x: get_tdc_task_specific(x, output="name"))
OmegaConf.register_new_resolver("get_metric_mode", lambda x: get_tdc_task_specific(x, output="mode"))
OmegaConf.register_new_resolver(
"get_last_activation", lambda x: get_tdc_task_specific(x, output="last_activation")
)
OmegaConf.register_new_resolver("eval", lambda x: eval(x, {"np": np}))
9 changes: 6 additions & 3 deletions graphium/cli/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
param_app = typer.Typer(help="Parameter counts.")
app.add_typer(param_app, name="params")


@param_app.command(name="infer", help="Infer parameter count.")
def infer_parameter_count(overrides: List[str] = []) -> int:
with initialize(version_base=None, config_path="../../expts/hydra-configs"):
Expand Down Expand Up @@ -52,8 +53,11 @@ def infer_parameter_count(overrides: List[str] = []) -> int:

return num_params


@param_app.command(name="balance", help="Balance parameter count.")
def balance_parameter_count(overrides: List[str], target_param_count: int, max_rel_diff: float, rel_step: float, old_dim: int) -> None:
def balance_parameter_count(
overrides: List[str], target_param_count: int, max_rel_diff: float, rel_step: float, old_dim: int
) -> None:
with initialize(version_base=None, config_path="../../expts/hydra-configs"):
cfg = compose(
config_name="main",
Expand Down Expand Up @@ -88,11 +92,10 @@ def balance_parameter_count(overrides: List[str], target_param_count: int, max_r
logger.info(f"Hidden edge dim unchanged: {tmp_edge_dim}.")
print(tmp_dim, tmp_edge_dim, rel_step, "true")
return

# Reduce step size when overshooting
if np.sign(old_dim - tmp_dim) != np.sign(tmp_dim - new_dim) and old_dim > 0:
rel_step /= 2
logger.info(f"Relative step changed: {2*rel_step} -> {rel_step}.")

print(new_dim, new_edge_dim, rel_step, "false")

6 changes: 3 additions & 3 deletions graphium/cli/train_finetune_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,11 @@ def run_training_finetuning_testing(cfg: DictConfig) -> None:

# When checkpoints are logged during training, we can, e.g., use the best or last checkpoint for testing
test_ckpt_path = None
test_ckpt_name = cfg['trainer'].pop('test_from_checkpoint', None)
test_ckpt_dir = cfg['trainer']['model_checkpoint'].pop('dirpath', None)
test_ckpt_name = cfg["trainer"].pop("test_from_checkpoint", None)
test_ckpt_dir = cfg["trainer"]["model_checkpoint"].pop("dirpath", None)
if test_ckpt_name is not None and test_ckpt_dir is not None:
test_ckpt_path = os.path.join(test_ckpt_dir, test_ckpt_name)

# Run the model testing
with SafeRun(name="TESTING", raise_error=cfg["constants"]["raise_train_error"], verbose=True):
trainer.test(model=predictor, datamodule=datamodule, ckpt_path=test_ckpt_path)
Expand Down
4 changes: 2 additions & 2 deletions graphium/config/_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def load_architecture(
architecture: The datamodule used to process and load the data
"""

if isinstance(config, dict) and 'finetuning' not in config:
if isinstance(config, dict) and "finetuning" not in config:
config = omegaconf.OmegaConf.create(config)
cfg_arch = config["architecture"]

Expand Down Expand Up @@ -249,7 +249,7 @@ def load_architecture(
gnn_kwargs.setdefault("in_dim", edge_in_dim)

# Set the parameters for the full network
if 'finetuning' not in config:
if "finetuning" not in config:
task_heads_kwargs = omegaconf.OmegaConf.to_object(task_heads_kwargs)

# Set all the input arguments for the model
Expand Down
2 changes: 1 addition & 1 deletion graphium/finetuning/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def freeze_module(self, pl_module, module_name: str, module_map: Dict[str, Union
module_map: Dictionary mapping from module_name to corresponding module(s)
"""
modules = module_map[module_name]

if module_name == "pe_encoders":
for param in pl_module.model.pretrained_model.net.encoder_manager.parameters():
param.requires_grad = False
Expand Down
2 changes: 1 addition & 1 deletion graphium/finetuning/finetuning_architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def __init__(
super().__init__()

# Load pretrained model
pretrained_model = PredictorModule.load_pretrained_model(pretrained_model, device='cpu').model
pretrained_model = PredictorModule.load_pretrained_model(pretrained_model, device="cpu").model
pretrained_model.create_module_map()

# Initialize new model with architecture after desired modifications to architecture.
Expand Down

0 comments on commit 3895d23

Please sign in to comment.