diff --git a/docs/_static/img/examples/genomic_contributions.png b/docs/_static/img/examples/genomic_contributions.png new file mode 100644 index 00000000..c9b35654 Binary files /dev/null and b/docs/_static/img/examples/genomic_contributions.png differ diff --git a/docs/api/index.md b/docs/api/index.md index 98043ebe..2103c541 100644 --- a/docs/api/index.md +++ b/docs/api/index.md @@ -1,7 +1,5 @@ # API -Import CREsted as: - ``` import crested ``` @@ -11,7 +9,7 @@ import crested io preprocessing -tools -plotting +tools/index.md +plotting/index.md logging ``` diff --git a/docs/api/io.md b/docs/api/io.md index a87fba8a..977f416f 100644 --- a/docs/api/io.md +++ b/docs/api/io.md @@ -4,15 +4,13 @@ Importing of topics bed files (outputs from [pycistopic](https://pycistopic.read Importing of bigwigs files and consensus region bed files into Anndata format for peak regression or classification. ```{eval-rst} -.. module:: crested +.. currentmodule:: crested ``` ```{eval-rst} -.. currentmodule:: crested - .. autosummary:: :toctree: _autosummary import_bigwigs - import_topics + import_beds ``` diff --git a/docs/api/logging.md b/docs/api/logging.md index de41daf5..6173784f 100644 --- a/docs/api/logging.md +++ b/docs/api/logging.md @@ -3,12 +3,10 @@ Helper functions for logging during use of package. ```{eval-rst} -.. module:: crested +.. currentmodule:: crested ``` ```{eval-rst} -.. currentmodule:: crested - .. autosummary:: :toctree: _autosummary diff --git a/docs/api/plotting.md b/docs/api/plotting.md deleted file mode 100644 index 09b53b3a..00000000 --- a/docs/api/plotting.md +++ /dev/null @@ -1,71 +0,0 @@ -# Plotting: `pl` - -Plotting description - -```{eval-rst} -.. module:: crested.pl -``` - -```{eval-rst} -.. currentmodule:: crested -``` - -## Contribution scores - -```{eval-rst} -.. autosummary:: - :toctree: _autosummary - - pl.contribution_scores -``` - -## Bar plots - -```{eval-rst} -.. autosummary:: - :toctree: _autosummary - - pl.bar.region - pl.bar.region_predictions - pl.bar.normalization_weights -``` - -## Distribution plots - -```{eval-rst} -.. autosummary:: - :toctree: _autosummary - - pl.hist.distribution -``` - -## Heatmap - -Correlations - -```{eval-rst} -.. autosummary:: - :toctree: _autosummary - - pl.heatmap.correlations_self - pl.heatmap.correlations_predictions -``` - -## Scatter plots - -```{eval-rst} -.. autosummary:: - :toctree: _autosummary - - pl.scatter.class_density -``` - -## Utility functions - -```{eval-rst} -.. autosummary:: - :toctree: _autosummary - - pl.render_plot -``` - diff --git a/docs/api/plotting/bar.md b/docs/api/plotting/bar.md new file mode 100644 index 00000000..ed48a8f8 --- /dev/null +++ b/docs/api/plotting/bar.md @@ -0,0 +1,16 @@ +# Bar `pl.bar` + +Bar plots to inspect per region ground truths and compare them to their predictions. + +```{eval-rst} +.. currentmodule:: crested.pl.bar +``` + +```{eval-rst} +.. autosummary:: + :toctree: _autosummary + + region + region_predictions + normalization_weights +``` diff --git a/docs/api/plotting/heatmap.md b/docs/api/plotting/heatmap.md new file mode 100644 index 00000000..784b1f47 --- /dev/null +++ b/docs/api/plotting/heatmap.md @@ -0,0 +1,15 @@ +# Heatmap `pl.heatmap` + +Investigate autocorrelation between classes and between ground truths and predictions using heatmaps. + +```{eval-rst} +.. currentmodule:: crested.pl.heatmap +``` + +```{eval-rst} +.. autosummary:: + :toctree: _autosummary + + correlations_self + correlations_predictions +``` diff --git a/docs/api/plotting/hist.md b/docs/api/plotting/hist.md new file mode 100644 index 00000000..eca43fff --- /dev/null +++ b/docs/api/plotting/hist.md @@ -0,0 +1,14 @@ +# Hist `pl.hist` + +Plots for inspecting distributions of ground truth and predictions? + +```{eval-rst} +.. currentmodule:: crested.pl.hist +``` + +```{eval-rst} +.. autosummary:: + :toctree: _autosummary + + distribution +``` diff --git a/docs/api/plotting/index.md b/docs/api/plotting/index.md new file mode 100644 index 00000000..028febbc --- /dev/null +++ b/docs/api/plotting/index.md @@ -0,0 +1,66 @@ +# Plotting: `pl` + +Plotting description + +```{eval-rst} +.. currentmodule:: crested.pl +``` + +```{toctree} +:maxdepth: 2 +:hidden: + +patterns +bar +hist +heatmap +scatter +``` + +```{eval-rst} +.. autosummary:: + :toctree: _autosummary + + render_plot +``` + +## Patterns: Contribution scores and Modisco results + +```{eval-rst} +.. autosummary:: + patterns.contribution_scores + patterns.modisco_results +``` + +## Bar plots + +```{eval-rst} +.. autosummary:: + bar.region + bar.region_predictions + bar.normalization_weights +``` + +## Distribution plots + +```{eval-rst} +.. autosummary:: + hist.distribution +``` + +## Heatmap + +Correlations + +```{eval-rst} +.. autosummary:: + heatmap.correlations_self + heatmap.correlations_predictions +``` + +## Scatter plots + +```{eval-rst} +.. autosummary:: + scatter.class_density +``` \ No newline at end of file diff --git a/docs/api/plotting/patterns.md b/docs/api/plotting/patterns.md new file mode 100644 index 00000000..51ed776d --- /dev/null +++ b/docs/api/plotting/patterns.md @@ -0,0 +1,15 @@ +# Patterns `pl.patterns` + +Plot contribution scores and analyze them using tfmodisco. + +```{eval-rst} +.. currentmodule:: crested.pl.patterns +``` + +```{eval-rst} +.. autosummary:: + :toctree: _autosummary + + contribution_scores + modisco_results +``` diff --git a/docs/api/plotting/scatter.md b/docs/api/plotting/scatter.md new file mode 100644 index 00000000..1b32fb92 --- /dev/null +++ b/docs/api/plotting/scatter.md @@ -0,0 +1,14 @@ +# Scatter `pl.scatter` + +Useful scatter plots. + +```{eval-rst} +.. currentmodule:: crested.pl.scatter +``` + +```{eval-rst} +.. autosummary:: + :toctree: _autosummary + + class_density +``` diff --git a/docs/api/preprocessing.md b/docs/api/preprocessing.md index e80b7b37..9f9fee57 100644 --- a/docs/api/preprocessing.md +++ b/docs/api/preprocessing.md @@ -3,16 +3,12 @@ Preparing Anndata object for CREsted training/evaluations. ```{eval-rst} -.. module:: crested.pp -``` - -```{eval-rst} -.. currentmodule:: crested +.. currentmodule:: crested.pp .. autosummary:: :toctree: _autosummary - pp.train_val_test_split - pp.filter_regions_on_specificity - pp.normalize_peaks + train_val_test_split + filter_regions_on_specificity + normalize_peaks ``` diff --git a/docs/api/tools.md b/docs/api/tools.md deleted file mode 100644 index 32745fb2..00000000 --- a/docs/api/tools.md +++ /dev/null @@ -1,81 +0,0 @@ -# Tools `tl` - -Training and testing of models. -Explanation of models. -Sequence design. - -```{eval-rst} -.. module:: crested.tl -``` - -```{eval-rst} -.. currentmodule:: crested -``` - -## Basic - -```{eval-rst} -.. autosummary:: - :toctree: _autosummary - - tl.Crested - tl.TaskConfig - tl.default_configs -``` - - -## Data - -Utility functions to prepare data for training and evaluation. -Generally, `tl.data.AnnDataModule` is the only one that should be called directly by the user. - -```{eval-rst} -.. autosummary:: - :toctree: _autosummary - - tl.data.AnnDataModule - tl.data.AnnDataLoader - tl.data.AnnDataset -``` - -## Losses - -Custom `tf.Keras.losses.Loss` functions for specific use cases. -Supply these (or your own) to a `tl.TaskConfig` to be able to use them for training. - -```{eval-rst} -.. autosummary:: - :toctree: _autosummary - - tl.losses.CosineMSELoss -``` - -## Metrics - -Custom `tf.keras.metrics.Metric` metrics for specific use cases. -Supply these (or your own) to a `tl.TaskConfig` to be able to use them for training. - -```{eval-rst} -.. autosummary:: - :toctree: _autosummary - - tl.metrics.ConcordanceCorrelationCoefficient - tl.metrics.PearsonCorrelation - tl.metrics.PearsonCorrelationLog - tl.metrics.ZeroPenaltyMetric -``` - -## Model Zoo - -Custom `tf.keras.Model` definitions that have shown to work well in specific use cases. -Supply these (or your own) to `tl.Crested(...)` to use them in training. - -```{eval-rst} -.. autosummary:: - :toctree: _autosummary - - tl.zoo.basenji - tl.zoo.chrombpnet - tl.zoo.deeptopic_cnn - tl.zoo.simple_convnet -``` \ No newline at end of file diff --git a/docs/api/tools/data.md b/docs/api/tools/data.md new file mode 100644 index 00000000..cf9bac35 --- /dev/null +++ b/docs/api/tools/data.md @@ -0,0 +1,17 @@ +# Data `tl.data` + +Utility functions to prepare data for training and evaluation. +Generally, `tl.data.AnnDataModule` is the only one that should be called directly by the user. + +```{eval-rst} +.. currentmodule:: crested.tl.data +``` + +```{eval-rst} +.. autosummary:: + :toctree: _autosummary + + AnnDataModule + AnnDataLoader + AnnDataset +``` \ No newline at end of file diff --git a/docs/api/tools/index.md b/docs/api/tools/index.md new file mode 100644 index 00000000..fedad79e --- /dev/null +++ b/docs/api/tools/index.md @@ -0,0 +1,62 @@ +# Tools `tl` + +```{eval-rst} +.. currentmodule:: crested.tl +``` + +```{eval-rst} +.. autosummary:: + :toctree: _autosummary + + Crested + TaskConfig + tfmodisco + default_configs +``` + + +```{toctree} +:maxdepth: 2 +:hidden: + +data +zoo +losses +metrics +``` + +## Data + +```{eval-rst} +.. autosummary:: + data.AnnDataModule + data.AnnDataLoader + data.AnnDataset +``` + +## Model Zoo + +```{eval-rst} +.. autosummary:: + zoo.basenji + zoo.chrombpnet + zoo.deeptopic_cnn + zoo.simple_convnet +``` + +## Losses + +```{eval-rst} +.. autosummary:: + losses.CosineMSELoss +``` + +## Metrics + +```{eval-rst} +.. autosummary:: + metrics.ConcordanceCorrelationCoefficient + metrics.PearsonCorrelation + metrics.PearsonCorrelationLog + metrics.ZeroPenaltyMetric +``` \ No newline at end of file diff --git a/docs/api/tools/losses.md b/docs/api/tools/losses.md new file mode 100644 index 00000000..16f7ccc8 --- /dev/null +++ b/docs/api/tools/losses.md @@ -0,0 +1,15 @@ +# Losses `tl.losses` + +Custom `tf.Keras.losses.Loss` functions for specific use cases. +Supply these (or your own) to a `tl.TaskConfig` to be able to use them for training. + +```{eval-rst} +.. currentmodule:: crested.tl.losses +``` + +```{eval-rst} +.. autosummary:: + :toctree: _autosummary + + CosineMSELoss +``` \ No newline at end of file diff --git a/docs/api/tools/metrics.md b/docs/api/tools/metrics.md new file mode 100644 index 00000000..38838b08 --- /dev/null +++ b/docs/api/tools/metrics.md @@ -0,0 +1,18 @@ +# Metrics `tl.metrics` + +Custom `tf.keras.metrics.Metric` metrics for specific use cases. +Supply these (or your own) to a `tl.TaskConfig` to be able to use them for training. + +```{eval-rst} +.. currentmodule:: crested.tl.metrics +``` + +```{eval-rst} +.. autosummary:: + :toctree: _autosummary + + ConcordanceCorrelationCoefficient + PearsonCorrelation + PearsonCorrelationLog + ZeroPenaltyMetric +``` \ No newline at end of file diff --git a/docs/api/tools/zoo.md b/docs/api/tools/zoo.md new file mode 100644 index 00000000..b92ab6e3 --- /dev/null +++ b/docs/api/tools/zoo.md @@ -0,0 +1,18 @@ +# Model Zoo `tl.zoo` + +Custom `tf.keras.Model` definitions that have shown to work well in specific use cases. +Supply these (or your own) to `tl.Crested(...)` to use them in training. + +```{eval-rst} +.. currentmodule:: crested.tl.zoo +``` + +```{eval-rst} +.. autosummary:: + :toctree: _autosummary + + basenji + chrombpnet + deeptopic_cnn + simple_convnet +``` \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index 1f726ba2..c58fae18 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -94,6 +94,10 @@ "python": ("https://docs.python.org/3", None), "anndata": ("https://anndata.readthedocs.io/en/stable/", None), "numpy": ("https://numpy.org/doc/stable/", None), + "tensorflow": ( + "https://www.tensorflow.org/api_docs/python", + "https://github.com/GPflow/tensorflow-intersphinx/raw/master/tf2_py_objects.inv", + ), } # List of patterns, relative to source directory, that match files and diff --git a/docs/tutorials/index.md b/docs/tutorials/index.md index de7ff149..d9186246 100644 --- a/docs/tutorials/index.md +++ b/docs/tutorials/index.md @@ -4,4 +4,5 @@ :maxdepth: 1 introduction +mouse_biccn ``` \ No newline at end of file diff --git a/docs/tutorials/introduction.ipynb b/docs/tutorials/introduction.ipynb index 5977c866..a5ef6267 100644 --- a/docs/tutorials/introduction.ipynb +++ b/docs/tutorials/introduction.ipynb @@ -35,7 +35,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We can use the {func}`~crested.import_topics` function to import regions per topic BED files and a consensus regions BED file (output from running [pycistopic](https://pycistopic.readthedocs.io/en/latest/) into an {class}`anndata.AnnData` object,\n", + "We can use the {func}`~crested.import_beds` function to import regions per topic BED files and a consensus regions BED file (output from running [pycistopic](https://pycistopic.readthedocs.io/en/latest/) into an {class}`anndata.AnnData` object,\n", "with the imported topics as the `AnnData.obs` and the consensus peak regions as the `AnnData.var`. \n", "\n", "Optionally, provide a chromsizes file to filter out regions that are not within the chromsizes. " @@ -74,7 +74,7 @@ } ], "source": [ - "adata = crested.import_topics(\n", + "adata = crested.import_beds(\n", " topics_folder=\"/staging/leuven/stg_00002/lcb/lmahieu/projects/DeepTopic/biccn_test/otsu\",\n", " regions_file=\"/staging/leuven/stg_00002/lcb/lmahieu/projects/DeepTopic/biccn_test/consensus_peaks_bicnn.bed\",\n", " # topics_subset=[\"topic_1\", \"topic_2\"], # optional subset of topics to import\n", @@ -282,13 +282,13 @@ "\n", "The entire CREsted workflow is built around the {func}`~crested.tl.Crested` class.\n", "This class has a couple of required arguments:\n", - "- `data`: the {class}`~crested.tl.AnnDataModule` object containing all the data (anndata, genome) and dataloaders that specify how to load the data.\n", + "- `data`: the {class}`~crested.tl.data.AnnDataModule` object containing all the data (anndata, genome) and dataloaders that specify how to load the data.\n", "- `model`: the {class}`~tf.keras.Model` object containing the model architecture.\n", "- `config`: the {class}`~crested.tl.TaskConfig` object containing the optimizer, loss function, and metrics to use in training.\n", "\n", - "#### Data\n", + "### Data\n", "\n", - "We'll start by initializing the {class}`~crested.tl.AnnDataModule` object with our data. \n", + "We'll start by initializing the {class}`~crested.tl.data.AnnDataModule` object with our data. \n", "This will tell our model how to load the data and what data to load during fitting/evaluation.\n", "The main arguments to suuply are the `adata` object, the `genome` object, and the `batch_size`. \n", "Other optional arguments are related to the training data loading (e.g. shuffling, whether to load the sequences into memory, ...)" @@ -313,7 +313,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "#### Model definition \n", + "### Model definition \n", "\n", "Next, we'll define the model architecture. This is a standard Keras model definition, so you can provide any model you like.\n", "Altneratively, there are a couple of ready-to-use models available in the `crested.tl.zoo` module.\n", @@ -338,7 +338,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "#### TaskConfig\n", + "### TaskConfig\n", "\n", "The TaskConfig object specifies the optimizer, loss function, and metrics to use in training (we call this our 'task'). \n", "Some default configurations are available for some common tasks such as 'topic_classification' and 'peak_regression',\n", @@ -552,7 +552,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "These can then be plotted using the {func}`~crested.pl.contribution_scores` function." + "These can then be plotted using the {func}`~crested.pl.patterns.contribution_scores` function." ] }, { @@ -579,7 +579,7 @@ } ], "source": [ - "crested.pl.contribution_scores(\n", + "crested.pl.patterns.contribution_scores(\n", " scores, one_hot_encoded_sequences, class_names=list(adata.obs_names)[0:2]\n", ")" ] diff --git a/docs/tutorials/mouse_biccn.ipynb b/docs/tutorials/mouse_biccn.ipynb index 0b66b159..55e865a0 100644 --- a/docs/tutorials/mouse_biccn.ipynb +++ b/docs/tutorials/mouse_biccn.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Introduction to CREsted\n", + "# Mouse BICCN\n", "\n", "In this introductory notebook, we will train a topic classification model and inspect the results." ] @@ -85,7 +85,7 @@ " bigwigs_folder='/home/VIB.LOCAL/niklas.kempynck/nkemp/mouse/biccn/bigwigs/bws',\n", " regions_file='/home/VIB.LOCAL/niklas.kempynck/nkemp/mouse/biccn/consensus_peaks_inputs.bed',\n", " chromsizes_file='/home/VIB.LOCAL/niklas.kempynck/nkemp/mouse/biccn/mm.chrom.sizes',\n", - " target_region_width=1000 \n", + " target_region_width=1000\n", ")\n", "adata" ] @@ -151,7 +151,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We can use the {func}`~crested.import_topics` function to import regions per topic BED files and a consensus regions BED file (output from running [pycistopic](https://pycistopic.readthedocs.io/en/latest/) into an {class}`anndata.AnnData` object,\n", + "We can use the {func}`~crested.import_beds` function to import regions per topic BED files and a consensus regions BED file (output from running [pycistopic](https://pycistopic.readthedocs.io/en/latest/) into an {class}`anndata.AnnData` object,\n", "with the imported topics as the `AnnData.obs` and the consensus peak regions as the `AnnData.var`. \n", "\n", "Optionally, provide a chromsizes file to filter out regions that are not within the chromsizes. " @@ -487,13 +487,13 @@ "\n", "The entire CREsted workflow is built around the {func}`~crested.tl.Crested` class.\n", "This class has a couple of required arguments:\n", - "- `data`: the {class}`~crested.tl.AnnDataModule` object containing all the data (anndata, genome) and dataloaders that specify how to load the data.\n", + "- `data`: the {class}`~crested.tl.data.AnnDataModule` object containing all the data (anndata, genome) and dataloaders that specify how to load the data.\n", "- `model`: the {class}`~tf.keras.Model` object containing the model architecture.\n", "- `config`: the {class}`~crested.tl.TaskConfig` object containing the optimizer, loss function, and metrics to use in training.\n", "\n", - "#### Data\n", + "### Data\n", "\n", - "We'll start by initializing the {class}`~crested.tl.AnnDataModule` object with our data. \n", + "We'll start by initializing the {class}`~crested.tl.data.AnnDataModule` object with our data. \n", "This will tell our model how to load the data and what data to load during fitting/evaluation.\n", "The main arguments to suuply are the `adata` object, the `genome` object, and the `batch_size`. \n", "Other optional arguments are related to the training data loading (e.g. shuffling, whether to load the sequences into memory, ...)" @@ -521,7 +521,6 @@ " batch_size=64,\n", " in_memory=False,\n", " max_stochastic_shift=3\n", - " \n", ")" ] }, @@ -558,7 +557,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "#### Model definition \n", + "### Model definition \n", "\n", "Next, we'll define the model architecture. This is a standard Keras model definition, so you can provide any model you like.\n", "Altneratively, there are a couple of ready-to-use models available in the `crested.tl.zoo` module.\n", @@ -595,7 +594,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "#### TaskConfig\n", + "### TaskConfig\n", "\n", "The TaskConfig object specifies the optimizer, loss function, and metrics to use in training (we call this our 'task'). \n", "Some default configurations are available for some common tasks such as 'topic_classification' and 'peak_regression',\n", @@ -1050,7 +1049,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "These can then be plotted using the {func}`~crested.pl.contribution_scores` function." + "These can then be plotted using the {func}`~crested.pl.patterns.contribution_scores` function." ] }, { @@ -1570,7 +1569,7 @@ " y = predicted_values[model_name]\n", " # Initialize a matrix to store the correlations\n", " correlation_matrix = np.zeros((n_features, n_features))\n", - " \n", + "\n", " # Calculate the correlation for each pair of prediction-target\n", " for i in range(n_features):\n", " for j in range(n_features):\n", @@ -1992,7 +1991,6 @@ } ], "source": [ - "import anndata as ad\n", "adata_crested = crested.import_bigwigs(\n", " bigwigs_folder='/home/VIB.LOCAL/niklas.kempynck/nkemp/mouse/biccn/bigwigs/bws',\n", " regions_file='crested_regs.bed',\n", @@ -2202,7 +2200,7 @@ "source": [ "adata_crested_preds = evaluator_crested.predict(\n", " adata_crested, model_name=\"crested1\"\n", - ") " + ")" ] }, { @@ -2549,7 +2547,7 @@ "import pyBigWig\n", "\n", "gt = np.zeros(19)\n", - "for i, bw in enumerate(bw_list_mm): \n", + "for i, bw in enumerate(bw_list_mm):\n", " # Retrieve values from the bigWig file for the given interval\n", " #gt[i] = np.sum(np.nan_to_num(bw.values('chr18', 3901975+557, 3904089-557)))\n", " gt[i] = np.sum(bw.values('chr18', 3901975+557, 3904089-557))\n", @@ -2583,7 +2581,7 @@ "import pyBigWig\n", "\n", "gt = np.zeros(19)\n", - "for i, bw in enumerate(bw_list_mm): \n", + "for i, bw in enumerate(bw_list_mm):\n", " # Retrieve values from the bigWig file for the given interval\n", " gt[i] = np.sum(np.nan_to_num(bw.values(test_df.iloc[16+idx]['chr'], int(test_df.iloc[16+idx]['start'])+557, int(test_df.iloc[16+idx]['end'])-557)))\n", "\n", @@ -2650,7 +2648,7 @@ "kernelspec": { "display_name": "crested", "language": "python", - "name": "crested" + "name": "python3" }, "language_info": { "codemirror_mode": { diff --git a/pyproject.toml b/pyproject.toml index ec8766ac..b796b44e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,9 @@ dependencies = [ "loguru", "logomaker", "pybigtools", - "seaborn" + "seaborn", + "cmake", + "modisco-lite", ] [project.optional-dependencies] diff --git a/src/crested/__init__.py b/src/crested/__init__.py index 13b7c017..7d00443b 100644 --- a/src/crested/__init__.py +++ b/src/crested/__init__.py @@ -3,10 +3,10 @@ from importlib.metadata import version from . import pl, pp, tl -from ._io import import_bigwigs, import_topics +from ._io import import_beds, import_bigwigs from ._logging import setup_logging -__all__ = ["pl", "pp", "tl", "import_topics", "import_bigwigs", "setup_logging"] +__all__ = ["pl", "pp", "tl", "import_beds", "import_bigwigs", "setup_logging"] __version__ = version("crested") diff --git a/src/crested/_io.py b/src/crested/_io.py index c082b729..b4db32a5 100644 --- a/src/crested/_io.py +++ b/src/crested/_io.py @@ -1,8 +1,9 @@ -"""I/O functions for importing topics and bigWigs into AnnData objects.""" +"""I/O functions for importing beds and bigWigs into AnnData objects.""" from __future__ import annotations import os +import re from concurrent.futures import ProcessPoolExecutor from os import PathLike from pathlib import Path @@ -16,11 +17,11 @@ from scipy.sparse import csr_matrix -def _sort_topic_files(filename: str): - """Sorts topic files. +def _sort_files(filename: str): + """Sorts files. - Prioritizes numeric extraction from filenames of the format 'Topic_X.bed' (X=int). - Other filenames are sorted alphabetically, with 'Topic_' files coming last if numeric extraction fails. + Prioritizes numeric extraction from filenames of the format 'Class_X.bed' (X=int). + Other filenames are sorted alphabetically, with 'Class_' files coming last if numeric extraction fails. """ filename = Path(filename) parts = filename.stem.split("_") @@ -32,13 +33,28 @@ def _sort_topic_files(filename: str): # If the numeric part is not an integer, handle gracefully return (True, filename.stem) - # Return True for the first element to sort non-'Topic_X' filenames alphabetically after 'Topic_X' + # Return True for the first element to sort non-'Class_X' filenames alphabetically after 'Class_X' return ( True, filename.stem, ) +def _custom_region_sort(region: str) -> tuple(int, int, int): + """Custom sorting function for regions in the format chr:start-end.""" + chrom, pos = region.split(":") + start, _ = map(int, pos.split("-")) + + # check if the chromosome part contains digits + numeric_match = re.match(r"chr(\d+)|chrom(\d+)", chrom, re.IGNORECASE) + + if numeric_match: + chrom_num = int(numeric_match.group(1) or numeric_match.group(2)) + return (0, chrom_num, start) + else: + return (1, chrom, start) + + def _read_chromsizes(chromsizes_file: PathLike) -> dict[str, int]: """Read chromsizes file into a dictionary.""" chromsizes = pd.read_csv( @@ -130,46 +146,43 @@ def _create_temp_bed_file( return temp_bed_file -def import_topics( - topics_folder: PathLike, - regions_file: PathLike, +def import_beds( + beds_folder: PathLike, + regions_file: PathLike | None = None, chromsizes_file: PathLike | None = None, - topics_subset: list | None = None, + classes_subset: list | None = None, remove_empty_regions: bool = True, compress: bool = False, ) -> AnnData: """ - Import topic and consensus regions BED files into AnnData format. + Import beds and optionally consensus regions BED files into AnnData format. + + Expects the folder with BED files where each file is named {class_name}.bed + The result is an AnnData object with classes as rows and the regions as columns, + with the .X values indicating whether a region is open in a class. - This format is required to be able to train a topic prediction model. - The topic and consensus regions are the outputs from running pycisTopic + Note + ---- + This is the default function to import topic BED files coming from running pycisTopic (https://pycistopic.readthedocs.io/en/latest/) on your data. The result is an AnnData object with topics as rows and consensus region as columns, with binary values indicating whether a region is present in a topic. - Example - ------- - >>> anndata = crested.import_topics( - ... topics_folder="path/to/topics", - ... regions_file="path/to/regions.bed", - ... chromsizes_file="path/to/chrom.sizes", - ... topics_subset=["Topic_1", "Topic_2"], - ... ) - Parameters ---------- - topics_folder - Folder path containing the topic BED files. + beds_folder + Folder path containing the BED files. regions_file - File path of the consensus regions BED file. - topics_subset - List of topics to include in the AnnData object. If None, all topics + File path of the consensus regions BED file to use as columns in the AnnData object. + If None, the regions will be extracted from the files. + classes_subset + List of classes to include in the AnnData object. If None, all files will be included. - Topics should be named after the topics file name without the extension. + Classes should be named after the file name without the extension. chromsizes_file File path of the chromsizes file. Used for checking if the new regions are within the chromosome boundaries. remove_empty_regions - Remove regions that are not open in any topic. + Remove regions that are not open in any class (only possible if regions_file is provided) compress Compress the AnnData.X matrix. If True, the matrix will be stored as a sparse matrix. If False, the matrix will be stored as a dense matrix. @@ -179,15 +192,24 @@ def import_topics( Returns ------- - AnnData object with topics as rows and peaks as columns. + AnnData object with classes as rows and peaks as columns. + + Example + ------- + >>> anndata = crested.import_beds( + ... beds_folder="path/to/beds/folder/", + ... regions_file="path/to/regions.bed", + ... chromsizes_file="path/to/chrom.sizes", + ... classes_subset=["Topic_1", "Topic_2"], + ... ) """ - topics_folder = Path(topics_folder) - regions_file = Path(regions_file) + beds_folder = Path(beds_folder) + regions_file = Path(regions_file) if regions_file else None # Input checks - if not topics_folder.is_dir(): - raise FileNotFoundError(f"Directory '{topics_folder}' not found") - if not regions_file.is_file(): + if not beds_folder.is_dir(): + raise FileNotFoundError(f"Directory '{beds_folder}' not found") + if (regions_file is not None) and (not regions_file.is_file()): raise FileNotFoundError(f"File '{regions_file}' not found") if chromsizes_file is not None: chromsizes_file = Path(chromsizes_file) @@ -198,47 +220,93 @@ def import_topics( "Chromsizes file not provided. Will not check if regions are within chromosomes", stacklevel=1, ) - if topics_subset is not None: - for topic in topics_subset: - if not any(topics_folder.glob(f"{topic}.bed")): - raise FileNotFoundError( - f"Topic '{topic}' not found in '{topics_folder}'" + if classes_subset is not None: + for classname in classes_subset: + if not any(beds_folder.glob(f"{classname}.bed")): + raise FileNotFoundError(f"'{classname}' not found in '{beds_folder}'") + + if regions_file: + # Read consensus regions BED file and filter out regions not within chromosomes + consensus_peaks = _read_consensus_regions(regions_file, chromsizes_file) + + binary_matrix = pd.DataFrame(0, index=[], columns=consensus_peaks["region"]) + file_paths = [] + + # Which regions are present in the consensus regions + logger.info( + f"Reading bed files from {beds_folder} and using {regions_file} as var_names..." + ) + for file in sorted(beds_folder.glob("*.bed"), key=_sort_files): + class_name = file.stem + if classes_subset is None or class_name in classes_subset: + class_regions = pd.read_csv( + file, sep="\t", header=None, usecols=[0, 1, 2] + ) + class_regions["region"] = ( + class_regions[0].astype(str) + + ":" + + class_regions[1].astype(str) + + "-" + + class_regions[2].astype(str) ) - # Read consensus regions BED file and filter out regions not within chromosomes - consensus_peaks = _read_consensus_regions(regions_file, chromsizes_file) + # Create binary row for the current topic + binary_row = binary_matrix.columns.isin(class_regions["region"]).astype( + int + ) + binary_matrix.loc[class_name] = binary_row + file_paths.append(str(file)) - binary_matrix = pd.DataFrame(0, index=[], columns=consensus_peaks["region"]) - topic_file_paths = [] + # else, get regions from the bed files + else: + file_paths = [] + all_regions = set() - # Which topic regions are present in the consensus regions - logger.info(f"Reading topics from {topics_folder}...") - for topic_file in sorted(topics_folder.glob("*.bed"), key=_sort_topic_files): - topic_name = topic_file.stem - if topics_subset is None or topic_name in topics_subset: - topic_peaks = pd.read_csv( - topic_file, sep="\t", header=None, usecols=[0, 1, 2] - ) - topic_peaks["region"] = ( - topic_peaks[0].astype(str) + # Collect all regions from the BED files + logger.info( + f"Reading bed files from {beds_folder} without consensus regions..." + ) + for file in sorted(beds_folder.glob("*.bed"), key=_sort_files): + class_name = file.stem + if classes_subset is None or class_name in classes_subset: + class_regions = pd.read_csv( + file, sep="\t", header=None, usecols=[0, 1, 2] + ) + class_regions["region"] = ( + class_regions[0].astype(str) + + ":" + + class_regions[1].astype(str) + + "-" + + class_regions[2].astype(str) + ) + all_regions.update(class_regions["region"].tolist()) + file_paths.append(str(file)) + + # Convert set to sorted list + all_regions = sorted(all_regions, key=_custom_region_sort) + binary_matrix = pd.DataFrame(0, index=[], columns=all_regions) + + # Populate the binary matrix + for file in file_paths: + class_name = Path(file).stem + class_regions = pd.read_csv(file, sep="\t", header=None, usecols=[0, 1, 2]) + class_regions["region"] = ( + class_regions[0].astype(str) + ":" - + topic_peaks[1].astype(str) + + class_regions[1].astype(str) + "-" - + topic_peaks[2].astype(str) + + class_regions[2].astype(str) ) - - # Create binary row for the current topic - binary_row = binary_matrix.columns.isin(topic_peaks["region"]).astype(int) - binary_matrix.loc[topic_name] = binary_row - topic_file_paths.append(str(topic_file)) + binary_row = binary_matrix.columns.isin(class_regions["region"]).astype(int) + binary_matrix.loc[class_name] = binary_row ann_data = AnnData( binary_matrix, ) - ann_data.obs["file_path"] = topic_file_paths + ann_data.obs["file_path"] = file_paths ann_data.obs["n_open_regions"] = ann_data.X.sum(axis=1) - ann_data.var["n_topics"] = ann_data.X.sum(axis=0) + ann_data.var["n_classes"] = ann_data.X.sum(axis=0) ann_data.var["chr"] = ann_data.var.index.str.split(":").str[0] ann_data.var["start"] = ( ann_data.var.index.str.split(":").str[1].str.split("-").str[0] @@ -251,18 +319,18 @@ def import_topics( ann_data.X = csr_matrix(ann_data.X) # Output checks - topics_no_open_regions = ann_data.obs[ann_data.obs["n_open_regions"] == 0] - if not topics_no_open_regions.empty: + classes_no_open_regions = ann_data.obs[ann_data.obs["n_open_regions"] == 0] + if not classes_no_open_regions.empty: raise ValueError( - f"Topics {topics_no_open_regions.index} have 0 open regions in the consensus peaks" + f"{classes_no_open_regions.index} have 0 open regions in the consensus peaks" ) - regions_no_topics = ann_data.var[ann_data.var["n_topics"] == 0] - if not regions_no_topics.empty: + regions_no_classes = ann_data.var[ann_data.var["n_classes"] == 0] + if not regions_no_classes.empty: if remove_empty_regions: logger.warning( - f"{len(regions_no_topics.index)} consensus regions are not open in any topic. Removing them from the AnnData object. Disable this behavior by setting 'remove_empty_regions=False'", + f"{len(regions_no_classes.index)} consensus regions are not open in any class. Removing them from the AnnData object. Disable this behavior by setting 'remove_empty_regions=False'", ) - ann_data = ann_data[:, ann_data.var["n_topics"] > 0] + ann_data = ann_data[:, ann_data.var["n_classes"] > 0] return ann_data @@ -285,16 +353,6 @@ def import_bigwigs( where the original region will still be used as the index. This is often useful to extract sequence information around the actual peak region. - Example - ------- - >>> anndata = crested.import_bigwigs( - ... bigwigs_folder="path/to/bigwigs", - ... regions_file="path/to/peaks.bed", - ... chromsizes_file="path/to/chrom.sizes", - ... target="max", - ... target_region_width=500, - ... ) - Parameters ---------- bigwigs_folder @@ -315,6 +373,16 @@ def import_bigwigs( Returns ------- AnnData object with bigWigs as rows and peaks as columns. + + Example + ------- + >>> anndata = crested.import_bigwigs( + ... bigwigs_folder="path/to/bigwigs", + ... regions_file="path/to/peaks.bed", + ... chromsizes_file="path/to/chrom.sizes", + ... target="max", + ... target_region_width=500, + ... ) """ bigwigs_folder = Path(bigwigs_folder) regions_file = Path(regions_file) diff --git a/src/crested/pl/__init__.py b/src/crested/pl/__init__.py index 52e7bbf4..9d346a6d 100644 --- a/src/crested/pl/__init__.py +++ b/src/crested/pl/__init__.py @@ -1,3 +1,2 @@ -from . import bar, heatmap, hist, scatter -from ._contribution_scores import contribution_scores +from . import bar, heatmap, hist, patterns, scatter from ._utils import render_plot diff --git a/src/crested/pl/_utils.py b/src/crested/pl/_utils.py index e14ce9ea..ec5ef9ff 100644 --- a/src/crested/pl/_utils.py +++ b/src/crested/pl/_utils.py @@ -2,9 +2,9 @@ from __future__ import annotations -import logomaker +import os + import matplotlib.pyplot as plt -import numpy as np def render_plot( @@ -51,49 +51,8 @@ def render_plot( plt.tight_layout() if save_path: + if not os.path.exists(os.path.dirname(save_path)): + os.makedirs(os.path.dirname(save_path)) plt.savefig(save_path) plt.show() - - -def grad_times_input_to_df(x, grad, alphabet="ACGT"): - """Generate pandas dataframe for saliency plot based on grad x inputs""" - x_index = np.argmax(np.squeeze(x), axis=1) - grad = np.squeeze(grad) - L, A = grad.shape - - seq = "" - saliency = np.zeros(L) - for i in range(L): - seq += alphabet[x_index[i]] - saliency[i] = grad[i, x_index[i]] - - # create saliency matrix - saliency_df = logomaker.saliency_to_matrix(seq=seq, values=saliency) - return saliency_df - - -def grad_times_input_to_df_mutagenesis(x, grad, alphabet="ACGT"): - import pandas as pd - - """Generate pandas dataframe for mutagenesis plot based on grad x inputs""" - x = np.squeeze(x) # Ensure x is correctly squeezed - grad = np.squeeze(grad) - L, A = x.shape - - # Get original nucleotides' indices, ensure it's 1D - x_index = np.argmax(x, axis=1) - - # Convert index array to nucleotide letters - original_nucleotides = np.array([alphabet[idx] for idx in x_index]) - - # Data preparation for DataFrame - data = { - "Position": np.repeat(np.arange(L), A), - "Nucleotide": np.tile(list(alphabet), L), - "Effect": grad.reshape( - -1 - ), # Flatten grad assuming it matches the reshaped size - "Original": np.repeat(original_nucleotides, A), - } - df = pd.DataFrame(data) - return df + plt.close() diff --git a/src/crested/pl/bar/_normalization_weights.py b/src/crested/pl/bar/_normalization_weights.py index 11c2000d..7723864a 100644 --- a/src/crested/pl/bar/_normalization_weights.py +++ b/src/crested/pl/bar/_normalization_weights.py @@ -37,7 +37,7 @@ def normalization_weights(adata: AnnData, **kwargs): ... title="Normalization scaling factors", ... ) - .. image:: ../../../docs/_static/img/examples/bar_normalization_weights.png + .. image:: ../../../../docs/_static/img/examples/bar_normalization_weights.png """ @log_and_raise(ValueError) diff --git a/src/crested/pl/bar/_region.py b/src/crested/pl/bar/_region.py index 3c24e587..b7215cf9 100644 --- a/src/crested/pl/bar/_region.py +++ b/src/crested/pl/bar/_region.py @@ -50,7 +50,7 @@ def region_predictions( ... title="Region chr1:3094805-3095305" ... ) - .. image:: ../../../docs/_static/img/examples/bar_region_predictions.png + .. image:: ../../../../docs/_static/img/examples/bar_region_predictions.png """ @log_and_raise(ValueError) @@ -144,7 +144,7 @@ def region(adata: AnnData, region: str, target: str = "groundtruth", **kwargs) - ... figtitle="chr1:3094805-3095305", ... ) - .. image:: ../../../docs/_static/img/examples/bar_region.png + .. image:: ../../../../docs/_static/img/examples/bar_region.png """ @log_and_raise(ValueError) diff --git a/src/crested/pl/heatmap/_correlations.py b/src/crested/pl/heatmap/_correlations.py index f3f31674..890a9c69 100644 --- a/src/crested/pl/heatmap/_correlations.py +++ b/src/crested/pl/heatmap/_correlations.py @@ -65,7 +65,7 @@ def correlations_self( ... adata, log_transform=True, title="Self correlations heatmap" ... ) - .. image:: ../../../docs/_static/img/examples/heatmap_self_correlations.png + .. image:: ../../../../docs/_static/img/examples/heatmap_self_correlations.png """ x = adata.X classes = list(adata.obs_names) @@ -132,7 +132,7 @@ def correlations_predictions( ... title="Correlations: Predictions vs Ground Truth", ... ) - .. image:: ../../../docs/_static/img/examples/heatmap_correlations_predictions.png + .. image:: ../../../../docs/_static/img/examples/heatmap_correlations_predictions.png """ @log_and_raise(ValueError) diff --git a/src/crested/pl/hist/_distribution.py b/src/crested/pl/hist/_distribution.py index 88a6a616..51bc1c96 100644 --- a/src/crested/pl/hist/_distribution.py +++ b/src/crested/pl/hist/_distribution.py @@ -51,7 +51,7 @@ def distribution( ... adata, split="test", share_y=False, class_names=["Astro", "Vip"] ... ) - .. image:: ../../../docs/_static/img/examples/hist_distribution.png + .. image:: ../../../../docs/_static/img/examples/hist_distribution.png """ @log_and_raise(ValueError) diff --git a/src/crested/pl/patterns/__init__.py b/src/crested/pl/patterns/__init__.py new file mode 100644 index 00000000..08e7c743 --- /dev/null +++ b/src/crested/pl/patterns/__init__.py @@ -0,0 +1,2 @@ +from ._contribution_scores import contribution_scores +from ._modisco_results import modisco_results diff --git a/src/crested/pl/_contribution_scores.py b/src/crested/pl/patterns/_contribution_scores.py similarity index 66% rename from src/crested/pl/_contribution_scores.py rename to src/crested/pl/patterns/_contribution_scores.py index 0a4b577f..4209994a 100644 --- a/src/crested/pl/_contribution_scores.py +++ b/src/crested/pl/patterns/_contribution_scores.py @@ -2,47 +2,42 @@ from __future__ import annotations -import logomaker import matplotlib.pyplot as plt import numpy as np from loguru import logger from crested._logging import log_and_raise -from crested.pl._utils import grad_times_input_to_df +from crested.pl._utils import render_plot - -def _plot_attribution_map(saliency_df, ax=None, figsize=(20, 1)): - """Plot an attribution map using logomaker""" - logomaker.Logo(saliency_df, figsize=figsize, ax=ax) - if ax is None: - ax = plt.gca() - ax.spines["right"].set_visible(False) - ax.spines["top"].set_visible(False) - # ax.yaxis.set_ticks_position("none") - ax.xaxis.set_ticks_position("none") - plt.xticks([]) +from ._utils import _plot_attribution_map, grad_times_input_to_df @log_and_raise(ValueError) def _check_contrib_params( zoom_n_bases: int | None, scores: np.ndarray, + labels: list | None, ): """Check contribution scores parameters.""" if zoom_n_bases is not None and zoom_n_bases > scores.shape[2]: raise ValueError( f"zoom_n_bases ({zoom_n_bases}) must be less than or equal to the number of bases in the sequence ({scores.shape[2]})" ) + if labels: + if len(labels) != scores.shape[1]: + raise ValueError( + f"Number of plot labels ({len(labels)}) must match the number of classes ({scores.shape[1]}), since each class has a separate plot." + ) def contribution_scores( scores: np.ndarray, seqs_one_hot: np.ndarray, - class_names: list, + labels: list | None = None, zoom_n_bases: int | None = None, highlight_positions: list[tuple[int, int]] | None = None, ylim: tuple | None = None, - save_path: str | None = None, + **kwargs, ): """ Visualize interpretation scores with optional highlighted positions. @@ -51,36 +46,36 @@ def contribution_scores( Parameters ---------- - scores : np.ndarray + scores Contribution scores of shape (n_seqs, n_classes, n_bases, n_features). - seqs_one_hot : np.ndarray + seqs_one_hot One-hot encoded corresponding sequences of shape (n_seqs, n_bases, n_features). - class_names : list - List of class names to use as labels. - zoom_n_bases : int, optional + labels + List of labels to add to the plot. Should have the same length as the number of classes. + zoom_n_bases Number of center bases to zoom in on. Default is None (no zooming). - highlight_positions : list[tuple[int, int]], optional + highlight_positions List of tuples with start and end positions to highlight. Default is None. - ylim : tuple, optional + ylim Y-axis limits. Default is None. - save_path : str, optional - Path to save the plot. Default is None (only show the plot). Examples -------- >>> import numpy as np >>> scores = np.random.rand(1, 1, 100, 4) >>> seqs_one_hot = np.random.randint(0, 2, (1, 100, 4)) - >>> class_names = ["class1"] - >>> crested.pl.contribution_scores(scores, seqs_one_hot, class_names) + >>> labels = ["class1"] + >>> crested.pl.patterns.contribution_scores(scores, seqs_one_hot, labels) - .. image:: ../../../docs/_static/img/examples/contribution_scores.png + .. image:: ../../../../docs/_static/img/examples/contribution_scores.png """ # Center and zoom - _check_contrib_params(zoom_n_bases, scores) + _check_contrib_params(zoom_n_bases, scores, labels) if zoom_n_bases is None: zoom_n_bases = scores.shape[2] + if labels and not isinstance(labels, list): + labels = [str(labels)] center = int(scores.shape[2] / 2) start_idx = center - int(zoom_n_bases / 2) scores = scores[:, :, start_idx : start_idx + zoom_n_bases, :] @@ -92,13 +87,17 @@ def contribution_scores( logger.info(f"Plotting contribution scores for {seqs_one_hot.shape[0]} sequence(s)") for seq in range(seqs_one_hot.shape[0]): fig_height_per_class = 2 - fig = plt.figure(figsize=(50, fig_height_per_class * len(class_names))) - for i, class_name in enumerate(class_names): + fig = plt.figure(figsize=(50, fig_height_per_class * scores.shape[1])) + for i in range(scores.shape[1]): seq_class_scores = scores[seq, i, :, :] seq_class_x = seqs_one_hot[seq, :, :] intgrad_df = grad_times_input_to_df(seq_class_x, seq_class_scores) - ax = plt.subplot(len(class_names), 1, i + 1) - _plot_attribution_map(intgrad_df, ax=ax) + ax = plt.subplot(scores.shape[1], 1, i + 1) + _plot_attribution_map(intgrad_df, ax=ax, return_ax=False) + if labels: + class_name = labels[i] + else: + class_name = f"Class {i}" text_to_add = class_name if ylim: ax.set_ylim(ylim[0], ylim[1]) @@ -129,8 +128,14 @@ def contribution_scores( plt.xlabel("Position") plt.xticks(np.arange(0, zoom_n_bases, 50)) - if save_path: - plt.savefig(save_path) - plt.close(fig) - else: - plt.show() + + if "width" not in kwargs: + kwargs["width"] = 50 + if "height" not in kwargs: + kwargs["height"] = fig_height_per_class * scores.shape[1] + if "xlabel" not in kwargs: + kwargs["xlabel"] = "Position" + if "ylabel" not in kwargs: + kwargs["ylabel"] = "Scores" + + render_plot(fig, **kwargs) diff --git a/src/crested/pl/patterns/_modisco_results.py b/src/crested/pl/patterns/_modisco_results.py new file mode 100644 index 00000000..e050ee87 --- /dev/null +++ b/src/crested/pl/patterns/_modisco_results.py @@ -0,0 +1,292 @@ +from __future__ import annotations + +import h5py +import matplotlib.pyplot as plt +import modiscolite as modisco +import numpy as np +from loguru import logger + +from crested._logging import log_and_raise +from crested.pl._utils import render_plot + +from ._utils import _plot_attribution_map + + +@log_and_raise(ValueError) +def _trim_pattern_by_ic( + pattern: dict, + pos_pattern: bool, + min_v: float, + background: list[float] = None, + pseudocount: float = 1e-6, +) -> dict: + """ + Trims the pattern based on information content (IC). + + Parameters + ---------- + pattern + Dictionary containing the pattern data. + pos_pattern + Indicates if the pattern is a positive pattern. + min_v + Minimum value for trimming. + background + Background probabilities for each nucleotide. + pseudocount + Pseudocount for IC calculation. + + Returns + ------- + Trimmed pattern. + """ + if background is None: + background = [0.27, 0.23, 0.23, 0.27] + contrib_scores = np.array(pattern["contrib_scores"]) + if not pos_pattern: + contrib_scores = -contrib_scores + contrib_scores[contrib_scores < 0] = 0 + + ic = modisco.util.compute_per_position_ic( + ppm=np.array(contrib_scores), background=background, pseudocount=pseudocount + ) + np.nan_to_num(ic, copy=False, nan=0.0) + v = (abs(np.array(contrib_scores)) * ic[:, None]).max(1) + v = (v - v.min()) / (v.max() - v.min() + 1e-9) + + try: + start_idx = min(np.where(np.diff((v > min_v) * 1))[0]) + end_idx = max(np.where(np.diff((v > min_v) * 1))[0]) + 1 + except ValueError: + logger.error("No valid pattern found. Aborting...") + + return _trim(pattern, start_idx, end_idx) + + +def _trim(pattern: dict, start_idx: int, end_idx: int) -> dict: + """ + Trims the pattern to the specified start and end indices. + + Parameters + ---------- + pattern + Dictionary containing the pattern data. + start_idx + Start index for trimming. + end_idx (int) + End index for trimming. + + Returns + ------- + Trimmed pattern. + """ + return { + "sequence": np.array(pattern["sequence"])[start_idx:end_idx], + "contrib_scores": np.array(pattern["contrib_scores"])[start_idx:end_idx], + "hypothetical_contribs": np.array(pattern["hypothetical_contribs"])[ + start_idx:end_idx + ], + } + + +def _get_ic( + contrib_scores: np.ndarray, + pos_pattern: bool, + background: list[float] = None, +) -> np.ndarray: + """ + Computes the information content (IC) for the given contribution scores. + + Parameters + ---------- + contrib_scores + Array of contribution scores. + pos_pattern + Indicates if the pattern is a positive pattern. + background + background probabilities for each nucleotide. + + Returns + ------- + Information content for the contribution scores. + """ + if background is None: + background = [0.27, 0.23, 0.23, 0.27] + background = np.array(background) + if not pos_pattern: + contrib_scores = -contrib_scores + contrib_scores[contrib_scores < 0] = 1e-9 + ppm = contrib_scores / np.sum(contrib_scores, axis=1)[:, None] + + ic = (np.log((ppm + 0.001) / (1.004)) / np.log(2)) * ppm - ( + np.log(background) * background / np.log(2) + ) + return ppm * (np.sum(ic, axis=1)[:, None]) + + +def modisco_results( + classes: list[str], + contribution: str, + contribution_dir: str, + num_seq: int, + viz: str = "contrib", + min_seqlets: int = 0, + verbose: bool = False, + y_min: float = -0.05, + y_max: float = 0.25, + background: list[float] = None, + **kwargs, +) -> None: + """ + Plot genomic contributions for the given classes. + + Requires the modisco results to be present in the specified directory. + The contribution scores are trimmed based on information content (IC). + + Parameters + ---------- + classes + List of classes to plot genomic contributions for. + contribution + Contribution type to plot. Choose either "positive" or "negative". + contribution_dir + Directory containing the modisco results. + Each class should have a separate modisco .h5 results file in the format {class}_modisco_results.h5. + num_seq + Total number of sequences used for the modisco run. + Necessary to calculate the percentage of sequences with the pattern. + viz + Visualization method. Choose either "contrib" or "pwm". + min_seqlets + Minimum number of seqlets required for a pattern to be considered. + verbose + Print verbose output. + y_min + Minimum y-axis limit for the plot if viz is "contrib". + y_max + Maximum y-axis limit for the plot if viz is "contrib". + background + Background probabilities for each nucleotide. Default is [0.27, 0.23, 0.23, 0.27]. + kwargs + Additional keyword arguments for the plot. + + Examples + -------- + >>> crested.pl.patterns.modisco_results( + ... classes=["Lamp5", "Pvalb", "Sst", ""Sst-Chodl", "Vip"], + ... contribution="positive", + ... contribution_dir="/path/to/modisco_results", + ... num_seq=1000, + ... viz="pwm", + ... save_path="/path/to/genomic_contributions.png", + ... ) + + .. image:: ../../../../docs/_static/img/examples/genomic_contributions.png + + See Also + -------- + crested.tl.tfmodisco + crested.pl.render_plot + """ + if background is None: + background = [0.27, 0.23, 0.23, 0.27] + background = np.array(background) + pos_pat = contribution == "positive" + + logger.info(f"Starting genomic contributions plot for classes: {classes}") + + max_num_patterns = 0 + all_patterns = [] + + for cell_type in classes: + if verbose: + logger.info(cell_type) + hdf5_results = h5py.File( + f"{contribution_dir}/{cell_type}_modisco_results.h5", "r" + ) + metacluster_names = list(hdf5_results.keys()) + + if f"{contribution[:3]}_patterns" not in metacluster_names: + raise ValueError( + f"No {contribution[:3]}_patterns for {cell_type}. Aborting..." + ) + + for metacluster_name in [f"{contribution[:3]}_patterns"]: + all_pattern_names = list(hdf5_results[metacluster_name]) + max_num_patterns = max(max_num_patterns, len(all_pattern_names)) + + fig, axes = plt.subplots( + nrows=max_num_patterns, + ncols=len(classes), + figsize=(8 * len(classes), 2 * max_num_patterns), + ) + + if verbose: + logger.info(f"Max patterns for selected classes: {max_num_patterns}") + + motif_counter = 1 + + for idx, cell_type in enumerate(classes): + hdf5_results = h5py.File( + f"{contribution_dir}/{cell_type}_modisco_results.h5", "r" + ) + metacluster_names = list(hdf5_results.keys()) + + if verbose: + logger.info(metacluster_names) + + for metacluster_name in [f"{contribution[:3]}_patterns"]: + all_pattern_names = list(hdf5_results[metacluster_name]) + + for _pattern_idx, pattern_name in enumerate(all_pattern_names): + if len(classes) > 1: + ax = axes[motif_counter - 1, idx] + elif max_num_patterns > 1: + ax = axes[motif_counter - 1] + else: + ax = axes + motif_counter += 1 + all_patterns.append((metacluster_name, pattern_name)) + pattern = hdf5_results[metacluster_name][pattern_name] + num_seqlets = list( + hdf5_results[metacluster_name][pattern_name]["seqlets"]["n_seqlets"] + )[0] + if verbose: + logger.info(metacluster_name, pattern_name) + logger.info("total seqlets:", num_seqlets) + if num_seqlets < min_seqlets: + break + pattern_trimmed = _trim_pattern_by_ic(pattern, pos_pat, 0.1) + if viz == "contrib": + ax = _plot_attribution_map( + ax=ax, + saliency_df=np.array(pattern_trimmed["contrib_scores"]), + return_ax=True, + figsize=None, + ) + ax.set_ylim([y_min, y_max]) + ax.set_title( + f"{cell_type}: {np.around(num_seqlets / num_seq * 100, 2)}% seqlet frequency" + ) + elif viz == "pwm": + pwm = _get_ic(np.array(pattern_trimmed["contrib_scores"]), pos_pat) + ax = _plot_attribution_map( + ax=ax, saliency_df=pwm, return_ax=True, figsize=None + ) + ax.set_title( + f"{cell_type}: {np.around(num_seqlets / num_seq * 100, 2)}% seqlet frequency - Average IC: {np.around(np.mean(pwm), 2)}" + ) + ax.set_ylim([0, 2]) + else: + raise ValueError( + 'Invalid visualization method. Choose either "contrib" or "pwm". Aborting...' + ) + motif_counter = 1 + + plt.tight_layout() + if "width" not in kwargs: + kwargs["width"] = 6 * len(classes) + if "height" not in kwargs: + kwargs["height"] = 2 * max_num_patterns + + render_plot(fig, **kwargs) diff --git a/src/crested/pl/patterns/_utils.py b/src/crested/pl/patterns/_utils.py new file mode 100644 index 00000000..ede20c84 --- /dev/null +++ b/src/crested/pl/patterns/_utils.py @@ -0,0 +1,73 @@ +"""Sequence pattern utility functions for plotting.""" + +from __future__ import annotations + +import logomaker +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + + +def grad_times_input_to_df(x, grad, alphabet="ACGT"): + """Generate pandas dataframe for saliency plot based on grad x inputs""" + x_index = np.argmax(np.squeeze(x), axis=1) + grad = np.squeeze(grad) + L, A = grad.shape + + seq = "" + saliency = np.zeros(L) + for i in range(L): + seq += alphabet[x_index[i]] + saliency[i] = grad[i, x_index[i]] + + # create saliency matrix + saliency_df = logomaker.saliency_to_matrix(seq=seq, values=saliency) + return saliency_df + + +def grad_times_input_to_df_mutagenesis(x, grad, alphabet="ACGT"): + """Generate pandas dataframe for mutagenesis plot based on grad x inputs""" + x = np.squeeze(x) # Ensure x is correctly squeezed + grad = np.squeeze(grad) + L, A = x.shape + + # Get original nucleotides' indices, ensure it's 1D + x_index = np.argmax(x, axis=1) + + # Convert index array to nucleotide letters + original_nucleotides = np.array([alphabet[idx] for idx in x_index]) + + # Data preparation for DataFrame + data = { + "Position": np.repeat(np.arange(L), A), + "Nucleotide": np.tile(list(alphabet), L), + "Effect": grad.reshape( + -1 + ), # Flatten grad assuming it matches the reshaped size + "Original": np.repeat(original_nucleotides, A), + } + df = pd.DataFrame(data) + return df + + +def _plot_attribution_map( + saliency_df, + ax=None, + return_ax: bool = True, + spines: bool = True, + figsize: tuple | None = (20, 1), +): + """Plot an attribution map using logomaker""" + if type(saliency_df) is not pd.DataFrame: + saliency_df = pd.DataFrame(saliency_df, columns=["A", "C", "G", "T"]) + if figsize is not None: + logomaker.Logo(saliency_df, figsize=figsize, ax=ax) + else: + logomaker.Logo(saliency_df, ax=ax) + if ax is None: + ax = plt.gca() + if not spines: + ax.spines["right"].set_visible(False) + ax.spines["top"].set_visible(False) + if return_ax: + return ax diff --git a/src/crested/pl/scatter/_class_density.py b/src/crested/pl/scatter/_class_density.py index fd65ba3d..f874ae60 100644 --- a/src/crested/pl/scatter/_class_density.py +++ b/src/crested/pl/scatter/_class_density.py @@ -60,7 +60,7 @@ def class_density( ... log_transform=True, ... ) - .. image:: ../../../docs/_static/img/examples/scatter_class_density.png + .. image:: ../../../../docs/_static/img/examples/scatter_class_density.png """ @log_and_raise(ValueError) diff --git a/src/crested/pp/_filter.py b/src/crested/pp/_filter.py index 4e2f9e91..fd9496fc 100644 --- a/src/crested/pp/_filter.py +++ b/src/crested/pp/_filter.py @@ -11,20 +11,17 @@ def filter_regions_on_specificity( adata: AnnData, gini_std_threshold: float = 1.0, + model_name: str | None = None, ) -> AnnData: """ - Filter bed regions & targets based on high Gini score. + Filter bed regions & targets/predictions based on high Gini score. This function filters regions based on their specificity using Gini scores. The regions with high Gini scores are retained, and a new AnnData object is created with the filtered data. - - Example - ------- - >>> filtered_adata = crested.pp.filter_regions_on_specificity( - ... adata, - ... gini_std_threshold=1.0, - ... ) + If model_name is provided, will look for the corresponding predictions in the + adata.layers[model_name] layer. Else, it will use the targets in adata.X to decide + which regions to keep. Parameters ---------- @@ -33,18 +30,35 @@ def filter_regions_on_specificity( gini_std_threshold The number of standard deviations above the mean Gini score used to determine the threshold for high variability. + model_name + The name of the model to look for in adata.layers[model_name] for predictions. + If None, will use the targets in adata.X to select specific regions. Returns ------- - ad.AnnData - A new AnnData object with the filtered matrix and updated variable names. + A new AnnData object with the filtered matrix and updated variable names. + + + Example + ------- + >>> filtered_adata = crested.pp.filter_regions_on_specificity( + ... adata, + ... gini_std_threshold=1.0, + ... ) """ - if isinstance(adata.X, csr_matrix): - target_matrix = ( - adata.X.toarray().T - ) # Convert to dense and transpose to (regions, cell types) + if model_name is None: + if isinstance(adata.X, csr_matrix): + target_matrix = ( + adata.X.toarray().T + ) # Convert to dense and transpose to (regions, cell types) + else: + target_matrix = adata.X.T else: - target_matrix = adata.X.T + if model_name not in adata.layers: + raise ValueError( + f"Model name {model_name} not found in adata.layers. Please provide a valid model name." + ) + target_matrix = adata.layers[model_name].T gini_scores = _calc_gini(target_matrix) mean = np.mean(np.max(gini_scores, axis=1)) @@ -60,10 +74,17 @@ def filter_regions_on_specificity( ) # Create a new AnnData object with the filtered data - if isinstance(adata.X, csr_matrix): - new_X = csr_matrix(target_matrix_filt.T) + if model_name is None: + if isinstance(adata.X, csr_matrix): + new_X = csr_matrix(target_matrix_filt.T) + else: + new_X = target_matrix_filt.T else: - new_X = target_matrix_filt.T + if isinstance(adata.X, csr_matrix): + new_X = csr_matrix(adata.X[:, selected_indices]) + else: + new_X = adata.X[:, selected_indices] + new_pred_matrix = target_matrix_filt.T filtered_adata = AnnData(new_X) filtered_adata.obs = adata.obs.copy() @@ -73,4 +94,7 @@ def filter_regions_on_specificity( # Copy over any other attributes or layers if needed filtered_adata.obsm = adata.obsm.copy() + if model_name is not None: + filtered_adata.layers[model_name] = new_pred_matrix + return filtered_adata diff --git a/src/crested/pp/_normalization.py b/src/crested/pp/_normalization.py index fe4fb0e5..a3422245 100644 --- a/src/crested/pp/_normalization.py +++ b/src/crested/pp/_normalization.py @@ -24,15 +24,6 @@ def normalize_peaks( a defined threshold and considering the variability within those peaks. Only used on continuous .X data. Modifies the input AnnData.X in place. - Example - ------- - >>> crested.pp.normalize_peaks( - ... adata, - ... peak_threshold=0, - ... gini_std_threshold=2.0, - ... top_k_percent=0.05, - ... ) - Parameters ---------- adata @@ -49,9 +40,16 @@ def normalize_peaks( Returns ------- - anndata.AnnData - The AnnData object with the normalized matrix and cell - type weights used for normalization in the obsm attribute. + The AnnData object with the normalized matrix and cell type weights used for normalization in the obsm attribute. + + Example + ------- + >>> crested.pp.normalize_peaks( + ... adata, + ... peak_threshold=0, + ... gini_std_threshold=2.0, + ... top_k_percent=0.05, + ... ) """ if isinstance(adata.X, csr_matrix): target_matrix = ( @@ -67,7 +65,7 @@ def normalize_peaks( gini_scores_all = [] overall_gini_scores = _calc_gini(target_matrix) - mean = np.mean(np.max(overall_gini_scores, axis=1)) + mean = np.mean(np.max(overall_gini_scores, axis=1)) std_dev = np.std(np.max(overall_gini_scores, axis=1)) gini_threshold = mean - gini_std_threshold * std_dev @@ -107,8 +105,6 @@ def normalize_peaks( filtered_regions_df = regions_df.iloc[list(all_low_gini_indices)] - adata.X = normalized_matrix return filtered_regions_df - diff --git a/src/crested/pp/_regions.py b/src/crested/pp/_regions.py index 3f3350fe..4cb2f1c0 100644 --- a/src/crested/pp/_regions.py +++ b/src/crested/pp/_regions.py @@ -45,8 +45,7 @@ def change_regions_width( Returns ------- - anndata.AnnData - The AnnData object with the modified regions. + The AnnData object with the modified regions. Example ------- diff --git a/src/crested/pp/_split.py b/src/crested/pp/_split.py index b691ed1f..4715fd59 100644 --- a/src/crested/pp/_split.py +++ b/src/crested/pp/_split.py @@ -201,24 +201,6 @@ def train_val_test_split( ---- Model training always requires a `split` column in the `.var` DataFrame. - Examples - -------- - >>> crested.train_val_test_split( - ... adata, - ... strategy="region", - ... val_size=0.1, - ... test_size=0.1, - ... shuffle=True, - ... random_state=42, - ... ) - - >>> crested.train_val_test_split( - ... adata, - ... strategy="chr", - ... val_chroms=["chr1", "chr2"], - ... test_chroms=["chr3", "chr4"], - ... ) - Parameters ---------- adata @@ -253,10 +235,26 @@ def train_val_test_split( Returns ------- - None - - Adds a new column to `adata.var`: + Adds a new column inplace to `adata.var`: 'split': 'train', 'val', or 'test' + + Examples + -------- + >>> crested.train_val_test_split( + ... adata, + ... strategy="region", + ... val_size=0.1, + ... test_size=0.1, + ... shuffle=True, + ... random_state=42, + ... ) + + >>> crested.train_val_test_split( + ... adata, + ... strategy="chr", + ... val_chroms=["chr1", "chr2"], + ... test_chroms=["chr3", "chr4"], + ... ) """ # Input checks if strategy not in ["region", "chr", "chr_auto"]: diff --git a/src/crested/tl/__init__.py b/src/crested/tl/__init__.py index 63b750ba..42f87fda 100644 --- a/src/crested/tl/__init__.py +++ b/src/crested/tl/__init__.py @@ -1,3 +1,4 @@ from . import data, losses, metrics, zoo from ._configs import TaskConfig, default_configs from ._crested import Crested +from ._tfmodisco import tfmodisco diff --git a/src/crested/tl/_configs.py b/src/crested/tl/_configs.py index dad82403..f5568386 100644 --- a/src/crested/tl/_configs.py +++ b/src/crested/tl/_configs.py @@ -101,6 +101,15 @@ class TaskConfig(NamedTuple): The TaskConfig class is a simple NamedTuple that holds the optimizer, loss, and metrics + Parameters + ---------- + optimizer + Optimizer used for training. + loss + Loss function used for training. + metrics + Metrics used for training. + Example ------- >>> optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3) @@ -113,15 +122,9 @@ class TaskConfig(NamedTuple): ... ] >>> configs = TaskConfig(optimizer, loss, metrics) - - Attributes - ---------- - optimizer : tf.keras.optimizers.Optimizer - Optimizer used for training. - loss : tf.keras.losses.Loss - Loss function used for training. - metrics : list[tf.keras.metrics.Metric] - Metrics used for training. + See Also + -------- + crested.tl.default_configs """ optimizer: tf.keras.optimizers.Optimizer @@ -136,8 +139,7 @@ def to_dict(self) -> dict: Returns ------- - dict - Dictionary representation of the TaskConfig. + Dictionary representation of the TaskConfig. """ optimizer_info = { "optimizer": self.optimizer.__class__.__name__, @@ -173,13 +175,16 @@ def default_configs( Parameters ---------- - task : str + tasks Task for which to get default components. Returns ------- - tuple - Optimizer, loss, and metrics for the given task. + Optimizer, loss, and metrics for the given task. + + See Also + -------- + crested.tl.TaskConfig """ task_classes = { "topic_classification": TopicClassificationConfig, diff --git a/src/crested/tl/_crested.py b/src/crested/tl/_crested.py index 8227080d..041c12d3 100644 --- a/src/crested/tl/_crested.py +++ b/src/crested/tl/_crested.py @@ -9,6 +9,7 @@ import tensorflow as tf from anndata import AnnData from loguru import logger +from tqdm import tqdm from crested._logging import log_and_raise from crested.tl import TaskConfig @@ -36,22 +37,22 @@ class Crested: Parameters ---------- - data : AnnDataModule + data AnndataModule object containing the data. - model : tf.keras.Model + model Model architecture to use for training. - config : TaskConfig + config Task configuration (optimizer, loss, and metrics) for use in tl.Crested. - project_name : str + project_name Name of the project. Used for logging and creating output directories. If not provided, the default project name "CREsted" will be used. - run_name : str + run_name Name of the run. Used for wandb logging and creating output directories. If not provided, the current date and time will be used. - logger : str + logger Logger to use for logging. Can be "wandb" or "tensorboard" (tensorboard not implemented yet) If not provided, no additional logging will be done. - seed : int + seed Seed to use for reproducibility. Examples @@ -84,7 +85,7 @@ class Crested: >>> trainer.predict(anndata, model_name="predictions") >>> # Calculate contribution scores - >>> scores, seqs_one_hot = trainer.calculate_contribution_scores( + >>> scores, seqs_one_hot = trainer.calculate_contribution_scores_regions( ... region_idx="chr1:1000-2000", ... class_indices=[0, 1, 2], ... method="integrated_grad", @@ -182,16 +183,18 @@ def _initialize_logger(logger_type: str | None, project_name: str, run_name: str return run, callbacks - def load_model(self, model_path: os.PathLike, compile: bool = False) -> None: + def load_model(self, model_path: os.PathLike, compile: bool = True) -> None: """ Load a (pretrained) model from a file. Parameters ---------- - model_path : os.PathLike + model_path Path to the model file. - compile: bool - Compile model after loading. + compile + Compile the model after loading. Set to False if you only want to load + the model weights (e.g. when finetuning a model). If False, you should + provide a TaskConfig to the Crested object before calling fit. """ self.model = tf.keras.models.load_model(model_path, compile=compile) @@ -203,7 +206,7 @@ def fit( model_checkpointing_best_only: bool = True, early_stopping: bool = True, early_stopping_patience: int = 10, - learning_rate_reduce: bool = False, + learning_rate_reduce: bool = True, learning_rate_reduce_patience: int = 5, custom_callbacks: list | None = None, ) -> None: @@ -212,23 +215,23 @@ def fit( Parameters ---------- - epochs : int + epochs Number of epochs to train the model. - mixed_precision : bool + mixed_precision Enable mixed precision training. - model_checkpointing : bool + model_checkpointing Save model checkpoints. - model_checkpointing_best_only : bool + model_checkpointing_best_only Save only the best model checkpoint. - early_stopping : bool + early_stopping Enable early stopping. - early_stopping_patience : int + early_stopping_patience Number of epochs with no improvement after which training will be stopped. - learning_rate_reduce : bool + learning_rate_reduce Enable learning rate reduction. - learning_rate_reduce_patience : int + learning_rate_reduce_patience Number of epochs with no improvement after which learning rate will be reduced. - custom_callbacks : list + custom_callbacks List of custom callbacks to use during training. """ self._check_fit_params() @@ -328,8 +331,12 @@ def test(self, return_metrics: bool = False) -> dict | None: Parameters ---------- - return_metrics : bool + return_metrics Return the evaluation metrics as a dictionary. + + Returns + ------- + Evaluation metrics as a dictionary or None if return_metrics is False. """ self._check_test_params() self._check_gpu_availability() @@ -358,14 +365,19 @@ def predict( """ Make predictions using the model on the full dataset - Adds the predictions to anndata as a .layers attribute. + If anndata and model_name are provided, will add the predictions to anndata as a .layers[model_name] attribute. + Else, will return the predictions as a numpy array. Parameters ---------- - anndata : AnnData + anndata Anndata object containing the data. - model_name : str + model_name Name that will be used to store the predictions in anndata.layers[model_name]. + + Returns + ------- + Predictions of shape (N, C) """ self._check_predict_params(anndata, model_name) self._check_gpu_availability() @@ -385,7 +397,7 @@ def predict( return predictions - def predict_region( + def predict_regions( self, region_idx: list[str] | str, ) -> np.ndarray: @@ -399,8 +411,7 @@ def predict_region( Returns ------- - np.ndarray - Predictions for the specified region(s) of shape (N, C) + Predictions for the specified region(s) of shape (N, C) """ if self.anndatamodule.predict_dataset is None: self.anndatamodule.setup("predict") @@ -421,34 +432,151 @@ def predict_region( def calculate_contribution_scores( self, - region_idx: str, - class_indices: list | None = None, - method: str = "integrated_grad", - return_one_hot: bool = True, - ) -> tuple[np.ndarray, np.ndarray] | np.ndarray: + anndata: AnnData | None = None, + class_names: list[str] | None = None, + method: str = "expected_integrated_grad", + store_in_varm: bool = False, + ) -> tuple[np.ndarray, np.ndarray] | None: + """ + Calculate contribution scores based on the given method for the full dataset. + + These scores can then be plotted to visualize the importance of each base in the dataset + using :func:`~crested.pl.patterns.contribution_scores`. + + Parameters + ---------- + anndata + Anndata object to store the contribution scores in as a .varm[class_name] attribute. + If None, will only return the contribution scores without storing them. + class_names + List of class names to calculate the contribution scores for (should match anndata.obs_names) + If None, the contribution scores for the 'combined' class will be calculated. + method + Method to use for calculating the contribution scores. + Options are: 'integrated_grad', 'mutagenesis', 'expected_integrated_grad'. + + Returns + ------- + Contribution scores (N, C, L, 4) and one-hot encoded sequences (N, L, 4) or None if anndata is provided. + + See Also + -------- + crested.pl.patterns.contribution_scores + """ + self._check_contribution_scores_params(class_names) + self._check_gpu_availability() + + if self.anndatamodule.predict_dataset is None: + self.anndatamodule.setup("predict") + predict_loader = self.anndatamodule.predict_dataloader + + all_scores = [] + all_one_hot_sequences = [] + + all_class_names = list(self.anndatamodule.adata.obs_names) + + if class_names is not None: + n_classes = len(class_names) + class_indices = [ + all_class_names.index(class_name) for class_name in class_names + ] + varm_names = class_names + else: + n_classes = 1 # 'combined' class + class_indices = [None] + varm_names = ["combined"] + logger.info( + f"Calculating contribution scores for {n_classes} class(es) and {len(predict_loader)} batch(es) of regions." + ) + + for batch_index, (x, _) in enumerate( + tqdm( + predict_loader.data, + desc="Batch", + total=len(predict_loader), + ), + ): + all_one_hot_sequences.append(x) + + scores = np.zeros( + (x.shape[0], n_classes, x.shape[1], x.shape[2]) + ) # (N, C, W, 4) + + for i, class_index in enumerate(class_indices): + explainer = Explainer(self.model, class_index=class_index) + if method == "integrated_grad": + scores[:, i, :, :] = explainer.integrated_grad( + x, baseline_type="zeros" + ) + elif method == "mutagenesis": + scores[:, i, :, :] = explainer.mutagenesis( + x, class_index=class_index + ) + elif method == "expected_integrated_grad": + scores[:, i, :, :] = explainer.expected_integrated_grad( + x, num_baseline=25 + ) + all_scores.append(scores) + + # predict_loader.data is infinite, so limit the number of iterations + if batch_index == len(predict_loader) - 1: + break + + concatenated_scores = np.concatenate(all_scores, axis=0) + + if anndata: + for varm_name in varm_names: + logger.info(f"Adding contribution scores to anndata.varm[{varm_name}].") + if varm_name == "combined": + anndata.varm[varm_name] = concatenated_scores[:, 0] + else: + anndata.varm[varm_name] = concatenated_scores[ + :, class_names.index(varm_name) + ] + anndata.varm["one_hot_sequences"] = np.concatenate( + all_one_hot_sequences, axis=0 + ) + logger.info( + "Added one-hot encoded sequences and contribution scores per class to anndata.varm." + ) + else: + return concatenated_scores, np.concatenate(all_one_hot_sequences, axis=0) + + def calculate_contribution_scores_regions( + self, + region_idx: list[str] | str, + class_names: list[str] | None = None, + method: str = "expected_integrated_grad", + ) -> tuple[np.ndarray, np.ndarray]: """ Calculate contribution scores based on given method for a specified region. These scores can then be plotted to visualize the importance of each base in the region - using :func:`~crested.pl.contribution_scores`. + using :func:`~crested.pl.patterns.contribution_scores`. Parameters ---------- - region_idx : str - Region for which to calculate the contribution scores in the format "chr:start-end". - class_indices : list - List of class indices to calculate the contribution scores for. + region_idx + Region(s) for which to calculate the contribution scores in the format "chr:start-end". + class_names + List of class names to calculate the contribution scores for (should match anndata.obs_names) If None, the contribution scores for the 'combined' class will be calculated. - method : str + method Method to use for calculating the contribution scores. - Options are: 'integrated_grad', 'smooth_grad', 'mutagenesis', 'saliency', 'expected_integrated_grad'. - return_one_hot : bool - Return the one-hot encoded sequences along with the contribution scores. - """ + Options are: 'integrated_grad', 'mutagenesis', 'expected_integrated_grad'. + Returns + ------- + Contribution scores (N, C, L, 4) and one-hot encoded sequences (N, L, 4). + + See Also + -------- + crested.pl.patterns.contribution_scores + """ self._check_contrib_params(method) if self.anndatamodule.predict_dataset is None: self.anndatamodule.setup("predict") + self._check_contribution_scores_params(class_names) if isinstance(region_idx, str): region_idx = [region_idx] @@ -456,19 +584,30 @@ def calculate_contribution_scores( all_scores = [] all_one_hot_sequences = [] - for region in region_idx: + all_class_names = list(self.anndatamodule.adata.obs_names) + + if class_names is not None: + n_classes = len(class_names) + class_indices = [ + all_class_names.index(class_name) for class_name in class_names + ] + else: + n_classes = 1 # 'combined' class + class_indices = [None] + + logger.info( + f"Calculating contribution scores for {n_classes} class(es) and {len(region_idx)} region(s)." + ) + for region in tqdm( + region_idx, + desc="Region", + ): sequence = self.anndatamodule.predict_dataset.sequence_loader.get_sequence( region ) x = one_hot_encode_sequence(sequence) all_one_hot_sequences.append(x) - if class_indices is not None: - n_classes = len(class_indices) - else: - n_classes = 1 # 'combined' class - class_indices = [None] - scores = np.zeros( (x.shape[0], n_classes, x.shape[1], x.shape[2]) ) # (N, C, W, 4) @@ -479,16 +618,10 @@ def calculate_contribution_scores( scores[:, i, :, :] = explainer.integrated_grad( x, baseline_type="zeros" ) - elif method == "smooth_grad": - scores[:, i, :, :] = explainer.smoothgrad( - x, num_samples=50, mean=0.0, stddev=0.1 - ) elif method == "mutagenesis": scores[:, i, :, :] = explainer.mutagenesis( x, class_index=class_index ) - elif method == "saliency": - scores[:, i, :, :] = explainer.saliency_maps(x) elif method == "expected_integrated_grad": scores[:, i, :, :] = explainer.expected_integrated_grad( x, num_baseline=25 @@ -498,12 +631,9 @@ def calculate_contribution_scores( all_scores.append(scores) - if return_one_hot: - return np.concatenate(all_scores, axis=0), np.concatenate( - all_one_hot_sequences, axis=0 - ) - else: - return np.concatenate(all_scores, axis=0) + return np.concatenate(all_scores, axis=0), np.concatenate( + all_one_hot_sequences, axis=0 + ) @staticmethod def _check_gpu_availability(): @@ -514,7 +644,13 @@ def _check_gpu_availability(): @log_and_raise(ValueError) def _check_contrib_params(self, method): - if method not in ['integrated_grad', 'smooth_grad','mutagenesis', 'saliency', 'expected_integrated_grad']: + if method not in [ + "integrated_grad", + "smooth_grad", + "mutagenesis", + "saliency", + "expected_integrated_grad", + ]: raise ValueError( "Contribution score method not implemented. Choose out of the following options: integrated_grad, smooth_grad, mutagenesis, saliency, expected_integrated_grad." ) @@ -560,5 +696,20 @@ def _check_predict_params(self, anndata: AnnData | None, model_name: str | None) "Both anndata and model_name must be provided if one of them is provided." ) + @log_and_raise(ValueError) + def _check_contribution_scores_params(self, class_names: list | None): + """Check if the necessary parameters are set for the calculate_contribution_scores method.""" + if not self.model: + raise ValueError( + "Model not set. Please load a model from pretrained using Crested.load_model(...) before calling calculate_contribution_scores_(regions)." + ) + if class_names is not None: + all_class_names = list(self.anndatamodule.adata.obs_names) + for class_name in class_names: + if class_name not in all_class_names: + raise ValueError( + f"Class name {class_name} not found in anndata.obs_names." + ) + def __repr__(self): return f"Crested(data={self.anndatamodule is not None}, model={self.model is not None}, config={self.config is not None})" diff --git a/src/crested/tl/_tfmodisco.py b/src/crested/tl/_tfmodisco.py new file mode 100644 index 00000000..b7169867 --- /dev/null +++ b/src/crested/tl/_tfmodisco.py @@ -0,0 +1,135 @@ +"""Code adapted from https://github.com/jmschrei/tfmodisco-lite/blob/main/modisco.""" + +from __future__ import annotations + +import os + +import anndata +import modiscolite +from loguru import logger + +from crested._logging import log_and_raise + + +def _calculate_window_offsets(center: int, window_size: int) -> tuple: + return (center - window_size // 2, center + window_size // 2) + + +@log_and_raise(Exception) +def tfmodisco( + adata: anndata.AnnData, + class_names: list[str] | None = None, + output_dir: os.PathLike = "modisco_results", + max_seqlets: int = 2000, + window: int = 400, + n_leiden: int = 2, + report: bool = False, + meme_db: str = None, + verbose: bool = True, +): + """ + Runs tf-modisco on one-hot encoded sequences and contribution scores stored in an AnnData object. + + Parameters + ---------- + adata + AnnData object containing the one-hot encoded sequences and contribution scores. + The one-hot encoded sequences should be stored in adata.varm["one_hot_sequences"]. + The contribution scores should be stored in adata.varm[class_name], where class_name is one of the class names. + class_names + List of class names to process. If None, all adata.varm keys will be processed. + output_dir + Directory where output files will be saved. + max_seqlets + Maximum number of seqlets per metacluster. + window + The window surrounding the peak center that will be considered for motif discovery. + n_leiden + Number of Leiden clusterings to perform with different random seeds. + report + Generate a modisco report. + meme_db + Path to a MEME file (.meme) containing motifs. Required if report is True. + verbose + Print verbose output. + + See Also + -------- + crested.tl.Crested.calculate_contribution_scores + + Examples + -------- + >>> evaluator = crested.tl.Crested(...) + >>> evaluator.load_model(/path/to/trained/model.keras) + >>> evaluator.calculate_contribution_scores( + ... adata, class_names=["Astro", "Vip"], method="integrated_grad" + ... ) + >>> crested.tl.tfmodisco( + ... adata, class_names=["Astro", "Vip"], output_dir="modisco_results" + ... ) + """ + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + if class_names is None: + class_names = list(adata.varm.keys()) + if "one_hot_sequences" in class_names: + class_names.remove("one_hot_sequences") + + one_hot_sequences = adata.varm["one_hot_sequences"] + + if one_hot_sequences.shape[1] < window: + raise ValueError( + f"Window ({window}) cannot be longer than the sequences ({one_hot_sequences.shape[1]})" + ) + for class_name in class_names: + try: + contribution_scores = adata.varm[class_name] + + if one_hot_sequences.shape != contribution_scores.shape: + raise ValueError( + f"Shape mismatch between sequences and scores for class {class_name}" + ) + + center = one_hot_sequences.shape[1] // 2 + start, end = _calculate_window_offsets(center, window) + + sequences = one_hot_sequences[:, start:end, :] + attributions = contribution_scores[:, start:end, :] + + sequences = sequences.astype("float32") + attributions = attributions.astype("float32") + + # Define filenames for the output files + output_file = os.path.join(output_dir, f"{class_name}_modisco_results.h5") + report_dir = os.path.join(output_dir, f"{class_name}_report") + + # Check if the modisco results .h5 file does not exist for the cell type + if not os.path.exists(output_file): + logger.info(f"Running modisco for class: {class_name}") + pos_patterns, neg_patterns = modiscolite.tfmodisco.TFMoDISco( + hypothetical_contribs=attributions, + one_hot=sequences, + max_seqlets_per_metacluster=max_seqlets, + sliding_window_size=20, + flank_size=5, + target_seqlet_fdr=0.05, + n_leiden_runs=n_leiden, + verbose=verbose, + ) + + modiscolite.io.save_hdf5(output_file, pos_patterns, neg_patterns) + + # Generate the modisco report + if report: + modiscolite.report.report_motifs( + output_file, + report_dir, + meme_motif_db=meme_db, + top_n_matches=3, + ) + else: + print(f"Modisco results already exist for class: {class_name}") + + except KeyError as e: + logger.error(f"Missing data for class: {class_name}, error: {e}") diff --git a/src/crested/tl/losses/_cosinemse.py b/src/crested/tl/losses/_cosinemse.py index c126ea22..c47ec019 100644 --- a/src/crested/tl/losses/_cosinemse.py +++ b/src/crested/tl/losses/_cosinemse.py @@ -1,8 +1,9 @@ from __future__ import annotations + import tensorflow as tf -from tensorflow.keras.losses import Loss -class CosineMSELoss(Loss): + +class CosineMSELoss(tf.keras.losses.Loss): """Custom loss function that combines cosine similarity and mean squared error.""" def __init__(self, max_weight=1.0, name="CustomMSELoss", reduction=None): @@ -31,7 +32,7 @@ def call(self, y_true, y_pred): # Calculate cosine similarity loss cosine_loss = -tf.reduce_sum(y_true1 * y_pred1, axis=-1) - total_loss = weight * cosine_loss + mse_loss + total_loss = weight * cosine_loss + mse_loss return total_loss @@ -44,4 +45,4 @@ def get_config(self): @classmethod def from_config(cls, config): - return cls(**config) \ No newline at end of file + return cls(**config) diff --git a/src/crested/tl/zoo/_basenji.py b/src/crested/tl/zoo/_basenji.py index b7dc0e87..7efb980e 100644 --- a/src/crested/tl/zoo/_basenji.py +++ b/src/crested/tl/zoo/_basenji.py @@ -1,7 +1,6 @@ """Basenji model architecture.""" import tensorflow as tf -import tensorflow.keras.layers as layers from crested.tl.zoo.utils import conv_block_bs, dilated_residual @@ -43,8 +42,7 @@ def basenji( Returns ------- - tf.keras.Model - A TensorFlow Keras model. + A TensorFlow Keras model. """ window_size = int(seq_len // 2) @@ -54,7 +52,7 @@ def basenji( else: pool_1 = 2 - sequence = layers.Input(shape=(seq_len, 4), name="sequence") + sequence = tf.keras.layers.Input(shape=(seq_len, 4), name="sequence") current = conv_block_bs( sequence, @@ -101,9 +99,9 @@ def basenji( bn_momentum=0.9, ) - current = layers.GlobalAveragePooling1D()(current) + current = tf.keras.layers.GlobalAveragePooling1D()(current) - outputs = layers.Dense( + outputs = tf.keras.layers.Dense( units=num_classes, use_bias=True, activation=output_activation, diff --git a/src/crested/tl/zoo/_chrombpnet.py b/src/crested/tl/zoo/_chrombpnet.py index 49e58ded..e3068e6c 100644 --- a/src/crested/tl/zoo/_chrombpnet.py +++ b/src/crested/tl/zoo/_chrombpnet.py @@ -1,8 +1,6 @@ """Chrombp net like model architecture for peak regression.""" import tensorflow as tf -import tensorflow.keras.layers as layers -from tensorflow.keras.backend import int_shape def chrombpnet( @@ -63,14 +61,13 @@ def chrombpnet( Returns ------- - tf.keras.Model - A TensorFlow Keras model. + A TensorFlow Keras model. """ # Model - inputs = layers.Input(shape=(seq_len, 4), name="sequence") + inputs = tf.keras.layers.Input(shape=(seq_len, 4), name="sequence") # Convolutional block without dilation - x = layers.Conv1D( + x = tf.keras.layers.Conv1D( filters=first_conv_filters, kernel_size=first_conv_filter_size, strides=1, @@ -80,18 +77,20 @@ def chrombpnet( kernel_regularizer=tf.keras.regularizers.l2(first_conv_l2), use_bias=False, )(inputs) - x = layers.BatchNormalization(momentum=0.9, gamma_initializer="ones")(x) - x = layers.Activation(first_conv_activation)(x) + x = tf.keras.layers.BatchNormalization(momentum=0.9, gamma_initializer="ones")(x) + x = tf.keras.layers.Activation(first_conv_activation)(x) if first_conv_pool_size > 1: - x = layers.MaxPooling1D(pool_size=first_conv_pool_size, padding="same")(x) - x = layers.Dropout(first_conv_dropout)(x) + x = tf.keras.layers.MaxPooling1D( + pool_size=first_conv_pool_size, padding="same" + )(x) + x = tf.keras.layers.Dropout(first_conv_dropout)(x) # Dilated convolutions layer_names = [str(i) for i in range(1, n_dil_layers + 1)] for i in range(1, n_dil_layers + 1): conv_layer_name = f"bpnet_{layer_names[i - 1]}conv" - conv_x = layers.Conv1D( + conv_x = tf.keras.layers.Conv1D( filters=num_filters, kernel_size=filter_size, strides=1, @@ -104,29 +103,31 @@ def chrombpnet( name=conv_layer_name, )(x) if batch_norm: - conv_x = layers.BatchNormalization( + conv_x = tf.keras.layers.BatchNormalization( momentum=0.9, gamma_initializer="ones", name=f"bpnet_{layer_names[i - 1]}bn", )(conv_x) if activation != "none": - conv_x = layers.Activation( + conv_x = tf.keras.layers.Activation( activation, name=f"bpnet_{layer_names[i - 1]}activation" )(conv_x) - x_len = int_shape(x)[1] - conv_x_len = int_shape(conv_x)[1] + x_len = tf.keras.backend.int_shape(x)[1] + conv_x_len = tf.keras.backend.int_shape(conv_x)[1] assert (x_len - conv_x_len) % 2 == 0 # for symmetric cropping - x = layers.Cropping1D( + x = tf.keras.layers.Cropping1D( (x_len - conv_x_len) // 2, name=f"bpnet_{layer_names[i - 1]}crop" )(x) - x = layers.add([conv_x, x]) + x = tf.keras.layers.add([conv_x, x]) if dropout > 0: - x = layers.Dropout(dropout, name=f"bpnet_{layer_names[i-1]}dropout")(x) + x = tf.keras.layers.Dropout( + dropout, name=f"bpnet_{layer_names[i-1]}dropout" + )(x) - x = layers.GlobalAveragePooling1D()(x) - outputs = layers.Dense( + x = tf.keras.layers.GlobalAveragePooling1D()(x) + outputs = tf.keras.layers.Dense( units=num_classes, activation="softplus", use_bias=dense_bias )(x) diff --git a/src/crested/tl/zoo/_deeptopic_cnn.py b/src/crested/tl/zoo/_deeptopic_cnn.py index b7b4f53f..2638dc9c 100644 --- a/src/crested/tl/zoo/_deeptopic_cnn.py +++ b/src/crested/tl/zoo/_deeptopic_cnn.py @@ -1,7 +1,6 @@ """Deeptopic CNN model architecture.""" import tensorflow as tf -import tensorflow.keras.layers as layers from crested.tl.zoo.utils import conv_block, dense_block @@ -59,10 +58,9 @@ def deeptopic_cnn( Returns ------- - tf.keras.Model - A TensorFlow Keras model. + A TensorFlow Keras model. """ - inputs = layers.Input(shape=(seq_len, 4), name="sequence") + inputs = tf.keras.layers.Input(shape=(seq_len, 4), name="sequence") x = conv_block( inputs, @@ -126,8 +124,8 @@ def deeptopic_cnn( batchnorm_momentum=0.9, ) - x = layers.Flatten()(x) - x = layers.Dropout(pre_dense_do)(x) + x = tf.keras.layers.Flatten()(x) + x = tf.keras.layers.Dropout(pre_dense_do)(x) x = dense_block( x, dense_out, @@ -137,6 +135,6 @@ def deeptopic_cnn( name_prefix="denseblock", use_bias=False, ) - logits = layers.Dense(num_classes, activation="linear", use_bias=True)(x) - outputs = layers.Activation("sigmoid")(logits) + logits = tf.keras.layers.Dense(num_classes, activation="linear", use_bias=True)(x) + outputs = tf.keras.layers.Activation("sigmoid")(logits) return tf.keras.Model(inputs=inputs, outputs=outputs) diff --git a/src/crested/tl/zoo/_simple_convnet.py b/src/crested/tl/zoo/_simple_convnet.py index 8d0bc27f..9f23b49f 100644 --- a/src/crested/tl/zoo/_simple_convnet.py +++ b/src/crested/tl/zoo/_simple_convnet.py @@ -1,7 +1,6 @@ """Simple convnet model architecture.""" import tensorflow as tf -import tensorflow.keras.layers as layers from crested.tl.zoo.utils import conv_block, dense_block @@ -78,10 +77,9 @@ def simple_convnet( Returns ------- - tf.keras.Model - A TensorFlow Keras model. + A TensorFlow Keras model. """ - inputs = layers.Input(shape=(seq_len, 4), name="sequence") + inputs = tf.keras.layers.Input(shape=(seq_len, 4), name="sequence") x = conv_block( inputs, @@ -108,9 +106,9 @@ def simple_convnet( ) if flatten: - x = layers.Flatten()(x) + x = tf.keras.layers.Flatten()(x) else: - x = layers.GlobalAveragePooling1D()(x) + x = tf.keras.layers.GlobalAveragePooling1D()(x) for _ in range(1, num_dense_blocks): x = dense_block( @@ -129,5 +127,5 @@ def simple_convnet( normalization=normalization, ) - outputs = layers.Dense(num_classes, activation=output_activation)(x) + outputs = tf.keras.layers.Dense(num_classes, activation=output_activation)(x) return tf.keras.Model(inputs=inputs, outputs=outputs) diff --git a/src/crested/tl/zoo/utils/_layers.py b/src/crested/tl/zoo/utils/_layers.py index 5dbd4ca7..28849204 100644 --- a/src/crested/tl/zoo/utils/_layers.py +++ b/src/crested/tl/zoo/utils/_layers.py @@ -4,8 +4,6 @@ import numpy as np import tensorflow as tf -import tensorflow.keras.layers as layers -from tensorflow.keras import regularizers __all__ = [ "dense_block", @@ -56,10 +54,9 @@ def dense_block( Returns ------- - tf.Tensor - The output tensor of the dense block. + The output tensor of the dense block. """ - x = layers.Dense( + x = tf.keras.layers.Dense( units, activation=None, use_bias=use_bias, @@ -69,22 +66,22 @@ def dense_block( )(inputs) if normalization == "batch": - x = layers.BatchNormalization( + x = tf.keras.layers.BatchNormalization( momentum=bn_momentum, gamma_initializer=bn_gamma, name=name_prefix + "_batchnorm" if name_prefix else None, )(x) elif normalization == "layer": - x = layers.LayerNormalization( + x = tf.keras.layers.LayerNormalization( name=name_prefix + "_layernorm" if name_prefix else None )(x) - x = layers.Activation( + x = tf.keras.layers.Activation( activation, name=name_prefix + "_activation" if name_prefix else None )(x) - x = layers.Dropout(dropout, name=name_prefix + "_dropout" if name_prefix else None)( - x - ) + x = tf.keras.layers.Dropout( + dropout, name=name_prefix + "_dropout" if name_prefix else None + )(x) return x @@ -134,35 +131,34 @@ def conv_block( Returns ------- - tf.Tensor - The output tensor of the convolution block. + The output tensor of the convolution block. """ if res: residual = inputs - x = layers.Convolution1D( + x = tf.keras.layers.Convolution1D( filters=filters, kernel_size=kernel_size, padding=padding, - kernel_regularizer=regularizers.L2(l2), + kernel_regularizer=tf.keras.regularizers.L2(l2), use_bias=conv_bias, )(inputs) if normalization == "batch": - x = layers.BatchNormalization(momentum=batchnorm_momentum)(x) + x = tf.keras.layers.BatchNormalization(momentum=batchnorm_momentum)(x) elif normalization == "layer": - x = layers.LayerNormalization()(x) - x = layers.Activation(activation)(x) + x = tf.keras.layers.LayerNormalization()(x) + x = tf.keras.layers.Activation(activation)(x) if res: if filters != residual.shape[2]: - residual = layers.Convolution1D(filters=filters, kernel_size=1, strides=1)( - residual - ) - x = layers.Add()([x, residual]) + residual = tf.keras.layers.Convolution1D( + filters=filters, kernel_size=1, strides=1 + )(residual) + x = tf.keras.layers.Add()([x, residual]) if pool_size > 1: - x = layers.MaxPooling1D(pool_size=pool_size, padding=padding)(x) + x = tf.keras.layers.MaxPooling1D(pool_size=pool_size, padding=padding)(x) if dropout > 0: - x = layers.Dropout(dropout)(x) + x = tf.keras.layers.Dropout(dropout)(x) return x @@ -182,8 +178,7 @@ def activate(current: tf.Tensor, activation: str, verbose: bool = False) -> tf.T Returns ------- - tf.Tensor - Output tensor after applying activation. + Output tensor after applying activation. """ if verbose: print("activate:", activation) @@ -269,8 +264,7 @@ def conv_block_bs( Returns ------- - tf.Tensor - Output tensor after applying the convolution block. + Output tensor after applying the convolution block. """ current = inputs @@ -373,8 +367,7 @@ def dilated_residual( Returns ------- - tf.Tensor - Output tensor after applying the dilated residual block. + Output tensor after applying the dilated residual block. """ current = inputs diff --git a/tests/test_io.py b/tests/test_io.py index 7e4fbf71..d67a625c 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -11,9 +11,9 @@ def test_package_has_version(): assert crested.__version__ is not None -def test_import_topics_shape(): - ann_data = crested.import_topics( - topics_folder="tests/data/test_topics", +def test_import_beds_shape(): + ann_data = crested.import_beds( + beds_folder="tests/data/test_topics", regions_file="tests/data/test.regions.bed", ) # Test type @@ -28,36 +28,34 @@ def test_import_topics_shape(): # Test columns assert "file_path" in ann_data.obs.columns assert "n_open_regions" in ann_data.obs.columns - assert "n_topics" in ann_data.var.columns + assert "n_classes" in ann_data.var.columns -def test_import_topics_topics_subset(): - ann_data = crested.import_topics( - topics_folder="tests/data/test_topics", +def test_import_beds_classes_subset(): + ann_data = crested.import_beds( + beds_folder="tests/data/test_topics", regions_file="tests/data/test.regions.bed", - topics_subset=["Topic_1", "Topic_2"], + classes_subset=["Topic_1", "Topic_2"], ) assert ann_data.shape[0] == 2 -def test_import_topics_invalid_files(): +def test_import_beds_invalid_files(): with pytest.raises(FileNotFoundError): - crested.import_topics( - topics_folder="invalid_folder", regions_file="invalid_file" - ) + crested.import_beds(beds_folder="invalid_folder", regions_file="invalid_file") -def test_import_topics_compression(): - ann_data_c = crested.import_topics( - topics_folder="tests/data/test_topics", +def test_import_beds_compression(): + ann_data_c = crested.import_beds( + beds_folder="tests/data/test_topics", regions_file="tests/data/test.regions.bed", compress=True, ) assert ann_data_c.X.getformat() == "csr" assert ann_data_c.X.shape == (3, 23186) - ann_data = crested.import_topics( - topics_folder="tests/data/test_topics", + ann_data = crested.import_beds( + beds_folder="tests/data/test_topics", regions_file="tests/data/test.regions.bed", compress=False, ) @@ -67,9 +65,9 @@ def test_import_topics_compression(): assert sys.getsizeof(ann_data_c.X) < sys.getsizeof(ann_data.X) -def test_import_topics_chromsizes(): - ann_data = crested.import_topics( - topics_folder="tests/data/test_topics", +def test_import_beds_chromsizes(): + ann_data = crested.import_beds( + beds_folder="tests/data/test_topics", regions_file="tests/data/test.regions.bed", chromsizes_file="tests/data/test.chrom.sizes", compress=True,