diff --git a/Snakefile b/Snakefile index 93164e3..1419a87 100644 --- a/Snakefile +++ b/Snakefile @@ -57,6 +57,12 @@ META_DISC_DIR = os.path.join(RUN_DIR, META_DISC_PATH) ## Define a separate directory for merged outputs to avoid conflicts between different merge operations. MERGED_DIR = os.path.join(RUN_DIR, "merged") +## Define the directory to store correlation outputs +CORRELATION_PATH = config.get("correlation_dir", "correlations") +CORRELATION_DIR = os.path.join(RUN_DIR, CORRELATION_PATH) + +CORRELATION_DATA = config["correlation_data"] + ## Generate the paths for the embeddings and metadata files based on the merge keys. EMBEDDINGS_PATHS = [os.path.join(MERGED_DIR, f"{key}_embeddings.npz") for key in MERGE_KEYS] METADATA_PATHS = [os.path.join(MERGED_DIR, f"{key}_metadata.pkl") for key in MERGE_KEYS] @@ -92,6 +98,7 @@ MD_FILES = expand( ## If a configuration directory is provided, it is included in the command; otherwise, ## individual parameters such as trainer, model, and data are passed explicitly. TRAIN_COMMAND = config["train_command"] +CORRELATION_COMMAND = config["correlation_command"] TRAIN_COMMAND += str( f" --default_root_dir {ROOT_DIR} " @@ -100,13 +107,31 @@ TRAIN_COMMAND += str( f"--predict_dir {PREDICT_SUBDIR} " ) +CORRELATION_COMMAND += str( + f" --default_root_dir {ROOT_DIR} " + f"--experiment_name {EXPERIMENT_NAME} --run_name {RUN_NAME} " + f"--seed_everything {SEED} " + f"--predict_dir {PREDICT_SUBDIR} " + f"--ckpt_path {CKPT_PATH} " +) + +CORRELATION_FILES = expand( + "{correlation_dir}/correlations.csv", + correlation_dir=CORRELATION_DIR, +) + +CORRELATION_FILES += expand( + "{correlation_dir}/correlations.pkl", + correlation_dir=CORRELATION_DIR, +) ## Define the final output rule for Snakemake, specifying the target files that should be generated ## by the end of the workflow. rule all: input: EVALUATION_FILES, - MD_FILES + CORRELATION_FILES + # MD_FILES ## Define the rule for training the CMMVAE model. ## The output includes the configuration file, the checkpoint path, and the directory for predictions. @@ -138,6 +163,23 @@ rule merge_predictions: cmmvae workflow merge-predictions --directory {input.predict_dir} --keys {params.merge_keys} --save_dir {MERGED_DIR} """ +## Define the rule for getting R^2 correlations on the filtered data +## This rule outputs correlation scores per filtered data group +rule correlations: + input: + embeddings_path=EMBEDDINGS_PATHS, + output: + CORRELATION_FILES, + params: + command=CORRELATION_COMMAND, + data=CORRELATION_DATA, + save_dir=CORRELATION_DIR, + shell: + """ + mkdir -p {CORRELATION_DIR} + cmmvae workflow correlations {params.command} --correlation_data {params.data} --save_dir {params.save_dir} + """ + ## Define the rule for generating UMAP visualizations from the merged predictions. ## This rule produces UMAP images for each combination of category and merge key. rule umap_predictions: @@ -156,16 +198,16 @@ rule umap_predictions: cmmvae workflow umap-predictions --directory {params.predict_dir} {params.categories} {params.merge_keys} --save_dir {params.save_dir} """ -rule meta_discriminators: - input: - CKPT_PATH - output: - MD_FILES, - params: - log_dir=META_DISC_DIR, - ckpt=CKPT_PATH, - config=TRAIN_CONFIG_FILE - shell: - """ - cmmvae workflow meta-discriminator --log_dir {params.log_dir} --ckpt {params.ckpt} --config {params.config} - """ \ No newline at end of file +# rule meta_discriminators: +# input: +# CKPT_PATH +# output: +# MD_FILES, +# params: +# log_dir=META_DISC_DIR, +# ckpt=CKPT_PATH, +# config=TRAIN_CONFIG_FILE +# shell: +# """ +# cmmvae workflow meta-discriminator --log_dir {params.log_dir} --ckpt {params.ckpt} --config {params.config} +# """ \ No newline at end of file diff --git a/scripts/data-preprocessing/census_data_filtering.py b/scripts/data-preprocessing/census_data_filtering.py new file mode 100644 index 0000000..5bb9475 --- /dev/null +++ b/scripts/data-preprocessing/census_data_filtering.py @@ -0,0 +1,121 @@ +import os +import re +import glob +import argparse as ap +import multiprocessing as mp +import numpy as np +import pandas as pd +import scipy.sparse as sp +import data_filtering_functions as ff +from collections import defaultdict +from data_processing_functions import extract_file_number + +def main( + directory: str, + species: list[str], + train_dir: str, + save_dir: str, + seed: int, + skip_metadata: bool, +): + if not skip_metadata: + np.random.seed(seed) + dataframes = {} + for specie in species: + # Process file paths for full and training data + train_files = tuple(glob.glob(os.path.join(train_dir, f"{specie}*.pkl"))) + train_files = ff.filter_and_sort_train_files(train_files) + + # Get somaIDs of training samples + train_ids = ff.get_train_data_ids(train_files) + + # Prepare main df that is to be filtered + metadata_files = glob.glob(os.path.join(directory, f"{specie}*.pkl")) + metadata_files.sort(key=extract_file_number) + meta_dataframe = ff.load_and_merge_metadata(tuple(metadata_files)) + + # Filter out training samples from the full data + dataframes[specie] = ff.filter_train_ids(meta_dataframe, train_ids) + + # Process groups and subset size if needed + grouped_data = ff.filter_into_groups(dataframes) + valid_groups = ff.validate_and_sample_groups(grouped_data, species[0]) + + # Save filtered metadata + ff.save_grouped_data(valid_groups, dataframes, save_dir) + + # Load data and slice by metadata index + for specie in species: + data_files = glob.glob(os.path.join(directory, f"{specie}*.npz")) + data_files.sort(key=extract_file_number) + metadata_files = glob.glob(os.path.join(save_dir, f"{specie}*.pkl")) + metadata_files.sort(key=extract_file_number) + filtered_data = defaultdict(list) + for chunk_n, data_file in enumerate(data_files, start=1): + current_chunk = sp.load_npz(data_file) + for gid, metadata_file in enumerate(metadata_files, start=1): + current_df = pd.read_pickle(metadata_file) + idxes = current_df[current_df["chunk_source"] == chunk_n]["data_index"] + sliced_chunk = current_chunk[idxes, :] + filtered_data[gid].append(sliced_chunk) + + # Save filtered data + for gid, data in filtered_data.items(): + chunk_data = sp.vstack(data) + sp.save_npz(os.path.join(save_dir, f"{specie}_filtered_{gid}.npz"), chunk_data) + +if __name__ == "__main__": + parser = ap.ArgumentParser() + parser.add_argument( + "--directory", + type=str, + required=True, + help="Directory to load data from." + ) + parser.add_argument( + "--species", + type=str, + nargs="+", + required=True, + help="Species to load data for." + ) + parser.add_argument( + "--train_data", + type=str, + required=False, + default=None, + help="Directory where the training data is stored. Defaults to '--directory'" + ) + parser.add_argument( + "--save_directory", + type=str, + required=False, + default=None, + help="Directory to save filtered data in. Defaults to '--directory'" + ) + parser.add_argument( + "--seed", + type=int, + required=False, + default=42, + help="Seed for the random module." + ) + parser.add_argument( + "--skip_metadata", + action="store_true", + help="Whether to skip the metadata filtering and just slice the count data." + ) + + args = parser.parse_args() + + train_dir = args.directory if args.train_data is None else args.train_data + save_dir = args.directory if args.save_directory is None else args.save_directory + + main( + directory= args.directory, + species= args.species, + train_dir= train_dir, + save_dir= save_dir, + seed= args.seed, + skip_metadata= args.skip_metadata + ) \ No newline at end of file diff --git a/scripts/data-preprocessing/data_filtering_functions.py b/scripts/data-preprocessing/data_filtering_functions.py new file mode 100644 index 0000000..6b4e03c --- /dev/null +++ b/scripts/data-preprocessing/data_filtering_functions.py @@ -0,0 +1,99 @@ +import os +import re +import csv +import numpy as np +import pandas as pd +import pandas.core.groupby.generic as gb +from collections import defaultdict +from data_processing_functions import extract_file_number + +TRAIN_DATA_FILES = set(range(1, 14)) +GROUPING_COLUMNS = ["sex", "tissue", "cell_type", "assay"] +MIN_SAMPLE_SIZE = 100 +MAX_SAMPLE_SIZE = 10000 +SPECIES_SAMPLE_SIZE = {"human": 60530, "mouse": 52437} + +def get_train_data_ids(files: tuple[str]) -> set[int]: + + combined_df = pd.DataFrame() + + for file in files: + df = pd.read_pickle(file) + combined_df = pd.concat([combined_df, df], ignore_index=True) + + return set(combined_df["soma_joinid"]) + +def filter_train_ids(df: pd.DataFrame, ids: set[int]) -> pd.DataFrame: + filtered_df = df[~(df["soma_joinid"].isin(ids))] + filtered_df = filtered_df.reset_index(drop=True) + return filtered_df + +def filter_and_sort_train_files(unfiltered_files: tuple[str]) -> tuple[str]: + + filtered_files = [ + file for file in unfiltered_files + if (match := re.search(r'(\d+)\.pkl$', file)) and int(match.group(1)) in TRAIN_DATA_FILES + ] + filtered_files.sort(key=extract_file_number) + return tuple(filtered_files) + +def load_and_merge_metadata(files: tuple[str]) -> pd.DataFrame: + + merged_df = pd.DataFrame() + for file in files: + df = pd.read_pickle(file) + df["chunk_source"] = extract_file_number(file) + df = df.reset_index(drop=False) + df.rename(columns={"index": "data_index"}, inplace=True) + merged_df = pd.concat([merged_df, df], ignore_index=True) + + return merged_df + +def filter_into_groups(dfs: dict[str, pd.DataFrame]): + + grouped = {} + for specie, data in dfs.items(): + grouped[specie] = data.groupby(GROUPING_COLUMNS) + + return grouped + +def validate_and_sample_groups(data_groups: dict[str, gb.DataFrameGroupBy], primary_species: str = None): + + valid_groups = defaultdict(dict) + if primary_species is not None: + main_df = data_groups.pop(primary_species) + else: + primary_species, main_df = data_groups.popitem() + + for gid, idxes in main_df.groups.items(): + if len(idxes) < MIN_SAMPLE_SIZE: + continue + elif all( + gid in group.groups.keys() and len(group.groups[gid]) >= MIN_SAMPLE_SIZE + for group in data_groups.values() + ): + sample_size = min( + [len(idxes), MAX_SAMPLE_SIZE] + [len(group.groups[gid]) for group in data_groups.values()] + ) + + valid_groups[gid][primary_species] = np.random.choice(idxes, sample_size, replace= False) + for specie, group in data_groups.items(): + valid_groups[gid][specie] = np.random.choice(group.groups[gid], sample_size, replace= False) + + return valid_groups + +def save_grouped_data(groups: dict[tuple[str], dict[str, np.ndarray]], dfs: dict[str, pd.DataFrame], save_dir: str): + + with open(os.path.join(save_dir, "group_references.csv"), "w") as file: + writer = csv.writer(file) + writer.writerow(["group_id", "num_samples"] + GROUPING_COLUMNS) + for i, gid in enumerate(groups.keys(), start=1): + for specie, idx in groups[gid].items(): + df = dfs[specie].iloc[idx] + df["group_id"] = i + df["num_samples"] = len(idx) + df = df.sort_values("chunk_source") + df.to_pickle(os.path.join(save_dir, f"{specie}_filtered_{i}.pkl")) + writer.writerow([i, len(idx)] + list(gid)) + file = pd.read_csv(os.path.join(save_dir, "group_references.csv")) + file.to_pickle(os.path.join(save_dir, "group_references.pkl")) diff --git a/src/cmmvae/data/local/__init__.py b/src/cmmvae/data/local/__init__.py index d9ac7b7..691e9e8 100644 --- a/src/cmmvae/data/local/__init__.py +++ b/src/cmmvae/data/local/__init__.py @@ -5,9 +5,11 @@ from cmmvae.data.local.cellxgene_datamodule import SpeciesDataModule from cmmvae.data.local.cellxgene_manager import SpeciesManager +from cmmvae.data.local.cellxgene_datapipe import SpeciesDataPipe __all__ = [ "SpeciesManager", "SpeciesDataModule", + "SpeciesDataPipe", ] diff --git a/src/cmmvae/data/local/cellxgene_datapipe.py b/src/cmmvae/data/local/cellxgene_datapipe.py index 0246d82..76dff65 100644 --- a/src/cmmvae/data/local/cellxgene_datapipe.py +++ b/src/cmmvae/data/local/cellxgene_datapipe.py @@ -172,7 +172,7 @@ def __iter__(self): for i in range(0, n_samples, self.batch_size): data_batch = sparse_matrix[i : i + self.batch_size] - if self.allow_partials and not data_batch.shape[0] == self.batch_size: + if data_batch.shape[0] != self.batch_size and not self.allow_partials: continue tensor = torch.sparse_csr_tensor( @@ -252,6 +252,7 @@ def __init__( npz_masks: Union[str, list[str]], metadata_masks: Union[str, list[str]], batch_size: int, + allow_partials = False, shuffle: bool = True, return_dense: bool = False, verbose: bool = False, @@ -305,6 +306,7 @@ def __init__( print(path) self.batch_size = batch_size + self.allow_partials = allow_partials self.return_dense = return_dense self.verbose = verbose self._shuffle = shuffle @@ -324,7 +326,7 @@ def __init__( dp = dp.shuffle_matrix_and_dataframe() dp = dp.batch_csr_matrix_and_dataframe( - self.batch_size, return_dense=self.return_dense + self.batch_size, return_dense=self.return_dense, allow_partials=self.allow_partials ) # thought process on removal diff --git a/src/cmmvae/runners/3d_umap.py b/src/cmmvae/runners/3d_umap.py new file mode 100644 index 0000000..4736929 --- /dev/null +++ b/src/cmmvae/runners/3d_umap.py @@ -0,0 +1,114 @@ +import os +import numpy as np +import pandas as pd +import pickle +import umap +import matplotlib.pyplot as plt +import seaborn as sns +from mpl_toolkits.mplot3d import Axes3D +from matplotlib.animation import FuncAnimation + +def load_embeddings(npz_path, meta_path): + """Load embeddings and metadata from specified paths.""" + embedding = np.load(npz_path)["embeddings"] + metadata = pd.read_pickle(meta_path) + return embedding, metadata + +def plot_3d_umap(embedding, metadata, column, output_path, fig_size=(10, 10), point_size=2): + fig = plt.figure(figsize=fig_size) + ax = fig.add_subplot(111, projection='3d') + scatter = ax.scatter( + embedding[:, 0], embedding[:, 1], embedding[:, 2], + c=metadata[column].astype('category').cat.codes, + cmap='hsv', + s=point_size + ) + ax.set_title(f'3D UMAP projection colored by {column}\nFilename: {output_path.split("/")[-1]}') + ax.set_xlabel('UMAP 1') + ax.set_ylabel('UMAP 2') + ax.set_zlabel('UMAP 3') + + def update(num): + ax.view_init(elev=30, azim=num) + + ani = FuncAnimation(fig, update, frames=360, interval=20) + ani.save(output_path.replace('.png', '.gif'), writer='imagemagick') + + plt.close() + +def plot_umap( + X, + metadata, + n_neighbors=30, + min_dist=0.3, + n_components=3, + metric="cosine", + low_memory=False, + n_jobs=40, + n_epochs=200, + n_largest=15, + method=None, + save_dir=None, + **umap_kwargs, +): + """ + Generate UMAP embeddings and plot them. + + Args: + directory (str): Directory where embeddings are stored. + keys (list[str]): List of embedding keys. + categories (list[str]): List of categories to color by. + n_neighbors (int): Number of neighbors for UMAP. + min_dist (float): Minimum distance for UMAP. + n_components (int): Number of components for UMAP. + metric (str): Metric for UMAP. + low_memory (bool): Low memory setting for UMAP. + n_jobs (int): Number of CPUs available for UMAP. + n_epochs (int): Number of epochs to run UMAP. + n_largest (int): Number of most common categories to plot. + method (str): Method title to add to the graph. + save_dir (str): Directory to save UMAP plots. + **umap_kwargs: Extra kwargs passed to `umap.UMAP`. + """ + import umap + # Fit and transform the data using UMAP + reducer = umap.UMAP( + n_neighbors=n_neighbors, + min_dist=min_dist, + n_components=n_components, + metric=metric, + low_memory=low_memory, + n_jobs=n_jobs, + n_epochs=n_epochs, + **umap_kwargs, + ) + + embedding = reducer.fit_transform(X) + embedding_file_name = "3d_z_umap_embeddings.npz" + embedding_path = os.path.join(save_dir, embedding_file_name) + metadata_file_name = "3d_z_umap_metadata.pkl" + metadata_path = os.path.join(save_dir, metadata_file_name) + os.makedirs(save_dir, exist_ok=True) + np.savez(embedding_path, embeddings=embedding) + metadata.to_pickle(metadata_path) + +directory = "/mnt/projects/debruinz_project/tony_boos/MMVAE/lightning_logs/baseline/1e7Beta_256rclt_noadv/" +embedding, metadata = load_embeddings( + npz_path= os.path.join(directory, 'merged/z_embeddings.npz'), + meta_path=os.path.join(directory, 'merged/z_metadata.pkl') +) + +plot_umap(embedding, metadata, save_dir=os.path.join(directory, 'umap')) + +embedding, metadata = load_embeddings( + npz_path= os.path.join(directory, 'umap/3d_z_umap_embeddings.npz'), + meta_path=os.path.join(directory, 'umap/3d_z_umap_metadata.pkl') +) +categories = ['donor_id', 'assay', 'dataset_id', 'cell_type', 'tissue', 'species'] +for col in categories: + plot_3d_umap( + embedding=embedding, + metadata=metadata, + column=col, + output_path=os.path.join(directory, f"umap/3D_UMAP_by_{col}.png") + ) \ No newline at end of file diff --git a/src/cmmvae/runners/cli.py b/src/cmmvae/runners/cli.py index 6d0d38b..d0cb834 100644 --- a/src/cmmvae/runners/cli.py +++ b/src/cmmvae/runners/cli.py @@ -18,7 +18,7 @@ def __init__(self, extra_parser_kwargs: dict = {}, **kwargs): Handles loading trainer, model, and data modules from config file, while linking common arguments for ease of access. """ - self.is_run = not kwargs.get("run", False) + self.is_run = not kwargs.get("run", True) super().__init__( parser_kwargs={ @@ -52,7 +52,7 @@ def add_arguments_to_parser(self, parser): "--predict_dir", required=True, help="Where to store predictions after fit" ) - if not self.is_run: + if self.is_run: parser.add_argument( "--ckpt_path", required=True, help="Ckpt path to be passed to model" ) diff --git a/src/cmmvae/runners/correlations.py b/src/cmmvae/runners/correlations.py new file mode 100644 index 0000000..bcc848a --- /dev/null +++ b/src/cmmvae/runners/correlations.py @@ -0,0 +1,168 @@ +""" +Get R^2 correlations between cis and cross species generations +""" +import os +import sys +import click +import torch + +import numpy as np +import pandas as pd +import scipy.sparse as sp + +from torch.utils.data import DataLoader + +from cmmvae.models import CMMVAEModel +from cmmvae.constants import REGISTRY_KEYS as RK +from cmmvae.runners.cli import CMMVAECli +from cmmvae.data.local import SpeciesDataPipe + +FILTERED_BY_CATEGORIES = ["assay", "cell_type", "tissue", "sex"] + +def setup_datapipes(data_dir: str): + human_pipe = SpeciesDataPipe( + directory_path= data_dir, + npz_masks= "human*.npz", + metadata_masks= "human*.pkl", + batch_size= 10000, + allow_partials=True, + shuffle= False, + return_dense= False, + verbose= True, + ) + mouse_pipe = SpeciesDataPipe( + directory_path= data_dir, + npz_masks= "mouse*.npz", + metadata_masks= "mouse*.pkl", + batch_size= 10000, + allow_partials=True, + shuffle= False, + return_dense= False, + verbose= True, + ) + return human_pipe, mouse_pipe + +def setup_dataloaders(data_dir: str): + human_pipe, mouse_pipe = setup_datapipes(data_dir) + + human_dataloader = DataLoader( + dataset=human_pipe, + batch_size=None, + shuffle=False, + collate_fn=lambda x: x, + persistent_workers=False, + num_workers=6, + ) + mouse_dataloader = DataLoader( + dataset=mouse_pipe, + batch_size=None, + shuffle=False, + collate_fn=lambda x: x, + persistent_workers=False, + num_workers=6, + ) + + return human_dataloader, mouse_dataloader + +def convert_to_tensor(batch: sp.csr_matrix): + return torch.sparse_csr_tensor( + crow_indices=batch.indptr, + col_indices=batch.indices, + values=batch.data, + size=batch.shape, + ) + +def calc_correlations(human_out: np.ndarray, mouse_out: np.ndarray, n_samples: int): + with np.errstate(divide="ignore", invalid="ignore"): + human_correlations = np.corrcoef(human_out) + mouse_correlations = np.corrcoef(mouse_out) + + human_cis = np.round(np.nan_to_num(human_correlations[:n_samples, :n_samples]).mean(), 3) + human_cross = np.round(np.nan_to_num(human_correlations[n_samples:, n_samples:]).mean(), 3) + human_comb = np.round(np.nan_to_num(human_correlations[:n_samples, n_samples:]).mean(), 3) + + mouse_cis = np.round(np.nan_to_num(mouse_correlations[:n_samples, :n_samples]).mean(), 3) + mouse_cross = np.round(np.nan_to_num(mouse_correlations[n_samples:, n_samples:]).mean(), 3) + mouse_comb = np.round(np.nan_to_num(mouse_correlations[:n_samples, n_samples:]).mean(), 3) + + return pd.DataFrame( + { + "human_cis": [human_cis], + "human_cross": [human_cross], + "human_comb": [human_comb], + "mouse_cis": [mouse_cis], + "mouse_cross": [mouse_cross], + "mouse_comb": [mouse_comb] + } + ) + +def get_correlations(model: CMMVAEModel, data_dir: str): + + human_dataloader, mouse_dataloader = setup_dataloaders(data_dir) + + correlations = pd.DataFrame( + columns=[ + "tag", "group_id", "num_samples", + "human_cis", "human_cross", "human_comb", + "mouse_cis", "mouse_cross", "mouse_comb" + ] + ) + + for (human_batch, human_metadata), (mouse_batch, mouse_metadata) in zip(human_dataloader, mouse_dataloader): + + n_samples = human_metadata["num_samples"].iloc[0] + + model.module.eval() + with torch.no_grad(): + _, _, _, human_out, _ = model.module(human_batch, human_metadata, RK.HUMAN, cross_generate=True) + _, _, _, mouse_out, _ = model.module(mouse_batch, mouse_metadata, RK.MOUSE, cross_generate=True) + + human_stacked_out = np.vstack((human_out[RK.HUMAN].numpy(), mouse_out[RK.HUMAN].numpy())) + mouse_stacked_out = np.vstack((mouse_out[RK.MOUSE].numpy(), human_out[RK.MOUSE].numpy())) + + avg_correlations = calc_correlations(human_stacked_out, mouse_stacked_out, n_samples) + + avg_correlations["tag"] = " ".join([human_metadata[cat].iloc[0] for cat in FILTERED_BY_CATEGORIES]) + avg_correlations["group_id"] = human_metadata["group_id"].iloc[0] + avg_correlations["num_samples"] = n_samples + + correlations = pd.concat([correlations, avg_correlations], ignore_index=True) + + return correlations + +def save_correlations(correlations: pd.DataFrame, save_dir: str): + correlations = correlations.sort_values("group_id") + correlations.to_csv(os.path.join(save_dir, "correlations.csv"), index=False) + correlations.to_pickle(os.path.join(save_dir, "correlations.pkl")) + +@click.command( + context_settings=dict( + ignore_unknown_options=True, + allow_extra_args=True, + ) +) +@click.option( + "--correlation_data", + type=click.Path(exists=True), + required=True, + help="Directory where filtered correlation data is saved" +) +@click.option( + "--save_dir", + type=click.Path(exists=True), + required=True, + help="Directory where correlation outputs are saved" +) +@click.pass_context +def correlations(ctx: click.Context, correlation_data: str, save_dir: str): + """Run using the LightningCli.""" + if ctx.args: + # Ensure `args` is passed as the command-line arguments + sys.argv = [sys.argv[0]] + ctx.args + + cli = CMMVAECli(run=False) + correlations = get_correlations(cli.model, correlation_data) + save_correlations(correlations, save_dir) + +if __name__ == "__main__": + correlations() \ No newline at end of file diff --git a/src/cmmvae/runners/workflow.py b/src/cmmvae/runners/workflow.py index 59689dc..6608c1e 100644 --- a/src/cmmvae/runners/workflow.py +++ b/src/cmmvae/runners/workflow.py @@ -1,6 +1,7 @@ import click from cmmvae.runners.cli import cli +from cmmvae.runners.correlations import correlations from cmmvae.runners.umap_predictions import umap_predictions from cmmvae.runners.merge_predictions import merge_predictions from cmmvae.runners.meta_discriminators import meta_discriminator @@ -12,6 +13,7 @@ def workflow(): workflow.add_command(cli) +workflow.add_command(correlations) workflow.add_command(umap_predictions) workflow.add_command(merge_predictions) workflow.add_command(meta_discriminator) diff --git a/workflow/config.yaml b/workflow/config.yaml index 48f79ba..06ee36a 100644 --- a/workflow/config.yaml +++ b/workflow/config.yaml @@ -8,12 +8,17 @@ run_name: version000 predict_dir: samples train_command: fit --trainer configs/trainer/config.test.yaml --model configs/model/config.yaml --data configs/data/local.yaml +correlation_command: --trainer configs/trainer/config.test.yaml --model configs/model/config.yaml categories: - 'donor_id' - 'assay' - 'dataset_id' - 'cell_type' +- 'tissue' +- 'species' merge_keys: - z + +correlation_data: /mnt/projects/debruinz_project/july2024_census_data/filtered \ No newline at end of file diff --git a/workflow/profile/slurm/config.yaml b/workflow/profile/slurm/config.yaml index 7d9c6e3..999e16f 100644 --- a/workflow/profile/slurm/config.yaml +++ b/workflow/profile/slurm/config.yaml @@ -39,6 +39,11 @@ set-resources: mem: 179GB gpus_per_node: "" cpus_per_task: 1 + correlations: + partition: gpu + mem: 179GB + gpus_per_node: tesla_v100s:1 + cpus_per_task: 12 umap_predictions: partition: all mem: 179GB