Skip to content

Commit

Permalink
Automated balancing of parameter count
Browse files Browse the repository at this point in the history
  • Loading branch information
WenkelF committed Sep 12, 2023
1 parent c249267 commit 3800525
Show file tree
Hide file tree
Showing 11 changed files with 136 additions and 11 deletions.
4 changes: 2 additions & 2 deletions expts/hydra-configs/model/gated_gcn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ architecture:
residual_type: none

gnn:
out_dim: 704 # &gnn_dim 704
out_dim: ${constants.gnn_dim} # &gnn_dim 704
hidden_dims: ${architecture.gnn.out_dim} # *gnn_dim
hidden_dims_edges: 256
hidden_dims_edges: ${constants.gnn_edge_dim}
layer_type: 'pyg:gated-gcn'

graph_output_nn:
Expand Down
4 changes: 2 additions & 2 deletions expts/hydra-configs/model/gine.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

architecture:
pre_nn_edges: # Set as null to avoid a pre-nn network
out_dim: 32
out_dim: ${constants.gnn_edge_dim}
hidden_dims: 128
depth: 2
activation: relu
Expand All @@ -13,7 +13,7 @@ architecture:
residual_type: none

gnn:
out_dim: 704
out_dim: ${constants.gnn_dim}
hidden_dims: ${architecture.gnn.out_dim}
layer_type: 'pyg:gine'

Expand Down
4 changes: 2 additions & 2 deletions expts/hydra-configs/model/mpnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ architecture:
residual_type: none

gnn:
out_dim: 704
out_dim: ${constants.gnn_dim}
hidden_dims: ${architecture.gnn.out_dim}
hidden_dims_edges: 128
hidden_dims_edges: ${constants.gnn_edge_dim}
layer_type: 'pyg:mpnnplus'

graph_output_nn:
Expand Down
8 changes: 4 additions & 4 deletions expts/hydra-configs/training/accelerator/largemix_gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ accelerator:

datamodule:
args:
batch_size_training: 2000
batch_size_inference: 2000
featurization_n_jobs: 4
num_workers: 4
batch_size_training: 2048
batch_size_inference: 2048
featurization_n_jobs: 6
num_workers: 6

predictor:
metrics_every_n_train_steps: 1000
Expand Down
2 changes: 1 addition & 1 deletion expts/hydra-configs/training/largemix.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ predictor:
torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: &max_epochs 100
warmup_epochs: 10
warmup_epochs: 5
verbose: False
scheduler_kwargs:
target_nan_mask: null # null: no mask, 0: 0 mask, ignore-flatten, ignore-mean-per-label
Expand Down
2 changes: 2 additions & 0 deletions expts/hydra-configs/training/model/largemix_gated_gcn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ constants:
data_dir: ../data/graphium/large-dataset/
raise_train_error: true
datacache_path: ../datacache/large-dataset/
gnn_dim: 512
gnn_edge_dim: 128
norm: "layer_norm"

trainer:
Expand Down
2 changes: 2 additions & 0 deletions expts/hydra-configs/training/model/largemix_gine.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ constants:
data_dir: ../data/graphium/large-dataset/
raise_train_error: true
datacache_path: ../datacache/large-dataset/
gnn_dim: 512
gnn_edge_dim: 32
norm: "layer_norm"

trainer:
Expand Down
2 changes: 2 additions & 0 deletions expts/hydra-configs/training/model/largemix_mpnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ constants:
data_dir: ../data/graphium/large-dataset/
raise_train_error: true
datacache_path: ../datacache/large-dataset/
gnn_dim: 512
gnn_edge_dim: 256
norm: "layer_norm"

trainer:
Expand Down
1 change: 1 addition & 0 deletions graphium/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .data import data_app
from .parameters import param_app
from .finetune_utils import finetune_app
from .main import app
97 changes: 97 additions & 0 deletions graphium/cli/parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import timeit
from typing import List
from omegaconf import DictConfig, OmegaConf
import typer

import numpy as np

