diff --git a/notebooks/data_splitting_random.ipynb b/notebooks/data_splitting_random.ipynb new file mode 100644 index 0000000..c565094 --- /dev/null +++ b/notebooks/data_splitting_random.ipynb @@ -0,0 +1,171 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Data Split Creation\n", + "\n", + "This notebook creates data splits used to evaluate gRNAde on randomly split RNAs.\n", + "\n", + "**Workflow:**\n", + "1. Order the samples based on some metric:\n", + " - Avg. RMSD among available structures\n", + " - Total structures available\n", + "2. Training, validation, and test splits become progressively harder.\n", + " - Top 100 samples with highest metric -- test set.\n", + " - Next 100 samples with highest metric -- validation set.\n", + " - All remaining samples -- training set.\n", + " - Very large (> 1000 nts) or very small (< 10nts) RNAs -- training set.\n", + "\n", + "Note that we separate very large RNA samples (> 1000 nts) from clustering and directly add these to the training set, as it is unlikely that we want to redesign very large RNAs. Likewise for very short RNA samples (< 10 nts)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "import sys\n", + "sys.path.append('../')\n", + "\n", + "import os\n", + "import subprocess\n", + "import numpy as np\n", + "import pandas as pd\n", + "import torch\n", + "from tqdm import tqdm\n", + "\n", + "import matplotlib.pyplot as plt\n", + "from mpl_toolkits.axes_grid1.inset_locator import inset_axes, InsetPosition, mark_inset\n", + "import seaborn as sns\n", + "\n", + "from Bio import SeqIO\n", + "from Bio.Seq import Seq\n", + "from Bio.SeqRecord import SeqRecord\n", + "\n", + "from src.data_utils import get_avg_rmsds" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load data list\n", + "data_list = torch.load(os.path.join(\"../data/\", \"processed.pt\"))\n", + "print(len(data_list))\n", + "\n", + "# List of sample sequences (used to create .fasta input file)\n", + "seq_list = []\n", + "for idx, data in enumerate(data_list):\n", + " seq = data[\"seq\"]\n", + " seq_list.append(SeqRecord(Seq(seq), id=str(idx))) # the ID for each sequence is its index in data_list\n", + "\n", + "# List of intra-sequence avg. RMSDs\n", + "rmsd_list = get_avg_rmsds(data_list)\n", + "\n", + "# List of number of structures per sequence\n", + "count_list = [len(data[\"coords_list\"]) for data in data_list]\n", + "\n", + "assert len(data_list) == len(seq_list) == len(rmsd_list) == len(count_list)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# RMSD Split\n", + "\n", + "# Zip the two lists together\n", + "zipped = zip(list(range(len(data_list))), rmsd_list)\n", + "# Sort the zipped list based on the values (descending order, highest first)\n", + "sorted_zipped = sorted(zipped, key=lambda x: x[1], reverse=True)\n", + "# Unzip the sorted list back into two separate lists\n", + "sorted_data_list_idx, sorted_rmsd_list = zip(*sorted_zipped)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_idx_list = []\n", + "val_idx_list = []\n", + "train_idx_list = []\n", + "\n", + "for idx, avg_rmsd in sorted_zipped:\n", + " \n", + " num_structs = count_list[idx] # len(data_list[idx]['coords_list'])\n", + " \n", + " seq_len = len(seq_list[idx])\n", + "\n", + " if seq_len < 1000 and seq_len > 10:\n", + "\n", + " # Test set\n", + " if len(test_idx_list) < 100:\n", + " test_idx_list.append(idx)\n", + " \n", + " # Validation set\n", + " elif len(val_idx_list) < 100:\n", + " val_idx_list.append(idx)\n", + " \n", + " # Training set\n", + " else:\n", + " train_idx_list.append(idx)\n", + " \n", + " # Training set\n", + " else:\n", + " train_idx_list.append(idx)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "assert len(test_idx_list) + len(val_idx_list) + len(train_idx_list) == len(data_list)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "torch.save((train_idx_list, val_idx_list, test_idx_list), \"../data/random_rmsd_split.pt\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.16" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/cluster_seq_identity.ipynb b/notebooks/data_splitting_seqid.ipynb similarity index 54% rename from notebooks/cluster_seq_identity.ipynb rename to notebooks/data_splitting_seqid.ipynb index 6e4e9de..fa90059 100644 --- a/notebooks/cluster_seq_identity.ipynb +++ b/notebooks/data_splitting_seqid.ipynb @@ -5,10 +5,26 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Sequence Identity Split Creation\n", + "# Data Split Creation\n", "\n", - "This notebook creates the Sequence Identity split used to evaluate gRNAde on biologically dissimilar clusters of RNAs.\n", - "We cluster the sequences based on nucleotide similarity using CD-HIT (Fu et al., 2012) with an identity threshold of 80% to create training, validation and test sets." + "This notebook creates data splits used to evaluate gRNAde on biologically dissimilar clusters of RNAs.\n", + "\n", + "**Workflow:**\n", + "1. Cluster RNA sample sequences into groups based on: \n", + " - Sequence identity -- CD-HIT (Fu et al., 2012) with identity threshold of 90%.\n", + " - Structural similarity -- US-align with similarity threshold 0.45 (TODO).\n", + "2. Order the clusters based on some metric:\n", + " - Avg. of intra-sequence avg. RMSD among available structures\n", + " - Avg. of intra-sequence number of structures available\n", + "3. Training, validation, and test splits become progressively harder.\n", + " - Top 100 samples from clusters with highest metric -- test set.\n", + " - Next 100 samples from clusters with highest metric -- validation set.\n", + " - All remaining samples -- training set.\n", + " - For clusters with >20 samples within them -- training set.\n", + " - Very large (> 1000 nts) or very small (< 10nts) RNAs -- training set.\n", + "4. If any samples were not assigned clusters, append them to the training set.\n", + "\n", + "Note that we separate very large RNA samples (> 1000 nts) from clustering and directly add these to the training set, as it is unlikely that we want to redesign very large RNAs. Likewise for very short RNA samples (< 10 nts)." ] }, { @@ -25,11 +41,19 @@ "import os\n", "import subprocess\n", "import numpy as np\n", + "import pandas as pd\n", "import torch\n", "from tqdm import tqdm\n", + "\n", + "import matplotlib.pyplot as plt\n", + "from mpl_toolkits.axes_grid1.inset_locator import inset_axes, InsetPosition, mark_inset\n", + "import seaborn as sns\n", + "\n", "from Bio import SeqIO\n", "from Bio.Seq import Seq\n", - "from Bio.SeqRecord import SeqRecord" + "from Bio.SeqRecord import SeqRecord\n", + "\n", + "from src.data_utils import get_avg_rmsds" ] }, { @@ -38,7 +62,7 @@ "metadata": {}, "outputs": [], "source": [ - "def run_cd_hit_est(\n", + "def create_clusters_sequence_identity(\n", " input_sequences, \n", " identity_threshold = 0.9,\n", " word_size = 2,\n", @@ -75,9 +99,9 @@ " seq_idx_to_cluster[sequence_id] = current_cluster\n", "\n", " # Delete temporary files\n", - " # os.remove(input_file)\n", - " # os.remove(output_file)\n", - " # os.remove(output_file + \".clstr\")\n", + " os.remove(input_file)\n", + " os.remove(output_file)\n", + " os.remove(output_file + \".clstr\")\n", "\n", " return clustered_sequences, seq_idx_to_cluster" ] @@ -88,13 +112,23 @@ "metadata": {}, "outputs": [], "source": [ + "# Load data list\n", "data_list = torch.load(os.path.join(\"../data/\", \"processed.pt\"))\n", - "seq_list = []\n", + "print(len(data_list))\n", "\n", + "# List of sample sequences (used to create .fasta input file)\n", + "seq_list = []\n", "for idx, data in enumerate(data_list):\n", " seq = data[\"seq\"]\n", " seq_list.append(SeqRecord(Seq(seq), id=str(idx))) # the ID for each sequence is its index in data_list\n", - "print(len(seq_list))" + "\n", + "# List of intra-sequence avg. RMSDs\n", + "rmsd_list = get_avg_rmsds(data_list)\n", + "\n", + "# List of number of structures per sequence\n", + "count_list = [len(data[\"coords_list\"]) for data in data_list]\n", + "\n", + "assert len(data_list) == len(seq_list) == len(rmsd_list) == len(count_list)" ] }, { @@ -104,7 +138,7 @@ "outputs": [], "source": [ "# Cluster at 80% sequence identity (lowest currently possible)\n", - "clustered_sequences, seq_idx_to_cluster = run_cd_hit_est(seq_list, identity_threshold=0.8, word_size=3)" + "clustered_sequences, seq_idx_to_cluster = create_clusters_sequence_identity(seq_list, identity_threshold=0.8, word_size=3)" ] }, { @@ -123,7 +157,8 @@ "metadata": {}, "outputs": [], "source": [ - "# Sanity check: it seems short sequences are not being clustered\n", + "# Sanity check: it seems very short sequences (<10nt) are not being clustered.\n", + "# These will be added to the training set after initial splitting.\n", "try:\n", " # Why does this fail? Guess: sequences are too short?\n", " assert len(seq_idx_to_cluster.keys()) == len(seq_list)\n", @@ -146,11 +181,20 @@ "metadata": {}, "outputs": [], "source": [ - "# Cluster sizes: number of sequences in each cluster\n", - "cluster_ids, cluster_sizes = np.unique(list(seq_idx_to_cluster.values()), return_counts=True)\n", - "for id, size in zip(cluster_ids[:10], cluster_sizes[:10]):\n", - " print(id, size)\n", - "# Print some examples" + "# seq_idx_to_cluster: (index in data_list: cluster ID)\n", + "# (NEW) cluster_to_seq_idx_list: (cluster ID: list of indices in data_list)\n", + "cluster_to_seq_idx_list = {}\n", + "for seq_idx, cluster in seq_idx_to_cluster.items():\n", + " # Sanity check to filter very large or very small RNAs\n", + " if len(seq_list[seq_idx]) > 1000 or len(seq_list[seq_idx]) < 10 and seq_idx not in idx_not_clustered:\n", + " idx_not_clustered.append(seq_idx)\n", + " # print(f\"Pruned idx {seq_idx} of length {len(seq_list[seq_idx])}.\")\n", + " else:\n", + " if cluster in cluster_to_seq_idx_list.keys():\n", + " cluster_to_seq_idx_list[cluster].append(seq_idx)\n", + " else:\n", + " cluster_to_seq_idx_list[cluster] = [seq_idx]\n", + "print(\"Number of unassigned indices (not clustered + too large + too small): \", len(idx_not_clustered))" ] }, { @@ -159,14 +203,37 @@ "metadata": {}, "outputs": [], "source": [ - "# seq_idx_to_cluster: (index in data_list: cluster ID)\n", - "# (NEW) cluster_to_seq_idx_list: (cluster ID: list of indices in data_list)\n", - "cluster_to_seq_idx_list = {}\n", - "for seq_idx, cluster in seq_idx_to_cluster.items():\n", - " if cluster in cluster_to_seq_idx_list.keys():\n", - " cluster_to_seq_idx_list[cluster].append(seq_idx)\n", - " else:\n", - " cluster_to_seq_idx_list[cluster] = [seq_idx]" + "# Cluster sizes: number of sequences in each cluster\n", + "cluster_ids = list(cluster_to_seq_idx_list.keys())\n", + "cluster_sizes = [len(list) for list in cluster_to_seq_idx_list.values()]\n", + "\n", + "# Number of structures in each cluster (total and intra-sequence avg.)\n", + "total_structs_list = []\n", + "avg_structs_list = []\n", + "avg_rmsds_list = []\n", + "avg_seq_len_list = []\n", + "for cluster, seq_idx_list in cluster_to_seq_idx_list.items():\n", + " count = []\n", + " rmsds = []\n", + " lens = []\n", + " for seq_idx in seq_idx_list:\n", + " count.append(count_list[seq_idx])\n", + " rmsds.append(rmsd_list[seq_idx])\n", + " lens.append(len(seq_list[seq_idx]))\n", + " total_structs_list.append(np.sum(count))\n", + " avg_structs_list.append(np.mean(count))\n", + " avg_rmsds_list.append(np.mean(rmsds))\n", + " avg_seq_len_list.append(np.mean(lens))\n", + "\n", + "df = pd.DataFrame({\n", + " 'Cluster ID': cluster_ids,\n", + " 'Cluster size': cluster_sizes,\n", + " 'Total no. structures': total_structs_list,\n", + " 'Avg. sequence length': avg_seq_len_list,\n", + " 'Avg. intra-sequence no. structures': avg_structs_list,\n", + " 'Avg. intra-sequence avg. RMSD': avg_rmsds_list,\n", + "})\n", + "df" ] }, { @@ -175,19 +242,14 @@ "metadata": {}, "outputs": [], "source": [ - "# Cluster sizes: number of structures (total) in each cluster\n", - "cluster_sizes_structs = []\n", - "for cluster, seq_idx_list in cluster_to_seq_idx_list.items():\n", - " count = 0\n", - " for seq_idx in seq_idx_list:\n", - " count += len(data_list[seq_idx]['coords_list'])\n", - " cluster_sizes_structs.append(count)\n", + "# RMSD Split\n", "\n", - "# Cluster sequence size and structure size\n", - "print(\"cluster ID, # sequences, total # structures\")\n", - "for id, size, size_structs in zip(cluster_ids[:10], cluster_sizes[:10], cluster_sizes_structs[:10]):\n", - " print(id, size, size_structs)\n", - "# Print some examples" + "# Zip the two lists together\n", + "zipped = zip(cluster_ids, avg_rmsds_list)\n", + "# Sort the zipped list based on the values (descending order, highest first)\n", + "sorted_zipped = sorted(zipped, key=lambda x: x[1], reverse=True)\n", + "# Unzip the sorted list back into two separate lists\n", + "sorted_cluster_ids, sorted_avg_rmsds_list = zip(*sorted_zipped)" ] }, { @@ -200,17 +262,19 @@ "val_idx_list = []\n", "train_idx_list = []\n", "\n", - "# Some heuristics\n", - "# * Add samples to validation and test sets till their sizes are filled (200 samples), after which add everything to the train set\n", - "# * Do not add very large seqeuence clusters (sizes > 100) to validation or test set\n", - "# \n", + "for cluster in sorted_cluster_ids:\n", + " seq_idx_list = cluster_to_seq_idx_list[cluster]\n", + " cluster_size = len(seq_idx_list)\n", "\n", - "for cluster, seq_idx_list in cluster_to_seq_idx_list.items():\n", - " \n", - " if len(test_idx_list) < 200 and cluster_sizes[cluster] < 100:\n", + " # Test set\n", + " if len(test_idx_list) < 100 and cluster_size < 25:\n", " test_idx_list += seq_idx_list\n", - " elif len(val_idx_list) < 200 and cluster_sizes[cluster] < 100:\n", + " \n", + " # Validation set\n", + " elif len(val_idx_list) < 100 and cluster_size < 25:\n", " val_idx_list += seq_idx_list\n", + " \n", + " # Training set\n", " else:\n", " train_idx_list += seq_idx_list" ] @@ -235,7 +299,7 @@ "metadata": {}, "outputs": [], "source": [ - "torch.save((train_idx_list, val_idx_list, test_idx_list), \"../data/seq_identity_split.pt\")" + "torch.save((train_idx_list, val_idx_list, test_idx_list), \"../data/seqid_rmsd_split.pt\")" ] } ], diff --git a/notebooks/data_stats.ipynb b/notebooks/data_stats.ipynb index b0a89e8..6ac4ca6 100644 --- a/notebooks/data_stats.ipynb +++ b/notebooks/data_stats.ipynb @@ -172,7 +172,7 @@ "# ax2.tick_params(axis='x', which='major', pad=8)\n", "\n", "# Display the plot\n", - "plt.savefig('hist_seq_len.pdf', dpi=300)\n", + "# plt.savefig('hist_seq_len.pdf', dpi=300)\n", "plt.show()" ] }, @@ -261,7 +261,7 @@ "# ax2.tick_params(axis='x', which='major', pad=8)\n", "\n", "# Display the plot\n", - "plt.savefig('hist_num_struct_per_seq.pdf', dpi=300)\n", + "# plt.savefig('hist_num_struct_per_seq.pdf', dpi=300)\n", "plt.show()" ] }, @@ -412,7 +412,7 @@ "# ax2.set_ylim([0,1000])\n", "\n", "# Display the plot\n", - "plt.savefig('hist_rmsd_per_sequence.pdf', dpi=300)\n", + "# plt.savefig('hist_rmsd_per_sequence.pdf', dpi=300)\n", "plt.show()" ] }, @@ -529,7 +529,7 @@ "# )\n", "\n", "# Display the plot\n", - "plt.savefig('bivariate_seq_vs_rmsd.pdf', dpi=300)\n", + "# plt.savefig('bivariate_seq_vs_rmsd.pdf', dpi=300)\n", "plt.show()\n" ] }, @@ -545,6 +545,13 @@ "# show the plot\n", "plt.show()\n" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": {