Skip to content

Commit

Permalink
added ism plot option, fixed contrib score plotting, started on spear…
Browse files Browse the repository at this point in the history
…man (still bugging)
  • Loading branch information
nkempynck committed Jun 26, 2024
1 parent 33e5870 commit f4288eb
Show file tree
Hide file tree
Showing 11 changed files with 1,001 additions and 878 deletions.
1,698 changes: 842 additions & 856 deletions docs/tutorials/mouse_biccn.ipynb

Large diffs are not rendered by default.

35 changes: 27 additions & 8 deletions src/crested/pl/patterns/_contribution_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from crested._logging import log_and_raise
from crested.pl._utils import render_plot

from ._utils import _plot_attribution_map, grad_times_input_to_df
from ._utils import _plot_attribution_map, _plot_mutagenesis_map, grad_times_input_to_df, grad_times_input_to_df_mutagenesis


@log_and_raise(ValueError)
Expand Down Expand Up @@ -37,6 +37,7 @@ def contribution_scores(
zoom_n_bases: int | None = None,
highlight_positions: list[tuple[int, int]] | None = None,
ylim: tuple | None = None,
method: str | None = None,
**kwargs,
):
"""
Expand All @@ -58,6 +59,8 @@ def contribution_scores(
List of tuples with start and end positions to highlight. Default is None.
ylim
Y-axis limits. Default is None.
method
Method used for calculating contribution scores. If mutagenesis, specify.
Examples
--------
Expand All @@ -80,20 +83,36 @@ def contribution_scores(
start_idx = center - int(zoom_n_bases / 2)
scores = scores[:, :, start_idx : start_idx + zoom_n_bases, :]

global_min = scores.min()
global_max = scores.max()

# Plot
logger.info(f"Plotting contribution scores for {seqs_one_hot.shape[0]} sequence(s)")
for seq in range(seqs_one_hot.shape[0]):
fig_height_per_class = 2
fig = plt.figure(figsize=(50, fig_height_per_class * scores.shape[1]))
seq_class_x = seqs_one_hot[seq, start_idx : start_idx + zoom_n_bases, :]

if method == 'mutagenesis':
global_max = scores[seq].max()+0.25*np.abs(scores[seq].max())
global_min = scores[seq].min()-0.25*np.abs(scores[seq].min())
else:
mins = []
maxs = []
for i in range(scores.shape[1]):
seq_class_scores = scores[seq, i, :, :]
mins.append(np.min(seq_class_scores*seq_class_x))
maxs.append(np.max(seq_class_scores*seq_class_x))
global_max = np.array(maxs).max()+0.25*np.abs(np.array(maxs).max())
global_min = np.array(mins).min()-0.25*np.abs(np.array(mins).min())

for i in range(scores.shape[1]):
seq_class_scores = scores[seq, i, :, :]
seq_class_x = seqs_one_hot[seq, :, :]
intgrad_df = grad_times_input_to_df(seq_class_x, seq_class_scores)
ax = plt.subplot(scores.shape[1], 1, i + 1)
_plot_attribution_map(intgrad_df, ax=ax, return_ax=False)
if (method =='mutagenesis'):
mutagenesis_df = grad_times_input_to_df_mutagenesis(seq_class_x, seq_class_scores)
_plot_mutagenesis_map(mutagenesis_df, ax=ax)
else:
intgrad_df = grad_times_input_to_df(seq_class_x, seq_class_scores)
_plot_attribution_map(intgrad_df, ax=ax, return_ax=False)
if labels:
class_name = labels[i]
else:
Expand All @@ -102,11 +121,11 @@ def contribution_scores(
if ylim:
ax.set_ylim(ylim[0], ylim[1])
x_pos = 5
y_pos = 0.75 * ylim[1]
y_pos = 0.5 * ylim[1]
else:
ax.set_ylim([global_min, global_max])
x_pos = 5
y_pos = 0.75 * global_max
y_pos = 0.5 * global_max
ax.text(x_pos, y_pos, text_to_add, fontsize=16, ha="left", va="center")

# Draw rectangles to highlight positions
Expand Down
2 changes: 1 addition & 1 deletion src/crested/pl/patterns/_modisco_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _trim_pattern_by_ic(
contrib_scores = np.array(pattern["contrib_scores"])
if not pos_pattern:
contrib_scores = -contrib_scores
contrib_scores[contrib_scores < 0] = 0
contrib_scores[contrib_scores < 0] = 1e-9 # avoid division by zero

ic = modisco.util.compute_per_position_ic(
ppm=np.array(contrib_scores), background=background, pseudocount=pseudocount
Expand Down
25 changes: 25 additions & 0 deletions src/crested/pl/patterns/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,28 @@ def _plot_attribution_map(
ax.spines["top"].set_visible(False)
if return_ax:
return ax

def _plot_mutagenesis_map(mutagenesis_df, ax=None):
"""Plot an attribution map for mutagenesis using different colored dots, with adjusted x-axis limits."""
colors = {'A': 'green', 'C': 'blue', 'G': 'orange', 'T': 'red'}
if ax is None:
ax = plt.gca()

# Add horizontal line at y=0
ax.axhline(0, color='gray', linewidth=1, linestyle='--')

# Scatter plot for each nucleotide type
for nuc, color in colors.items():
# Filter out dots where the variant is the same as the original nucleotide
subset = mutagenesis_df[(mutagenesis_df['Nucleotide'] == nuc) & (mutagenesis_df['Nucleotide'] != mutagenesis_df['Original'])]
ax.scatter(subset['Position'], subset['Effect'], color=color, label=nuc, s=10) # s is the size of the dot

# Set the limits of the x-axis to match exactly the first and last position
if not mutagenesis_df.empty:
ax.set_xlim(mutagenesis_df['Position'].min() - 0.5, mutagenesis_df['Position'].max() + 0.5)

ax.legend(title="Nucleotide", loc='upper right')
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.xaxis.set_ticks_position("none")
plt.xticks([]) # Optionally, hide x-axis ticks for a cleaner look
17 changes: 13 additions & 4 deletions src/crested/tl/_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
PearsonCorrelation,
PearsonCorrelationLog,
ZeroPenaltyMetric,
SpearmanCorrelationPerClass
)


Expand Down Expand Up @@ -74,6 +75,9 @@ def metrics(self) -> list[tf.keras.metrics.Metric]:
class PeakRegressionConfig(BaseConfig):
"""Default configuration for peak regression task."""

def __init__(self, num_classes=None):
self.num_classes = num_classes

@property
def loss(self) -> tf.keras.losses.Loss:
return CosineMSELoss()
Expand All @@ -84,15 +88,18 @@ def optimizer(self) -> tf.keras.optimizers.Optimizer:

@property
def metrics(self) -> list[tf.keras.metrics.Metric]:
return [
metrics = [
tf.keras.metrics.MeanAbsoluteError(),
tf.keras.metrics.MeanSquaredError(),
tf.keras.metrics.CosineSimilarity(axis=1),
PearsonCorrelation(),
ConcordanceCorrelationCoefficient(),
PearsonCorrelationLog(),
ZeroPenaltyMetric(),
ZeroPenaltyMetric()
]
#if self.num_classes is not None:
# metrics.append(SpearmanCorrelationPerClass(num_classes=self.num_classes))
return metrics


class TaskConfig(NamedTuple):
Expand Down Expand Up @@ -156,7 +163,7 @@ def to_dict(self) -> dict:


def default_configs(
task: str,
task: str, num_classes: int = None
) -> TaskConfig:
"""
Get default loss, optimizer, and metrics for an existing task.
Expand All @@ -177,6 +184,8 @@ def default_configs(
----------
tasks
Task for which to get default components.
num_classes
Number of output classes of model. Required for Spearman correlation metric.
Returns
-------
Expand All @@ -196,7 +205,7 @@ def default_configs(
f"Task '{task}' not supported. Only {list(task_classes.keys())} are supported."
)

task_class = task_classes[task]()
task_class = task_classes[task](num_classes=num_classes) if task =='peak_regression' else task_classes[task]()
loss = task_class.loss
optimizer = task_class.optimizer
metrics = task_class.metrics
Expand Down
29 changes: 29 additions & 0 deletions src/crested/tl/_crested.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,31 @@ def predict_regions(

return np.concatenate(all_predictions, axis=0)

def predict_sequence(
self,
sequence: str) -> np.ndarray:
"""
Make predictions using the model on the provided DNA sequence.
Parameters
----------
model : a trained TensorFlow/Keras model
sequence : str
A string containing a DNA sequence (A, C, G, T).
Returns
-------
np.ndarray
Predictions for the provided sequence.
"""
# One-hot encode the sequence
x = one_hot_encode_sequence(sequence)

# Make prediction
predictions = self.model.predict(x)

return predictions

def calculate_contribution_scores(
self,
anndata: AnnData | None = None,
Expand Down Expand Up @@ -584,12 +609,16 @@ def calculate_contribution_scores_regions(
if isinstance(region_idx, str):
region_idx = [region_idx]

if isinstance(class_names, str):
class_names = [class_names]

all_scores = []
all_one_hot_sequences = []

all_class_names = list(self.anndatamodule.adata.obs_names)

if class_names is not None:
print(class_names)
n_classes = len(class_names)
class_indices = [
all_class_names.index(class_name) for class_name in class_names
Expand Down
2 changes: 1 addition & 1 deletion src/crested/tl/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from ._anndatamodule import AnnDataModule
from ._dataloader import AnnDataLoader
from ._dataset import AnnDataset
from ._dataset import AnnDataset, SequenceLoader
7 changes: 3 additions & 4 deletions src/crested/tl/data/_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ def __init__(
self,
genome_file: PathLike,
chromsizes: dict | None,
in_memory: bool,
always_reverse_complement: bool,
max_stochastic_shift: int,
in_memory: bool = False,
always_reverse_complement: bool = False,
max_stochastic_shift: int = 0,
regions: list[str] = None,
):
self.genome = FastaFile(genome_file)
Expand Down Expand Up @@ -82,7 +82,6 @@ def get_sequence(self, region: str, strand: str = "+", shift: int = 0) -> str:
sequence = self.sequences[key]
else:
sequence = self._get_extended_sequence(region)

chrom, start_end = region.split(":")
start, end = map(int, start_end.split("-"))
start_idx = self.max_stochastic_shift + shift
Expand Down
8 changes: 4 additions & 4 deletions src/crested/tl/losses/_cosinemse.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
class CosineMSELoss(tf.keras.losses.Loss):
"""Custom loss function that combines cosine similarity and mean squared error."""

def __init__(self, max_weight=1.0, name="CustomMSELoss", reduction=None):
def __init__(self, max_weight=1.0, name="CustomMSELoss"):
super().__init__(name=name)
self.max_weight = max_weight
self.reduction=reduction
#self.reduction=reduction

@tf.function
def call(self, y_true, y_pred):
Expand Down Expand Up @@ -39,8 +39,8 @@ def call(self, y_true, y_pred):
def get_config(self):
config = super().get_config()
config.update({
"max_weight": self.max_weight,#})
"reduction":self.reduction})
"max_weight": self.max_weight})
#"reduction":self.reduction})
return config

@classmethod
Expand Down
1 change: 1 addition & 0 deletions src/crested/tl/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from ._pearsoncorr import PearsonCorrelation
from ._pearsoncorrlog import PearsonCorrelationLog
from ._zeropenalty import ZeroPenaltyMetric
from ._spearmancorr import SpearmanCorrelationPerClass
55 changes: 55 additions & 0 deletions src/crested/tl/metrics/_spearmancorr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""Spearman correlation metric."""

from __future__ import annotations
import tensorflow as tf

@tf.keras.utils.register_keras_serializable(package="Metrics")
class SpearmanCorrelationPerClass(tf.keras.metrics.Metric):
def __init__(self, num_classes, name='spearman_correlation_per_class', **kwargs):
super(SpearmanCorrelationPerClass, self).__init__(name=name, **kwargs)
self.num_classes = num_classes
self.correlation_sums = self.add_weight(name='correlation_sums', shape=(num_classes,), initializer='zeros')
self.update_counts = self.add_weight(name='update_counts', shape=(num_classes,), initializer='zeros')

def update_state(self, y_true, y_pred, sample_weight=None):
for i in range(self.num_classes):
y_true_class = tf.cast(y_true[:, i], tf.float32)
y_pred_class = tf.cast(y_pred[:, i], tf.float32)

non_zero_indices = tf.where(tf.not_equal(y_true_class, 0))
y_true_non_zero = tf.gather(y_true_class, non_zero_indices)
y_pred_non_zero = tf.gather(y_pred_class, non_zero_indices)

# Ensure sizes are constant by checking them before the operation
num_elements = tf.size(y_true_non_zero)
proceed = num_elements > 1

def compute():
return self.compute_correlation(y_true_non_zero, y_pred_non_zero)

def skip():
return 0.0

correlation = tf.cond(proceed, compute, skip)
self.correlation_sums[i].assign_add(correlation)
self.update_counts[i].assign_add(tf.cast(proceed, tf.float32))

def compute_correlation(self, y_true_non_zero, y_pred_non_zero):
ranks_true = tf.argsort(tf.argsort(y_true_non_zero))
ranks_pred = tf.argsort(tf.argsort(y_pred_non_zero))

rank_diffs = tf.cast(ranks_true, tf.float32) - tf.cast(ranks_pred, tf.float32)
rank_diffs_squared_sum = tf.reduce_sum(tf.square(rank_diffs))
n = tf.cast(tf.size(y_true_non_zero), tf.float32)

correlation = 1 - (6 * rank_diffs_squared_sum) / (n * (n*n - 1))
return tf.where(tf.math.is_nan(correlation), 0.0, correlation)

def result(self):
valid_counts = self.update_counts
avg_correlations = self.correlation_sums / valid_counts
return tf.reduce_mean(avg_correlations)

def reset_state(self):
self.correlation_sums.assign(tf.zeros_like(self.correlation_sums))
self.update_counts.assign(tf.zeros_like(self.update_counts))

0 comments on commit f4288eb

Please sign in to comment.