Skip to content

Commit

Permalink
Update notebooks
Browse files Browse the repository at this point in the history
  • Loading branch information
chaitjo committed Sep 11, 2023
1 parent e64a781 commit 66019bb
Show file tree
Hide file tree
Showing 3 changed files with 292 additions and 50 deletions.
171 changes: 171 additions & 0 deletions notebooks/data_splitting_random.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Data Split Creation\n",
"\n",
"This notebook creates data splits used to evaluate gRNAde on randomly split RNAs.\n",
"\n",
"**Workflow:**\n",
"1. Order the samples based on some metric:\n",
" - Avg. RMSD among available structures\n",
" - Total structures available\n",
"2. Training, validation, and test splits become progressively harder.\n",
" - Top 100 samples with highest metric -- test set.\n",
" - Next 100 samples with highest metric -- validation set.\n",
" - All remaining samples -- training set.\n",
" - Very large (> 1000 nts) or very small (< 10nts) RNAs -- training set.\n",
"\n",
"Note that we separate very large RNA samples (> 1000 nts) from clustering and directly add these to the training set, as it is unlikely that we want to redesign very large RNAs. Likewise for very short RNA samples (< 10 nts)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"import sys\n",
"sys.path.append('../')\n",
"\n",
"import os\n",
"import subprocess\n",
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"from tqdm import tqdm\n",
"\n",
"import matplotlib.pyplot as plt\n",
"from mpl_toolkits.axes_grid1.inset_locator import inset_axes, InsetPosition, mark_inset\n",
"import seaborn as sns\n",
"\n",
"from Bio import SeqIO\n",
"from Bio.Seq import Seq\n",
"from Bio.SeqRecord import SeqRecord\n",
"\n",
"from src.data_utils import get_avg_rmsds"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Load data list\n",
"data_list = torch.load(os.path.join(\"../data/\", \"processed.pt\"))\n",
"print(len(data_list))\n",
"\n",
"# List of sample sequences (used to create .fasta input file)\n",
"seq_list = []\n",
"for idx, data in enumerate(data_list):\n",
" seq = data[\"seq\"]\n",
" seq_list.append(SeqRecord(Seq(seq), id=str(idx))) # the ID for each sequence is its index in data_list\n",
"\n",
"# List of intra-sequence avg. RMSDs\n",
"rmsd_list = get_avg_rmsds(data_list)\n",
"\n",
"# List of number of structures per sequence\n",
"count_list = [len(data[\"coords_list\"]) for data in data_list]\n",
"\n",
"assert len(data_list) == len(seq_list) == len(rmsd_list) == len(count_list)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# RMSD Split\n",
"\n",
"# Zip the two lists together\n",
"zipped = zip(list(range(len(data_list))), rmsd_list)\n",
"# Sort the zipped list based on the values (descending order, highest first)\n",
"sorted_zipped = sorted(zipped, key=lambda x: x[1], reverse=True)\n",
"# Unzip the sorted list back into two separate lists\n",
"sorted_data_list_idx, sorted_rmsd_list = zip(*sorted_zipped)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"test_idx_list = []\n",
"val_idx_list = []\n",
"train_idx_list = []\n",
"\n",
"for idx, avg_rmsd in sorted_zipped:\n",
" \n",
" num_structs = count_list[idx] # len(data_list[idx]['coords_list'])\n",
" \n",
" seq_len = len(seq_list[idx])\n",
"\n",
" if seq_len < 1000 and seq_len > 10:\n",
"\n",
" # Test set\n",
" if len(test_idx_list) < 100:\n",
" test_idx_list.append(idx)\n",
" \n",
" # Validation set\n",
" elif len(val_idx_list) < 100:\n",
" val_idx_list.append(idx)\n",
" \n",
" # Training set\n",
" else:\n",
" train_idx_list.append(idx)\n",
" \n",
" # Training set\n",
" else:\n",
" train_idx_list.append(idx)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"assert len(test_idx_list) + len(val_idx_list) + len(train_idx_list) == len(data_list)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"torch.save((train_idx_list, val_idx_list, test_idx_list), \"../data/random_rmsd_split.pt\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit 66019bb

Please sign in to comment.