This repository includes all the code for our study: Substrate Scope Contrastive Learning: Repurposing Human Bias to Learn Atomic Representations. Additionally, we provide a pre-trained model here and validation data, which encompasses the 500 most frequently used aryl halides from the CAS Content Collection
Executing create_env.sh
sets up the virtual environment, addressing most dependencies for this repository. Note that we used PyTorch 2.0 and CUDA 11.7 for our experiments. For different versions, adjust the correponding lines as needed, referring to the official PyG site for guidance.
To train the model, execute the following command from this directory:
scipts/pre_train.sh
For hyper-parameter tuning, run:
scipts/ray_tune.sh
For encoding new aryl halides, refer to this notebook. The necessary code includes (have not tested yet):
from substrate_metric_learning.networks import Net
from substrate_metric_learning.features import smiles_to_graph_substrate
config_path = os.path.join(HOME_DIR, "configs/hparams_default.yaml")
config = Objdict(yaml.safe_load(open(config_path)))
input_dim = 133
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_pretrained = Net(input_dim, config.hidden_channels, 1, config.num_layers, config.pool).to(device)
model_pretrained.load_state_dict(torch.load("results/default/gin_epoch_56_sum_r2_1.538.pth"))
@torch.no_grad()
def get_embedding_from_smi(smi_list, c_index_list, model, device):
assert len(smi_list) == len(c_index_list)
assert model.pool_method == 'c'
train_dataset = [smiles_to_graph_substrate(smiles=smi_list[i], s=0, y=0, atm_idx=[c_index_list[i]]) for i in range(len(smi_list))]
loader = DataLoader(train_dataset, 128, shuffle=False)
embeddings = []
for data in loader:
data = data.to(device)
_, emb = model(data.x, data.edge_index, data.batch, data.atm_idx)
embeddings.append(emb.cpu().detach().numpy())
return np.concatenate(embeddings, axis=0)
I use wandb for experiment monitoring. If you want to use wandb to log your results, please login with your wandb account first (see https://docs.wandb.ai/quickstart). If you don't want to use wandb, you can turn it off by using argument --wandb disabled
.
This project is licensed under the terms of the Apache 2.0 License.
Please contact [email protected] for help or submit an issue.