from loguru import logger
from hydra import initialize, compose

from .main import app
from graphium.config._loader import (
load_accelerator,
load_architecture,
load_datamodule,
)

from graphium.trainer.predictor_options import ModelOptions

param_app = typer.Typer(help="Parameter counts.")
app.add_typer(param_app, name="params")

@param_app.command(name="infer", help="Infer parameter count.")
def infer_parameter_count(overrides: List[str] = []) -> int:
with initialize(version_base=None, config_path="../../expts/hydra-configs"):
cfg = compose(
config_name="main",
overrides=overrides,
)

cfg = OmegaConf.to_container(cfg, resolve=True)

## Accelerator
cfg, accelerator_type = load_accelerator(cfg)

## Datamodule
datamodule = load_datamodule(cfg, accelerator_type)

## Architecture
model_class, model_kwargs = load_architecture(cfg, in_dims=datamodule.in_dims)
model_options = ModelOptions(
model_class=model_class,
model_kwargs=model_kwargs,
)
model = model_options.model_class(**model_options.model_kwargs)

# Count parameters
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

logger.info(f"Number of parameters: {num_params}.")

return num_params

@param_app.command(name="balance", help="Balance parameter count.")
def balance_parameter_count(overrides: List[str], target_param_count: int, max_rel_diff: float, rel_step: float, old_dim: int) -> None:
with initialize(version_base=None, config_path="../../expts/hydra-configs"):
cfg = compose(
config_name="main",
overrides=overrides,
)

cfg = OmegaConf.to_container(cfg, resolve=True)

# Infer parameter count
num_params = infer_parameter_count(overrides=overrides)

# Get current hidden node and edge dim
tmp_dim = cfg["constants"]["gnn_dim"]
tmp_edge_dim = cfg["constants"]["gnn_edge_dim"]

rel_diff = (num_params - target_param_count) / target_param_count

# Balance parameter count when difference is too large
if np.abs(rel_diff) > max_rel_diff:
if rel_diff > 0:
new_dim = int(tmp_dim * (1 - rel_step))
new_edge_dim = int(tmp_edge_dim * (1 - rel_step))
else:
new_dim = int(tmp_dim * (1 + rel_step))
new_edge_dim = int(tmp_edge_dim * (1 + rel_step))

logger.info(f"Hidden node dim changed: {tmp_dim} -> {new_dim}.")
logger.info(f"Hidden edge dim changed: {tmp_edge_dim} -> {new_edge_dim}.")

else:
logger.info(f"Hidden node dim unchanged: {tmp_dim}.")
logger.info(f"Hidden edge dim unchanged: {tmp_edge_dim}.")
print(tmp_dim, tmp_edge_dim, rel_step, "true")
return

# Reduce step size when overshooting
if np.sign(old_dim - tmp_dim) != np.sign(tmp_dim - new_dim) and old_dim > 0:
rel_step /= 2
logger.info(f"Relative step changed: {2*rel_step} -> {rel_step}.")

print(new_dim, new_edge_dim, rel_step, "false")

21 changes: 21 additions & 0 deletions scripts/balance_params_and_train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#!/bin/bash

set -e

old_dim=0
num_params=10000000
rel_error=0.005
rel_step=0.5

out=$(graphium params balance "${@}" "$num_params" "$rel_error" "$rel_step" "$old_dim")
read -r new_dim new_edge_dim rel_step stop <<< "$out"

while true; do
tmp_dim=$new_dim
out=$(graphium params balance "${@}" constants.gnn_dim="$new_dim" constants.gnn_edge_dim="$new_edge_dim" "$num_params" "$rel_error" "$rel_step" "$old_dim")
read -r new_dim new_edge_dim rel_step stop <<< "$out"
old_dim=$tmp_dim
[[ $stop == "true" ]] && break
done

graphium-train "${@}" constants.gnn_dim="$new_dim" constants.gnn_edge_dim="$new_edge_dim"

0 comments on commit 3800525

Please sign in to comment.