-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
292 additions
and
50 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Data Split Creation\n", | ||
"\n", | ||
"This notebook creates data splits used to evaluate gRNAde on randomly split RNAs.\n", | ||
"\n", | ||
"**Workflow:**\n", | ||
"1. Order the samples based on some metric:\n", | ||
" - Avg. RMSD among available structures\n", | ||
" - Total structures available\n", | ||
"2. Training, validation, and test splits become progressively harder.\n", | ||
" - Top 100 samples with highest metric -- test set.\n", | ||
" - Next 100 samples with highest metric -- validation set.\n", | ||
" - All remaining samples -- training set.\n", | ||
" - Very large (> 1000 nts) or very small (< 10nts) RNAs -- training set.\n", | ||
"\n", | ||
"Note that we separate very large RNA samples (> 1000 nts) from clustering and directly add these to the training set, as it is unlikely that we want to redesign very large RNAs. Likewise for very short RNA samples (< 10 nts)." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%load_ext autoreload\n", | ||
"%autoreload 2\n", | ||
"import sys\n", | ||
"sys.path.append('../')\n", | ||
"\n", | ||
"import os\n", | ||
"import subprocess\n", | ||
"import numpy as np\n", | ||
"import pandas as pd\n", | ||
"import torch\n", | ||
"from tqdm import tqdm\n", | ||
"\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"from mpl_toolkits.axes_grid1.inset_locator import inset_axes, InsetPosition, mark_inset\n", | ||
"import seaborn as sns\n", | ||
"\n", | ||
"from Bio import SeqIO\n", | ||
"from Bio.Seq import Seq\n", | ||
"from Bio.SeqRecord import SeqRecord\n", | ||
"\n", | ||
"from src.data_utils import get_avg_rmsds" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Load data list\n", | ||
"data_list = torch.load(os.path.join(\"../data/\", \"processed.pt\"))\n", | ||
"print(len(data_list))\n", | ||
"\n", | ||
"# List of sample sequences (used to create .fasta input file)\n", | ||
"seq_list = []\n", | ||
"for idx, data in enumerate(data_list):\n", | ||
" seq = data[\"seq\"]\n", | ||
" seq_list.append(SeqRecord(Seq(seq), id=str(idx))) # the ID for each sequence is its index in data_list\n", | ||
"\n", | ||
"# List of intra-sequence avg. RMSDs\n", | ||
"rmsd_list = get_avg_rmsds(data_list)\n", | ||
"\n", | ||
"# List of number of structures per sequence\n", | ||
"count_list = [len(data[\"coords_list\"]) for data in data_list]\n", | ||
"\n", | ||
"assert len(data_list) == len(seq_list) == len(rmsd_list) == len(count_list)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# RMSD Split\n", | ||
"\n", | ||
"# Zip the two lists together\n", | ||
"zipped = zip(list(range(len(data_list))), rmsd_list)\n", | ||
"# Sort the zipped list based on the values (descending order, highest first)\n", | ||
"sorted_zipped = sorted(zipped, key=lambda x: x[1], reverse=True)\n", | ||
"# Unzip the sorted list back into two separate lists\n", | ||
"sorted_data_list_idx, sorted_rmsd_list = zip(*sorted_zipped)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"test_idx_list = []\n", | ||
"val_idx_list = []\n", | ||
"train_idx_list = []\n", | ||
"\n", | ||
"for idx, avg_rmsd in sorted_zipped:\n", | ||
" \n", | ||
" num_structs = count_list[idx] # len(data_list[idx]['coords_list'])\n", | ||
" \n", | ||
" seq_len = len(seq_list[idx])\n", | ||
"\n", | ||
" if seq_len < 1000 and seq_len > 10:\n", | ||
"\n", | ||
" # Test set\n", | ||
" if len(test_idx_list) < 100:\n", | ||
" test_idx_list.append(idx)\n", | ||
" \n", | ||
" # Validation set\n", | ||
" elif len(val_idx_list) < 100:\n", | ||
" val_idx_list.append(idx)\n", | ||
" \n", | ||
" # Training set\n", | ||
" else:\n", | ||
" train_idx_list.append(idx)\n", | ||
" \n", | ||
" # Training set\n", | ||
" else:\n", | ||
" train_idx_list.append(idx)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"assert len(test_idx_list) + len(val_idx_list) + len(train_idx_list) == len(data_list)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"torch.save((train_idx_list, val_idx_list, test_idx_list), \"../data/random_rmsd_split.pt\")" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.8.16" | ||
}, | ||
"orig_nbformat": 4 | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.