diff --git a/docs/tutorials/mouse_biccn.ipynb b/docs/tutorials/mouse_biccn.ipynb index 786a41d3..986d5c8e 100644 --- a/docs/tutorials/mouse_biccn.ipynb +++ b/docs/tutorials/mouse_biccn.ipynb @@ -18,11 +18,11 @@ "name": "stderr", "output_type": "stream", "text": [ - "2024-06-25 17:33:01.205298: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", - "2024-06-25 17:33:01.244866: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", + "2024-06-26 14:46:47.842156: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", + "2024-06-26 14:46:47.881644: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", "To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 AVX512_FP16 AVX_VNNI AMX_TILE AMX_INT8 AMX_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", - "2024-06-25 17:33:02.721837: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n", - "2024-06-25 17:33:04.704132: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1928] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 78790 MB memory: -> device: 0, name: NVIDIA H100 80GB HBM3, pci bus id: 0000:55:00.0, compute capability: 9.0\n" + "2024-06-26 14:46:49.415214: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n", + "2024-06-26 14:46:51.642221: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1928] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 78790 MB memory: -> device: 0, name: NVIDIA H100 80GB HBM3, pci bus id: 0000:55:00.0, compute capability: 9.0\n" ] } ], @@ -30,13 +30,16 @@ "import sys\n", "sys.path.insert(0, '/home/VIB.LOCAL/niklas.kempynck/.conda/envs/crested/lib/python3.11/site-packages')\n", "sys.path.insert(0,'/data/projects/c04/cbd-saerts/nkemp/software/CREsted/src')\n", - "#sys.path.remove('/mnt/modules/easybuild/software/SciPy-bundle/2023.07-gfbf-2023a/lib/python3.11/site-packages')\n", - "import crested\n" + "sys.path.remove('/mnt/modules/easybuild/software/SciPy-bundle/2023.07-gfbf-2023a/lib/python3.11/site-packages')\n", + "import crested\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "%matplotlib inline" ] }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -45,7 +48,7 @@ "" ] }, - "execution_count": 19, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -64,7 +67,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "2024-06-25T17:33:11.288334+0200 INFO Extracting values from 19 bigWig files...\n" + "2024-06-26T14:46:57.920142+0200 INFO Extracting values from 19 bigWig files...\n" ] }, { @@ -99,9 +102,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "2024-06-25T17:33:38.278601+0200 INFO Filtering on top k Gini scores...\n", - "2024-06-25T17:33:44.150386+0200 INFO Added normalization weights to adata.obsm['weights']...\n", - "2024-06-25T17:33:52.507340+0200 INFO After specificity filtering, kept 86887 out of 546993 regions.\n" + "2024-06-26T14:47:26.453591+0200 INFO Filtering on top k Gini scores...\n", + "2024-06-26T14:47:32.314144+0200 INFO Added normalization weights to adata.obsm['weights']...\n", + "2024-06-26T14:47:40.624157+0200 INFO After specificity filtering, kept 86887 out of 546993 regions.\n" ] } ], @@ -119,7 +122,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Mon Jun 24 17:36:30 2024 \n", + "Wed Jun 26 11:35:03 2024 \n", "+-----------------------------------------------------------------------------------------+\n", "| NVIDIA-SMI 555.42.02 Driver Version: 555.42.02 CUDA Version: 12.5 |\n", "|-----------------------------------------+------------------------+----------------------+\n", @@ -137,8 +140,7 @@ "| GPU GI CI PID Type Process name GPU Memory |\n", "| ID ID Usage |\n", "|=========================================================================================|\n", - "| 0 N/A N/A 539601 C ...ynck/.conda/envs/crested/bin/python 72556MiB |\n", - "| 0 N/A N/A 542471 C ...ynck/.conda/envs/crested/bin/python 6746MiB |\n", + "| 0 N/A N/A 559712 C ...ynck/.conda/envs/crested/bin/python 79310MiB |\n", "+-----------------------------------------------------------------------------------------+\n" ] } @@ -167,7 +169,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -308,7 +310,7 @@ "[86887 rows x 4 columns]" ] }, - "execution_count": 6, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -330,7 +332,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -460,7 +462,7 @@ "[7705 rows x 4 columns]" ] }, - "execution_count": 7, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -501,14 +503,14 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "2024-06-25T17:33:52.563735+0200 WARNING Chromsizes file not provided when shifting. Will not check if shifted regions are within chromosomes\n" + "2024-06-26T14:47:40.728963+0200 WARNING Chromsizes file not provided when shifting. Will not check if shifted regions are within chromosomes\n" ] } ], @@ -526,27 +528,7 @@ }, { "cell_type": "code", - "execution_count": 37, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "AnnDataLoader(dataset=AnnDataset(anndata_shape=(19, 440993), n_samples=881986, num_outputs=19, split=train, in_memory=False), batch_size=256, shuffle=True, one_hot_encode=True, drop_remainder=True)" - ] - }, - "execution_count": 37, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "datamodule.train_dataloader" - ] - }, - { - "cell_type": "code", - "execution_count": 10, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -568,7 +550,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -577,7 +559,7 @@ "" ] }, - "execution_count": 11, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -603,14 +585,14 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "TaskConfig(optimizer=, loss=, metrics=[, , , , , , ])\n" + "TaskConfig(optimizer=, loss=, metrics=[, , , , , , , ])\n" ] } ], @@ -618,7 +600,7 @@ "# Load the default configuration for training a topic classication model\n", "from crested.tl import default_configs, TaskConfig\n", "\n", - "config = default_configs(\"peak_regression\")\n", + "config = default_configs(\"peak_regression\", num_classes=19)\n", "print(config)\n", "\n", "# If you want to change some small parameters to an existing config, you can do it like this\n", @@ -647,7 +629,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -658,7 +640,7 @@ " config=config,\n", " project_name=\"deeppeak_benchmarking\",\n", " logger=\"wandb\",\n", - " run_name='crested_norm_TL'\n", + " run_name='crested_norm_TL_spearman'\n", ")\n" ] }, @@ -671,7 +653,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Fri Jun 21 11:05:50 2024 \n", + "Wed Jun 26 11:18:17 2024 \n", "+-----------------------------------------------------------------------------------------+\n", "| NVIDIA-SMI 555.42.02 Driver Version: 555.42.02 CUDA Version: 12.5 |\n", "|-----------------------------------------+------------------------+----------------------+\n", @@ -680,7 +662,7 @@ "| | | MIG M. |\n", "|=========================================+========================+======================|\n", "| 0 NVIDIA H100 80GB HBM3 On | 00000000:55:00.0 Off | 0 |\n", - "| N/A 48C P0 123W / 700W | 79323MiB / 81559MiB | 0% Default |\n", + "| N/A 42C P0 118W / 700W | 79323MiB / 81559MiB | 0% Default |\n", "| | | Disabled |\n", "+-----------------------------------------+------------------------+----------------------+\n", " \n", @@ -689,7 +671,7 @@ "| GPU GI CI PID Type Process name GPU Memory |\n", "| ID ID Usage |\n", "|=========================================================================================|\n", - "| 0 N/A N/A 516332 C ...ynck/.conda/envs/crested/bin/python 70340MiB |\n", + "| 0 N/A N/A 558471 C ...ynck/.conda/envs/crested/bin/python 79312MiB |\n", "+-----------------------------------------------------------------------------------------+\n" ] } @@ -700,9 +682,545 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mkemp\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" + ] + }, + { + "data": { + "text/html": [ + "wandb version 0.17.3 is available! To upgrade, please run:\n", + " $ pip install wandb --upgrade" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.17.0" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in /data/projects/c04/cbd-saerts/nkemp/software/CREsted/docs/tutorials/wandb/run-20240626_113547-70xblei4" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Syncing run crested_norm_TL_spearman to Weights & Biases (docs)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at https://wandb.ai/kemp/deeppeak_benchmarking" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at https://wandb.ai/kemp/deeppeak_benchmarking/runs/70xblei4" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Model: \"functional_1\"\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1mModel: \"functional_1\"\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n",
+       "┃ Layer (type)         Output Shape          Param #  Connected to      ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n",
+       "│ sequence            │ (None, 2114, 4)   │          0 │ -                 │\n",
+       "│ (InputLayer)        │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ conv1d (Conv1D)     │ (None, 2114, 512) │     10,240 │ sequence[0][0]    │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ batch_normalization │ (None, 2114, 512) │      2,048 │ conv1d[0][0]      │\n",
+       "│ (BatchNormalizatio… │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ activation          │ (None, 2114, 512) │          0 │ batch_normalizat… │\n",
+       "│ (Activation)        │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ dropout (Dropout)   │ (None, 2114, 512) │          0 │ activation[0][0]  │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ bpnet_1conv         │ (None, 2110, 512) │    786,432 │ dropout[0][0]     │\n",
+       "│ (Conv1D)            │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ bpnet_1bn           │ (None, 2110, 512) │      2,048 │ bpnet_1conv[0][0] │\n",
+       "│ (BatchNormalizatio… │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ bpnet_1activation   │ (None, 2110, 512) │          0 │ bpnet_1bn[0][0]   │\n",
+       "│ (Activation)        │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ bpnet_1crop         │ (None, 2110, 512) │          0 │ dropout[0][0]     │\n",
+       "│ (Cropping1D)        │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ add (Add)           │ (None, 2110, 512) │          0 │ bpnet_1activatio… │\n",
+       "│                     │                   │            │ bpnet_1crop[0][0] │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ bpnet_1dropout      │ (None, 2110, 512) │          0 │ add[0][0]         │\n",
+       "│ (Dropout)           │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ bpnet_2conv         │ (None, 2102, 512) │    786,432 │ bpnet_1dropout[0… │\n",
+       "│ (Conv1D)            │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ bpnet_2bn           │ (None, 2102, 512) │      2,048 │ bpnet_2conv[0][0] │\n",
+       "│ (BatchNormalizatio… │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ bpnet_2activation   │ (None, 2102, 512) │          0 │ bpnet_2bn[0][0]   │\n",
+       "│ (Activation)        │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ bpnet_2crop         │ (None, 2102, 512) │          0 │ bpnet_1dropout[0… │\n",
+       "│ (Cropping1D)        │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ add_1 (Add)         │ (None, 2102, 512) │          0 │ bpnet_2activatio… │\n",
+       "│                     │                   │            │ bpnet_2crop[0][0] │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ bpnet_2dropout      │ (None, 2102, 512) │          0 │ add_1[0][0]       │\n",
+       "│ (Dropout)           │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ bpnet_3conv         │ (None, 2086, 512) │    786,432 │ bpnet_2dropout[0… │\n",
+       "│ (Conv1D)            │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ bpnet_3bn           │ (None, 2086, 512) │      2,048 │ bpnet_3conv[0][0] │\n",
+       "│ (BatchNormalizatio… │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ bpnet_3activation   │ (None, 2086, 512) │          0 │ bpnet_3bn[0][0]   │\n",
+       "│ (Activation)        │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ bpnet_3crop         │ (None, 2086, 512) │          0 │ bpnet_2dropout[0… │\n",
+       "│ (Cropping1D)        │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ add_2 (Add)         │ (None, 2086, 512) │          0 │ bpnet_3activatio… │\n",
+       "│                     │                   │            │ bpnet_3crop[0][0] │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ bpnet_3dropout      │ (None, 2086, 512) │          0 │ add_2[0][0]       │\n",
+       "│ (Dropout)           │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ bpnet_4conv         │ (None, 2054, 512) │    786,432 │ bpnet_3dropout[0… │\n",
+       "│ (Conv1D)            │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ bpnet_4bn           │ (None, 2054, 512) │      2,048 │ bpnet_4conv[0][0] │\n",
+       "│ (BatchNormalizatio… │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ bpnet_4activation   │ (None, 2054, 512) │          0 │ bpnet_4bn[0][0]   │\n",
+       "│ (Activation)        │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ bpnet_4crop         │ (None, 2054, 512) │          0 │ bpnet_3dropout[0… │\n",
+       "│ (Cropping1D)        │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ add_3 (Add)         │ (None, 2054, 512) │          0 │ bpnet_4activatio… │\n",
+       "│                     │                   │            │ bpnet_4crop[0][0] │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ bpnet_4dropout      │ (None, 2054, 512) │          0 │ add_3[0][0]       │\n",
+       "│ (Dropout)           │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ bpnet_5conv         │ (None, 1990, 512) │    786,432 │ bpnet_4dropout[0… │\n",
+       "│ (Conv1D)            │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ bpnet_5bn           │ (None, 1990, 512) │      2,048 │ bpnet_5conv[0][0] │\n",
+       "│ (BatchNormalizatio… │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ bpnet_5activation   │ (None, 1990, 512) │          0 │ bpnet_5bn[0][0]   │\n",
+       "│ (Activation)        │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ bpnet_5crop         │ (None, 1990, 512) │          0 │ bpnet_4dropout[0… │\n",
+       "│ (Cropping1D)        │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ add_4 (Add)         │ (None, 1990, 512) │          0 │ bpnet_5activatio… │\n",
+       "│                     │                   │            │ bpnet_5crop[0][0] │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ bpnet_5dropout      │ (None, 1990, 512) │          0 │ add_4[0][0]       │\n",
+       "│ (Dropout)           │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ bpnet_6conv         │ (None, 1862, 512) │    786,432 │ bpnet_5dropout[0… │\n",
+       "│ (Conv1D)            │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ bpnet_6bn           │ (None, 1862, 512) │      2,048 │ bpnet_6conv[0][0] │\n",
+       "│ (BatchNormalizatio… │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ bpnet_6activation   │ (None, 1862, 512) │          0 │ bpnet_6bn[0][0]   │\n",
+       "│ (Activation)        │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ bpnet_6crop         │ (None, 1862, 512) │          0 │ bpnet_5dropout[0… │\n",
+       "│ (Cropping1D)        │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ add_5 (Add)         │ (None, 1862, 512) │          0 │ bpnet_6activatio… │\n",
+       "│                     │                   │            │ bpnet_6crop[0][0] │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ bpnet_6dropout      │ (None, 1862, 512) │          0 │ add_5[0][0]       │\n",
+       "│ (Dropout)           │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ bpnet_7conv         │ (None, 1606, 512) │    786,432 │ bpnet_6dropout[0… │\n",
+       "│ (Conv1D)            │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ bpnet_7bn           │ (None, 1606, 512) │      2,048 │ bpnet_7conv[0][0] │\n",
+       "│ (BatchNormalizatio… │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ bpnet_7activation   │ (None, 1606, 512) │          0 │ bpnet_7bn[0][0]   │\n",
+       "│ (Activation)        │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ bpnet_7crop         │ (None, 1606, 512) │          0 │ bpnet_6dropout[0… │\n",
+       "│ (Cropping1D)        │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ add_6 (Add)         │ (None, 1606, 512) │          0 │ bpnet_7activatio… │\n",
+       "│                     │                   │            │ bpnet_7crop[0][0] │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ bpnet_7dropout      │ (None, 1606, 512) │          0 │ add_6[0][0]       │\n",
+       "│ (Dropout)           │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ bpnet_8conv         │ (None, 1094, 512) │    786,432 │ bpnet_7dropout[0… │\n",
+       "│ (Conv1D)            │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ bpnet_8bn           │ (None, 1094, 512) │      2,048 │ bpnet_8conv[0][0] │\n",
+       "│ (BatchNormalizatio… │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ bpnet_8activation   │ (None, 1094, 512) │          0 │ bpnet_8bn[0][0]   │\n",
+       "│ (Activation)        │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ bpnet_8crop         │ (None, 1094, 512) │          0 │ bpnet_7dropout[0… │\n",
+       "│ (Cropping1D)        │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ add_7 (Add)         │ (None, 1094, 512) │          0 │ bpnet_8activatio… │\n",
+       "│                     │                   │            │ bpnet_8crop[0][0] │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ bpnet_8dropout      │ (None, 1094, 512) │          0 │ add_7[0][0]       │\n",
+       "│ (Dropout)           │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ global_average_poo… │ (None, 512)       │          0 │ bpnet_8dropout[0… │\n",
+       "│ (GlobalAveragePool… │                   │            │                   │\n",
+       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
+       "│ dense (Dense)       │ (None, 19)        │      9,747 │ global_average_p… │\n",
+       "└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n", + "│ sequence │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2114\u001b[0m, \u001b[38;5;34m4\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "│ (\u001b[38;5;33mInputLayer\u001b[0m) │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ conv1d (\u001b[38;5;33mConv1D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2114\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m10,240\u001b[0m │ sequence[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ batch_normalization │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2114\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m2,048\u001b[0m │ conv1d[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mBatchNormalizatio…\u001b[0m │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ activation │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2114\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ batch_normalizat… │\n", + "│ (\u001b[38;5;33mActivation\u001b[0m) │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ dropout (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2114\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ activation[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ bpnet_1conv │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2110\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m786,432\u001b[0m │ dropout[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mConv1D\u001b[0m) │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ bpnet_1bn │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2110\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m2,048\u001b[0m │ bpnet_1conv[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mBatchNormalizatio…\u001b[0m │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ bpnet_1activation │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2110\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ bpnet_1bn[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mActivation\u001b[0m) │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ bpnet_1crop │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2110\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ dropout[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mCropping1D\u001b[0m) │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ add (\u001b[38;5;33mAdd\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2110\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ bpnet_1activatio… │\n", + "│ │ │ │ bpnet_1crop[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ bpnet_1dropout │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2110\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ add[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mDropout\u001b[0m) │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ bpnet_2conv │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2102\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m786,432\u001b[0m │ bpnet_1dropout[\u001b[38;5;34m0\u001b[0m… │\n", + "│ (\u001b[38;5;33mConv1D\u001b[0m) │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ bpnet_2bn │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2102\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m2,048\u001b[0m │ bpnet_2conv[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mBatchNormalizatio…\u001b[0m │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ bpnet_2activation │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2102\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ bpnet_2bn[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mActivation\u001b[0m) │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ bpnet_2crop │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2102\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ bpnet_1dropout[\u001b[38;5;34m0\u001b[0m… │\n", + "│ (\u001b[38;5;33mCropping1D\u001b[0m) │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ add_1 (\u001b[38;5;33mAdd\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2102\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ bpnet_2activatio… │\n", + "│ │ │ │ bpnet_2crop[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ bpnet_2dropout │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2102\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ add_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mDropout\u001b[0m) │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ bpnet_3conv │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2086\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m786,432\u001b[0m │ bpnet_2dropout[\u001b[38;5;34m0\u001b[0m… │\n", + "│ (\u001b[38;5;33mConv1D\u001b[0m) │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ bpnet_3bn │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2086\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m2,048\u001b[0m │ bpnet_3conv[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mBatchNormalizatio…\u001b[0m │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ bpnet_3activation │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2086\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ bpnet_3bn[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mActivation\u001b[0m) │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ bpnet_3crop │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2086\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ bpnet_2dropout[\u001b[38;5;34m0\u001b[0m… │\n", + "│ (\u001b[38;5;33mCropping1D\u001b[0m) │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ add_2 (\u001b[38;5;33mAdd\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2086\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ bpnet_3activatio… │\n", + "│ │ │ │ bpnet_3crop[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ bpnet_3dropout │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2086\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ add_2[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mDropout\u001b[0m) │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ bpnet_4conv │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2054\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m786,432\u001b[0m │ bpnet_3dropout[\u001b[38;5;34m0\u001b[0m… │\n", + "│ (\u001b[38;5;33mConv1D\u001b[0m) │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ bpnet_4bn │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2054\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m2,048\u001b[0m │ bpnet_4conv[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mBatchNormalizatio…\u001b[0m │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ bpnet_4activation │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2054\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ bpnet_4bn[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mActivation\u001b[0m) │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ bpnet_4crop │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2054\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ bpnet_3dropout[\u001b[38;5;34m0\u001b[0m… │\n", + "│ (\u001b[38;5;33mCropping1D\u001b[0m) │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ add_3 (\u001b[38;5;33mAdd\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2054\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ bpnet_4activatio… │\n", + "│ │ │ │ bpnet_4crop[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ bpnet_4dropout │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2054\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ add_3[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mDropout\u001b[0m) │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ bpnet_5conv │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1990\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m786,432\u001b[0m │ bpnet_4dropout[\u001b[38;5;34m0\u001b[0m… │\n", + "│ (\u001b[38;5;33mConv1D\u001b[0m) │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ bpnet_5bn │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1990\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m2,048\u001b[0m │ bpnet_5conv[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mBatchNormalizatio…\u001b[0m │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ bpnet_5activation │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1990\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ bpnet_5bn[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mActivation\u001b[0m) │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ bpnet_5crop │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1990\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ bpnet_4dropout[\u001b[38;5;34m0\u001b[0m… │\n", + "│ (\u001b[38;5;33mCropping1D\u001b[0m) │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ add_4 (\u001b[38;5;33mAdd\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1990\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ bpnet_5activatio… │\n", + "│ │ │ │ bpnet_5crop[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ bpnet_5dropout │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1990\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ add_4[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mDropout\u001b[0m) │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ bpnet_6conv │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1862\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m786,432\u001b[0m │ bpnet_5dropout[\u001b[38;5;34m0\u001b[0m… │\n", + "│ (\u001b[38;5;33mConv1D\u001b[0m) │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ bpnet_6bn │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1862\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m2,048\u001b[0m │ bpnet_6conv[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mBatchNormalizatio…\u001b[0m │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ bpnet_6activation │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1862\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ bpnet_6bn[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mActivation\u001b[0m) │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ bpnet_6crop │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1862\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ bpnet_5dropout[\u001b[38;5;34m0\u001b[0m… │\n", + "│ (\u001b[38;5;33mCropping1D\u001b[0m) │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ add_5 (\u001b[38;5;33mAdd\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1862\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ bpnet_6activatio… │\n", + "│ │ │ │ bpnet_6crop[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ bpnet_6dropout │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1862\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ add_5[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mDropout\u001b[0m) │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ bpnet_7conv │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1606\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m786,432\u001b[0m │ bpnet_6dropout[\u001b[38;5;34m0\u001b[0m… │\n", + "│ (\u001b[38;5;33mConv1D\u001b[0m) │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ bpnet_7bn │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1606\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m2,048\u001b[0m │ bpnet_7conv[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mBatchNormalizatio…\u001b[0m │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ bpnet_7activation │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1606\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ bpnet_7bn[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mActivation\u001b[0m) │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ bpnet_7crop │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1606\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ bpnet_6dropout[\u001b[38;5;34m0\u001b[0m… │\n", + "│ (\u001b[38;5;33mCropping1D\u001b[0m) │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ add_6 (\u001b[38;5;33mAdd\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1606\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ bpnet_7activatio… │\n", + "│ │ │ │ bpnet_7crop[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ bpnet_7dropout │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1606\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ add_6[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mDropout\u001b[0m) │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ bpnet_8conv │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1094\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m786,432\u001b[0m │ bpnet_7dropout[\u001b[38;5;34m0\u001b[0m… │\n", + "│ (\u001b[38;5;33mConv1D\u001b[0m) │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ bpnet_8bn │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1094\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m2,048\u001b[0m │ bpnet_8conv[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mBatchNormalizatio…\u001b[0m │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ bpnet_8activation │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1094\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ bpnet_8bn[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mActivation\u001b[0m) │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ bpnet_8crop │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1094\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ bpnet_7dropout[\u001b[38;5;34m0\u001b[0m… │\n", + "│ (\u001b[38;5;33mCropping1D\u001b[0m) │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ add_7 (\u001b[38;5;33mAdd\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1094\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ bpnet_8activatio… │\n", + "│ │ │ │ bpnet_8crop[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ bpnet_8dropout │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1094\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ add_7[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mDropout\u001b[0m) │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ global_average_poo… │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ bpnet_8dropout[\u001b[38;5;34m0\u001b[0m… │\n", + "│ (\u001b[38;5;33mGlobalAveragePool…\u001b[0m │ │ │ │\n", + "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", + "│ dense (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m19\u001b[0m) │ \u001b[38;5;34m9,747\u001b[0m │ global_average_p… │\n", + "└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Total params: 6,329,875 (24.15 MB)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m6,329,875\u001b[0m (24.15 MB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Trainable params: 6,320,659 (24.11 MB)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m6,320,659\u001b[0m (24.11 MB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Non-trainable params: 9,216 (36.00 KB)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m9,216\u001b[0m (36.00 KB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "None\n", + "2024-06-26T11:35:49.840921+0200 INFO Number of GPUs in use: 1\n", + "Epoch 1/50\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(Label(value='0.030 MB of 0.030 MB uploaded\\r'), FloatProgress(value=1.0, max=1.0)))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run crested_norm_TL_spearman at: https://wandb.ai/kemp/deeppeak_benchmarking/runs/70xblei4
View project at: https://wandb.ai/kemp/deeppeak_benchmarking
Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Find logs at: ./wandb/run-20240626_113547-70xblei4/logs" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "ename": "AttributeError", + "evalue": "'SymbolicTensor' object has no attribute 'assign_add'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[12], line 13\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[38;5;66;03m# Optionally, configure Python's logging module to suppress INFO logs\u001b[39;00m\n\u001b[1;32m 11\u001b[0m logging\u001b[38;5;241m.\u001b[39mgetLogger(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtensorflow\u001b[39m\u001b[38;5;124m'\u001b[39m)\u001b[38;5;241m.\u001b[39msetLevel(logging\u001b[38;5;241m.\u001b[39mWARNING)\n\u001b[0;32m---> 13\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mepochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m50\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mearly_stopping_patience\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m6\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmixed_precision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlearning_rate_reduce\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlearning_rate_reduce_patience\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m3\u001b[39;49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/data/projects/c04/cbd-saerts/nkemp/software/CREsted/src/crested/tl/_crested.py:314\u001b[0m, in \u001b[0;36mCrested.fit\u001b[0;34m(self, epochs, mixed_precision, model_checkpointing, model_checkpointing_best_only, early_stopping, early_stopping_patience, learning_rate_reduce, learning_rate_reduce_patience, custom_callbacks)\u001b[0m\n\u001b[1;32m 290\u001b[0m run\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mupdate(\n\u001b[1;32m 291\u001b[0m {\n\u001b[1;32m 292\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mepochs\u001b[39m\u001b[38;5;124m\"\u001b[39m: epochs,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 310\u001b[0m }\n\u001b[1;32m 311\u001b[0m )\n\u001b[1;32m 313\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 314\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 315\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrain_loader\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdata\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 316\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalidation_data\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mval_loader\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdata\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 317\u001b[0m \u001b[43m \u001b[49m\u001b[43mepochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mepochs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 318\u001b[0m \u001b[43m \u001b[49m\u001b[43msteps_per_epoch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mn_train_steps_per_epoch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 319\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalidation_steps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mn_val_steps_per_epoch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 320\u001b[0m \u001b[43m \u001b[49m\u001b[43mcallbacks\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcallbacks\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 321\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 322\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mKeyboardInterrupt\u001b[39;00m:\n\u001b[1;32m 323\u001b[0m logger\u001b[38;5;241m.\u001b[39mwarning(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTraining interrupted by user.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "File \u001b[0;32m~/.conda/envs/crested/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py:122\u001b[0m, in \u001b[0;36mfilter_traceback..error_handler\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 119\u001b[0m filtered_tb \u001b[38;5;241m=\u001b[39m _process_traceback_frames(e\u001b[38;5;241m.\u001b[39m__traceback__)\n\u001b[1;32m 120\u001b[0m \u001b[38;5;66;03m# To get the full stack trace, call:\u001b[39;00m\n\u001b[1;32m 121\u001b[0m \u001b[38;5;66;03m# `keras.config.disable_traceback_filtering()`\u001b[39;00m\n\u001b[0;32m--> 122\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m e\u001b[38;5;241m.\u001b[39mwith_traceback(filtered_tb) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 123\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 124\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m filtered_tb\n", + "File \u001b[0;32m/data/projects/c04/cbd-saerts/nkemp/software/CREsted/src/crested/tl/metrics/_spearmancorr.py:34\u001b[0m, in \u001b[0;36mSpearmanCorrelationPerClass.update_state\u001b[0;34m(self, y_true, y_pred, sample_weight)\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;241m0.0\u001b[39m\n\u001b[1;32m 33\u001b[0m correlation \u001b[38;5;241m=\u001b[39m tf\u001b[38;5;241m.\u001b[39mcond(proceed, compute, skip)\n\u001b[0;32m---> 34\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcorrelation_sums\u001b[49m\u001b[43m[\u001b[49m\u001b[43mi\u001b[49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43massign_add\u001b[49m(correlation)\n\u001b[1;32m 35\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mupdate_counts[i]\u001b[38;5;241m.\u001b[39massign_add(tf\u001b[38;5;241m.\u001b[39mcast(proceed, tf\u001b[38;5;241m.\u001b[39mfloat32))\n", + "\u001b[0;31mAttributeError\u001b[0m: 'SymbolicTensor' object has no attribute 'assign_add'" + ] + } + ], "source": [ "# train the model\n", "\n", @@ -734,27 +1252,9 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 7, "metadata": {}, "outputs": [], - "source": [ - "classes = ['Astro','Endo','L2_3IT','L5ET','L5IT','L5_6NP','L6CT','L6IT','L6b','Lamp5','Micro_PVM','OPC','Oligo','Pvalb','Sncg','Sst','Sst-Chodl','VLMC','Vip']" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The autoreload extension is already loaded. To reload it, use:\n", - " %reload_ext autoreload\n" - ] - } - ], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", @@ -779,7 +1279,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -787,7 +1287,7 @@ "\n", "# load an existing model\n", "evaluator.load_model(\n", - " \"deeppeak_benchmarking/crested_norm_TL/checkpoints/05.keras\", compile=True\n", + " \"deeppeak_benchmarking/crested_norm_TL/checkpoints/05.keras\", compile=False\n", ")" ] }, @@ -849,6 +1349,36 @@ "evaluator.test()" ] }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Index(['chr1:3093998-3096112', 'chr1:3094663-3096777', 'chr1:3111367-3113481',\n", + " 'chr1:3112727-3114841', 'chr1:3133779-3135893', 'chr1:3164901-3167015',\n", + " 'chr1:3166116-3168230', 'chr1:3180288-3182402', 'chr1:3209410-3211524',\n", + " 'chr1:3210016-3212130',\n", + " ...\n", + " 'chrX:169716892-169719006', 'chrX:169812832-169814946',\n", + " 'chrX:169824148-169826262', 'chrX:169826845-169828959',\n", + " 'chrX:169837427-169839541', 'chrX:169838199-169840313',\n", + " 'chrX:169843232-169845346', 'chrX:169862011-169864125',\n", + " 'chrX:169924670-169926784', 'chrX:169947743-169949857'],\n", + " dtype='object', length=86887)" + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "adata.var_names" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -864,127 +1394,137 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 55, + "execution_count": 9, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", + "I0000 00:00:1719406062.105265 563997 service.cc:145] XLA service 0x7f6b8400c560 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:\n", + "I0000 00:00:1719406062.105293 563997 service.cc:153] StreamExecutor device (0): NVIDIA H100 80GB HBM3, Compute Capability 9.0\n", + "2024-06-26 14:47:42.120878: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.\n", + "2024-06-26 14:47:42.210696: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:465] Loaded cuDNN version 8907\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ - "Converted call: .flat_map_fn at 0x7fdc3f2b0220>\n", - " args: (,)\n", - " kwargs: {}\n", - "\n", - "Converted call: .get_iterator_id_fn at 0x7fdc3f2b3060>\n", - " args: (,)\n", - " kwargs: {}\n", - "\n", - "Converted call: .generator_next_fn at 0x7fdc3f2b0180>\n", - " args: ( dtype=int64>,)\n", - " kwargs: {}\n", - "\n", - "Converted call: .finalize_fn at 0x7fdc3f2b02c0>\n", - " args: ( dtype=int64>,)\n", - " kwargs: {}\n", - "\n", - "Converted call: . at 0x7fdc3f2b0360>\n", - " args: (, )\n", - " kwargs: {}\n", - "\n", - "Converted call: \n", - " args: (, )\n", - " kwargs: None\n", - "\n", - "Converted call: >\n", - " args: (, )\n", - " kwargs: {}\n", - "\n", - "Converted call: \n", - " args: (, )\n", - " kwargs: None\n", - "\n", - "Converted call: \n", - " args: (, )\n", - " kwargs: None\n", - "\n", - "Converted call: \n", - " args: (, 0)\n", - " kwargs: None\n", - "\n", - "Converted call: \n", - " args: ()\n", - " kwargs: {'shape': (2114, 4), 'dtype': tf.float32}\n", - "\n", - "Converted call: \n", - " args: (.inner_factory..tf___map_one_hot_encode..one_hot_encode at 0x7fdc714444a0>, )\n", - " kwargs: {'fn_output_signature': TensorSpec(shape=(2114, 4), dtype=tf.float32, name=None)}\n", - "\n", - "Converted call: \n", - " args: (, 'UTF-8')\n", - " kwargs: None\n", - "\n", - "Converted call: >\n", - " args: (,)\n", - " kwargs: None\n", - "\n", - "Converted call: \n", - " args: (,)\n", - " kwargs: {'depth': 4}\n", - "\n", - "Converted call: \n", - " args: (,)\n", - " kwargs: {'axis': 0}\n", - "\n", - "Converted call: .one_step_on_data_distributed at 0x7fdc71447740>\n", - " args: ([(, )],)\n", - " kwargs: {}\n", - "\n", - "Converted call: .one_step_on_data at 0x7fdc71447560>\n", - " args: ((, ),)\n", - " kwargs: {}\n", - "\n", - "\u001b[1m1385/1389\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m━\u001b[0m \u001b[1m0s\u001b[0m 14ms/stepConverted call: .one_step_on_data_distributed at 0x7fdc71447740>\n", - " args: ([(, )],)\n", - " kwargs: {}\n", - "\n", - "Converted call: .one_step_on_data at 0x7fdc71447560>\n", - " args: ((, ),)\n", - " kwargs: {}\n", - "\n", - "\u001b[1m1389/1389\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 14ms/step\n", - "2024-06-21T12:54:06.955339+0200 INFO Adding predictions to anndata.layers[crested1].\n" + "\u001b[1m 13/1358\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m18s\u001b[0m 14ms/step" ] }, { - "ename": "AttributeError", - "evalue": "'numpy.ndarray' object has no attribute 'layers'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[55], line 5\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# add predictions for model 1 to the adata\u001b[39;00m\n\u001b[1;32m 2\u001b[0m evaluator\u001b[38;5;241m.\u001b[39mload_model(\n\u001b[1;32m 3\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdeeppeak_benchmarking/crested_norm_TL/checkpoints/05.keras\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 4\u001b[0m )\n\u001b[0;32m----> 5\u001b[0m adata \u001b[38;5;241m=\u001b[39m \u001b[43mevaluator\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpredict\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 6\u001b[0m \u001b[43m \u001b[49m\u001b[43madata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcrested1\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\n\u001b[1;32m 7\u001b[0m \u001b[43m)\u001b[49m \u001b[38;5;66;03m# adds the predictions to the adata.layers[\"model_1\"]\u001b[39;00m\n", - "File \u001b[0;32m/data/projects/c04/cbd-saerts/nkemp/software/CREsted/src/crested/tl/_crested.py:384\u001b[0m, in \u001b[0;36mCrested.predict\u001b[0;34m(self, anndata, model_name)\u001b[0m\n\u001b[1;32m 382\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m anndata \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m model_name \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 383\u001b[0m logger\u001b[38;5;241m.\u001b[39minfo(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAdding predictions to anndata.layers[\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmodel_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m].\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 384\u001b[0m \u001b[43manndata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlayers\u001b[49m[model_name] \u001b[38;5;241m=\u001b[39m predictions\u001b[38;5;241m.\u001b[39mT\n\u001b[1;32m 386\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m predictions\n", - "\u001b[0;31mAttributeError\u001b[0m: 'numpy.ndarray' object has no attribute 'layers'" + "name": "stderr", + "output_type": "stream", + "text": [ + "I0000 00:00:1719406065.390897 563997 device_compiler.h:188] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[1m1357/1358\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m━\u001b[0m \u001b[1m0s\u001b[0m 14ms/step" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", + "I0000 00:00:1719406084.884494 564227 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_371', 64 bytes spill stores, 64 bytes spill loads\n", + "\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[1m1358/1358\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m25s\u001b[0m 16ms/step\n", + "2024-06-26T14:48:07.117265+0200 INFO Adding predictions to anndata.layers[crested1].\n" ] } ], "source": [ "# add predictions for model 1 to the adata\n", "evaluator.load_model(\n", - " \"deeppeak_benchmarking/crested_norm_TL/checkpoints/05.keras\"\n", + " \"deeppeak_benchmarking/crested_norm_TL/checkpoints/05.keras\", compile=False\n", ")\n", - "adata = evaluator.predict(\n", + "adata_p = evaluator.predict(\n", " adata, model_name=\"crested1\"\n", ") # adds the predictions to the adata.layers[\"model_1\"]" ] }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 16ms/step\n", + "2024-06-26T14:58:27.262228+0200 INFO Plotting bar plots for region: chr18:4369383-4371497, models: ['crested1']\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "idx = 32\n", + "chrom=test_df.iloc[idx]['chr']\n", + "start=test_df.iloc[idx]['start']\n", + "end=test_df.iloc[idx]['end']\n", + "region = chrom+':'+str(start)+'-'+str(end)\n", + "region\n", + "\n", + "pred = evaluator.predict_regions(region)\n", + "crested.pl.bar.region_predictions(evaluator.anndatamodule.adata, region)" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 17ms/step\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "sequence_loader = crested.tl.data.SequenceLoader(genome_file='/home/VIB.LOCAL/niklas.kempynck/nkemp/software/dev_DeepPeak/DeepPeak/data/raw/genome.fa', chromsizes='/home/VIB.LOCAL/niklas.kempynck/nkemp/mouse/biccn/mm.chrom.sizes')\n", + "seq = sequence_loader.get_sequence(\"chr18:61107668-61109782\")\n", + "plt.figure(figsize=(25,3))\n", + "plt.bar(list(adata.obs_names),evaluator.predict_sequence(seq)[0])\n", + "plt.show()\n", + "#crested.pl.bar.region(adata_p, [\"chr18:61107668-61109782\"])" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -1003,7 +1543,7 @@ }, { "cell_type": "code", - "execution_count": 106, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ @@ -1012,37 +1552,50 @@ }, { "cell_type": "code", - "execution_count": 109, + "execution_count": 35, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['L5ET']\n", + "2024-06-26T14:58:31.495910+0200 INFO Calculating contribution scores for 1 class(es) and 1 region(s).\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Region: 100%|██████████| 1/1 [00:02<00:00, 2.27s/it]\n" + ] + } + ], "source": [ - "# focus on two topics of interest\n", - "scores, one_hot_encoded_sequences = evaluator.calculate_contribution_scores(\n", - " [\"chr18:61107668-61109782\"], class_indices=[10], method='expected_integrated_grad'\n", - ")\n", - "\n", - "# calculate the contribution scores for two regions for all topics\n", - "# scores, one_hot_encoded_sequences = evaluator.calculate_contribution_scores(['chr1:1000-1500', 'chr1:2000-2500'], class_indices=range(len(adata.obs)))" + "cts =['L5ET']\n", + "scores, one_hot_encoded_sequences = evaluator.calculate_contribution_scores_regions(\n", + " region_idx = region, method='mutagenesis', class_names=cts\n", + ")" ] }, { "cell_type": "code", - "execution_count": 112, + "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "Crested(data=True, model=True, config=False)" + "0.650580644607544" ] }, - "execution_count": 112, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "evaluator" + "scores.max()" ] }, { @@ -1054,24 +1607,30 @@ }, { "cell_type": "code", - "execution_count": 113, + "execution_count": 38, "metadata": {}, "outputs": [ { - "ename": "AttributeError", - "evalue": "'numpy.ndarray' object has no attribute 'obs_names'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[113], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m crested\u001b[38;5;241m.\u001b[39mpl\u001b[38;5;241m.\u001b[39mcontribution_scores(\n\u001b[0;32m----> 2\u001b[0m scores, one_hot_encoded_sequences, class_indices\u001b[38;5;241m=\u001b[39m[\u001b[38;5;241m10\u001b[39m],class_names\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mlist\u001b[39m(\u001b[43madata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mobs_names\u001b[49m)[\u001b[38;5;241m10\u001b[39m], zoom_n_bases\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m500\u001b[39m\n\u001b[1;32m 3\u001b[0m )\n", - "\u001b[0;31mAttributeError\u001b[0m: 'numpy.ndarray' object has no attribute 'obs_names'" + "name": "stdout", + "output_type": "stream", + "text": [ + "2024-06-26T14:59:13.091642+0200 INFO Plotting contribution scores for 1 sequence(s)\n" ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ - "crested.pl.contribution_scores(\n", - " scores, one_hot_encoded_sequences, class_indices=[10],class_names=list(adata.obs_names)[10], zoom_n_bases=500\n", + "crested.pl.patterns.contribution_scores(\n", + " scores, one_hot_encoded_sequences, labels=cts, zoom_n_bases=500, method='mutagenesis'\n", ")" ] }, @@ -2516,621 +3075,43 @@ "name": "stdout", "output_type": "stream", "text": [ - "2024-06-25T17:22:30.603439+0200 INFO After sorting and filtering, kept 950 regions.\n" + "2024-06-25T17:40:53.041985+0200 INFO After sorting and filtering, kept 9500 regions.\n" ] } ], "source": [ - "adata_spec = crested.pp.sort_and_filter_regions_on_specificity(adata, top_k=50, class_names=list(adata.obs_names), method='gini')" + "adata_spec = crested.pp.sort_and_filter_regions_on_specificity(adata, top_k=500, class_names=list(adata.obs_names), method='gini')" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 9, "metadata": {}, "outputs": [ { "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
chrstartendClass namerankgini_score
chr10:96093683-96095797chr109609368396095797Astro10.756436
chr3:159593453-159595567chr3159593453159595567Astro20.755814
chr3:115410072-115412186chr3115410072115412186Astro30.736226
chr10:56898762-56900876chr105689876256900876Astro40.732570
chr1:143052747-143054861chr1143052747143054861Astro50.727982
chr18:66022269-66024383chr186602226966024383Astro60.727601
chrX:102587693-102589807chrX102587693102589807Astro70.725532
chrX:14301286-14303400chrX1430128614303400Astro80.724276
chrX:6452334-6454448chrX64523346454448Astro90.720401
chr6:141530067-141532181chr6141530067141532181Astro100.720374
chr18:8043916-8046030chr1880439168046030Astro110.718077
chrX:11157519-11159633chrX1115751911159633Astro120.717722
chr1:143052021-143054135chr1143052021143054135Astro130.714811
chr8:54791458-54793572chr85479145854793572Astro140.711594
chr1:5032334-5034448chr150323345034448Astro150.708968
chr3:115843758-115845872chr3115843758115845872Astro160.707991
chr6:141529548-141531662chr6141529548141531662Astro170.707742
chr6:49020327-49022441chr64902032749022441Astro180.707586
chr2:110134069-110136183chr2110134069110136183Astro190.705950
chrX:163660440-163662554chrX163660440163662554Astro200.705680
chr1:127159287-127161401chr1127159287127161401Astro210.705280
chr8:30201885-30203999chr83020188530203999Astro220.703659
chr18:8043408-8045522chr1880434088045522Astro230.703530
chr14:111252324-111254438chr14111252324111254438Astro240.703077
chr3:6427088-6429202chr364270886429202Astro250.701328
chr1:161272641-161274755chr1161272641161274755Astro260.701141
chr2:57521912-57524026chr25752191257524026Astro270.698021
chr4:154312378-154314492chr4154312378154314492Astro280.696996
chr2:65274604-65276718chr26527460465276718Astro290.696961
chr18:56440643-56442757chr185644064356442757Astro300.696608
chr15:18236980-18239094chr151823698018239094Astro310.695703
chr10:29896923-29899037chr102989692329899037Astro320.695475
chr3:47667963-47670077chr34766796347670077Astro330.694782
chr15:8814404-8816518chr1588144048816518Astro340.694436
chrX:130761597-130763711chrX130761597130763711Astro350.694203
chr5:9868290-9870404chr598682909870404Astro360.692700
chr12:90914962-90917076chr129091496290917076Astro370.692215
chr3:133051169-133053283chr3133051169133053283Astro380.691771
chr4:71483394-71485508chr47148339471485508Astro390.690659
chrX:110801591-110803705chrX110801591110803705Astro400.689756
chr14:78015493-78017607chr147801549378017607Astro410.689670
chr10:92247689-92249803chr109224768992249803Astro420.689451
chr10:56525742-56527856chr105652574256527856Astro430.689304
chr4:11761293-11763407chr41176129311763407Astro440.688151
chr1:149401421-149403535chr1149401421149403535Astro450.687990
chr2:137431719-137433833chr2137431719137433833Astro460.687783
chr1:161273161-161275275chr1161273161161275275Astro470.687692
chr15:8864466-8866580chr1588644668866580Astro480.687259
chr3:50421064-50423178chr35042106450423178Astro490.686798
chr13:15543014-15545128chr131554301415545128Astro500.684049
\n", - "
" - ], "text/plain": [ - " chr start end Class name rank \\\n", - "chr10:96093683-96095797 chr10 96093683 96095797 Astro 1 \n", - "chr3:159593453-159595567 chr3 159593453 159595567 Astro 2 \n", - "chr3:115410072-115412186 chr3 115410072 115412186 Astro 3 \n", - "chr10:56898762-56900876 chr10 56898762 56900876 Astro 4 \n", - "chr1:143052747-143054861 chr1 143052747 143054861 Astro 5 \n", - "chr18:66022269-66024383 chr18 66022269 66024383 Astro 6 \n", - "chrX:102587693-102589807 chrX 102587693 102589807 Astro 7 \n", - "chrX:14301286-14303400 chrX 14301286 14303400 Astro 8 \n", - "chrX:6452334-6454448 chrX 6452334 6454448 Astro 9 \n", - "chr6:141530067-141532181 chr6 141530067 141532181 Astro 10 \n", - "chr18:8043916-8046030 chr18 8043916 8046030 Astro 11 \n", - "chrX:11157519-11159633 chrX 11157519 11159633 Astro 12 \n", - "chr1:143052021-143054135 chr1 143052021 143054135 Astro 13 \n", - "chr8:54791458-54793572 chr8 54791458 54793572 Astro 14 \n", - "chr1:5032334-5034448 chr1 5032334 5034448 Astro 15 \n", - "chr3:115843758-115845872 chr3 115843758 115845872 Astro 16 \n", - "chr6:141529548-141531662 chr6 141529548 141531662 Astro 17 \n", - "chr6:49020327-49022441 chr6 49020327 49022441 Astro 18 \n", - "chr2:110134069-110136183 chr2 110134069 110136183 Astro 19 \n", - "chrX:163660440-163662554 chrX 163660440 163662554 Astro 20 \n", - "chr1:127159287-127161401 chr1 127159287 127161401 Astro 21 \n", - "chr8:30201885-30203999 chr8 30201885 30203999 Astro 22 \n", - "chr18:8043408-8045522 chr18 8043408 8045522 Astro 23 \n", - "chr14:111252324-111254438 chr14 111252324 111254438 Astro 24 \n", - "chr3:6427088-6429202 chr3 6427088 6429202 Astro 25 \n", - "chr1:161272641-161274755 chr1 161272641 161274755 Astro 26 \n", - "chr2:57521912-57524026 chr2 57521912 57524026 Astro 27 \n", - "chr4:154312378-154314492 chr4 154312378 154314492 Astro 28 \n", - "chr2:65274604-65276718 chr2 65274604 65276718 Astro 29 \n", - "chr18:56440643-56442757 chr18 56440643 56442757 Astro 30 \n", - "chr15:18236980-18239094 chr15 18236980 18239094 Astro 31 \n", - "chr10:29896923-29899037 chr10 29896923 29899037 Astro 32 \n", - "chr3:47667963-47670077 chr3 47667963 47670077 Astro 33 \n", - "chr15:8814404-8816518 chr15 8814404 8816518 Astro 34 \n", - "chrX:130761597-130763711 chrX 130761597 130763711 Astro 35 \n", - "chr5:9868290-9870404 chr5 9868290 9870404 Astro 36 \n", - "chr12:90914962-90917076 chr12 90914962 90917076 Astro 37 \n", - "chr3:133051169-133053283 chr3 133051169 133053283 Astro 38 \n", - "chr4:71483394-71485508 chr4 71483394 71485508 Astro 39 \n", - "chrX:110801591-110803705 chrX 110801591 110803705 Astro 40 \n", - "chr14:78015493-78017607 chr14 78015493 78017607 Astro 41 \n", - "chr10:92247689-92249803 chr10 92247689 92249803 Astro 42 \n", - "chr10:56525742-56527856 chr10 56525742 56527856 Astro 43 \n", - "chr4:11761293-11763407 chr4 11761293 11763407 Astro 44 \n", - "chr1:149401421-149403535 chr1 149401421 149403535 Astro 45 \n", - "chr2:137431719-137433833 chr2 137431719 137433833 Astro 46 \n", - "chr1:161273161-161275275 chr1 161273161 161275275 Astro 47 \n", - "chr15:8864466-8866580 chr15 8864466 8866580 Astro 48 \n", - "chr3:50421064-50423178 chr3 50421064 50423178 Astro 49 \n", - "chr13:15543014-15545128 chr13 15543014 15545128 Astro 50 \n", - "\n", - " gini_score \n", - "chr10:96093683-96095797 0.756436 \n", - "chr3:159593453-159595567 0.755814 \n", - "chr3:115410072-115412186 0.736226 \n", - "chr10:56898762-56900876 0.732570 \n", - "chr1:143052747-143054861 0.727982 \n", - "chr18:66022269-66024383 0.727601 \n", - "chrX:102587693-102589807 0.725532 \n", - "chrX:14301286-14303400 0.724276 \n", - "chrX:6452334-6454448 0.720401 \n", - "chr6:141530067-141532181 0.720374 \n", - "chr18:8043916-8046030 0.718077 \n", - "chrX:11157519-11159633 0.717722 \n", - "chr1:143052021-143054135 0.714811 \n", - "chr8:54791458-54793572 0.711594 \n", - "chr1:5032334-5034448 0.708968 \n", - "chr3:115843758-115845872 0.707991 \n", - "chr6:141529548-141531662 0.707742 \n", - "chr6:49020327-49022441 0.707586 \n", - "chr2:110134069-110136183 0.705950 \n", - "chrX:163660440-163662554 0.705680 \n", - "chr1:127159287-127161401 0.705280 \n", - "chr8:30201885-30203999 0.703659 \n", - "chr18:8043408-8045522 0.703530 \n", - "chr14:111252324-111254438 0.703077 \n", - "chr3:6427088-6429202 0.701328 \n", - "chr1:161272641-161274755 0.701141 \n", - "chr2:57521912-57524026 0.698021 \n", - "chr4:154312378-154314492 0.696996 \n", - "chr2:65274604-65276718 0.696961 \n", - "chr18:56440643-56442757 0.696608 \n", - "chr15:18236980-18239094 0.695703 \n", - "chr10:29896923-29899037 0.695475 \n", - "chr3:47667963-47670077 0.694782 \n", - "chr15:8814404-8816518 0.694436 \n", - "chrX:130761597-130763711 0.694203 \n", - "chr5:9868290-9870404 0.692700 \n", - "chr12:90914962-90917076 0.692215 \n", - "chr3:133051169-133053283 0.691771 \n", - "chr4:71483394-71485508 0.690659 \n", - "chrX:110801591-110803705 0.689756 \n", - "chr14:78015493-78017607 0.689670 \n", - "chr10:92247689-92249803 0.689451 \n", - "chr10:56525742-56527856 0.689304 \n", - "chr4:11761293-11763407 0.688151 \n", - "chr1:149401421-149403535 0.687990 \n", - "chr2:137431719-137433833 0.687783 \n", - "chr1:161273161-161275275 0.687692 \n", - "chr15:8864466-8866580 0.687259 \n", - "chr3:50421064-50423178 0.686798 \n", - "chr13:15543014-15545128 0.684049 " + "chr chr3\n", + "start 38667902\n", + "end 38670016\n", + "Class name Endo\n", + "rank 500\n", + "gini_score 0.537468\n", + "Name: chr3:38667902-38670016, dtype: object" ] }, - "execution_count": 7, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "adata_spec.var.head(n=50)" + "adata_spec.var.iloc[999]" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -3145,33 +3126,32 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "2024-06-25T16:46:35.169387+0200 INFO Calculating contribution scores for 1 class(es) and 50 region(s).\n", - "2024-06-25T16:47:37.441391+0200 INFO Calculating contribution scores for 1 class(es) and 50 region(s).\n", - "2024-06-25T16:48:39.528406+0200 INFO Calculating contribution scores for 1 class(es) and 50 region(s).\n", - "2024-06-25T16:49:41.675013+0200 INFO Calculating contribution scores for 1 class(es) and 50 region(s).\n", - "2024-06-25T16:50:44.036741+0200 INFO Calculating contribution scores for 1 class(es) and 50 region(s).\n", - "2024-06-25T16:51:46.093464+0200 INFO Calculating contribution scores for 1 class(es) and 50 region(s).\n", - "2024-06-25T16:52:48.344770+0200 INFO Calculating contribution scores for 1 class(es) and 50 region(s).\n", - "2024-06-25T16:53:50.088466+0200 INFO Calculating contribution scores for 1 class(es) and 50 region(s).\n", - "2024-06-25T16:54:51.501521+0200 INFO Calculating contribution scores for 1 class(es) and 50 region(s).\n", - "2024-06-25T16:55:53.001062+0200 INFO Calculating contribution scores for 1 class(es) and 50 region(s).\n", - "2024-06-25T16:56:54.033865+0200 INFO Calculating contribution scores for 1 class(es) and 50 region(s).\n", - "2024-06-25T16:57:55.182021+0200 INFO Calculating contribution scores for 1 class(es) and 50 region(s).\n", - "2024-06-25T16:58:56.624829+0200 INFO Calculating contribution scores for 1 class(es) and 50 region(s).\n", - "2024-06-25T16:59:57.935584+0200 INFO Calculating contribution scores for 1 class(es) and 50 region(s).\n", - "2024-06-25T17:00:58.899692+0200 INFO Calculating contribution scores for 1 class(es) and 50 region(s).\n", - "2024-06-25T17:02:00.035373+0200 INFO Calculating contribution scores for 1 class(es) and 50 region(s).\n", - "2024-06-25T17:03:01.079742+0200 INFO Calculating contribution scores for 1 class(es) and 50 region(s).\n", - "2024-06-25T17:04:02.123088+0200 INFO Calculating contribution scores for 1 class(es) and 50 region(s).\n", - "2024-06-25T17:05:02.980093+0200 INFO Calculating contribution scores for 1 class(es) and 50 region(s).\n", - "Contribution scores and one-hot encoded sequences saved to modisco_results\n" + "2024-06-25T17:42:59.294645+0200 INFO Calculating contribution scores for 1 class(es) and 500 region(s).\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-25 17:42:59.782263: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:465] Loaded cuDNN version 8907\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2024-06-25T17:53:24.646278+0200 INFO Calculating contribution scores for 1 class(es) and 500 region(s).\n", + "2024-06-25T18:03:38.499302+0200 INFO Calculating contribution scores for 1 class(es) and 500 region(s).\n", + "2024-06-25T18:13:50.544979+0200 INFO Calculating contribution scores for 1 class(es) and 500 region(s).\n", + "2024-06-25T18:24:02.370729+0200 INFO Calculating contribution scores for 1 class(es) and 500 region(s).\n", + "2024-06-25T18:34:14.363745+0200 INFO Calculating contribution scores for 1 class(es) and 500 region(s).\n" ] } ], @@ -3181,59 +3161,65 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2024-06-25T22:00:20.552977+0200 INFO Running modisco for class: L6CT\n", + "Using 1493 positive seqlets\n" + ] + } + ], + "source": [ + "crested.tl.tfmodisco(window=1000, output_dir = 'modisco_results2')" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "2024-06-25T17:33:52.595025+0200 INFO Running modisco for class: L6CT\n", - "Using 142 positive seqlets\n", - "2024-06-25T17:34:00.034431+0200 INFO Running modisco for class: L6IT\n", - "Using 158 positive seqlets\n", - "2024-06-25T17:34:01.274526+0200 INFO Running modisco for class: Astro\n", - "Using 302 positive seqlets\n", - "2024-06-25T17:34:04.937461+0200 INFO Running modisco for class: L5IT\n", - "Using 128 positive seqlets\n", - "2024-06-25T17:34:05.687151+0200 INFO Running modisco for class: Lamp5\n", - "Using 158 positive seqlets\n", - "2024-06-25T17:34:06.992156+0200 INFO Running modisco for class: VLMC\n", - "Using 228 positive seqlets\n", - "2024-06-25T17:34:09.344291+0200 INFO Running modisco for class: Pvalb\n", - "Using 167 positive seqlets\n", - "2024-06-25T17:34:10.579415+0200 INFO Running modisco for class: Micro_PVM\n", - "Using 270 positive seqlets\n", - "2024-06-25T17:34:12.717802+0200 INFO Running modisco for class: Sncg\n", - "Using 145 positive seqlets\n", - "2024-06-25T17:34:13.627478+0200 INFO Running modisco for class: L2_3IT\n", - "Using 165 positive seqlets\n", - "2024-06-25T17:34:14.704486+0200 INFO Running modisco for class: L5ET\n", - "Using 125 positive seqlets\n", - "2024-06-25T17:34:15.472059+0200 INFO Running modisco for class: L6b\n", - "Using 148 positive seqlets\n", - "2024-06-25T17:34:16.704746+0200 INFO Running modisco for class: Vip\n", - "Using 131 positive seqlets\n", - "2024-06-25T17:34:17.608582+0200 INFO Running modisco for class: SstChodl\n", - "Using 271 positive seqlets\n", - "Extracted 167 negative seqlets\n", - "2024-06-25T17:34:21.328252+0200 INFO Running modisco for class: Oligo\n", - "Using 305 positive seqlets\n", - "Extracted 125 negative seqlets\n", - "2024-06-25T17:34:25.336564+0200 INFO Running modisco for class: Sst\n", - "Using 139 positive seqlets\n", - "2024-06-25T17:34:26.232381+0200 INFO Running modisco for class: L5_6NP\n", - "Using 125 positive seqlets\n", - "2024-06-25T17:34:27.018244+0200 INFO Running modisco for class: Endo\n", - "Extracted 105 negative seqlets\n", - "2024-06-25T17:34:27.667426+0200 INFO Running modisco for class: OPC\n", - "Using 246 positive seqlets\n", - "Extracted 152 negative seqlets\n" + "2024-06-26T11:51:51.428622+0200 INFO Starting genomic contributions plot for classes: ['L5ET']\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/data/projects/c04/cbd-saerts/nkemp/software/CREsted/src/crested/pl/_utils.py:52: UserWarning: The figure layout has changed to tight\n", + " plt.tight_layout()\n" ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ - "crested.tl.tfmodisco(window=1000)" + "%matplotlib inline\n", + "crested.pl.patterns.modisco_results(classes=['L5ET'], contribution='postive', contribution_dir='modisco_results2', num_seq=500, y_max=0.07, viz='contrib')" ] }, { diff --git a/src/crested/pl/patterns/_contribution_scores.py b/src/crested/pl/patterns/_contribution_scores.py index 4209994a..798269ea 100644 --- a/src/crested/pl/patterns/_contribution_scores.py +++ b/src/crested/pl/patterns/_contribution_scores.py @@ -9,7 +9,7 @@ from crested._logging import log_and_raise from crested.pl._utils import render_plot -from ._utils import _plot_attribution_map, grad_times_input_to_df +from ._utils import _plot_attribution_map, _plot_mutagenesis_map, grad_times_input_to_df, grad_times_input_to_df_mutagenesis @log_and_raise(ValueError) @@ -37,6 +37,7 @@ def contribution_scores( zoom_n_bases: int | None = None, highlight_positions: list[tuple[int, int]] | None = None, ylim: tuple | None = None, + method: str | None = None, **kwargs, ): """ @@ -58,6 +59,8 @@ def contribution_scores( List of tuples with start and end positions to highlight. Default is None. ylim Y-axis limits. Default is None. + method + Method used for calculating contribution scores. If mutagenesis, specify. Examples -------- @@ -80,20 +83,36 @@ def contribution_scores( start_idx = center - int(zoom_n_bases / 2) scores = scores[:, :, start_idx : start_idx + zoom_n_bases, :] - global_min = scores.min() - global_max = scores.max() # Plot logger.info(f"Plotting contribution scores for {seqs_one_hot.shape[0]} sequence(s)") for seq in range(seqs_one_hot.shape[0]): fig_height_per_class = 2 fig = plt.figure(figsize=(50, fig_height_per_class * scores.shape[1])) + seq_class_x = seqs_one_hot[seq, start_idx : start_idx + zoom_n_bases, :] + + if method == 'mutagenesis': + global_max = scores[seq].max()+0.25*np.abs(scores[seq].max()) + global_min = scores[seq].min()-0.25*np.abs(scores[seq].min()) + else: + mins = [] + maxs = [] + for i in range(scores.shape[1]): + seq_class_scores = scores[seq, i, :, :] + mins.append(np.min(seq_class_scores*seq_class_x)) + maxs.append(np.max(seq_class_scores*seq_class_x)) + global_max = np.array(maxs).max()+0.25*np.abs(np.array(maxs).max()) + global_min = np.array(mins).min()-0.25*np.abs(np.array(mins).min()) + for i in range(scores.shape[1]): seq_class_scores = scores[seq, i, :, :] - seq_class_x = seqs_one_hot[seq, :, :] - intgrad_df = grad_times_input_to_df(seq_class_x, seq_class_scores) ax = plt.subplot(scores.shape[1], 1, i + 1) - _plot_attribution_map(intgrad_df, ax=ax, return_ax=False) + if (method =='mutagenesis'): + mutagenesis_df = grad_times_input_to_df_mutagenesis(seq_class_x, seq_class_scores) + _plot_mutagenesis_map(mutagenesis_df, ax=ax) + else: + intgrad_df = grad_times_input_to_df(seq_class_x, seq_class_scores) + _plot_attribution_map(intgrad_df, ax=ax, return_ax=False) if labels: class_name = labels[i] else: @@ -102,11 +121,11 @@ def contribution_scores( if ylim: ax.set_ylim(ylim[0], ylim[1]) x_pos = 5 - y_pos = 0.75 * ylim[1] + y_pos = 0.5 * ylim[1] else: ax.set_ylim([global_min, global_max]) x_pos = 5 - y_pos = 0.75 * global_max + y_pos = 0.5 * global_max ax.text(x_pos, y_pos, text_to_add, fontsize=16, ha="left", va="center") # Draw rectangles to highlight positions diff --git a/src/crested/pl/patterns/_modisco_results.py b/src/crested/pl/patterns/_modisco_results.py index e050ee87..8d385616 100644 --- a/src/crested/pl/patterns/_modisco_results.py +++ b/src/crested/pl/patterns/_modisco_results.py @@ -45,7 +45,7 @@ def _trim_pattern_by_ic( contrib_scores = np.array(pattern["contrib_scores"]) if not pos_pattern: contrib_scores = -contrib_scores - contrib_scores[contrib_scores < 0] = 0 + contrib_scores[contrib_scores < 0] = 1e-9 # avoid division by zero ic = modisco.util.compute_per_position_ic( ppm=np.array(contrib_scores), background=background, pseudocount=pseudocount diff --git a/src/crested/pl/patterns/_utils.py b/src/crested/pl/patterns/_utils.py index ede20c84..8a57596f 100644 --- a/src/crested/pl/patterns/_utils.py +++ b/src/crested/pl/patterns/_utils.py @@ -71,3 +71,28 @@ def _plot_attribution_map( ax.spines["top"].set_visible(False) if return_ax: return ax + +def _plot_mutagenesis_map(mutagenesis_df, ax=None): + """Plot an attribution map for mutagenesis using different colored dots, with adjusted x-axis limits.""" + colors = {'A': 'green', 'C': 'blue', 'G': 'orange', 'T': 'red'} + if ax is None: + ax = plt.gca() + + # Add horizontal line at y=0 + ax.axhline(0, color='gray', linewidth=1, linestyle='--') + + # Scatter plot for each nucleotide type + for nuc, color in colors.items(): + # Filter out dots where the variant is the same as the original nucleotide + subset = mutagenesis_df[(mutagenesis_df['Nucleotide'] == nuc) & (mutagenesis_df['Nucleotide'] != mutagenesis_df['Original'])] + ax.scatter(subset['Position'], subset['Effect'], color=color, label=nuc, s=10) # s is the size of the dot + + # Set the limits of the x-axis to match exactly the first and last position + if not mutagenesis_df.empty: + ax.set_xlim(mutagenesis_df['Position'].min() - 0.5, mutagenesis_df['Position'].max() + 0.5) + + ax.legend(title="Nucleotide", loc='upper right') + ax.spines["right"].set_visible(False) + ax.spines["top"].set_visible(False) + ax.xaxis.set_ticks_position("none") + plt.xticks([]) # Optionally, hide x-axis ticks for a cleaner look diff --git a/src/crested/tl/_configs.py b/src/crested/tl/_configs.py index f5568386..09775bec 100644 --- a/src/crested/tl/_configs.py +++ b/src/crested/tl/_configs.py @@ -13,6 +13,7 @@ PearsonCorrelation, PearsonCorrelationLog, ZeroPenaltyMetric, + SpearmanCorrelationPerClass ) @@ -74,6 +75,9 @@ def metrics(self) -> list[tf.keras.metrics.Metric]: class PeakRegressionConfig(BaseConfig): """Default configuration for peak regression task.""" + def __init__(self, num_classes=None): + self.num_classes = num_classes + @property def loss(self) -> tf.keras.losses.Loss: return CosineMSELoss() @@ -84,15 +88,18 @@ def optimizer(self) -> tf.keras.optimizers.Optimizer: @property def metrics(self) -> list[tf.keras.metrics.Metric]: - return [ + metrics = [ tf.keras.metrics.MeanAbsoluteError(), tf.keras.metrics.MeanSquaredError(), tf.keras.metrics.CosineSimilarity(axis=1), PearsonCorrelation(), ConcordanceCorrelationCoefficient(), PearsonCorrelationLog(), - ZeroPenaltyMetric(), + ZeroPenaltyMetric() ] + #if self.num_classes is not None: + # metrics.append(SpearmanCorrelationPerClass(num_classes=self.num_classes)) + return metrics class TaskConfig(NamedTuple): @@ -156,7 +163,7 @@ def to_dict(self) -> dict: def default_configs( - task: str, + task: str, num_classes: int = None ) -> TaskConfig: """ Get default loss, optimizer, and metrics for an existing task. @@ -177,6 +184,8 @@ def default_configs( ---------- tasks Task for which to get default components. + num_classes + Number of output classes of model. Required for Spearman correlation metric. Returns ------- @@ -196,7 +205,7 @@ def default_configs( f"Task '{task}' not supported. Only {list(task_classes.keys())} are supported." ) - task_class = task_classes[task]() + task_class = task_classes[task](num_classes=num_classes) if task =='peak_regression' else task_classes[task]() loss = task_class.loss optimizer = task_class.optimizer metrics = task_class.metrics diff --git a/src/crested/tl/_crested.py b/src/crested/tl/_crested.py index 8c0207ad..c7d7cf8f 100644 --- a/src/crested/tl/_crested.py +++ b/src/crested/tl/_crested.py @@ -430,6 +430,31 @@ def predict_regions( return np.concatenate(all_predictions, axis=0) + def predict_sequence( + self, + sequence: str) -> np.ndarray: + """ + Make predictions using the model on the provided DNA sequence. + + Parameters + ---------- + model : a trained TensorFlow/Keras model + sequence : str + A string containing a DNA sequence (A, C, G, T). + + Returns + ------- + np.ndarray + Predictions for the provided sequence. + """ + # One-hot encode the sequence + x = one_hot_encode_sequence(sequence) + + # Make prediction + predictions = self.model.predict(x) + + return predictions + def calculate_contribution_scores( self, anndata: AnnData | None = None, @@ -584,12 +609,16 @@ def calculate_contribution_scores_regions( if isinstance(region_idx, str): region_idx = [region_idx] + if isinstance(class_names, str): + class_names = [class_names] + all_scores = [] all_one_hot_sequences = [] all_class_names = list(self.anndatamodule.adata.obs_names) if class_names is not None: + print(class_names) n_classes = len(class_names) class_indices = [ all_class_names.index(class_name) for class_name in class_names diff --git a/src/crested/tl/data/__init__.py b/src/crested/tl/data/__init__.py index a49456aa..95714fe8 100644 --- a/src/crested/tl/data/__init__.py +++ b/src/crested/tl/data/__init__.py @@ -1,3 +1,3 @@ from ._anndatamodule import AnnDataModule from ._dataloader import AnnDataLoader -from ._dataset import AnnDataset +from ._dataset import AnnDataset, SequenceLoader diff --git a/src/crested/tl/data/_dataset.py b/src/crested/tl/data/_dataset.py index 49002666..f220b64c 100644 --- a/src/crested/tl/data/_dataset.py +++ b/src/crested/tl/data/_dataset.py @@ -27,9 +27,9 @@ def __init__( self, genome_file: PathLike, chromsizes: dict | None, - in_memory: bool, - always_reverse_complement: bool, - max_stochastic_shift: int, + in_memory: bool = False, + always_reverse_complement: bool = False, + max_stochastic_shift: int = 0, regions: list[str] = None, ): self.genome = FastaFile(genome_file) @@ -82,7 +82,6 @@ def get_sequence(self, region: str, strand: str = "+", shift: int = 0) -> str: sequence = self.sequences[key] else: sequence = self._get_extended_sequence(region) - chrom, start_end = region.split(":") start, end = map(int, start_end.split("-")) start_idx = self.max_stochastic_shift + shift diff --git a/src/crested/tl/losses/_cosinemse.py b/src/crested/tl/losses/_cosinemse.py index c47ec019..3a7bd35a 100644 --- a/src/crested/tl/losses/_cosinemse.py +++ b/src/crested/tl/losses/_cosinemse.py @@ -6,10 +6,10 @@ class CosineMSELoss(tf.keras.losses.Loss): """Custom loss function that combines cosine similarity and mean squared error.""" - def __init__(self, max_weight=1.0, name="CustomMSELoss", reduction=None): + def __init__(self, max_weight=1.0, name="CustomMSELoss"): super().__init__(name=name) self.max_weight = max_weight - self.reduction=reduction + #self.reduction=reduction @tf.function def call(self, y_true, y_pred): @@ -39,8 +39,8 @@ def call(self, y_true, y_pred): def get_config(self): config = super().get_config() config.update({ - "max_weight": self.max_weight,#}) - "reduction":self.reduction}) + "max_weight": self.max_weight}) + #"reduction":self.reduction}) return config @classmethod diff --git a/src/crested/tl/metrics/__init__.py b/src/crested/tl/metrics/__init__.py index 38ad2875..deb4096b 100644 --- a/src/crested/tl/metrics/__init__.py +++ b/src/crested/tl/metrics/__init__.py @@ -2,3 +2,4 @@ from ._pearsoncorr import PearsonCorrelation from ._pearsoncorrlog import PearsonCorrelationLog from ._zeropenalty import ZeroPenaltyMetric +from ._spearmancorr import SpearmanCorrelationPerClass diff --git a/src/crested/tl/metrics/_spearmancorr.py b/src/crested/tl/metrics/_spearmancorr.py new file mode 100644 index 00000000..4ccaeee1 --- /dev/null +++ b/src/crested/tl/metrics/_spearmancorr.py @@ -0,0 +1,55 @@ +"""Spearman correlation metric.""" + +from __future__ import annotations +import tensorflow as tf + +@tf.keras.utils.register_keras_serializable(package="Metrics") +class SpearmanCorrelationPerClass(tf.keras.metrics.Metric): + def __init__(self, num_classes, name='spearman_correlation_per_class', **kwargs): + super(SpearmanCorrelationPerClass, self).__init__(name=name, **kwargs) + self.num_classes = num_classes + self.correlation_sums = self.add_weight(name='correlation_sums', shape=(num_classes,), initializer='zeros') + self.update_counts = self.add_weight(name='update_counts', shape=(num_classes,), initializer='zeros') + + def update_state(self, y_true, y_pred, sample_weight=None): + for i in range(self.num_classes): + y_true_class = tf.cast(y_true[:, i], tf.float32) + y_pred_class = tf.cast(y_pred[:, i], tf.float32) + + non_zero_indices = tf.where(tf.not_equal(y_true_class, 0)) + y_true_non_zero = tf.gather(y_true_class, non_zero_indices) + y_pred_non_zero = tf.gather(y_pred_class, non_zero_indices) + + # Ensure sizes are constant by checking them before the operation + num_elements = tf.size(y_true_non_zero) + proceed = num_elements > 1 + + def compute(): + return self.compute_correlation(y_true_non_zero, y_pred_non_zero) + + def skip(): + return 0.0 + + correlation = tf.cond(proceed, compute, skip) + self.correlation_sums[i].assign_add(correlation) + self.update_counts[i].assign_add(tf.cast(proceed, tf.float32)) + + def compute_correlation(self, y_true_non_zero, y_pred_non_zero): + ranks_true = tf.argsort(tf.argsort(y_true_non_zero)) + ranks_pred = tf.argsort(tf.argsort(y_pred_non_zero)) + + rank_diffs = tf.cast(ranks_true, tf.float32) - tf.cast(ranks_pred, tf.float32) + rank_diffs_squared_sum = tf.reduce_sum(tf.square(rank_diffs)) + n = tf.cast(tf.size(y_true_non_zero), tf.float32) + + correlation = 1 - (6 * rank_diffs_squared_sum) / (n * (n*n - 1)) + return tf.where(tf.math.is_nan(correlation), 0.0, correlation) + + def result(self): + valid_counts = self.update_counts + avg_correlations = self.correlation_sums / valid_counts + return tf.reduce_mean(avg_correlations) + + def reset_state(self): + self.correlation_sums.assign(tf.zeros_like(self.correlation_sums)) + self.update_counts.assign(tf.zeros_like(self.update_counts)) \ No newline at end of file