Skip to content

Commit

Permalink
update datamodule.py to make it compatible with OGB-lSC
Browse files Browse the repository at this point in the history
  • Loading branch information
alip67 committed Jun 10, 2021
1 parent cb3bbce commit f92c56a
Show file tree
Hide file tree
Showing 2 changed files with 241 additions and 17 deletions.
180 changes: 180 additions & 0 deletions expts/config_molPCQM4M.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
constants:
seed: &seed 42
raise_train_error: True # Whether the code should raise an error if it crashes during training

datamodule:
module_type: "DGLOGBDataModule"
args:
cache_data_path: null
dataset_name: "ogbg-molpcqm4m"


# Weights
weights_type: null
sample_size: 10000

# Featurization
featurization_n_jobs: 8
featurization_progress: True
featurization:
atom_property_list_onehot: [atomic-number, valence]
atom_property_list_float: [mass, electronegativity, in-ring, hybridization, chirality, aromatic, degree, formal-charge, single-bond, double-bond, radical-electron, vdw-radius, covalent-radius, metal]
edge_property_list: [bond-type-onehot, bond-type-float, stereo, in-ring, conjugated, estimated-bond-length]
add_self_loop: False
explicit_H: False
use_bonds_weights: False
pos_encoding_as_features: &pos_enc
pos_type: laplacian_eigvec
num_pos: 3
normalization: "none"
disconnected_comp: True
pos_encoding_as_directions: *pos_enc

# Train, val, test parameters
batch_size_train_val: 1500
batch_size_test: 1500

# Data loading
num_workers: 0
pin_memory: False
persistent_workers: False # Keep True on Windows if running multiple workers


architecture:
model_type: fulldglnetwork
pre_nn: # Set as null to avoid a pre-nn network
out_dim: &hidden_dim 420
hidden_dims: *hidden_dim
depth: 3
activation: relu
last_activation: none
dropout: &dropout_mlp 0.2
last_dropout: *dropout_mlp
batch_norm: &batch_norm True
last_batch_norm: *batch_norm
residual_type: simple

pre_nn_edges: # Set as null to avoid a pre-nn network
out_dim: 32
hidden_dims: 32
depth: 3
activation: relu
last_activation: none
dropout: *dropout_mlp
last_dropout: *dropout_mlp
batch_norm: *batch_norm
last_batch_norm: *batch_norm
residual_type: simple

gnn: # Set as null to avoid a post-nn network
out_dim: *hidden_dim
hidden_dims: *hidden_dim
depth: 5
activation: none
last_activation: none
dropout: &dropout_gnn 0.2
last_dropout: *dropout_gnn
batch_norm: *batch_norm
last_batch_norm: *batch_norm
residual_type: simple
pooling: ['sum', 'max', 'dir1']
virtual_node: 'none'
layer_type: 'dgn-msgpass'
layer_kwargs:
# num_heads: 3
aggregators: [mean, max, sum, dir1/dx_abs]
scalers: [identity]

post_nn:
out_dim: 1
hidden_dims: *hidden_dim
depth: 3
activation: relu
last_activation: none
dropout: *dropout_mlp
last_dropout: 0.
batch_norm: *batch_norm
last_batch_norm: False
residual_type: simple

predictor:
metrics_on_progress_bar: ["mae", "mse", "pearsonr"]
loss_fun: bce
random_seed: *seed
optim_kwargs:
lr: 5.e-3
weight_decay: 0
lr_reduce_on_plateau_kwargs:
factor: 0.5
patience: 20
min_lr: 2.e-4
scheduler_kwargs:
monitor: &monitor mae/val
mode: &mode max
frequency: 1
target_nan_mask: 0 # null: no mask, 0: 0 mask, ignore: ignore nan values from loss


metrics:
- name: mae
metric: mae
threshold_kwargs: null

- name: pearsonr
metric: pearsonr
threshold_kwargs: null

- name: mse
metric: mse
threshold_kwargs: null

- name: spearmanr
metric: spearmanr
threshold_kwargs: null


- name: f1 > 5
metric: f1
num_classes: 2
average: micro
threshold_kwargs: &threshold_1
operator: greater
threshold: 5
th_on_preds: True
th_on_target: True
target_to_int: True

- name: f1 > 4
metric: f1
num_classes: 2
average: micro
threshold_kwargs: &threshold_2
operator: greater
threshold: 4
th_on_preds: True
th_on_target: True
target_to_int: True


trainer:
logger:
save_dir: logs/ogb-molpcqm4m
early_stopping:
monitor: *monitor
min_delta: 0
patience: 80
mode: *mode
model_checkpoint:
dirpath: models_checkpoints/ogb-molpcqm4m/
filename: "model"
monitor: *monitor
mode: *mode
save_top_k: 1
period: 1
trainer:
max_epochs: 1000
min_epochs: 100
gpus: 1
accumulate_grad_batches: 1


