diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 88f3986..914eda7 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -34,19 +34,19 @@ jobs: - name: Install metl package from metl-pretrained repo shell: bash --login {0} run: pip install git+https://github.com/gitter-lab/metl-pretrained - # Log conda environment contents + # Log conda environment contents - name: Log conda environment shell: bash --login {0} run: conda list - # Pretrain source model on GFP Rosetta dataset + # Pretrain source model on GFP Rosetta dataset - name: Pretrain source METL model shell: bash --login {0} run: python code/train_source_model.py @args/pretrain_avgfp_local.txt --max_epochs 5 --limit_train_batches 5 --limit_val_batches 5 --limit_test_batches 5 - # Finetune target model on GFP DMS dataset + # Finetune target model on GFP DMS dataset - name: Finetune target METL model shell: bash --login {0} run: python code/train_target_model.py @args/finetune_avgfp_local.txt --enable_progress_bar false --enable_simple_progress_messages --max_epochs 50 --unfreeze_backbone_at_epoch 25 - # Load target model checkpoint and run inference on example variants + # Load target model checkpoint and run inference on example variants - name: Load and test target METL model shell: bash --login {0} run: python code/tests.py --checkpoint_path output/training_logs/DgLkMZxu/checkpoints/epoch=49-step=50.ckpt --variants E3K,G102S;T36P,S203T,K207R;V10A,D19G,F25S,E113V --dataset avgfp diff --git a/code/tests.py b/code/tests.py index 3746f56..f0f26d6 100644 --- a/code/tests.py +++ b/code/tests.py @@ -4,6 +4,8 @@ import torch import utils +from argparse import ArgumentParser + def load_checkpoint_run_inference(checkpoint_path, variants, dataset): """ loads a finetuned 3D model from a checkpoint and scores variants with the model """