Skip to content

Commit

Permalink
updating test matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrewq11 committed Nov 7, 2024
1 parent 73ffd37 commit 3050f60
Show file tree
Hide file tree
Showing 89 changed files with 512 additions and 392 deletions.
10 changes: 3 additions & 7 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,9 @@ jobs:
strategy:
fail-fast: false
matrix:
include:
- python-version: "3.10"
pytorch-version: "2.0"
- python-version: "3.11"
pytorch-version: "2.0"
- python-version: "3.12"
pytorch-version: "2.3"
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ["3.9", "3.10", "3.11"] -> Will re-enable support for py312 once pyg is released
pytorch-version: ["2.0"]

runs-on: "ubuntu-latest"
timeout-minutes: 30
Expand Down
4 changes: 2 additions & 2 deletions env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ channels:
# - pyg # Add for Windows

dependencies:
- python>=3.9, <3.12
- python >=3.9,<3.12
- pip
- typer
- loguru
Expand All @@ -30,7 +30,7 @@ dependencies:

# ML packages
- cuda-version # works also with CPU-only system.
- pytorch <2.5
- pytorch >=1.12,<2.5
- lightning >=2.0
- torchmetrics
- ogb
Expand Down
11 changes: 7 additions & 4 deletions graphium/cli/fingerprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
fp_app = typer.Typer(help="Automated fingerprinting from pretrained models.")
app.add_typer(fp_app, name="fps")


@fp_app.command(name="create", help="Create fingerprints for pretrained model.")
def smiles_to_fps(cfg_name: str, overrides: List[str]) -> Dict[str, Any]:
with initialize(version_base=None, config_path="../../expts/hydra-configs/fingerprinting"):
Expand All @@ -33,12 +34,14 @@ def smiles_to_fps(cfg_name: str, overrides: List[str]) -> Dict[str, Any]:

# Allow alternative definition of `pretrained_models` with the single model specifier and desired layers
if "layers" in pretrained_models.keys():
assert "model" in pretrained_models.keys(), "this workflow allows easier definition of fingerprinting sweeps"
assert (
"model" in pretrained_models.keys()
), "this workflow allows easier definition of fingerprinting sweeps"
model, layers = pretrained_models.pop("model"), pretrained_models.pop("layers")
pretrained_models[model] = layers

data_kwargs = cfg.get("datamodule")

