From 1c844476599b41b219f4f452f0ec2347a9758102 Mon Sep 17 00:00:00 2001 From: Tony Boos Date: Wed, 11 Sep 2024 18:32:40 -0400 Subject: [PATCH 01/14] Resolves Issue #37 --- src/cmmvae/models/cmmvae_model.py | 4 ++-- src/cmmvae/modules/cmmvae.py | 14 +++----------- 2 files changed, 5 insertions(+), 13 deletions(-) diff --git a/src/cmmvae/models/cmmvae_model.py b/src/cmmvae/models/cmmvae_model.py index 6f3c248..8e0a10f 100644 --- a/src/cmmvae/models/cmmvae_model.py +++ b/src/cmmvae/models/cmmvae_model.py @@ -118,7 +118,7 @@ def training_step( adversarial_optimizers = optims["adversarials"] # Perform forward pass and compute the loss - qz, pz, z, xhats, cg_xhats, hidden_representations = self.module( + qz, pz, z, xhats, hidden_representations = self.module( x, metadata, expert_id ) @@ -214,7 +214,7 @@ def validation_step(self, batch: Tuple[torch.Tensor, pd.DataFrame, str]) -> None expert_label = self.module.experts.labels[expert_id] # Perform forward pass and compute the loss - qz, pz, z, xhats, cg_xhats, hidden_representations = self.module( + qz, pz, z, xhats, hidden_representations = self.module( x, metadata, expert_id ) diff --git a/src/cmmvae/modules/cmmvae.py b/src/cmmvae/modules/cmmvae.py index cb656d1..de7dc45 100644 --- a/src/cmmvae/modules/cmmvae.py +++ b/src/cmmvae/modules/cmmvae.py @@ -68,7 +68,7 @@ def forward( metadata (pd.DataFrame): Metadata associated with the input data. expert_id (str): Identifier for the expert network to use. cross_generate (bool, optional): - Flag to enable cross-generation between experts. + Flag to enable cross-generation across experts. Defaults to False. Returns: @@ -80,8 +80,6 @@ def forward( - z (torch.Tensor): Sampled latent variable. - xhats (Dict[str, torch.Tensor]): Reconstructed outputs for each expert. - - cg_xhats (Dict[str, torch.Tensor]): - Cross-generated outputs if cross_generate is True. - hidden_representations (List[torch.Tensor]): Hidden representations from the VAE. """ @@ -92,7 +90,6 @@ def forward( qz, pz, z, shared_xhat, hidden_representations = self.vae(shared_x, metadata) xhats = {} - cg_xhats = {} # Perform cross-generation if enabled if cross_generate: @@ -105,20 +102,15 @@ def forward( """ ) + # Decode using all avaialble experts for expert in self.experts: xhats[expert] = self.experts[expert].decode(shared_xhat) - for xhat_expert_id, xhat_expert in xhats.items(): - if xhat_expert_id == expert_id: - continue - shared_x = self.experts[xhat_expert_id].encode(xhat_expert) - _, _, _, shared_xhat, _ = self.vae(shared_x, metadata) - cg_xhats[xhat_expert_id] = self.experts[expert_id].decode(shared_xhat) else: # Decode using the specified expert xhats[expert_id] = self.experts[expert_id].decode(shared_xhat) - return qz, pz, z, xhats, cg_xhats, hidden_representations + return qz, pz, z, xhats, hidden_representations @torch.no_grad() def get_latent_embeddings( From 750ba33701c7f7748533bf3b613e41e887b3fbf2 Mon Sep 17 00:00:00 2001 From: Tony Boos Date: Thu, 12 Sep 2024 12:07:47 -0400 Subject: [PATCH 02/14] Resolves UMAP runtime issues and VAE optimizer parameters --- src/cmmvae/models/base_model.py | 1 + src/cmmvae/models/cmmvae_model.py | 2 +- src/cmmvae/modules/cmmvae.py | 3 +++ 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/cmmvae/models/base_model.py b/src/cmmvae/models/base_model.py index 5735654..c8dfa27 100644 --- a/src/cmmvae/models/base_model.py +++ b/src/cmmvae/models/base_model.py @@ -338,3 +338,4 @@ def _save_paired_predictions(self): self.predict_dir, f"{key}_metadata_{self._curr_save_idx}.pkl" ), ) + self._running_predictions.clear() diff --git a/src/cmmvae/models/cmmvae_model.py b/src/cmmvae/models/cmmvae_model.py index 8e0a10f..24ebc9e 100644 --- a/src/cmmvae/models/cmmvae_model.py +++ b/src/cmmvae/models/cmmvae_model.py @@ -294,7 +294,7 @@ def configure_optimizers(self) -> List[torch.optim.Optimizer]: for expert_id in self.module.experts } optimizers["vae"] = torch.optim.Adam( - self.module.vae.encoder.parameters(), lr=1e-3, weight_decay=1e-6 + self.module.vae.parameters(), lr=1e-3, weight_decay=1e-6 ) optimizers["adversarials"] = { i: torch.optim.Adam(module.parameters(), lr=1e-3, weight_decay=1e-6) diff --git a/src/cmmvae/modules/cmmvae.py b/src/cmmvae/modules/cmmvae.py index de7dc45..b4e183f 100644 --- a/src/cmmvae/modules/cmmvae.py +++ b/src/cmmvae/modules/cmmvae.py @@ -136,4 +136,7 @@ def get_latent_embeddings( # Encode using the VAE _, z, _ = self.vae.encode(x) + # Tag the metadata with the expert_id + metadata['species'] = expert_id + return {RK.Z: z, f"{RK.Z}_{RK.METADATA}": metadata} From e4f90f89dd1a2da81097773459298727eaed698e Mon Sep 17 00:00:00 2001 From: Tony Boos Date: Fri, 27 Sep 2024 10:00:21 -0400 Subject: [PATCH 03/14] Removes outdata data pre-processing scripts --- .../data-preprocessing/3m_sampler_human.py | 65 ------------------- .../data-preprocessing/3m_sampler_mouse.py | 64 ------------------ .../cell_census_mouse_download.R | 50 -------------- .../mouse_data_R_to_Python.R | 22 ------- .../mouse_metadata_conversion.R | 4 -- 5 files changed, 205 deletions(-) delete mode 100755 scripts/data-preprocessing/3m_sampler_human.py delete mode 100755 scripts/data-preprocessing/3m_sampler_mouse.py delete mode 100644 scripts/data-preprocessing/cell_census_mouse_download.R delete mode 100644 scripts/data-preprocessing/mouse_data_R_to_Python.R delete mode 100644 scripts/data-preprocessing/mouse_metadata_conversion.R diff --git a/scripts/data-preprocessing/3m_sampler_human.py b/scripts/data-preprocessing/3m_sampler_human.py deleted file mode 100755 index 97e5cd1..0000000 --- a/scripts/data-preprocessing/3m_sampler_human.py +++ /dev/null @@ -1,65 +0,0 @@ -""" -2/7/2024 - Anthony Boos -The 3m Human Sampler takes 30k random cells from the full human dataset and produces a 3million cell subset -""" - -import random, csv -import scipy.sparse as sp - - -def load_data(n): - print(f"Loading chunk {n}") - with open( - f"/active/debruinz_project/human_data/python_data/chunk{n}_metadata.csv", "r" - ) as metadata_file: - metadata_reader = csv.reader(metadata_file) - metadata = list(metadata_reader) - return ( - sp.load_npz( - f"/active/debruinz_project/human_data/python_data/human_chunk_{n}.npz" - ), - metadata, - ) - - -def write_data(chunk, metadata, n): - print("Writing to disk") - sp.save_npz( - f"/active/debruinz_project/CellCensus_3M/3m_human_chunk_{n}.npz", - sp.vstack(chunk), - ) - with open( - f"/active/debruinz_project/CellCensus_3M/3m_human_metadata_{n}.csv", "w" - ) as subset_metadata_file: - metadata_writer = csv.writer(subset_metadata_file) - metadata_writer.writerows(metadata) - print(f"Chunk {n} complete!") - - -def main(): - print("Beginning human sampler") - n_written = 1 # Track what chunk is to be written - n_selections = 30000 # Sampling 30K cells per chunk - max_index = 285341 # Controls the generated range of indicies - chunk_slices = [] - metadata_slices = [] - for i in range(1, 101): - # Loop and load all 100 chunks of the full dataset - chunk, metadata = load_data(i) - if i == 100: - max_index = 285256 # Chunk 100 is a different size from the rest - # Get 30k random indicies to slice cells from the current chunk - indicies = random.sample(range(0, max_index), n_selections) - for idx in indicies: - chunk_slices.append(chunk[idx, :]) - metadata_slices.append( - metadata[idx + 1] - ) # Adjusting for t he column header row in the metadata - if len(chunk_slices) == 100000: - write_data(chunk_slices, metadata_slices, n_written) - chunk_slices = [] - metadata_slices = [] - n_written += 1 - - -main() diff --git a/scripts/data-preprocessing/3m_sampler_mouse.py b/scripts/data-preprocessing/3m_sampler_mouse.py deleted file mode 100755 index b948d59..0000000 --- a/scripts/data-preprocessing/3m_sampler_mouse.py +++ /dev/null @@ -1,64 +0,0 @@ -""" -2/7/2024 - Anthony Boos -The 3m Mouse Sampler takes 30k random cells from the full mouse dataset and produces a 3million cell subset -""" -import random, csv -import scipy.sparse as sp - - -def load_data(n): - print(f"Loading chunk {n}") - with open( - f"/active/debruinz_project/mouse_data/python_data/mouse_metadata_{n}.csv", "r" - ) as metadata_file: - metadata_reader = csv.reader(metadata_file) - metadata = list(metadata_reader) - return ( - sp.load_npz( - f"/active/debruinz_project/mouse_data/python_data/mouse_chunk_{n}.npz" - ), - metadata, - ) - - -def write_data(chunk, metadata, n): - print("Writing to disk") - sp.save_npz( - f"/active/debruinz_project/CellCensus_3M/3m_mouse_chunk_{n}.npz", - sp.vstack(chunk), - ) - with open( - f"/active/debruinz_project/CellCensus_3M/3m_mouse_metadata_{n}.csv", "w" - ) as subset_metadata_file: - metadata_writer = csv.writer(subset_metadata_file) - metadata_writer.writerows(metadata) - print(f"Chunk {n} complete!") - - -def main(): - print("Beginning mouse sampler") - n_written = 1 # Track what chunk is to be written - n_selections = 30000 # Sampling 30K cells per chunk - max_index = 38525 # Controls the generated range of indicies - chunk_slices = [] - metadata_slices = [] - for i in range(1, 101): - # Loop and load all 100 chunks of the full dataset - chunk, metadata = load_data(i) - if i == 100: - max_index = 38461 # Chunk 100 is a different size from the rest - # Get 30k random indicies to slice cells from the current chunk - indicies = random.sample(range(0, max_index), n_selections) - for idx in indicies: - chunk_slices.append(chunk[idx, :]) - metadata_slices.append( - metadata[idx + 1] - ) # Adjusting for the column header row in the metadata - if len(chunk_slices) == 100000: - write_data(chunk_slices, metadata_slices, n_written) - chunk_slices = [] - metadata_slices = [] - n_written += 1 - - -main() diff --git a/scripts/data-preprocessing/cell_census_mouse_download.R b/scripts/data-preprocessing/cell_census_mouse_download.R deleted file mode 100644 index a2ea98f..0000000 --- a/scripts/data-preprocessing/cell_census_mouse_download.R +++ /dev/null @@ -1,50 +0,0 @@ -# Mouse data pulled on 2/5/2024 - -library("tiledb") - -library("tiledbsoma") - -library("cellxgene.census") - -census <- open_soma() - -df <- as.data.frame( - - census$get("census_data")$get("mus_musculus")$obs$read( - - value_filter = "is_primary_data == TRUE", - - )$concat()) - -# filter to only assays mentioning "10x" and including "microwell-seq" - -df<- df [grepl("10x", df$assay) | df$assay == "microwell-seq",] - -# scramble df - -df <- df[sample(1:nrow(df), nrow(df)), ] - -# break up df into 100 chunks - -n <- ceiling(nrow(df) / 100) - -chunks <- split(df$soma_joinid, rep(1:ceiling(nrow(df)/n), each = n, length.out = nrow(df))) -metadata_chunks <- split(df, rep(1:ceiling(nrow(df)/n), each = n, length.out = nrow(df))) - -for(i in 1:length(chunks)){ - - cat("CHUNK", i, "/", length(chunks),"\n") - saveRDS(metadata_chunks[[i]], paste0("/active/debruinz_project/mouse_data/r_data/mouse_metadata_", i, ".rds")) - query <- census$get("census_data")$get("mus_musculus")$axis_query( - - measurement_name = "RNA", - - obs_query = SOMAAxisQuery$new(coords = chunks[[i]]) - - ) - - A <- query$to_sparse_matrix(collection = "X", layer_name = "raw", obs_index = "soma_joinid", var_index = "feature_id" ) - A <- as(A, "dgCMatrix") - saveRDS(A, paste0("/active/debruinz_project/mouse_data/r_data/mouse_chunk_", i, ".rds")) - -} diff --git a/scripts/data-preprocessing/mouse_data_R_to_Python.R b/scripts/data-preprocessing/mouse_data_R_to_Python.R deleted file mode 100644 index b14a841..0000000 --- a/scripts/data-preprocessing/mouse_data_R_to_Python.R +++ /dev/null @@ -1,22 +0,0 @@ -library(reticulate) -# Ensures that scipy is available -use_virtualenv("/active/debruinz_project/tony_boos/torch_venv/", required = TRUE) -reticulate::py_config() -library(nmslibR) -scipy <- import("scipy") - -for (i in 1:100) { - - print(paste0("Chunk ", i)) - - dgCMatrix <- readRDS(paste0("/active/debruinz_project/mouse_data/r_data/mouse_chunk_", i, ".rds")) - - # Convert R dgCMatrix to a scipy sparse matrix - scipy_matrix <- nmslibR::TO_scipy_sparse(dgCMatrix) - - # Convert the scipy matrix to CSR format - csr_matrix <- scipy_matrix$tocsr() - - # Save the CSR matrix to an npz file - scipy$sparse$save_npz(paste0("/active/debruinz_project/mouse_data/python_data/mouse_chunk_", i, ".npz"), csr_matrix) -} diff --git a/scripts/data-preprocessing/mouse_metadata_conversion.R b/scripts/data-preprocessing/mouse_metadata_conversion.R deleted file mode 100644 index 2ef97e9..0000000 --- a/scripts/data-preprocessing/mouse_metadata_conversion.R +++ /dev/null @@ -1,4 +0,0 @@ -for(i in 1:100){ - metadata = readRDS(paste0("/active/debruinz_project/mouse_data/r_data/mouse_metadata_", i, ".rds")) - write.csv(metadata, paste0("/active/debruinz_project/mouse_data/python_data/mouse_metadata_", i, ".csv")) -} \ No newline at end of file From b15bcc5896f2b1ee6a2c932f57572bd41500460c Mon Sep 17 00:00:00 2001 From: Tony Boos Date: Fri, 27 Sep 2024 13:44:45 -0400 Subject: [PATCH 04/14] Adds data download scripts --- .../census_data_download.py | 62 +++++++++++++++++ .../data_processing_functions.py | 69 +++++++++++++++++++ 2 files changed, 131 insertions(+) create mode 100644 scripts/data-preprocessing/census_data_download.py create mode 100644 scripts/data-preprocessing/data_processing_functions.py diff --git a/scripts/data-preprocessing/census_data_download.py b/scripts/data-preprocessing/census_data_download.py new file mode 100644 index 0000000..7d4594a --- /dev/null +++ b/scripts/data-preprocessing/census_data_download.py @@ -0,0 +1,62 @@ +import argparse as ap +import cellxgene_census +import numpy as np +import os +import pandas as pd +import scipy.sparse as sp +import tiledbsoma as tdb + +from data_processing_functions import normalize_data, save_data_to_disk, verify_data + +VALUE_FILTER = 'is_primary_data == True and assay in ["microwell-seq", "10x 3\' v1", "10x 3\' v2", "10x 3\' v3", "10x 3\' transcription profiling", "10x 5\' transcription profiling", "10x 5\' v1", "10x 5\' v2"]' +VALID_SPECIES = ["homo_sapiens", "mus_musculus"] +SPECIES_MAP = {"homo_sapies": "human", "mus_musculus": "mouse"} +CHUNK_SIZE = 499968 + +def process_chunk(census: tdb.Collection, species: str, ids: list[int], chunk_n: int, save_dir: os.PathLike): + adata = cellxgene_census.get_anndata(census, species, "RNA", "raw", obs_value_filter=VALUE_FILTER, obs_coords=ids) + normalize_data(adata.X) + permutation = np.random.permutation(range(adata.X.shape[0])) + save_data_to_disk( + data_path=os.path.join(save_dir, f'{SPECIES_MAP[species]}_counts_{chunk_n}.npz'), + data=adata.X[permutation, :], + metadata_path=os.path.join(save_dir, f'{SPECIES_MAP[species]}_metadata_{chunk_n}.pkl'), + metdata=adata.obs[permutation].reset_index(drop=True) + ) + +def main(directory: os.PathLike, species: str, seed: int): + + if species not in VALID_SPECIES: + raise ValueError(f"Error: Invalid species provided - {species}. Valid values are: {VALID_SPECIES}") + + if not os.path.exists(directory): + raise FileExistsError("Error: Provided directory does not exist!") + + with cellxgene_census.open_soma(census_version="2024-07-01") as census: + + soma_ids = census["census_data"][species].obs.read(value_filter=VALUE_FILTER, column_names=["soma_joinid"]).concat().to_pandas() + soma_ids = list(soma_ids['soma_joinid']) + + num_samples = len(soma_ids) + set_ids = set(soma_ids) + assert num_samples == len(set_ids) + + np.random.seed(seed) + np.random.shuffle(soma_ids) + + chunk_count = 0 + for i in range(0, len(soma_ids), CHUNK_SIZE): + chunk_count += 1 + process_chunk(census, species, soma_ids[i:i+CHUNK_SIZE], chunk_count, directory) + + verify_data(directory, SPECIES_MAP[species], set_ids, CHUNK_SIZE, chunk_count, num_samples - (CHUNK_SIZE * (chunk_count - 1))) + + +if __name__ == "__main__": + parser = ap.ArgumentParser() + parser.add_argument("--directory", type=str, required=True, help="Directory to save data in.") + parser.add_argument("--species", type=str, required=True, help="Species to download data for.") + parser.add_argument("--seed", type=int, required=False, default=42) + + args = parser.parse_args() + main(args.directory, args.species, args.seed) \ No newline at end of file diff --git a/scripts/data-preprocessing/data_processing_functions.py b/scripts/data-preprocessing/data_processing_functions.py new file mode 100644 index 0000000..f82adfe --- /dev/null +++ b/scripts/data-preprocessing/data_processing_functions.py @@ -0,0 +1,69 @@ +import glob +import os +import re +import numpy as np +import pandas as pd +import scipy.sparse as sp + +def normalize_data(data: sp.csr_matrix): + for row in range(data.shape[0]): + start_idx = data.indptr[row] + end_idx = data.indptr[row + 1] + + row_data = data.getrow(row).data + row_sum = row_data.sum() + + # Apply the transformation and log1p + data.data[start_idx:end_idx] = np.log1p((row_data * 1e4) / row_sum) + + +def save_data_to_disk(data_path: os.PathLike, metadata_path: os.PathLike, data: sp.csr_matrix, metdata: pd.DataFrame): + sp.save_npz(data_path, data) + metdata.to_pickle(metadata_path, compression=None) + +def record_stats(path: os.PathLike): + pass + +def gather_stats(data: sp.csr_matrix, metadata: pd.DataFrame): + pass + +def verify_data(directory: os.PathLike, species: str, ids: set[int], expected_size: int, last_chunk: int = None, last_size: int = None): + + def extract_number(filename): + match = re.search(r'_(\d+)', filename) # Look for digits after an underscore (_) + return int(match.group(1)) if match else 0 + + data_files = glob.glob(os.path.join(directory, f'{species}*.npz')) + metadata_files = glob.glob(os.path.join(directory, f'{species}*.pkl')) + data_files.sort(key=extract_number) + metadata_files.sort(key=extract_number) + + for i, (data_path, metadata_path) in enumerate(zip(data_files, metadata_files), start=1): + errors_detected = False + + data = sp.load_npz(data_path) + metadata = pd.read_pickle(metadata_path) + + if data.shape[0] != expected_size: + if i != last_chunk: + print(f"Chunk size mismatch in chunk #{i}!!!!!") + errors_detected = True + elif data.shape[0] != last_size: + print(f"Chunk size mismatch in chunk #{i}!!!!!") + errors_detected = True + + soma_ids = list(metadata['soma_joinid']) + num_ids = len(soma_ids) + if num_ids != data.shape[0]: + print(f"Metadata mismatch in chunk #{i}!!!!!") + errors_detected = True + + set_ids = set(soma_ids) + if len(set_ids) != num_ids: + print(f'Duplicate IDs found in chunk #{i}!!!!!') + errors_detected = True + + ids.intersection_update(set_ids) + + if not errors_detected: + print(f"No issues found in Chunk #{i}.") \ No newline at end of file From ae58e043536c3b89e01e840aeb92e6b8b14b3c25 Mon Sep 17 00:00:00 2001 From: Tony Boos Date: Mon, 30 Sep 2024 14:57:35 -0400 Subject: [PATCH 05/14] Adds option to download a random subset --- .../census_data_download.py | 29 +++++++++++++------ 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/scripts/data-preprocessing/census_data_download.py b/scripts/data-preprocessing/census_data_download.py index 7d4594a..bf5ce0c 100644 --- a/scripts/data-preprocessing/census_data_download.py +++ b/scripts/data-preprocessing/census_data_download.py @@ -10,8 +10,7 @@ VALUE_FILTER = 'is_primary_data == True and assay in ["microwell-seq", "10x 3\' v1", "10x 3\' v2", "10x 3\' v3", "10x 3\' transcription profiling", "10x 5\' transcription profiling", "10x 5\' v1", "10x 5\' v2"]' VALID_SPECIES = ["homo_sapiens", "mus_musculus"] -SPECIES_MAP = {"homo_sapies": "human", "mus_musculus": "mouse"} -CHUNK_SIZE = 499968 +SPECIES_MAP = {"homo_sapiens": "human", "mus_musculus": "mouse"} def process_chunk(census: tdb.Collection, species: str, ids: list[int], chunk_n: int, save_dir: os.PathLike): adata = cellxgene_census.get_anndata(census, species, "RNA", "raw", obs_value_filter=VALUE_FILTER, obs_coords=ids) @@ -21,10 +20,10 @@ def process_chunk(census: tdb.Collection, species: str, ids: list[int], chunk_n: data_path=os.path.join(save_dir, f'{SPECIES_MAP[species]}_counts_{chunk_n}.npz'), data=adata.X[permutation, :], metadata_path=os.path.join(save_dir, f'{SPECIES_MAP[species]}_metadata_{chunk_n}.pkl'), - metdata=adata.obs[permutation].reset_index(drop=True) + metdata=adata.obs.iloc[permutation].reset_index(drop=True) ) -def main(directory: os.PathLike, species: str, seed: int): +def main(directory: os.PathLike, species: str, chunk_size: int, seed: int, sample_size: int): if species not in VALID_SPECIES: raise ValueError(f"Error: Invalid species provided - {species}. Valid values are: {VALID_SPECIES}") @@ -44,19 +43,31 @@ def main(directory: os.PathLike, species: str, seed: int): np.random.seed(seed) np.random.shuffle(soma_ids) + if sample_size is not None: + soma_ids = np.random.choice(soma_ids, sample_size, replace=False) + chunk_count = 0 - for i in range(0, len(soma_ids), CHUNK_SIZE): + for i in range(0, len(soma_ids), chunk_size): chunk_count += 1 - process_chunk(census, species, soma_ids[i:i+CHUNK_SIZE], chunk_count, directory) + process_chunk(census, species, soma_ids[i:i+chunk_size], chunk_count, directory) - verify_data(directory, SPECIES_MAP[species], set_ids, CHUNK_SIZE, chunk_count, num_samples - (CHUNK_SIZE * (chunk_count - 1))) + verify_data(directory, SPECIES_MAP[species], set_ids, chunk_size, chunk_count, num_samples % chunk_size) if __name__ == "__main__": parser = ap.ArgumentParser() parser.add_argument("--directory", type=str, required=True, help="Directory to save data in.") parser.add_argument("--species", type=str, required=True, help="Species to download data for.") - parser.add_argument("--seed", type=int, required=False, default=42) + parser.add_argument("--chunk_size", type=int, required=True, help="Number of samples to save in each chunk of data.") + parser.add_argument("--seed", type=int, required=False, default=42, help="Seed for the random module") + parser.add_argument("--sample_size", type=int, required=False, default=None, help="Number of samples to grab out of total. Omit to get all data.") args = parser.parse_args() - main(args.directory, args.species, args.seed) \ No newline at end of file + + main( + directory= args.directory, + species= args.species, + chunk_size= args.chunk_size, + seed= args.seed, + sample_size= args.sample_size + ) \ No newline at end of file From da5060230da575fd6e3deb862297c2b3df0592c7 Mon Sep 17 00:00:00 2001 From: Tony Boos Date: Mon, 30 Sep 2024 16:59:13 -0400 Subject: [PATCH 06/14] Adds stat taking functions --- .../data-preprocessing/census_data_stats.py | 56 +++++++++++++++++++ .../data_processing_functions.py | 42 +++++++++++--- 2 files changed, 90 insertions(+), 8 deletions(-) create mode 100644 scripts/data-preprocessing/census_data_stats.py diff --git a/scripts/data-preprocessing/census_data_stats.py b/scripts/data-preprocessing/census_data_stats.py new file mode 100644 index 0000000..8aa402a --- /dev/null +++ b/scripts/data-preprocessing/census_data_stats.py @@ -0,0 +1,56 @@ +import glob +import os +import argparse as ap +import pandas as pd +import scipy.sparse as sp + +from typing import Union +from collections import defaultdict +from data_processing_functions import extract_file_number, gather_stats, record_stats, DATA_CATEGORIES + +def update_totals(totals: dict[str, dict[str, int]], data: dict[str, Union[int, dict[str, int]]]): + + for category in DATA_CATEGORIES: + for key, value in data[category]: + totals[category][key] += value + +def main(directory: os.PathLike, species: str): + + data_files = glob.glob(os.path.join(directory, f'{species}*.npz')) + metadata_files = glob.glob(os.path.join(directory, f'{species}*.pkl')) + data_files.sort(key=extract_file_number) + metadata_files.sort(key=extract_file_number) + + totals = { + cat: defaultdict(int) for cat in DATA_CATEGORIES + } + + for data_path, metadata_path in zip(data_files, metadata_files): + data = sp.load_npz(data_path) + metadata = pd.read_pickle(metadata_path) + + chunk_stats = gather_stats(data, metadata) + filename = f'{os.path.basename(data_path).split(".")[0]}_stats.csv' + + update_totals(totals, chunk_stats) + + record_stats( + path= os.path.join(directory, filename), + data= chunk_stats + ) + + record_stats( + path=os.path.join(directory, f'{species}_stat_totals.csv'), + data=totals + ) + +if __name__ == '__main__': + parser = ap.ArgumentParser() + parser.add_argument("--directory", type=str, required=True, help="Directory to load data from.") + parser.add_argument("--species", type=str, required=True, help="Species to check data for.") + + args = parser.parse_args() + main( + directory= args.directory, + species= args.species + ) \ No newline at end of file diff --git a/scripts/data-preprocessing/data_processing_functions.py b/scripts/data-preprocessing/data_processing_functions.py index f82adfe..aaf34b4 100644 --- a/scripts/data-preprocessing/data_processing_functions.py +++ b/scripts/data-preprocessing/data_processing_functions.py @@ -1,3 +1,4 @@ +import csv import glob import os import re @@ -5,6 +6,10 @@ import pandas as pd import scipy.sparse as sp +from typing import Union + +DATA_CATEGORIES = ['assay', 'cell_type', 'tissue'] + def normalize_data(data: sp.csr_matrix): for row in range(data.shape[0]): start_idx = data.indptr[row] @@ -21,22 +26,43 @@ def save_data_to_disk(data_path: os.PathLike, metadata_path: os.PathLike, data: sp.save_npz(data_path, data) metdata.to_pickle(metadata_path, compression=None) -def record_stats(path: os.PathLike): - pass +def record_stats(path: os.PathLike, data: dict[str, Union[int, dict[str, int]]]): + + with open(path, mode='w', newline='') as file: + writer = csv.writer(file) + writer.writerow(['Stat', 'Value(s)']) + + for key, value in data.items(): + if isinstance(value, dict): + merged = ', '.join([f'{k} - {v}' for k, v in value.items()]) + else: + writer.writerow([key, value]) + +def get_data_stats(data: sp.csr_matrix): + row_sums = data.sum(axis=1).A.flatten() + return row_sums.mean(), np.std(row_sums) def gather_stats(data: sp.csr_matrix, metadata: pd.DataFrame): - pass + chunk_stats = {} + mean, std = get_data_stats(data) + chunk_stats["mean"] = mean + chunk_stats["std"] = std -def verify_data(directory: os.PathLike, species: str, ids: set[int], expected_size: int, last_chunk: int = None, last_size: int = None): - - def extract_number(filename): + for col in DATA_CATEGORIES: + chunk_stats[col] = metadata[col].value_counts().to_dict() + + return chunk_stats + +def extract_file_number(filename): match = re.search(r'_(\d+)', filename) # Look for digits after an underscore (_) return int(match.group(1)) if match else 0 + +def verify_data(directory: os.PathLike, species: str, ids: set[int], expected_size: int, last_chunk: int = None, last_size: int = None): data_files = glob.glob(os.path.join(directory, f'{species}*.npz')) metadata_files = glob.glob(os.path.join(directory, f'{species}*.pkl')) - data_files.sort(key=extract_number) - metadata_files.sort(key=extract_number) + data_files.sort(key=extract_file_number) + metadata_files.sort(key=extract_file_number) for i, (data_path, metadata_path) in enumerate(zip(data_files, metadata_files), start=1): errors_detected = False From b609acd0dd171546a5b9b4d8aa25245a2b7d1f98 Mon Sep 17 00:00:00 2001 From: Tony Boos Date: Wed, 2 Oct 2024 13:02:29 -0400 Subject: [PATCH 07/14] Adds multi-processing to the data download --- .../census_data_download.py | 41 ++++++++++++------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/scripts/data-preprocessing/census_data_download.py b/scripts/data-preprocessing/census_data_download.py index bf5ce0c..3b96fb3 100644 --- a/scripts/data-preprocessing/census_data_download.py +++ b/scripts/data-preprocessing/census_data_download.py @@ -1,5 +1,6 @@ import argparse as ap import cellxgene_census +import multiprocessing as mp import numpy as np import os import pandas as pd @@ -8,12 +9,13 @@ from data_processing_functions import normalize_data, save_data_to_disk, verify_data -VALUE_FILTER = 'is_primary_data == True and assay in ["microwell-seq", "10x 3\' v1", "10x 3\' v2", "10x 3\' v3", "10x 3\' transcription profiling", "10x 5\' transcription profiling", "10x 5\' v1", "10x 5\' v2"]' +VALUE_FILTER = 'is_primary_data == True' VALID_SPECIES = ["homo_sapiens", "mus_musculus"] SPECIES_MAP = {"homo_sapiens": "human", "mus_musculus": "mouse"} -def process_chunk(census: tdb.Collection, species: str, ids: list[int], chunk_n: int, save_dir: os.PathLike): - adata = cellxgene_census.get_anndata(census, species, "RNA", "raw", obs_value_filter=VALUE_FILTER, obs_coords=ids) +def process_chunk(species: str, ids: list[int], chunk_n: int, save_dir: os.PathLike): + with cellxgene_census.open_soma(census_version="2024-07-01") as census: + adata = cellxgene_census.get_anndata(census, species, "RNA", "raw", obs_value_filter=VALUE_FILTER, obs_coords=ids) normalize_data(adata.X) permutation = np.random.permutation(range(adata.X.shape[0])) save_data_to_disk( @@ -23,7 +25,7 @@ def process_chunk(census: tdb.Collection, species: str, ids: list[int], chunk_n: metdata=adata.obs.iloc[permutation].reset_index(drop=True) ) -def main(directory: os.PathLike, species: str, chunk_size: int, seed: int, sample_size: int): +def main(directory: os.PathLike, species: str, chunk_size: int, processes: int, seed: int, sample_size: int): if species not in VALID_SPECIES: raise ValueError(f"Error: Invalid species provided - {species}. Valid values are: {VALID_SPECIES}") @@ -36,30 +38,38 @@ def main(directory: os.PathLike, species: str, chunk_size: int, seed: int, sampl soma_ids = census["census_data"][species].obs.read(value_filter=VALUE_FILTER, column_names=["soma_joinid"]).concat().to_pandas() soma_ids = list(soma_ids['soma_joinid']) - num_samples = len(soma_ids) - set_ids = set(soma_ids) - assert num_samples == len(set_ids) + total_data = len(soma_ids) + set_ids = set(soma_ids) + assert total_data == len(set_ids) + assert chunk_size <= total_data + if sample_size is not None: + assert sample_size <= total_data + assert chunk_size <= sample_size - np.random.seed(seed) - np.random.shuffle(soma_ids) + np.random.seed(seed) + np.random.shuffle(soma_ids) - if sample_size is not None: - soma_ids = np.random.choice(soma_ids, sample_size, replace=False) + if sample_size is not None: + soma_ids = np.random.choice(soma_ids, sample_size, replace=False) + mp.set_start_method('spawn') + with mp.Pool(processes= processes) as pool: chunk_count = 0 for i in range(0, len(soma_ids), chunk_size): chunk_count += 1 - process_chunk(census, species, soma_ids[i:i+chunk_size], chunk_count, directory) + pool.apply_async(process_chunk, args=(species, soma_ids[i:i+chunk_size], chunk_count, directory)) + pool.close() + pool.join() - verify_data(directory, SPECIES_MAP[species], set_ids, chunk_size, chunk_count, num_samples % chunk_size) + verify_data(directory, SPECIES_MAP[species], set_ids, chunk_size, chunk_count, total_data % chunk_size) - if __name__ == "__main__": parser = ap.ArgumentParser() parser.add_argument("--directory", type=str, required=True, help="Directory to save data in.") parser.add_argument("--species", type=str, required=True, help="Species to download data for.") parser.add_argument("--chunk_size", type=int, required=True, help="Number of samples to save in each chunk of data.") - parser.add_argument("--seed", type=int, required=False, default=42, help="Seed for the random module") + parser.add_argument("--processes", type=int, required=False, default=1, help="Number of sub-processes to use for the download.") + parser.add_argument("--seed", type=int, required=False, default=42, help="Seed for the random module.") parser.add_argument("--sample_size", type=int, required=False, default=None, help="Number of samples to grab out of total. Omit to get all data.") args = parser.parse_args() @@ -68,6 +78,7 @@ def main(directory: os.PathLike, species: str, chunk_size: int, seed: int, sampl directory= args.directory, species= args.species, chunk_size= args.chunk_size, + processes= args.processes, seed= args.seed, sample_size= args.sample_size ) \ No newline at end of file From d0a9b41453c7e14fad5a07d6fa124ccafe4ba7cc Mon Sep 17 00:00:00 2001 From: Tony Boos Date: Mon, 7 Oct 2024 17:05:00 -0400 Subject: [PATCH 08/14] Updates stat functions to be per-category --- .../data-preprocessing/census_data_stats.py | 82 +++++++++++----- .../data_processing_functions.py | 96 +++++++++++++------ 2 files changed, 126 insertions(+), 52 deletions(-) diff --git a/scripts/data-preprocessing/census_data_stats.py b/scripts/data-preprocessing/census_data_stats.py index 8aa402a..bb7b712 100644 --- a/scripts/data-preprocessing/census_data_stats.py +++ b/scripts/data-preprocessing/census_data_stats.py @@ -1,53 +1,85 @@ import glob import os import argparse as ap +from collections import defaultdict + import pandas as pd import scipy.sparse as sp -from typing import Union -from collections import defaultdict from data_processing_functions import extract_file_number, gather_stats, record_stats, DATA_CATEGORIES -def update_totals(totals: dict[str, dict[str, int]], data: dict[str, Union[int, dict[str, int]]]): - - for category in DATA_CATEGORIES: - for key, value in data[category]: - totals[category][key] += value - def main(directory: os.PathLike, species: str): + """ + Main function for gathering stats on downloaded CellxGene Census data. + + The mean and standard deviation for each file are recorded to a csv. + The counts for each unique value for each type of data specified in + DATA_CATEGORIES is also recorded per-file alongside total counts. + + The is meant as a sanity check on randomly sampled subsets of the + full data and is not intended to gather stats on all available + CellxGene census data. - data_files = glob.glob(os.path.join(directory, f'{species}*.npz')) - metadata_files = glob.glob(os.path.join(directory, f'{species}*.pkl')) + Args: + directory (PathLike): + The directory holding the downloaded data. + species (str): + The name of the species whose data is being checked. This + is used to perform pattern matching in the directory. + """ + data_files = glob.glob(os.path.join(directory, f"{species}*.npz")) + metadata_files = glob.glob(os.path.join(directory, f"{species}*.pkl")) + + # Sort by file number rather than filename data_files.sort(key=extract_file_number) metadata_files.sort(key=extract_file_number) - totals = { - cat: defaultdict(int) for cat in DATA_CATEGORIES + # Storage for data on all files + count_stats = defaultdict(dict) + total_stats = { + category: defaultdict(lambda: defaultdict(int)) + for category in DATA_CATEGORIES } + filenames = [] for data_path, metadata_path in zip(data_files, metadata_files): + + file = os.path.basename(data_path).split(".")[0] + filenames.append(file) + data = sp.load_npz(data_path) metadata = pd.read_pickle(metadata_path) - + chunk_stats = gather_stats(data, metadata) - filename = f'{os.path.basename(data_path).split(".")[0]}_stats.csv' - update_totals(totals, chunk_stats) + count_stats["mean"][file] = chunk_stats["mean"] + count_stats["std"][file] = chunk_stats["std"] - record_stats( - path= os.path.join(directory, filename), - data= chunk_stats - ) + # Update the total counts with the current files data + for category in DATA_CATEGORIES: + for value, count in chunk_stats[category].items(): + total_stats[category][value][file] = count - record_stats( - path=os.path.join(directory, f'{species}_stat_totals.csv'), - data=totals + features_df = pd.DataFrame(count_stats) + features_df.to_csv( + os.path.join(directory, + f"{species}_data_distribution.csv") ) -if __name__ == '__main__': + for category in DATA_CATEGORIES: + df = pd.DataFrame(data[category].T.fillna(0)) + df.columns = filenames + df["Total"] = df.sum(axis=1) + df.to_csv(os.path.join(directory, f"{species}_{category}_distribution.csv")) + +if __name__ == "__main__": parser = ap.ArgumentParser() - parser.add_argument("--directory", type=str, required=True, help="Directory to load data from.") - parser.add_argument("--species", type=str, required=True, help="Species to check data for.") + parser.add_argument( + "--directory", type=str, required=True, help="Directory to load data from." + ) + parser.add_argument( + "--species", type=str, required=True, help="Species to check data for." + ) args = parser.parse_args() main( diff --git a/scripts/data-preprocessing/data_processing_functions.py b/scripts/data-preprocessing/data_processing_functions.py index aaf34b4..b0b9726 100644 --- a/scripts/data-preprocessing/data_processing_functions.py +++ b/scripts/data-preprocessing/data_processing_functions.py @@ -2,15 +2,25 @@ import glob import os import re + import numpy as np import pandas as pd import scipy.sparse as sp -from typing import Union +DATA_CATEGORIES = ["assay", "cell_type", "tissue"] -DATA_CATEGORIES = ['assay', 'cell_type', 'tissue'] +def normalize_data( + data: sp.csr_matrix, +): + """ + Function to perform normalization on the RAW counts. Each row is summed, + and then every value in the row is multiplied by 1e4 before being divided + by the row sum. These values are passed to the standard log1p function. -def normalize_data(data: sp.csr_matrix): + Args: + data (scipy.sparse.csr_matrix): + The data to normalize. Normalization happens in-place! + """ for row in range(data.shape[0]): start_idx = data.indptr[row] end_idx = data.indptr[row + 1] @@ -18,31 +28,33 @@ def normalize_data(data: sp.csr_matrix): row_data = data.getrow(row).data row_sum = row_data.sum() - # Apply the transformation and log1p data.data[start_idx:end_idx] = np.log1p((row_data * 1e4) / row_sum) - -def save_data_to_disk(data_path: os.PathLike, metadata_path: os.PathLike, data: sp.csr_matrix, metdata: pd.DataFrame): +def save_data_to_disk( + data_path: os.PathLike, + metadata_path: os.PathLike, + data: sp.csr_matrix, + metdata: pd.DataFrame, +): + """ + Helper function to save a pair of count data and its associated + metadata to disk at the provided paths. + """ sp.save_npz(data_path, data) metdata.to_pickle(metadata_path, compression=None) -def record_stats(path: os.PathLike, data: dict[str, Union[int, dict[str, int]]]): - - with open(path, mode='w', newline='') as file: - writer = csv.writer(file) - writer.writerow(['Stat', 'Value(s)']) - - for key, value in data.items(): - if isinstance(value, dict): - merged = ', '.join([f'{k} - {v}' for k, v in value.items()]) - else: - writer.writerow([key, value]) - def get_data_stats(data: sp.csr_matrix): + """ + Helper function to get the mean and std of each file. + """ row_sums = data.sum(axis=1).A.flatten() return row_sums.mean(), np.std(row_sums) def gather_stats(data: sp.csr_matrix, metadata: pd.DataFrame): + """ + Gathers counts of unique values for each metadata category specified + in DATA_CATEGORIES. Also gets the mean and std of the feature expressions. + """ chunk_stats = {} mean, std = get_data_stats(data) chunk_stats["mean"] = mean @@ -54,13 +66,43 @@ def gather_stats(data: sp.csr_matrix, metadata: pd.DataFrame): return chunk_stats def extract_file_number(filename): - match = re.search(r'_(\d+)', filename) # Look for digits after an underscore (_) - return int(match.group(1)) if match else 0 - -def verify_data(directory: os.PathLike, species: str, ids: set[int], expected_size: int, last_chunk: int = None, last_size: int = None): - - data_files = glob.glob(os.path.join(directory, f'{species}*.npz')) - metadata_files = glob.glob(os.path.join(directory, f'{species}*.pkl')) + """ + Helper function to sort filenames by their file number. + """ + match = re.search(r"_(\d+)", filename) + return int(match.group(1)) if match else 0 + +def verify_data( + directory: os.PathLike, + species: str, + ids: set[int], + expected_size: int, + last_chunk: int = None, + last_size: int = None, +): + """ + Checks that all data 'chunks' are of the expected size and that they match + their associated metadata. It also checks that no duplicate data exists. + + Args: + directory (PathLike): + The directory holding the downloaded data. + species (str): + The name of the species whose data is being checked. This + is used to perform pattern matching in the directory. + ids (set[int]): + Set containing all expected soma_joinids. Used to ensure no + duplicate data is found in a chunk. + expected_size (int): + The expected number of samples in the data and metadata. + last_chunk (int, optional): + The final chunk to be checked. Used to signal that the last + chunk has a different sample size from the rest. + last_size (int, optional): + The expected number of samples in the last chunk. + """ + data_files = glob.glob(os.path.join(directory, f"{species}*.npz")) + metadata_files = glob.glob(os.path.join(directory, f"{species}*.pkl")) data_files.sort(key=extract_file_number) metadata_files.sort(key=extract_file_number) @@ -78,7 +120,7 @@ def verify_data(directory: os.PathLike, species: str, ids: set[int], expected_si print(f"Chunk size mismatch in chunk #{i}!!!!!") errors_detected = True - soma_ids = list(metadata['soma_joinid']) + soma_ids = list(metadata["soma_joinid"]) num_ids = len(soma_ids) if num_ids != data.shape[0]: print(f"Metadata mismatch in chunk #{i}!!!!!") @@ -86,7 +128,7 @@ def verify_data(directory: os.PathLike, species: str, ids: set[int], expected_si set_ids = set(soma_ids) if len(set_ids) != num_ids: - print(f'Duplicate IDs found in chunk #{i}!!!!!') + print(f"Duplicate IDs found in chunk #{i}!!!!!") errors_detected = True ids.intersection_update(set_ids) From 001da81e89b4f2b028fb066f7ad4776058449711 Mon Sep 17 00:00:00 2001 From: Tony Boos Date: Mon, 7 Oct 2024 17:05:22 -0400 Subject: [PATCH 09/14] Adds documentation --- .../census_data_download.py | 145 ++++++++++++++++-- 1 file changed, 129 insertions(+), 16 deletions(-) diff --git a/scripts/data-preprocessing/census_data_download.py b/scripts/data-preprocessing/census_data_download.py index 3b96fb3..4c6a4c4 100644 --- a/scripts/data-preprocessing/census_data_download.py +++ b/scripts/data-preprocessing/census_data_download.py @@ -9,13 +9,46 @@ from data_processing_functions import normalize_data, save_data_to_disk, verify_data -VALUE_FILTER = 'is_primary_data == True' +CENSUS_VERSION = "2024-07-01" +VALUE_FILTER = 'is_primary_data == True' # Metadata filters applied when retrieving data VALID_SPECIES = ["homo_sapiens", "mus_musculus"] SPECIES_MAP = {"homo_sapiens": "human", "mus_musculus": "mouse"} -def process_chunk(species: str, ids: list[int], chunk_n: int, save_dir: os.PathLike): - with cellxgene_census.open_soma(census_version="2024-07-01") as census: - adata = cellxgene_census.get_anndata(census, species, "RNA", "raw", obs_value_filter=VALUE_FILTER, obs_coords=ids) +def process_chunk( + save_dir: os.PathLike, + species: str, + ids: list[int], + chunk_n: int, +): + """ + Helper function to download and save a chunk of data. + + This is the spawn point of the processes when MultiProcessing + is used. + + Args: + save_dir (PathLike): + Directory to save the data to. + species (str): + The name of the species whose data is being downloaded. + NOTE: The taxonomic name is expected and is later mapped + to the English name when saving the data. + ids (list[int]): + The soma_joinids to query the data that is to be downloaded. + chunk_n (int): + Current chunk being processes. Appended to the filename when + its saved to disk. + """ + with cellxgene_census.open_soma(census_version=CENSUS_VERSION) as census: + adata = cellxgene_census.get_anndata( + census= census, + organism= species, + measurement_name= "RNA", + X_name= "raw", + obs_value_filter=VALUE_FILTER, + obs_coords=ids + ) + normalize_data(adata.X) permutation = np.random.permutation(range(adata.X.shape[0])) save_data_to_disk( @@ -25,17 +58,54 @@ def process_chunk(species: str, ids: list[int], chunk_n: int, save_dir: os.PathL metdata=adata.obs.iloc[permutation].reset_index(drop=True) ) -def main(directory: os.PathLike, species: str, chunk_size: int, processes: int, seed: int, sample_size: int): +def main( + directory: os.PathLike, + species: str, + chunk_size: int, + processes: int, + seed: int, + sample_size: int, +): + """ + Main downloading function for CellxGene census data. + Can be used to download all available data that matches the provided + filters or a random subset. + + Args: + directory (PathLike): + The directory to save the downloaded data in. + species (str): + The name of the species whose data is being downloaded. + NOTE: The taxonomic name is expected and is later mapped + to the English name when saving the data. + chunk_size (int): + Number of samples to retrieve and save to disk in a single + file. Breaks up the total available data into 'chunks'. + processes (int, optional): + Number of download processes to run at a time. Helps speed + up the download but uses more RAM. Defaults to 1. + seed (int, optional): + Seed for NumPy's Random module. Set for reproducibility. + Defaults to 42. + sample_size (int, optional): + Number of samples to randomly grab out of the total available + data. Defaults to None, and all available data is retrieved. + """ if species not in VALID_SPECIES: - raise ValueError(f"Error: Invalid species provided - {species}. Valid values are: {VALID_SPECIES}") + raise ValueError( + f"Error: Invalid species provided - {species}. Valid values are: {VALID_SPECIES}" + ) if not os.path.exists(directory): raise FileExistsError("Error: Provided directory does not exist!") - with cellxgene_census.open_soma(census_version="2024-07-01") as census: + with cellxgene_census.open_soma(census_version=CENSUS_VERSION) as census: - soma_ids = census["census_data"][species].obs.read(value_filter=VALUE_FILTER, column_names=["soma_joinid"]).concat().to_pandas() + soma_ids = census["census_data"][species].obs.read( + value_filter=VALUE_FILTER, + column_names=["soma_joinid"] + ).concat().to_pandas() soma_ids = list(soma_ids['soma_joinid']) total_data = len(soma_ids) @@ -57,20 +127,63 @@ def main(directory: os.PathLike, species: str, chunk_size: int, processes: int, chunk_count = 0 for i in range(0, len(soma_ids), chunk_size): chunk_count += 1 - pool.apply_async(process_chunk, args=(species, soma_ids[i:i+chunk_size], chunk_count, directory)) + pool.apply_async( + func= process_chunk, + args=(directory, species, soma_ids[i:i+chunk_size], chunk_count) + ) pool.close() pool.join() - verify_data(directory, SPECIES_MAP[species], set_ids, chunk_size, chunk_count, total_data % chunk_size) + verify_data( + directory= directory, + species= SPECIES_MAP[species], + ids= set_ids, + expected_size= chunk_size, + last_chunk= chunk_count, + last_size= total_data % chunk_size + ) if __name__ == "__main__": parser = ap.ArgumentParser() - parser.add_argument("--directory", type=str, required=True, help="Directory to save data in.") - parser.add_argument("--species", type=str, required=True, help="Species to download data for.") - parser.add_argument("--chunk_size", type=int, required=True, help="Number of samples to save in each chunk of data.") - parser.add_argument("--processes", type=int, required=False, default=1, help="Number of sub-processes to use for the download.") - parser.add_argument("--seed", type=int, required=False, default=42, help="Seed for the random module.") - parser.add_argument("--sample_size", type=int, required=False, default=None, help="Number of samples to grab out of total. Omit to get all data.") + parser.add_argument( + "--directory", + type=str, + required=True, + help="Directory to save data in." + ) + parser.add_argument( + "--species", + type=str, + required=True, + help="Species to download data for." + ) + parser.add_argument( + "--chunk_size", + type=int, + required=True, + help="Number of samples to save in each chunk of data." + ) + parser.add_argument( + "--processes", + type=int, + required=False, + default=1, + help="Number of sub-processes to use for the download." + ) + parser.add_argument( + "--seed", + type=int, + required=False, + default=42, + help="Seed for the random module." + ) + parser.add_argument( + "--sample_size", + type=int, + required=False, + default=None, + help="Number of samples to grab out of total. Omit to get all data." + ) args = parser.parse_args() From df68c42fb58c9f11c7ddb37e1347d3dafd24eebb Mon Sep 17 00:00:00 2001 From: Tony Boos Date: Mon, 7 Oct 2024 17:49:31 -0400 Subject: [PATCH 10/14] Fixes bug with transposing the metadata counts dataframe --- scripts/data-preprocessing/census_data_stats.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/scripts/data-preprocessing/census_data_stats.py b/scripts/data-preprocessing/census_data_stats.py index bb7b712..720870a 100644 --- a/scripts/data-preprocessing/census_data_stats.py +++ b/scripts/data-preprocessing/census_data_stats.py @@ -6,7 +6,7 @@ import pandas as pd import scipy.sparse as sp -from data_processing_functions import extract_file_number, gather_stats, record_stats, DATA_CATEGORIES +from data_processing_functions import extract_file_number, gather_stats, DATA_CATEGORIES def main(directory: os.PathLike, species: str): """ @@ -61,16 +61,20 @@ def main(directory: os.PathLike, species: str): total_stats[category][value][file] = count features_df = pd.DataFrame(count_stats) + features_df.columns = ["mean", "std"] features_df.to_csv( - os.path.join(directory, - f"{species}_data_distribution.csv") + os.path.join(directory, f"{species}_data_distribution.csv"), + index_label="file" ) for category in DATA_CATEGORIES: - df = pd.DataFrame(data[category].T.fillna(0)) + df = pd.DataFrame(total_stats[category]).T.fillna(0) df.columns = filenames df["Total"] = df.sum(axis=1) - df.to_csv(os.path.join(directory, f"{species}_{category}_distribution.csv")) + df.to_csv( + os.path.join(directory, f"{species}_{category}_distribution.csv"), + index_label=category + ) if __name__ == "__main__": parser = ap.ArgumentParser() From 029791247f2a71cfe30764fc162a183cfda20355 Mon Sep 17 00:00:00 2001 From: Tony Boos Date: Thu, 10 Oct 2024 09:45:42 -0400 Subject: [PATCH 11/14] Updates default configs to match the new census data release --- configs/data/local.yaml | 124 +++++++++++++++++++------------------- configs/model/config.yaml | 8 +-- 2 files changed, 66 insertions(+), 66 deletions(-) diff --git a/configs/data/local.yaml b/configs/data/local.yaml index f825a88..205dc2c 100644 --- a/configs/data/local.yaml +++ b/configs/data/local.yaml @@ -6,78 +6,78 @@ init_args: name: human return_dense: true batch_size: 128 - directory_path: /mnt/projects/debruinz_project/summer_census_data/3m_subset + directory_path: /mnt/projects/debruinz_project/july2024_census_data/subset train_npz_masks: - - 3m_human_counts_1.npz - - 3m_human_counts_2.npz - - 3m_human_counts_3.npz - - 3m_human_counts_4.npz - - 3m_human_counts_5.npz - - 3m_human_counts_6.npz - - 3m_human_counts_7.npz - - 3m_human_counts_8.npz - - 3m_human_counts_9.npz - - 3m_human_counts_10.npz - - 3m_human_counts_11.npz - - 3m_human_counts_12.npz - - 3m_human_counts_13.npz + - human_counts_1.npz + - human_counts_2.npz + - human_counts_3.npz + - human_counts_4.npz + - human_counts_5.npz + - human_counts_6.npz + - human_counts_7.npz + - human_counts_8.npz + - human_counts_9.npz + - human_counts_10.npz + - human_counts_11.npz + - human_counts_12.npz + - human_counts_13.npz train_metadata_masks: - - 3m_human_metadata_1.pkl - - 3m_human_metadata_2.pkl - - 3m_human_metadata_3.pkl - - 3m_human_metadata_4.pkl - - 3m_human_metadata_5.pkl - - 3m_human_metadata_6.pkl - - 3m_human_metadata_7.pkl - - 3m_human_metadata_8.pkl - - 3m_human_metadata_9.pkl - - 3m_human_metadata_10.pkl - - 3m_human_metadata_11.pkl - - 3m_human_metadata_12.pkl - - 3m_human_metadata_13.pkl - val_npz_masks: 3m_human_counts_14.npz - val_metadata_masks: 3m_human_metadata_14.pkl - test_npz_masks: 3m_human_counts_15.npz - test_metadata_masks: 3m_human_metadata_15.pkl + - human_metadata_1.pkl + - human_metadata_2.pkl + - human_metadata_3.pkl + - human_metadata_4.pkl + - human_metadata_5.pkl + - human_metadata_6.pkl + - human_metadata_7.pkl + - human_metadata_8.pkl + - human_metadata_9.pkl + - human_metadata_10.pkl + - human_metadata_11.pkl + - human_metadata_12.pkl + - human_metadata_13.pkl + val_npz_masks: human_counts_14.npz + val_metadata_masks: human_metadata_14.pkl + test_npz_masks: human_counts_15.npz + test_metadata_masks: human_metadata_15.pkl verbose: false - class_path: cmmvae.data.local.SpeciesManager init_args: name: mouse return_dense: true batch_size: 128 - directory_path: /mnt/projects/debruinz_project/summer_census_data/3m_subset + directory_path: /mnt/projects/debruinz_project/july2024_census_data/subset train_npz_masks: - - 3m_mouse_counts_1.npz - - 3m_mouse_counts_2.npz - - 3m_mouse_counts_3.npz - - 3m_mouse_counts_4.npz - - 3m_mouse_counts_5.npz - - 3m_mouse_counts_6.npz - - 3m_mouse_counts_7.npz - - 3m_mouse_counts_8.npz - - 3m_mouse_counts_9.npz - - 3m_mouse_counts_10.npz - - 3m_mouse_counts_11.npz - - 3m_mouse_counts_12.npz - - 3m_mouse_counts_13.npz + - mouse_counts_1.npz + - mouse_counts_2.npz + - mouse_counts_3.npz + - mouse_counts_4.npz + - mouse_counts_5.npz + - mouse_counts_6.npz + - mouse_counts_7.npz + - mouse_counts_8.npz + - mouse_counts_9.npz + - mouse_counts_10.npz + - mouse_counts_11.npz + - mouse_counts_12.npz + - mouse_counts_13.npz train_metadata_masks: - - 3m_mouse_metadata_1.pkl - - 3m_mouse_metadata_2.pkl - - 3m_mouse_metadata_3.pkl - - 3m_mouse_metadata_4.pkl - - 3m_mouse_metadata_5.pkl - - 3m_mouse_metadata_6.pkl - - 3m_mouse_metadata_7.pkl - - 3m_mouse_metadata_8.pkl - - 3m_mouse_metadata_9.pkl - - 3m_mouse_metadata_10.pkl - - 3m_mouse_metadata_11.pkl - - 3m_mouse_metadata_12.pkl - - 3m_mouse_metadata_13.pkl - val_npz_masks: 3m_mouse_counts_14.npz - val_metadata_masks: 3m_mouse_metadata_14.pkl - test_npz_masks: 3m_mouse_counts_15.npz - test_metadata_masks: 3m_mouse_metadata_15.pkl + - mouse_metadata_1.pkl + - mouse_metadata_2.pkl + - mouse_metadata_3.pkl + - mouse_metadata_4.pkl + - mouse_metadata_5.pkl + - mouse_metadata_6.pkl + - mouse_metadata_7.pkl + - mouse_metadata_8.pkl + - mouse_metadata_9.pkl + - mouse_metadata_10.pkl + - mouse_metadata_11.pkl + - mouse_metadata_12.pkl + - mouse_metadata_13.pkl + val_npz_masks: mouse_counts_14.npz + val_metadata_masks: mouse_metadata_14.pkl + test_npz_masks: mouse_counts_15.npz + test_metadata_masks: mouse_metadata_15.pkl verbose: false num_workers: 2 n_val_workers: 1 diff --git a/configs/model/config.yaml b/configs/model/config.yaml index 0049ffb..4c47fce 100644 --- a/configs/model/config.yaml +++ b/configs/model/config.yaml @@ -53,7 +53,7 @@ init_args: encoder_config: class_path: cmmvae.modules.base.FCBlockConfig init_args: - layers: [ 60664, 1024, 512 ] + layers: [ 60530, 1024, 512 ] dropout_rate: [ 0.1, 0.1 ] use_batch_norm: True use_layer_norm: False @@ -61,7 +61,7 @@ init_args: decoder_config: class_path: cmmvae.modules.base.FCBlockConfig init_args: - layers: [ 512, 1024, 60664 ] + layers: [ 512, 1024, 60530 ] dropout_rate: 0.0 use_batch_norm: False use_layer_norm: False @@ -72,7 +72,7 @@ init_args: encoder_config: class_path: cmmvae.modules.base.FCBlockConfig init_args: - layers: [ 52417, 1024, 512 ] + layers: [ 52437, 1024, 512 ] dropout_rate: [ 0.1, 0.1 ] use_batch_norm: True use_layer_norm: False @@ -80,7 +80,7 @@ init_args: decoder_config: class_path: cmmvae.modules.base.FCBlockConfig init_args: - layers: [ 512, 1024, 52417 ] + layers: [ 512, 1024, 52437 ] dropout_rate: 0.0 use_batch_norm: False use_layer_norm: False From 42bb1bc1b03cdfa2350bd2e03c4dfb6309a63d98 Mon Sep 17 00:00:00 2001 From: Tony Boos Date: Mon, 21 Oct 2024 13:48:44 -0400 Subject: [PATCH 12/14] Inital correlations setup for snakemake --- Snakefile | 60 +++++++++++++++++++++++++++++++++----------- workflow/config.yaml | 2 ++ 2 files changed, 48 insertions(+), 14 deletions(-) diff --git a/Snakefile b/Snakefile index 93164e3..f60edba 100644 --- a/Snakefile +++ b/Snakefile @@ -57,6 +57,12 @@ META_DISC_DIR = os.path.join(RUN_DIR, META_DISC_PATH) ## Define a separate directory for merged outputs to avoid conflicts between different merge operations. MERGED_DIR = os.path.join(RUN_DIR, "merged") +## Define the directory to store correlation outputs +CORRELATION_PATH = config.get("correlation_dir", "correlations") +CORRELATION_DIR = os.path.join(RUN_DIR, CORRELATION_PATH) + +CORRELATION_DATA = config["correlations_data"] + ## Generate the paths for the embeddings and metadata files based on the merge keys. EMBEDDINGS_PATHS = [os.path.join(MERGED_DIR, f"{key}_embeddings.npz") for key in MERGE_KEYS] METADATA_PATHS = [os.path.join(MERGED_DIR, f"{key}_metadata.pkl") for key in MERGE_KEYS] @@ -100,13 +106,23 @@ TRAIN_COMMAND += str( f"--predict_dir {PREDICT_SUBDIR} " ) +CORRELATION_COMMAND = str( + f"{TRAIN_COMMAND} " + f"--ckpt_path {CKPT_PATH} " +) + +CORRELATION_FILES = expand( + "{correlation_dir}/correlations.csv", + correlation_dir=CORRELATION_DIR, +) + ## Define the final output rule for Snakemake, specifying the target files that should be generated ## by the end of the workflow. rule all: input: EVALUATION_FILES, - MD_FILES + # MD_FILES ## Define the rule for training the CMMVAE model. ## The output includes the configuration file, the checkpoint path, and the directory for predictions. @@ -138,6 +154,22 @@ rule merge_predictions: cmmvae workflow merge-predictions --directory {input.predict_dir} --keys {params.merge_keys} --save_dir {MERGED_DIR} """ +## Define the rule for getting R^2 correlations on the filtered data +## This rule outputs correlation scores per filtered data group +rule correlations: + input: + embeddings_path=EMBEDDINGS_PATHS, + output: + CORRELATION_FILES, + params: + command=CORRELATION_COMMAND, + data=CORRELATION_DATA, + shell: + """ + mkdir -p {CORRELATION_DIR} + cmmvae workflow correlations {params.command} --correlation_data {params.data} + """ + ## Define the rule for generating UMAP visualizations from the merged predictions. ## This rule produces UMAP images for each combination of category and merge key. rule umap_predictions: @@ -156,16 +188,16 @@ rule umap_predictions: cmmvae workflow umap-predictions --directory {params.predict_dir} {params.categories} {params.merge_keys} --save_dir {params.save_dir} """ -rule meta_discriminators: - input: - CKPT_PATH - output: - MD_FILES, - params: - log_dir=META_DISC_DIR, - ckpt=CKPT_PATH, - config=TRAIN_CONFIG_FILE - shell: - """ - cmmvae workflow meta-discriminator --log_dir {params.log_dir} --ckpt {params.ckpt} --config {params.config} - """ \ No newline at end of file +# rule meta_discriminators: +# input: +# CKPT_PATH +# output: +# MD_FILES, +# params: +# log_dir=META_DISC_DIR, +# ckpt=CKPT_PATH, +# config=TRAIN_CONFIG_FILE +# shell: +# """ +# cmmvae workflow meta-discriminator --log_dir {params.log_dir} --ckpt {params.ckpt} --config {params.config} +# """ \ No newline at end of file diff --git a/workflow/config.yaml b/workflow/config.yaml index 48f79ba..0d6a90a 100644 --- a/workflow/config.yaml +++ b/workflow/config.yaml @@ -17,3 +17,5 @@ categories: merge_keys: - z + +correlations_data: /mnt/projects/debruinz_project/july2024_census_data/filtered \ No newline at end of file From 8eb744dcf98693d77e7294119e1d9b065de804ad Mon Sep 17 00:00:00 2001 From: Tony Boos Date: Tue, 22 Oct 2024 16:41:04 -0400 Subject: [PATCH 13/14] Adds filtering scripts to organize data for R^2 correlation calculations --- .../census_data_filtering.py | 121 ++++++++++++++++++ .../data_filtering_functions.py | 98 ++++++++++++++ 2 files changed, 219 insertions(+) create mode 100644 scripts/data-preprocessing/census_data_filtering.py create mode 100644 scripts/data-preprocessing/data_filtering_functions.py diff --git a/scripts/data-preprocessing/census_data_filtering.py b/scripts/data-preprocessing/census_data_filtering.py new file mode 100644 index 0000000..5bb9475 --- /dev/null +++ b/scripts/data-preprocessing/census_data_filtering.py @@ -0,0 +1,121 @@ +import os +import re +import glob +import argparse as ap +import multiprocessing as mp +import numpy as np +import pandas as pd +import scipy.sparse as sp +import data_filtering_functions as ff +from collections import defaultdict +from data_processing_functions import extract_file_number + +def main( + directory: str, + species: list[str], + train_dir: str, + save_dir: str, + seed: int, + skip_metadata: bool, +): + if not skip_metadata: + np.random.seed(seed) + dataframes = {} + for specie in species: + # Process file paths for full and training data + train_files = tuple(glob.glob(os.path.join(train_dir, f"{specie}*.pkl"))) + train_files = ff.filter_and_sort_train_files(train_files) + + # Get somaIDs of training samples + train_ids = ff.get_train_data_ids(train_files) + + # Prepare main df that is to be filtered + metadata_files = glob.glob(os.path.join(directory, f"{specie}*.pkl")) + metadata_files.sort(key=extract_file_number) + meta_dataframe = ff.load_and_merge_metadata(tuple(metadata_files)) + + # Filter out training samples from the full data + dataframes[specie] = ff.filter_train_ids(meta_dataframe, train_ids) + + # Process groups and subset size if needed + grouped_data = ff.filter_into_groups(dataframes) + valid_groups = ff.validate_and_sample_groups(grouped_data, species[0]) + + # Save filtered metadata + ff.save_grouped_data(valid_groups, dataframes, save_dir) + + # Load data and slice by metadata index + for specie in species: + data_files = glob.glob(os.path.join(directory, f"{specie}*.npz")) + data_files.sort(key=extract_file_number) + metadata_files = glob.glob(os.path.join(save_dir, f"{specie}*.pkl")) + metadata_files.sort(key=extract_file_number) + filtered_data = defaultdict(list) + for chunk_n, data_file in enumerate(data_files, start=1): + current_chunk = sp.load_npz(data_file) + for gid, metadata_file in enumerate(metadata_files, start=1): + current_df = pd.read_pickle(metadata_file) + idxes = current_df[current_df["chunk_source"] == chunk_n]["data_index"] + sliced_chunk = current_chunk[idxes, :] + filtered_data[gid].append(sliced_chunk) + + # Save filtered data + for gid, data in filtered_data.items(): + chunk_data = sp.vstack(data) + sp.save_npz(os.path.join(save_dir, f"{specie}_filtered_{gid}.npz"), chunk_data) + +if __name__ == "__main__": + parser = ap.ArgumentParser() + parser.add_argument( + "--directory", + type=str, + required=True, + help="Directory to load data from." + ) + parser.add_argument( + "--species", + type=str, + nargs="+", + required=True, + help="Species to load data for." + ) + parser.add_argument( + "--train_data", + type=str, + required=False, + default=None, + help="Directory where the training data is stored. Defaults to '--directory'" + ) + parser.add_argument( + "--save_directory", + type=str, + required=False, + default=None, + help="Directory to save filtered data in. Defaults to '--directory'" + ) + parser.add_argument( + "--seed", + type=int, + required=False, + default=42, + help="Seed for the random module." + ) + parser.add_argument( + "--skip_metadata", + action="store_true", + help="Whether to skip the metadata filtering and just slice the count data." + ) + + args = parser.parse_args() + + train_dir = args.directory if args.train_data is None else args.train_data + save_dir = args.directory if args.save_directory is None else args.save_directory + + main( + directory= args.directory, + species= args.species, + train_dir= train_dir, + save_dir= save_dir, + seed= args.seed, + skip_metadata= args.skip_metadata + ) \ No newline at end of file diff --git a/scripts/data-preprocessing/data_filtering_functions.py b/scripts/data-preprocessing/data_filtering_functions.py new file mode 100644 index 0000000..3fa2141 --- /dev/null +++ b/scripts/data-preprocessing/data_filtering_functions.py @@ -0,0 +1,98 @@ +import os +import re +import csv +import numpy as np +import pandas as pd +import pandas.core.groupby.generic as gb +from collections import defaultdict +from data_processing_functions import extract_file_number + +TRAIN_DATA_FILES = set(range(1, 14)) +GROUPING_COLUMNS = ["sex", "tissue", "cell_type", "assay"] +MIN_SAMPLE_SIZE = 100 +MAX_SAMPLE_SIZE = 10000 +SPECIES_SAMPLE_SIZE = {"human": 60530, "mouse": 52437} + +def get_train_data_ids(files: tuple[str]) -> set[int]: + + combined_df = pd.DataFrame() + + for file in files: + df = pd.read_pickle(file) + combined_df = pd.concat([combined_df, df], ignore_index=True) + + return set(combined_df["soma_joinid"]) + +def filter_train_ids(df: pd.DataFrame, ids: set[int]) -> pd.DataFrame: + filtered_df = df[~(df["soma_joinid"].isin(ids))] + filtered_df = filtered_df.reset_index(drop=True) + return filtered_df + +def filter_and_sort_train_files(unfiltered_files: tuple[str]) -> tuple[str]: + + filtered_files = [ + file for file in unfiltered_files + if (match := re.search(r'(\d+)\.pkl$', file)) and int(match.group(1)) in TRAIN_DATA_FILES + ] + filtered_files.sort(key=extract_file_number) + return tuple(filtered_files) + +def load_and_merge_metadata(files: tuple[str]) -> pd.DataFrame: + + merged_df = pd.DataFrame() + for file in files: + df = pd.read_pickle(file) + df["chunk_source"] = extract_file_number(file) + df = df.reset_index(drop=False) + df.rename(columns={"index": "data_index"}, inplace=True) + merged_df = pd.concat([merged_df, df], ignore_index=True) + + return merged_df + +def filter_into_groups(dfs: dict[str, pd.DataFrame]): + + grouped = {} + for specie, data in dfs.items(): + grouped[specie] = data.groupby(GROUPING_COLUMNS) + + return grouped + +def validate_and_sample_groups(data_groups: dict[str, gb.DataFrameGroupBy], primary_species: str = None): + + valid_groups = defaultdict(dict) + if primary_species is not None: + main_df = data_groups.pop(primary_species) + else: + primary_species, main_df = data_groups.popitem() + + for gid, idxes in main_df.groups.items(): + if len(idxes) < MIN_SAMPLE_SIZE: + continue + elif all( + gid in group.groups.keys() and len(group.groups[gid]) >= MIN_SAMPLE_SIZE + for group in data_groups.values() + ): + sample_size = min( + [len(idxes), MAX_SAMPLE_SIZE] + [len(group.groups[gid]) for group in data_groups.values()] + ) + + valid_groups[gid][primary_species] = np.random.choice(idxes, sample_size, replace= False) + for specie, group in data_groups.items(): + valid_groups[gid][specie] = np.random.choice(group.groups[gid], sample_size, replace= False) + + return valid_groups + +def save_grouped_data(groups: dict[tuple[str], dict[str, np.ndarray]], dfs: dict[str, pd.DataFrame], save_dir: str): + + with open(os.path.join(save_dir, "group_references.csv"), "w") as file: + writer = csv.writer(file) + writer.writerow(["group_id", "num_samples"] + GROUPING_COLUMNS) + for i, gid in enumerate(groups.keys(), start=1): + for specie, idx in groups[gid].items(): + df = dfs[specie].iloc[idx] + df["group_id"] = i + df = df.sort_values("chunk_source") + df.to_pickle(os.path.join(save_dir, f"{specie}_filtered_{i}.pkl")) + writer.writerow([i, len(idx)] + list(gid)) + file = pd.read_csv(os.path.join(save_dir, "group_references.csv")) + file.to_pickle(os.path.join(save_dir, "group_references.pkl")) From 3dae57152f37a2cbb3a5d730dec0a6a63b28c873 Mon Sep 17 00:00:00 2001 From: Tony Boos Date: Wed, 23 Oct 2024 16:38:56 -0400 Subject: [PATCH 14/14] Adds pearson correlation functions to SnakeMake and updates the data filtering functions --- Snakefile | 18 +- .../data_filtering_functions.py | 1 + src/cmmvae/data/local/__init__.py | 2 + src/cmmvae/data/local/cellxgene_datapipe.py | 6 +- src/cmmvae/runners/3d_umap.py | 114 ++++++++++++ src/cmmvae/runners/cli.py | 4 +- src/cmmvae/runners/correlations.py | 168 ++++++++++++++++++ src/cmmvae/runners/workflow.py | 2 + workflow/config.yaml | 5 +- workflow/profile/slurm/config.yaml | 5 + 10 files changed, 316 insertions(+), 9 deletions(-) create mode 100644 src/cmmvae/runners/3d_umap.py create mode 100644 src/cmmvae/runners/correlations.py diff --git a/Snakefile b/Snakefile index f60edba..1419a87 100644 --- a/Snakefile +++ b/Snakefile @@ -61,7 +61,7 @@ MERGED_DIR = os.path.join(RUN_DIR, "merged") CORRELATION_PATH = config.get("correlation_dir", "correlations") CORRELATION_DIR = os.path.join(RUN_DIR, CORRELATION_PATH) -CORRELATION_DATA = config["correlations_data"] +CORRELATION_DATA = config["correlation_data"] ## Generate the paths for the embeddings and metadata files based on the merge keys. EMBEDDINGS_PATHS = [os.path.join(MERGED_DIR, f"{key}_embeddings.npz") for key in MERGE_KEYS] @@ -98,6 +98,7 @@ MD_FILES = expand( ## If a configuration directory is provided, it is included in the command; otherwise, ## individual parameters such as trainer, model, and data are passed explicitly. TRAIN_COMMAND = config["train_command"] +CORRELATION_COMMAND = config["correlation_command"] TRAIN_COMMAND += str( f" --default_root_dir {ROOT_DIR} " @@ -106,8 +107,11 @@ TRAIN_COMMAND += str( f"--predict_dir {PREDICT_SUBDIR} " ) -CORRELATION_COMMAND = str( - f"{TRAIN_COMMAND} " +CORRELATION_COMMAND += str( + f" --default_root_dir {ROOT_DIR} " + f"--experiment_name {EXPERIMENT_NAME} --run_name {RUN_NAME} " + f"--seed_everything {SEED} " + f"--predict_dir {PREDICT_SUBDIR} " f"--ckpt_path {CKPT_PATH} " ) @@ -116,12 +120,17 @@ CORRELATION_FILES = expand( correlation_dir=CORRELATION_DIR, ) +CORRELATION_FILES += expand( + "{correlation_dir}/correlations.pkl", + correlation_dir=CORRELATION_DIR, +) ## Define the final output rule for Snakemake, specifying the target files that should be generated ## by the end of the workflow. rule all: input: EVALUATION_FILES, + CORRELATION_FILES # MD_FILES ## Define the rule for training the CMMVAE model. @@ -164,10 +173,11 @@ rule correlations: params: command=CORRELATION_COMMAND, data=CORRELATION_DATA, + save_dir=CORRELATION_DIR, shell: """ mkdir -p {CORRELATION_DIR} - cmmvae workflow correlations {params.command} --correlation_data {params.data} + cmmvae workflow correlations {params.command} --correlation_data {params.data} --save_dir {params.save_dir} """ ## Define the rule for generating UMAP visualizations from the merged predictions. diff --git a/scripts/data-preprocessing/data_filtering_functions.py b/scripts/data-preprocessing/data_filtering_functions.py index 3fa2141..6b4e03c 100644 --- a/scripts/data-preprocessing/data_filtering_functions.py +++ b/scripts/data-preprocessing/data_filtering_functions.py @@ -91,6 +91,7 @@ def save_grouped_data(groups: dict[tuple[str], dict[str, np.ndarray]], dfs: dict for specie, idx in groups[gid].items(): df = dfs[specie].iloc[idx] df["group_id"] = i + df["num_samples"] = len(idx) df = df.sort_values("chunk_source") df.to_pickle(os.path.join(save_dir, f"{specie}_filtered_{i}.pkl")) writer.writerow([i, len(idx)] + list(gid)) diff --git a/src/cmmvae/data/local/__init__.py b/src/cmmvae/data/local/__init__.py index d9ac7b7..691e9e8 100644 --- a/src/cmmvae/data/local/__init__.py +++ b/src/cmmvae/data/local/__init__.py @@ -5,9 +5,11 @@ from cmmvae.data.local.cellxgene_datamodule import SpeciesDataModule from cmmvae.data.local.cellxgene_manager import SpeciesManager +from cmmvae.data.local.cellxgene_datapipe import SpeciesDataPipe __all__ = [ "SpeciesManager", "SpeciesDataModule", + "SpeciesDataPipe", ] diff --git a/src/cmmvae/data/local/cellxgene_datapipe.py b/src/cmmvae/data/local/cellxgene_datapipe.py index 0246d82..76dff65 100644 --- a/src/cmmvae/data/local/cellxgene_datapipe.py +++ b/src/cmmvae/data/local/cellxgene_datapipe.py @@ -172,7 +172,7 @@ def __iter__(self): for i in range(0, n_samples, self.batch_size): data_batch = sparse_matrix[i : i + self.batch_size] - if self.allow_partials and not data_batch.shape[0] == self.batch_size: + if data_batch.shape[0] != self.batch_size and not self.allow_partials: continue tensor = torch.sparse_csr_tensor( @@ -252,6 +252,7 @@ def __init__( npz_masks: Union[str, list[str]], metadata_masks: Union[str, list[str]], batch_size: int, + allow_partials = False, shuffle: bool = True, return_dense: bool = False, verbose: bool = False, @@ -305,6 +306,7 @@ def __init__( print(path) self.batch_size = batch_size + self.allow_partials = allow_partials self.return_dense = return_dense self.verbose = verbose self._shuffle = shuffle @@ -324,7 +326,7 @@ def __init__( dp = dp.shuffle_matrix_and_dataframe() dp = dp.batch_csr_matrix_and_dataframe( - self.batch_size, return_dense=self.return_dense + self.batch_size, return_dense=self.return_dense, allow_partials=self.allow_partials ) # thought process on removal diff --git a/src/cmmvae/runners/3d_umap.py b/src/cmmvae/runners/3d_umap.py new file mode 100644 index 0000000..4736929 --- /dev/null +++ b/src/cmmvae/runners/3d_umap.py @@ -0,0 +1,114 @@ +import os +import numpy as np +import pandas as pd +import pickle +import umap +import matplotlib.pyplot as plt +import seaborn as sns +from mpl_toolkits.mplot3d import Axes3D +from matplotlib.animation import FuncAnimation + +def load_embeddings(npz_path, meta_path): + """Load embeddings and metadata from specified paths.""" + embedding = np.load(npz_path)["embeddings"] + metadata = pd.read_pickle(meta_path) + return embedding, metadata + +def plot_3d_umap(embedding, metadata, column, output_path, fig_size=(10, 10), point_size=2): + fig = plt.figure(figsize=fig_size) + ax = fig.add_subplot(111, projection='3d') + scatter = ax.scatter( + embedding[:, 0], embedding[:, 1], embedding[:, 2], + c=metadata[column].astype('category').cat.codes, + cmap='hsv', + s=point_size + ) + ax.set_title(f'3D UMAP projection colored by {column}\nFilename: {output_path.split("/")[-1]}') + ax.set_xlabel('UMAP 1') + ax.set_ylabel('UMAP 2') + ax.set_zlabel('UMAP 3') + + def update(num): + ax.view_init(elev=30, azim=num) + + ani = FuncAnimation(fig, update, frames=360, interval=20) + ani.save(output_path.replace('.png', '.gif'), writer='imagemagick') + + plt.close() + +def plot_umap( + X, + metadata, + n_neighbors=30, + min_dist=0.3, + n_components=3, + metric="cosine", + low_memory=False, + n_jobs=40, + n_epochs=200, + n_largest=15, + method=None, + save_dir=None, + **umap_kwargs, +): + """ + Generate UMAP embeddings and plot them. + + Args: + directory (str): Directory where embeddings are stored. + keys (list[str]): List of embedding keys. + categories (list[str]): List of categories to color by. + n_neighbors (int): Number of neighbors for UMAP. + min_dist (float): Minimum distance for UMAP. + n_components (int): Number of components for UMAP. + metric (str): Metric for UMAP. + low_memory (bool): Low memory setting for UMAP. + n_jobs (int): Number of CPUs available for UMAP. + n_epochs (int): Number of epochs to run UMAP. + n_largest (int): Number of most common categories to plot. + method (str): Method title to add to the graph. + save_dir (str): Directory to save UMAP plots. + **umap_kwargs: Extra kwargs passed to `umap.UMAP`. + """ + import umap + # Fit and transform the data using UMAP + reducer = umap.UMAP( + n_neighbors=n_neighbors, + min_dist=min_dist, + n_components=n_components, + metric=metric, + low_memory=low_memory, + n_jobs=n_jobs, + n_epochs=n_epochs, + **umap_kwargs, + ) + + embedding = reducer.fit_transform(X) + embedding_file_name = "3d_z_umap_embeddings.npz" + embedding_path = os.path.join(save_dir, embedding_file_name) + metadata_file_name = "3d_z_umap_metadata.pkl" + metadata_path = os.path.join(save_dir, metadata_file_name) + os.makedirs(save_dir, exist_ok=True) + np.savez(embedding_path, embeddings=embedding) + metadata.to_pickle(metadata_path) + +directory = "/mnt/projects/debruinz_project/tony_boos/MMVAE/lightning_logs/baseline/1e7Beta_256rclt_noadv/" +embedding, metadata = load_embeddings( + npz_path= os.path.join(directory, 'merged/z_embeddings.npz'), + meta_path=os.path.join(directory, 'merged/z_metadata.pkl') +) + +plot_umap(embedding, metadata, save_dir=os.path.join(directory, 'umap')) + +embedding, metadata = load_embeddings( + npz_path= os.path.join(directory, 'umap/3d_z_umap_embeddings.npz'), + meta_path=os.path.join(directory, 'umap/3d_z_umap_metadata.pkl') +) +categories = ['donor_id', 'assay', 'dataset_id', 'cell_type', 'tissue', 'species'] +for col in categories: + plot_3d_umap( + embedding=embedding, + metadata=metadata, + column=col, + output_path=os.path.join(directory, f"umap/3D_UMAP_by_{col}.png") + ) \ No newline at end of file diff --git a/src/cmmvae/runners/cli.py b/src/cmmvae/runners/cli.py index 6d0d38b..d0cb834 100644 --- a/src/cmmvae/runners/cli.py +++ b/src/cmmvae/runners/cli.py @@ -18,7 +18,7 @@ def __init__(self, extra_parser_kwargs: dict = {}, **kwargs): Handles loading trainer, model, and data modules from config file, while linking common arguments for ease of access. """ - self.is_run = not kwargs.get("run", False) + self.is_run = not kwargs.get("run", True) super().__init__( parser_kwargs={ @@ -52,7 +52,7 @@ def add_arguments_to_parser(self, parser): "--predict_dir", required=True, help="Where to store predictions after fit" ) - if not self.is_run: + if self.is_run: parser.add_argument( "--ckpt_path", required=True, help="Ckpt path to be passed to model" ) diff --git a/src/cmmvae/runners/correlations.py b/src/cmmvae/runners/correlations.py new file mode 100644 index 0000000..bcc848a --- /dev/null +++ b/src/cmmvae/runners/correlations.py @@ -0,0 +1,168 @@ +""" +Get R^2 correlations between cis and cross species generations +""" +import os +import sys +import click +import torch + +import numpy as np +import pandas as pd +import scipy.sparse as sp + +from torch.utils.data import DataLoader + +from cmmvae.models import CMMVAEModel +from cmmvae.constants import REGISTRY_KEYS as RK +from cmmvae.runners.cli import CMMVAECli +from cmmvae.data.local import SpeciesDataPipe + +FILTERED_BY_CATEGORIES = ["assay", "cell_type", "tissue", "sex"] + +def setup_datapipes(data_dir: str): + human_pipe = SpeciesDataPipe( + directory_path= data_dir, + npz_masks= "human*.npz", + metadata_masks= "human*.pkl", + batch_size= 10000, + allow_partials=True, + shuffle= False, + return_dense= False, + verbose= True, + ) + mouse_pipe = SpeciesDataPipe( + directory_path= data_dir, + npz_masks= "mouse*.npz", + metadata_masks= "mouse*.pkl", + batch_size= 10000, + allow_partials=True, + shuffle= False, + return_dense= False, + verbose= True, + ) + return human_pipe, mouse_pipe + +def setup_dataloaders(data_dir: str): + human_pipe, mouse_pipe = setup_datapipes(data_dir) + + human_dataloader = DataLoader( + dataset=human_pipe, + batch_size=None, + shuffle=False, + collate_fn=lambda x: x, + persistent_workers=False, + num_workers=6, + ) + mouse_dataloader = DataLoader( + dataset=mouse_pipe, + batch_size=None, + shuffle=False, + collate_fn=lambda x: x, + persistent_workers=False, + num_workers=6, + ) + + return human_dataloader, mouse_dataloader + +def convert_to_tensor(batch: sp.csr_matrix): + return torch.sparse_csr_tensor( + crow_indices=batch.indptr, + col_indices=batch.indices, + values=batch.data, + size=batch.shape, + ) + +def calc_correlations(human_out: np.ndarray, mouse_out: np.ndarray, n_samples: int): + with np.errstate(divide="ignore", invalid="ignore"): + human_correlations = np.corrcoef(human_out) + mouse_correlations = np.corrcoef(mouse_out) + + human_cis = np.round(np.nan_to_num(human_correlations[:n_samples, :n_samples]).mean(), 3) + human_cross = np.round(np.nan_to_num(human_correlations[n_samples:, n_samples:]).mean(), 3) + human_comb = np.round(np.nan_to_num(human_correlations[:n_samples, n_samples:]).mean(), 3) + + mouse_cis = np.round(np.nan_to_num(mouse_correlations[:n_samples, :n_samples]).mean(), 3) + mouse_cross = np.round(np.nan_to_num(mouse_correlations[n_samples:, n_samples:]).mean(), 3) + mouse_comb = np.round(np.nan_to_num(mouse_correlations[:n_samples, n_samples:]).mean(), 3) + + return pd.DataFrame( + { + "human_cis": [human_cis], + "human_cross": [human_cross], + "human_comb": [human_comb], + "mouse_cis": [mouse_cis], + "mouse_cross": [mouse_cross], + "mouse_comb": [mouse_comb] + } + ) + +def get_correlations(model: CMMVAEModel, data_dir: str): + + human_dataloader, mouse_dataloader = setup_dataloaders(data_dir) + + correlations = pd.DataFrame( + columns=[ + "tag", "group_id", "num_samples", + "human_cis", "human_cross", "human_comb", + "mouse_cis", "mouse_cross", "mouse_comb" + ] + ) + + for (human_batch, human_metadata), (mouse_batch, mouse_metadata) in zip(human_dataloader, mouse_dataloader): + + n_samples = human_metadata["num_samples"].iloc[0] + + model.module.eval() + with torch.no_grad(): + _, _, _, human_out, _ = model.module(human_batch, human_metadata, RK.HUMAN, cross_generate=True) + _, _, _, mouse_out, _ = model.module(mouse_batch, mouse_metadata, RK.MOUSE, cross_generate=True) + + human_stacked_out = np.vstack((human_out[RK.HUMAN].numpy(), mouse_out[RK.HUMAN].numpy())) + mouse_stacked_out = np.vstack((mouse_out[RK.MOUSE].numpy(), human_out[RK.MOUSE].numpy())) + + avg_correlations = calc_correlations(human_stacked_out, mouse_stacked_out, n_samples) + + avg_correlations["tag"] = " ".join([human_metadata[cat].iloc[0] for cat in FILTERED_BY_CATEGORIES]) + avg_correlations["group_id"] = human_metadata["group_id"].iloc[0] + avg_correlations["num_samples"] = n_samples + + correlations = pd.concat([correlations, avg_correlations], ignore_index=True) + + return correlations + +def save_correlations(correlations: pd.DataFrame, save_dir: str): + correlations = correlations.sort_values("group_id") + correlations.to_csv(os.path.join(save_dir, "correlations.csv"), index=False) + correlations.to_pickle(os.path.join(save_dir, "correlations.pkl")) + +@click.command( + context_settings=dict( + ignore_unknown_options=True, + allow_extra_args=True, + ) +) +@click.option( + "--correlation_data", + type=click.Path(exists=True), + required=True, + help="Directory where filtered correlation data is saved" +) +@click.option( + "--save_dir", + type=click.Path(exists=True), + required=True, + help="Directory where correlation outputs are saved" +) +@click.pass_context +def correlations(ctx: click.Context, correlation_data: str, save_dir: str): + """Run using the LightningCli.""" + if ctx.args: + # Ensure `args` is passed as the command-line arguments + sys.argv = [sys.argv[0]] + ctx.args + + cli = CMMVAECli(run=False) + correlations = get_correlations(cli.model, correlation_data) + save_correlations(correlations, save_dir) + +if __name__ == "__main__": + correlations() \ No newline at end of file diff --git a/src/cmmvae/runners/workflow.py b/src/cmmvae/runners/workflow.py index 59689dc..6608c1e 100644 --- a/src/cmmvae/runners/workflow.py +++ b/src/cmmvae/runners/workflow.py @@ -1,6 +1,7 @@ import click from cmmvae.runners.cli import cli +from cmmvae.runners.correlations import correlations from cmmvae.runners.umap_predictions import umap_predictions from cmmvae.runners.merge_predictions import merge_predictions from cmmvae.runners.meta_discriminators import meta_discriminator @@ -12,6 +13,7 @@ def workflow(): workflow.add_command(cli) +workflow.add_command(correlations) workflow.add_command(umap_predictions) workflow.add_command(merge_predictions) workflow.add_command(meta_discriminator) diff --git a/workflow/config.yaml b/workflow/config.yaml index 0d6a90a..06ee36a 100644 --- a/workflow/config.yaml +++ b/workflow/config.yaml @@ -8,14 +8,17 @@ run_name: version000 predict_dir: samples train_command: fit --trainer configs/trainer/config.test.yaml --model configs/model/config.yaml --data configs/data/local.yaml +correlation_command: --trainer configs/trainer/config.test.yaml --model configs/model/config.yaml categories: - 'donor_id' - 'assay' - 'dataset_id' - 'cell_type' +- 'tissue' +- 'species' merge_keys: - z -correlations_data: /mnt/projects/debruinz_project/july2024_census_data/filtered \ No newline at end of file +correlation_data: /mnt/projects/debruinz_project/july2024_census_data/filtered \ No newline at end of file diff --git a/workflow/profile/slurm/config.yaml b/workflow/profile/slurm/config.yaml index 7d9c6e3..999e16f 100644 --- a/workflow/profile/slurm/config.yaml +++ b/workflow/profile/slurm/config.yaml @@ -39,6 +39,11 @@ set-resources: mem: 179GB gpus_per_node: "" cpus_per_task: 1 + correlations: + partition: gpu + mem: 179GB + gpus_per_node: tesla_v100s:1 + cpus_per_task: 12 umap_predictions: partition: all mem: 179GB