78 changes: 61 additions & 17 deletions goli/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,11 @@ def _extract_smiles_labels(

smiles_col = smiles_col_all[0]

if label_cols is None:
label_cols = df.columns.drop(smiles_col)

label_cols = check_arg_iterator(label_cols, enforce_type=list)

smiles = df[smiles_col].values
labels = [pd.to_numeric(df[col], errors="coerce") for col in label_cols]
labels = np.stack(labels, axis=1)
Expand Down Expand Up @@ -709,9 +714,19 @@ def _get_split_indices(
test_indices = splits["test"].dropna().astype("int").tolist()

# Filter train, val and test indices
train_indices = [ii for ii, idx in enumerate(sample_idx) if idx in train_indices]
val_indices = [ii for ii, idx in enumerate(sample_idx) if idx in val_indices]
test_indices = [ii for ii, idx in enumerate(sample_idx) if idx in test_indices]
_, train_idx, _ = np.intersect1d(sample_idx,train_indices, return_indices=True)
train_indices = train_idx.tolist()
train_indices.sort()
_, valid_idx, _ = np.intersect1d(sample_idx,val_indices, return_indices=True)
val_indices = valid_idx.tolist()
val_indices.sort()
_, test_idx, _ = np.intersect1d(sample_idx, test_indices, return_indices=True)
test_indices = test_idx.tolist()
test_indices.sort()

# train_indices = [ii for ii, idx in enumerate(sample_idx) if idx in train_indices]
# val_indices = [ii for ii, idx in enumerate(sample_idx) if idx in val_indices]
# test_indices = [ii for ii, idx in enumerate(sample_idx) if idx in test_indices]

return train_indices, val_indices, test_indices

Expand Down Expand Up @@ -849,7 +864,10 @@ def _load_dataset(self, metadata: dict):
"""Download, extract and load an OGB dataset."""

base_dir = fs.get_cache_dir("ogb")
dataset_dir = base_dir / metadata["download_name"]
if metadata['download_name'] == "pcqm4m":
dataset_dir = base_dir / (metadata["download_name"] + "_kddcup2021")
else:
dataset_dir = base_dir / metadata["download_name"]

if not dataset_dir.exists():

Expand All @@ -866,27 +884,53 @@ def _load_dataset(self, metadata: dict):
zf.extractall(base_dir)

# Load CSV file
df_path = dataset_dir / "mapping" / "mol.csv.gz"
if metadata['download_name']== "pcqm4m":
df_path = dataset_dir / "raw" / "data.csv.gz"
else:
df_path = dataset_dir / "mapping" / "mol.csv.gz"
logger.info(f"Loading {df_path} in memory.")
df = pd.read_csv(df_path)

# Load split from the OGB dataset and save them in a single CSV file
split_name = metadata["split"]
train_split = pd.read_csv(dataset_dir / "split" / split_name / "train.csv.gz", header=None) # type: ignore
val_split = pd.read_csv(dataset_dir / "split" / split_name / "valid.csv.gz", header=None) # type: ignore
test_split = pd.read_csv(dataset_dir / "split" / split_name / "test.csv.gz", header=None) # type: ignore
if metadata['download_name'] == "pcqm4m":
split_name = metadata["split"]
split_dict = torch.load(dataset_dir / "split_dict.pt")
train_split = pd.DataFrame(split_dict['train'])
val_split = pd.DataFrame(split_dict['valid'])
test_split = pd.DataFrame(split_dict['test'])
splits = pd.concat([train_split, val_split, test_split], axis=1) # type: ignore
splits.columns = ["train", "val", "test"]

splits_path = dataset_dir / "split"
if not splits_path.exists():
os.makedirs(splits_path)
splits_path = dataset_dir / f"{split_name}.csv.gz"
else:
splits_path = splits_path / f"{split_name}.csv.gz"
logger.info(f"Saving splits to {splits_path}")
splits.to_csv(splits_path, index=None)
else:
split_name = metadata["split"]
train_split = pd.read_csv(dataset_dir / "split" / split_name / "train.csv.gz", header=None) # type: ignore
val_split = pd.read_csv(dataset_dir / "split" / split_name / "valid.csv.gz", header=None) # type: ignore
test_split = pd.read_csv(dataset_dir / "split" / split_name / "test.csv.gz", header=None) # type: ignore

splits = pd.concat([train_split, val_split, test_split], axis=1) # type: ignore
splits.columns = ["train", "val", "test"]
splits = pd.concat([train_split, val_split, test_split], axis=1) # type: ignore
splits.columns = ["train", "val", "test"]

splits_path = dataset_dir / "split" / f"{split_name}.csv.gz"
logger.info(f"Saving splits to {splits_path}")
splits.to_csv(splits_path, index=None)
splits_path = dataset_dir / "split" / f"{split_name}.csv.gz"
logger.info(f"Saving splits to {splits_path}")
splits.to_csv(splits_path, index=None)

# Get column names: OGB columns are predictable
idx_col = df.columns[-1]
smiles_col = df.columns[-2]
label_cols = df.columns[:-2].to_list()
if metadata['download_name'] == "pcqm4m":
idx_col = df.columns[0]
smiles_col = df.columns[-2]
label_cols = df.columns[-1:].to_list()
else:
idx_col = df.columns[-1]
smiles_col = df.columns[-2]
label_cols = df.columns[:-2].to_list()

return df, idx_col, smiles_col, label_cols, splits_path

Expand Down

0 comments on commit f92c56a

Please sign in to comment.