diff --git a/Snakefile b/Snakefile index 1b978e1..99fa5cf 100644 --- a/Snakefile +++ b/Snakefile @@ -162,9 +162,9 @@ rule predict: ## This rule outputs correlation scores per filtered data group rule correlations: input: - PREDICTIONS_PATH, + ckpt_path=CKPT_PATH, output: - CORRELATION_FILES, + os.path.join(CORRELATION_DIR, "correlations_complete.log") params: command=TRAIN_COMMAND.lstrip('fit'), data=CORRELATION_DATA, @@ -172,7 +172,21 @@ rule correlations: shell: """ mkdir -p {CORRELATION_DIR} - cmmvae workflow correlations {params.command} --correlation_data {params.data} --save_dir {params.save_dir} + cmmvae workflow correlations {params.command} --ckpt_path {input.ckpt_path} --correlation_data {params.data} --save_dir {params.save_dir} + touch {output} + """ + +rule run_correlations: + input: + rules.correlations.output + output: + CORRELATION_FILES, + params: + directory=CORRELATION_DIR, + shell: + """ + mkdir -p {CORRELATION_DIR} + cmmvae workflow run-correlations --directory {params.directory} """ ## Define the rule for generating UMAP visualizations from the merged predictions. diff --git a/src/cmmvae/runners/correlations.py b/src/cmmvae/runners/correlations.py index 8febbf6..41b1b68 100644 --- a/src/cmmvae/runners/correlations.py +++ b/src/cmmvae/runners/correlations.py @@ -17,14 +17,12 @@ from cmmvae.runners.cli import CMMVAECli from cmmvae.data.local import SpeciesDataPipe -FILTERED_BY_CATEGORIES = ["assay", "cell_type", "tissue", "sex"] - -def setup_datapipes(data_dir: str): +def setup_datapipes(data_dir: str, human_masks: str, mouse_masks: str): human_pipe = SpeciesDataPipe( directory_path=data_dir, - npz_masks="human*.npz", - metadata_masks="human*.pkl", + npz_masks=[f"{human_mask}.npz" for human_mask in human_masks], + metadata_masks=[f"{human_mask}.pkl" for human_mask in human_masks], batch_size=10000, allow_partials=True, shuffle=False, @@ -33,8 +31,8 @@ def setup_datapipes(data_dir: str): ) mouse_pipe = SpeciesDataPipe( directory_path=data_dir, - npz_masks="mouse*.npz", - metadata_masks="mouse*.pkl", + npz_masks=[f"{mouse_mask}.npz" for mouse_mask in mouse_masks], + metadata_masks=[f"{mouse_mask}.pkl" for mouse_mask in mouse_masks], batch_size=10000, allow_partials=True, shuffle=False, @@ -45,7 +43,15 @@ def setup_datapipes(data_dir: str): def setup_dataloaders(data_dir: str): - human_pipe, mouse_pipe = setup_datapipes(data_dir) + + gids = np.random.choice( + np.arange(1, 252), size=25, replace=False + ) + + human_masks = [f"human*_{gid}" for gid in gids] + mouse_masks = [f"mouse*_{gid}" for gid in gids] + + human_pipe, mouse_pipe = setup_datapipes(data_dir, human_masks, mouse_masks) human_dataloader = DataLoader( dataset=human_pipe, @@ -53,7 +59,7 @@ def setup_dataloaders(data_dir: str): shuffle=False, collate_fn=lambda x: x, persistent_workers=False, - num_workers=6, + num_workers=2, ) mouse_dataloader = DataLoader( dataset=mouse_pipe, @@ -61,7 +67,7 @@ def setup_dataloaders(data_dir: str): shuffle=False, collate_fn=lambda x: x, persistent_workers=False, - num_workers=6, + num_workers=2, ) return human_dataloader, mouse_dataloader @@ -78,69 +84,15 @@ def convert_to_tensor(batch: sp.csr_matrix): else torch.device("cpu"), ) - -def calc_correlations(human_out: np.ndarray, mouse_out: np.ndarray, n_samples: int): - with np.errstate(divide="ignore", invalid="ignore"): - human_correlations = np.corrcoef(human_out) - mouse_correlations = np.corrcoef(mouse_out) - - human_cis = np.round( - np.nan_to_num(human_correlations[:n_samples, :n_samples]).mean(), 3 - ) - human_cross = np.round( - np.nan_to_num(human_correlations[n_samples:, n_samples:]).mean(), 3 - ) - human_comb = np.round( - np.nan_to_num(human_correlations[:n_samples, n_samples:]).mean(), 3 - ) - - mouse_cis = np.round( - np.nan_to_num(mouse_correlations[:n_samples, :n_samples]).mean(), 3 - ) - mouse_cross = np.round( - np.nan_to_num(mouse_correlations[n_samples:, n_samples:]).mean(), 3 - ) - mouse_comb = np.round( - np.nan_to_num(mouse_correlations[:n_samples, n_samples:]).mean(), 3 - ) - - return pd.DataFrame( - { - "human_cis": [human_cis], - "human_cross": [human_cross], - "human_comb": [human_comb], - "mouse_cis": [mouse_cis], - "mouse_cross": [mouse_cross], - "mouse_comb": [mouse_comb], - } - ) - - -def get_correlations(model: CMMVAEModel, data_dir: str): +def get_correlations(model: CMMVAEModel, data_dir: str, save_dir: str): human_dataloader, mouse_dataloader = setup_dataloaders(data_dir) - correlations = pd.DataFrame( - columns=[ - "group_id", - "num_samples", - "human_cis", - "human_cross", - "human_comb", - "mouse_cis", - "mouse_cross", - "mouse_comb", - "tag", - ] - ) - for (human_batch, human_metadata), (mouse_batch, mouse_metadata) in zip( human_dataloader, mouse_dataloader ): human_batch = human_batch.cuda() mouse_batch = mouse_batch.cuda() - n_samples = human_metadata["num_samples"].iloc[0] - model.module.eval() with torch.no_grad(): _, _, _, human_out, _ = model.module( @@ -149,35 +101,35 @@ def get_correlations(model: CMMVAEModel, data_dir: str): _, _, _, mouse_out, _ = model.module( mouse_batch, mouse_metadata, RK.MOUSE, cross_generate=True ) - - human_stacked_out = np.vstack( - (human_out[RK.HUMAN].cpu().numpy(), mouse_out[RK.HUMAN].cpu().numpy()) + save_data = convert_to_csr(human_out, mouse_out) + metadata = {RK.HUMAN: human_metadata, RK.MOUSE: mouse_metadata} + save_correlations(save_data, metadata, save_dir, gid=human_metadata["group_id"].iloc[0]) + +def convert_to_csr( + human_xhats: dict[str: torch.Tensor], + mouse_xhats: dict[str: torch.Tensor], +): + converted = {} + for output_species, xhat in human_xhats.items(): + nparray = xhat.cpu().numpy() + converted[f"human_to_{output_species}"] = sp.csr_matrix(nparray) + for output_species, xhat in mouse_xhats.items(): + nparray = xhat.cpu().numpy() + converted[f"mouse_to_{output_species}"] = sp.csr_matrix(nparray) + return converted + + +def save_correlations(data: dict[str, sp.csr_matrix], metadata: pd.DataFrame, save_dir: str, gid: int): + for file_name, data in data.items(): + sp.save_npz( + os.path.join(save_dir, f"{file_name}_{gid}.npz"), + data ) - mouse_stacked_out = np.vstack( - (mouse_out[RK.MOUSE].cpu().numpy(), human_out[RK.MOUSE].cpu().numpy()) - ) - - avg_correlations = calc_correlations( - human_stacked_out, mouse_stacked_out, n_samples + for species, md in metadata.items(): + md.to_pickle( + os.path.join(save_dir, f"{species}_metadata_{gid}.pkl") ) - avg_correlations["group_id"] = human_metadata["group_id"].iloc[0] - avg_correlations["num_samples"] = n_samples - avg_correlations["tag"] = " ".join( - [human_metadata[cat].iloc[0] for cat in FILTERED_BY_CATEGORIES] - ) - - correlations = pd.concat([correlations, avg_correlations], ignore_index=True) - - return correlations - - -def save_correlations(correlations: pd.DataFrame, save_dir: str): - correlations = correlations.sort_values("group_id") - correlations.to_csv(os.path.join(save_dir, "correlations.csv"), index=False) - correlations.to_pickle(os.path.join(save_dir, "correlations.pkl")) - - @click.command( context_settings=dict( ignore_unknown_options=True, @@ -207,8 +159,7 @@ def correlations(ctx: click.Context, correlation_data: str, save_dir: str): model = type(cli.model).load_from_checkpoint( cli.config["ckpt_path"], module=cli.model.module ) - correlations = get_correlations(model, correlation_data) - save_correlations(correlations, save_dir) + get_correlations(model, correlation_data, save_dir) if __name__ == "__main__": diff --git a/src/cmmvae/runners/run_correlations.py b/src/cmmvae/runners/run_correlations.py new file mode 100644 index 0000000..b354ffa --- /dev/null +++ b/src/cmmvae/runners/run_correlations.py @@ -0,0 +1,163 @@ +import os +import re +import sys +import click +import torch +import glob + +import numpy as np +import pandas as pd +import scipy.sparse as sp + +from collections import defaultdict +from cmmvae.constants import REGISTRY_KEYS as RK + +FILTERED_BY_CATEGORIES = ["assay", "cell_type", "tissue", "sex"] + +def calc_correlations(human_out: np.ndarray, mouse_out: np.ndarray, n_samples: int): + with np.errstate(divide="ignore", invalid="ignore"): + human_correlations = np.corrcoef(human_out) + mouse_correlations = np.corrcoef(mouse_out) + + human_cis = np.round( + np.nan_to_num(human_correlations[:n_samples, :n_samples]).mean(), 3 + ) + human_cross = np.round( + np.nan_to_num(human_correlations[n_samples:, n_samples:]).mean(), 3 + ) + human_comb = np.round( + np.nan_to_num(human_correlations[:n_samples, n_samples:]).mean(), 3 + ) + human_rel = np.round( + (2 * human_comb) / (human_cis + human_cross), 3 + ) + + mouse_cis = np.round( + np.nan_to_num(mouse_correlations[:n_samples, :n_samples]).mean(), 3 + ) + mouse_cross = np.round( + np.nan_to_num(mouse_correlations[n_samples:, n_samples:]).mean(), 3 + ) + mouse_comb = np.round( + np.nan_to_num(mouse_correlations[:n_samples, n_samples:]).mean(), 3 + ) + mouse_rel = np.round( + (2 * mouse_comb) / (mouse_cis + mouse_cross), 3 + ) + + return pd.DataFrame( + { + "human_cis": [human_cis], + "human_cross": [human_cross], + "human_comb": [human_comb], + "human_rel": [human_rel], + "mouse_cis": [mouse_cis], + "mouse_cross": [mouse_cross], + "mouse_comb": [mouse_comb], + "mouse_rel": [mouse_rel] + } + ) + +def save_correlations(correlations: pd.DataFrame, save_dir: str): + correlations = correlations.sort_values("group_id") + correlations.to_csv(os.path.join(save_dir, "correlations.csv"), index=False) + correlations.to_pickle(os.path.join(save_dir, "correlations.pkl")) + +def get_correlations( + data_files: dict[str: dict[str: sp.csr_matrix]], + metadata_files: dict[str: dict[str: pd.DataFrame]] +): + + correlations = pd.DataFrame( + columns=[ + "group_id", + "num_samples", + "human_cis", + "human_cross", + "human_comb", + "human_rel", + "mouse_cis", + "mouse_cross", + "mouse_comb", + "mouse_rel", + "tag", + ] + ) + # print(data_files) + # print(metadata_files) + for gid, data in data_files.items(): + # print(gid) + # print(metadata_files[gid]) + n_samples = metadata_files[gid][RK.HUMAN]["num_samples"].iloc[0] + human_stacked_out = np.vstack( + ( + data[f"{RK.HUMAN}_to_{RK.HUMAN}"].toarray(), + data[f"{RK.MOUSE}_to_{RK.HUMAN}"].toarray() + ) + ) + mouse_stacked_out = np.vstack( + ( + data[f"{RK.MOUSE}_to_{RK.MOUSE}"].toarray(), + data[f"{RK.HUMAN}_to_{RK.MOUSE}"].toarray() + ) + ) + + avg_correlations = calc_correlations( + human_stacked_out, mouse_stacked_out, n_samples + ) + + avg_correlations["group_id"] = gid + avg_correlations["num_samples"] = n_samples + avg_correlations["tag"] = " ".join( + [metadata_files[gid][RK.HUMAN][cat].iloc[0] for cat in FILTERED_BY_CATEGORIES] + ) + + correlations = pd.concat([correlations, avg_correlations], ignore_index=True) + + return correlations + +def correlations(directory: str): + data_files = defaultdict(dict) + files = glob.glob(os.path.join(directory, "*.npz")) + file_pattern = re.compile(r"(.*_to_.*)_(\d+).npz") + for file in files: + # print(file) + match = file_pattern.match(file) + if match: + label = match.group(1) + gid = int(match.group(2)) + label = os.path.basename(label) + # print(f"label: {label}, GID: {gid}") + data = sp.load_npz(file) + data_files[gid][label] = data + + metadata_files = defaultdict(dict) + files = glob.glob(os.path.join(directory, "*.pkl")) + file_pattern = re.compile(r"(.*)_metadata_(\d+).pkl") + for file in files: + # print(file) + match = file_pattern.match(file) + if match: + species = match.group(1) + gid = int(match.group(2)) + species = os.path.basename(species) + # print(f"Species: {species}, GID: {gid}") + data = pd.read_pickle(file) + metadata_files[gid][species] = data + + correlations = get_correlations(data_files, metadata_files) + save_correlations(correlations, directory) + +@click.command() +@click.option( + "--directory", + type=click.Path(exists=True), + required=True, + default=lambda: os.environ.get("DIRECTORY", ""), + help="Directory where outputs are saved.", +) +def run_correlations(**kwargs): + correlations(**kwargs) + +if __name__ == "__main__": + run_correlations() \ No newline at end of file diff --git a/src/cmmvae/runners/workflow.py b/src/cmmvae/runners/workflow.py index 70d2f3b..d70d58b 100644 --- a/src/cmmvae/runners/workflow.py +++ b/src/cmmvae/runners/workflow.py @@ -3,6 +3,7 @@ from cmmvae.runners.expression import expression from cmmvae.runners.cli import cli from cmmvae.runners.correlations import correlations +from cmmvae.runners.run_correlations import run_correlations from cmmvae.runners.umap_predictions import umap_predictions from cmmvae.runners.merge_predictions import merge_predictions from cmmvae.runners.meta_discriminators import meta_discriminator @@ -16,6 +17,7 @@ def workflow(): workflow.add_command(expression) workflow.add_command(cli) workflow.add_command(correlations) +workflow.add_command(run_correlations) workflow.add_command(umap_predictions) workflow.add_command(merge_predictions) workflow.add_command(meta_discriminator) diff --git a/workflow/profile/slurm/config.yaml b/workflow/profile/slurm/config.yaml index 66f253e..c8193fb 100644 --- a/workflow/profile/slurm/config.yaml +++ b/workflow/profile/slurm/config.yaml @@ -47,8 +47,13 @@ set-resources: correlations: partition: gpu mem: 179GB - gpus_per_node: tesla_v100s:1 + gpus_per_node: 1 cpus_per_task: 12 + run_correlations: + partition: cpu + mem: 179GB + gpus_per_node: "" + cpus_per_task: 1 umap_predictions: partition: all mem: 179GB