datamodule = FingerprintDatamodule(
pretrained_models=pretrained_models,
**data_kwargs,
Expand All @@ -54,4 +57,4 @@ def smiles_to_fps(cfg_name: str, overrides: List[str]) -> Dict[str, Any]:


if __name__ == "__main__":
smiles_to_fps(cfg_name="example-tdc", overrides=[])
smiles_to_fps(cfg_name="example-tdc", overrides=[])
5 changes: 4 additions & 1 deletion graphium/cli/train_finetune_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@

TESTING_ONLY_CONFIG_KEY = "testing_only"


@hydra.main(version_base=None, config_path="../../expts/hydra-configs", config_name="main")
def cli(cfg: DictConfig) -> None:
"""
Expand Down Expand Up @@ -145,7 +146,9 @@ def run_training_finetuning_testing(cfg: DictConfig) -> None:

## Trainer
date_time_suffix = datetime.now().strftime("%d.%m.%Y_%H.%M.%S")
trainer = load_trainer(cfg, accelerator_type, date_time_suffix, metrics_on_progress_bar=metrics_on_progress_bar)
trainer = load_trainer(
cfg, accelerator_type, date_time_suffix, metrics_on_progress_bar=metrics_on_progress_bar
)

if not testing_only:
# Add the fine-tuning callback to trainer
Expand Down
1 change: 0 additions & 1 deletion graphium/config/_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
--------------------------------------------------------------------------------
"""


import importlib.resources

import omegaconf
Expand Down
10 changes: 5 additions & 5 deletions graphium/config/_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
--------------------------------------------------------------------------------
"""


# Misc
import os
from copy import deepcopy
Expand Down Expand Up @@ -80,7 +79,7 @@ def load_datamodule(
datamodule: The datamodule used to process and load the data
"""

from graphium.utils.spaces import DATAMODULE_DICT # Avoid circular imports with `spaces.py`
from graphium.utils.spaces import DATAMODULE_DICT # Avoid circular imports with `spaces.py`

cfg_data = config["datamodule"]["args"]

Expand All @@ -93,7 +92,6 @@ def load_datamodule(
return datamodule



def load_metrics(config: Union[omegaconf.DictConfig, Dict[str, Any]]) -> Dict[str, MetricWrapper]:
"""
Loading the metrics to be tracked.
Expand Down Expand Up @@ -338,7 +336,7 @@ def load_trainer(
name += f"_{date_time_suffix}"
trainer_kwargs["logger"] = WandbLogger(name=name, **wandb_cfg)

progress_bar_callback = ProgressBarMetrics(metrics_on_progress_bar = metrics_on_progress_bar)
progress_bar_callback = ProgressBarMetrics(metrics_on_progress_bar=metrics_on_progress_bar)
callbacks.append(progress_bar_callback)

trainer = Trainer(
Expand Down Expand Up @@ -516,7 +514,9 @@ def get_checkpoint_path(config: Union[omegaconf.DictConfig, Dict[str, Any]]) ->
Otherwise, assume it refers to a file in the checkpointing dir.
"""

from graphium.utils.spaces import GRAPHIUM_PRETRAINED_MODELS_DICT # Avoid circular imports with `spaces.py`
from graphium.utils.spaces import (
GRAPHIUM_PRETRAINED_MODELS_DICT,
) # Avoid circular imports with `spaces.py`

cfg_trainer = config["trainer"]

Expand Down
1 change: 0 additions & 1 deletion graphium/config/config_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
--------------------------------------------------------------------------------
"""


import omegaconf


Expand Down
4 changes: 2 additions & 2 deletions graphium/data/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
--------------------------------------------------------------------------------
"""


from collections.abc import Mapping, Sequence

# from pprint import pprint
Expand Down Expand Up @@ -127,7 +126,8 @@ def graphium_collate_fn(


def collage_pyg_graph(
pyg_graphs: List[Data], num_nodes: List[int],
pyg_graphs: List[Data],
num_nodes: List[int],
):
"""
Function to collate pytorch geometric graphs.
Expand Down
28 changes: 14 additions & 14 deletions graphium/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,8 @@ def get_collate_fn(self, collate_fn):
if collate_fn is None:
# Some values become `inf` when changing data type. `mask_nan` deals with that
collate_fn = partial(
graphium_collate_fn, mask_nan=0,
graphium_collate_fn,
mask_nan=0,
)
collate_fn.__name__ = graphium_collate_fn.__name__

Expand Down Expand Up @@ -871,14 +872,14 @@ def _get_len_from_cached_file(self):
if self._ready_to_load_all_from_file():
self._data_is_prepared = True
train_metadata = graphium_cpp.load_metadata_tensors(
self.processed_graph_data_path, "train", self.data_hash
)
self.processed_graph_data_path, "train", self.data_hash
)
val_metadata = graphium_cpp.load_metadata_tensors(
self.processed_graph_data_path, "val", self.data_hash
)
self.processed_graph_data_path, "val", self.data_hash
)
test_metadata = graphium_cpp.load_metadata_tensors(
self.processed_graph_data_path, "test", self.data_hash
)
self.processed_graph_data_path, "test", self.data_hash
)
length = 0
if len(train_metadata) > 0:
length += len(train_metadata[2])
Expand Down Expand Up @@ -1241,7 +1242,6 @@ def get_folder_size(self, path):
# check if the data items are actually saved into the folders
return sum(os.path.getsize(osp.join(path, f)) for f in os.listdir(path))


def get_dataloader(
self, dataset: Dataset, shuffle: bool, stage: RunningStage
) -> Union[DataLoader, "poptorch.DataLoader"]:
Expand Down Expand Up @@ -1556,7 +1556,7 @@ def _get_split_indices(
test_indices = np.asarray(splits[test]).astype("int")
test_indices = test_indices[~np.isnan(test_indices)].tolist()

elif split_type == "scaffold" and split_test != 1.:
elif split_type == "scaffold" and split_test != 1.0:
# Scaffold splitting
try:
import splito
Expand All @@ -1565,7 +1565,7 @@ def _get_split_indices(
f"To do the splitting, `splito` needs to be installed. "
f"Please install it with `pip install splito`"
) from error

# Split data into scaffolds
splitter = splito.ScaffoldSplit(
smiles=self.smiles,
Expand All @@ -1576,7 +1576,7 @@ def _get_split_indices(
train_val_smiles = [self.smiles[i] for i in train_val_indices]

sub_split_val = split_val / (1 - split_test)

splitter = splito.ScaffoldSplit(
smiles=train_val_smiles,
test_size=sub_split_val,
Expand All @@ -1590,10 +1590,10 @@ def _get_split_indices(

# Random splitting
if split_test + split_val > 0:
if split_test == 1.:
if split_test == 1.0:
train_indices = np.array([])
val_test_indices = sample_idx
sub_split_test = 1.
sub_split_test = 1.0
else:
train_indices, val_test_indices = train_test_split(
sample_idx,
Expand All @@ -1607,7 +1607,7 @@ def _get_split_indices(
sub_split_test = 0

if split_test > 0:
if split_test == 1.:
if split_test == 1.0:
val_indices = np.array([])
test_indices = val_test_indices
else:
Expand Down
1 change: 0 additions & 1 deletion graphium/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
--------------------------------------------------------------------------------
"""


import os
from copy import deepcopy
from functools import lru_cache
Expand Down
1 change: 0 additions & 1 deletion graphium/data/multilevel_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
--------------------------------------------------------------------------------
"""


import pandas as pd
import ast
import numpy as np
Expand Down
1 change: 0 additions & 1 deletion graphium/data/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
--------------------------------------------------------------------------------
"""


from typing import Optional
from loguru import logger
import numpy as np
Expand Down
1 change: 0 additions & 1 deletion graphium/data/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
--------------------------------------------------------------------------------
"""


from typing import Dict, Optional
from torch.utils.data.dataloader import Dataset

Expand Down
1 change: 0 additions & 1 deletion graphium/data/smiles_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
--------------------------------------------------------------------------------
"""


from typing import Type, List, Dict, Union, Any, Callable, Optional, Tuple, Iterable

import os
Expand Down
1 change: 0 additions & 1 deletion graphium/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
--------------------------------------------------------------------------------
"""


from typing import Union, List, Callable, Dict, Tuple, Any, Optional

import importlib.resources
Expand Down
1 change: 0 additions & 1 deletion graphium/features/featurizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
--------------------------------------------------------------------------------
"""


from typing import Union, List, Callable, Dict, Tuple, Any, Optional

import inspect
Expand Down
1 change: 0 additions & 1 deletion graphium/features/nmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
--------------------------------------------------------------------------------
"""


from typing import Tuple, Optional, Dict, Union
import importlib.resources
from copy import deepcopy
Expand Down
7 changes: 4 additions & 3 deletions graphium/finetuning/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
--------------------------------------------------------------------------------
"""


from typing import Iterable, List, Dict, Tuple, Union, Callable, Any, Optional, Type

from collections import OrderedDict
Expand Down Expand Up @@ -54,7 +53,7 @@ def __init__(
self.training_depth += unfreeze_pretrained_depth
self.epoch_unfreeze_all = epoch_unfreeze_all
self.always_freeze_modules = always_freeze_modules
if self.always_freeze_modules == 'none':
if self.always_freeze_modules == "none":
self.always_freeze_modules = None
if isinstance(self.always_freeze_modules, str):
self.always_freeze_modules = [self.always_freeze_modules]
Expand Down Expand Up @@ -115,4 +114,6 @@ def finetune_function(self, pl_module: pl.LightningModule, epoch: int, optimizer

if self.always_freeze_modules is not None:
for module_name in self.always_freeze_modules:
self.freeze_module(pl_module, module_name, pl_module.model.pretrained_model.net._module_map)
self.freeze_module(
pl_module, module_name, pl_module.model.pretrained_model.net._module_map
)
3 changes: 1 addition & 2 deletions graphium/finetuning/finetuning_architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
--------------------------------------------------------------------------------
"""


from typing import Any, Dict, Optional, Union

import torch
Expand Down Expand Up @@ -345,4 +344,4 @@ def make_mup_base_kwargs(self, divide_factor: float = 2.0, factor_in_dim: bool =
"""
# For the post-nn network, all the dimension are divided

return self.net.make_mup_base_kwargs(divide_factor=divide_factor, factor_in_dim=factor_in_dim)
return self.net.make_mup_base_kwargs(divide_factor=divide_factor, factor_in_dim=factor_in_dim)
Loading

0 comments on commit 3050f60

Please sign in to comment.