Skip to content

Commit

Permalink
Started a notebook for analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
cwognum committed Aug 22, 2023
1 parent 994d2d4 commit 5550d7b
Show file tree
Hide file tree
Showing 4 changed files with 345 additions and 28 deletions.
2 changes: 1 addition & 1 deletion expts/hydra-configs/finetuning/admet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ finetuning:
pretrained_model: dummy-pretrained-model
finetuning_module: task_heads # gnn
sub_module_from_pretrained: zinc # optional
new_sub_module: lipophilicity_astrazeneca # optional
new_sub_module: ${constants.task} # optional

# keep_modules_after_finetuning_module: # optional
# graph_output_nn/graph: {}
Expand Down
36 changes: 18 additions & 18 deletions graphium/cli/finetune_utils.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,48 @@
from typing import List, Optional
import yaml

import fsspec
import typer

from loguru import logger
from hydra import compose, initialize
import yaml
from datamol.utils import fs
from hydra import compose, initialize
from hydra.core.hydra_config import HydraConfig
from loguru import logger

from .main import app
from .train_finetune import run_training_finetuning


finetune_app = typer.Typer(help="Utility CLI for extra fine-tuning utilities.")
app.add_typer(finetune_app, name="finetune")


@finetune_app.command(name="admet")
def benchmark_tdc_admet_cli(
save_dir, wandb: bool = True, name: Optional[List[str]] = None, inclusive_filter: bool = True
overrides: list[str],
name: Optional[List[str]] = None,
inclusive_filter: bool = True,
):
"""
Utility CLI to easily fine-tune a model on (a subset of) the benchmarks in the TDC ADMET group.
The results are saved to the SAVE_DIR.
A major limitation is that we cannot use all features of the Hydra CLI, such as multiruns.
"""
try:
from tdc.utils import retrieve_benchmark_names
except ImportError:
raise ImportError("TDC needs to be installed to use this CLI. Run `pip install PyTDC`.")

# Get the benchmarks to run this for
if name is None:
if len(name) == 0:
name = retrieve_benchmark_names("admet_group")
elif not inclusive_filter:
name = [n for n in name if n not in retrieve_benchmark_names("admet_group")]

if not inclusive_filter:
name = [n for n in retrieve_benchmark_names("admet_group") if n not in name]

logger.info(f"Running fine-tuning for the following benchmarks: {name}")
results = {}

# Use the Compose API to construct the config
for n in name:
overrides = [
"+finetuning=admet",
f"finetuning.task={n}",
f"finetuning.finetuning_head.task={n}",
]

if not wandb:
overrides.append("~constants.wandb")
overrides += ["+finetuning=admet", f"constants.task={n}"]

with initialize(version_base=None, config_path="../../expts/hydra-configs"):
cfg = compose(
Expand All @@ -58,6 +55,9 @@ def benchmark_tdc_admet_cli(
ret = {k: v.item() for k, v in ret.items()}
results[n] = ret

# Save to the results_dir by default or to the Hydra output_dir if needed.
# This distinction is needed, because Hydra's output_dir cannot be remote.
save_dir = cfg["constants"].get("results_dir", HydraConfig.get()["runtime"]["output_dir"])
fs.mkdir(save_dir, exist_ok=True)
path = fs.join(save_dir, "results.yaml")
logger.info(f"Saving results to {path}")
Expand Down
9 changes: 0 additions & 9 deletions graphium/hyper_param_search/results.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,3 @@
import enum

import fsspec
import hydra
import torch
import yaml
from datamol.utils import fs
from hydra.core.hydra_config import HydraConfig

_OBJECTIVE_KEY = "objective"


Expand Down
326 changes: 326 additions & 0 deletions notebooks/compare-pretraining-finetuning-performance.ipynb

Large diffs are not rendered by default.

0 comments on commit 5550d7b

Please sign in to comment.