-
Notifications
You must be signed in to change notification settings - Fork 3
57 lines (56 loc) · 2.36 KB
/
test.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
name: Tests
on:
- push
- pull_request
jobs:
# Installs the conda environment and trains METL
train:
name: Test METL training
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
# Could also test on the beta M1 macOS runner
# https://docs.github.com/en/actions/using-github-hosted-runners/about-github-hosted-runners/about-github-hosted-runners#standard-github-hosted-runners-for-public-repositories
os:
- macos-latest
- ubuntu-latest
- windows-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
# Can set up package caching later conda-incubator/setup-miniconda
- name: Install conda environment
uses: conda-incubator/setup-miniconda@v2
with:
activate-environment: metl
environment-file: environment.yml
auto-activate-base: false
miniforge-variant: Mambaforge
miniforge-version: 'latest'
use-mamba: true
# Installs latest commit from main branch
- name: Install metl package from metl-pretrained repo
shell: bash --login {0}
run: pip install git+https://github.com/gitter-lab/metl-pretrained
# Log conda environment contents
- name: Log conda environment
shell: bash --login {0}
run: conda list
# Pretrain source model on GFP Rosetta dataset
- name: Pretrain source METL model
shell: bash --login {0}
run: python code/train_source_model.py @args/pretrain_avgfp_local.txt --max_epochs 5 --limit_train_batches 5 --limit_val_batches 5 --limit_test_batches 5
# Finetune target model on GFP DMS dataset
- name: Finetune target METL model
shell: bash --login {0}
run: python code/train_target_model.py @args/finetune_avgfp_local.txt --enable_progress_bar false --enable_simple_progress_messages --max_epochs 50 --unfreeze_backbone_at_epoch 25
# Load target model checkpoint and run inference on example variants
- name: Load and test target METL model
shell: bash --login {0}
# Log directory name is different on every run
run: |
cp=output/training_logs/*/checkpoints/epoch=49-step=50.ckpt
python code/convert_ckpt.py --ckpt_path $cp
cp=output/training_logs/*/checkpoints/*.pt
python code/tests.py --ckpt_path $cp --variants E3K,G102S_T36P,S203T,K207R_V10A,D19G,F25S,E113V --dataset avgfp