Skip to content

Commit

Permalink
Merge pull request #50 from anthonyboos559/correlations
Browse files Browse the repository at this point in the history
Census Data Filtering and Pearson Corrleation SnakeMake Rule
  • Loading branch information
anthonyboos559 authored Oct 23, 2024
2 parents 080e9f7 + 3dae571 commit 279e396
Show file tree
Hide file tree
Showing 11 changed files with 578 additions and 18 deletions.
70 changes: 56 additions & 14 deletions Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ META_DISC_DIR = os.path.join(RUN_DIR, META_DISC_PATH)
## Define a separate directory for merged outputs to avoid conflicts between different merge operations.
MERGED_DIR = os.path.join(RUN_DIR, "merged")

## Define the directory to store correlation outputs
CORRELATION_PATH = config.get("correlation_dir", "correlations")
CORRELATION_DIR = os.path.join(RUN_DIR, CORRELATION_PATH)

CORRELATION_DATA = config["correlation_data"]

## Generate the paths for the embeddings and metadata files based on the merge keys.
EMBEDDINGS_PATHS = [os.path.join(MERGED_DIR, f"{key}_embeddings.npz") for key in MERGE_KEYS]
METADATA_PATHS = [os.path.join(MERGED_DIR, f"{key}_metadata.pkl") for key in MERGE_KEYS]
Expand Down Expand Up @@ -92,6 +98,7 @@ MD_FILES = expand(
## If a configuration directory is provided, it is included in the command; otherwise,
## individual parameters such as trainer, model, and data are passed explicitly.
TRAIN_COMMAND = config["train_command"]
CORRELATION_COMMAND = config["correlation_command"]

TRAIN_COMMAND += str(
f" --default_root_dir {ROOT_DIR} "
Expand All @@ -100,13 +107,31 @@ TRAIN_COMMAND += str(
f"--predict_dir {PREDICT_SUBDIR} "
)

CORRELATION_COMMAND += str(
f" --default_root_dir {ROOT_DIR} "
f"--experiment_name {EXPERIMENT_NAME} --run_name {RUN_NAME} "
f"--seed_everything {SEED} "
f"--predict_dir {PREDICT_SUBDIR} "
f"--ckpt_path {CKPT_PATH} "
)

CORRELATION_FILES = expand(
"{correlation_dir}/correlations.csv",
correlation_dir=CORRELATION_DIR,
)

CORRELATION_FILES += expand(
"{correlation_dir}/correlations.pkl",
correlation_dir=CORRELATION_DIR,
)

## Define the final output rule for Snakemake, specifying the target files that should be generated
## by the end of the workflow.
rule all:
input:
EVALUATION_FILES,
MD_FILES
CORRELATION_FILES
# MD_FILES

## Define the rule for training the CMMVAE model.
## The output includes the configuration file, the checkpoint path, and the directory for predictions.
Expand Down Expand Up @@ -138,6 +163,23 @@ rule merge_predictions:
cmmvae workflow merge-predictions --directory {input.predict_dir} --keys {params.merge_keys} --save_dir {MERGED_DIR}
"""

## Define the rule for getting R^2 correlations on the filtered data
## This rule outputs correlation scores per filtered data group
rule correlations:
input:
embeddings_path=EMBEDDINGS_PATHS,
output:
CORRELATION_FILES,
params:
command=CORRELATION_COMMAND,
data=CORRELATION_DATA,
save_dir=CORRELATION_DIR,
shell:
"""
mkdir -p {CORRELATION_DIR}
cmmvae workflow correlations {params.command} --correlation_data {params.data} --save_dir {params.save_dir}
"""

## Define the rule for generating UMAP visualizations from the merged predictions.
## This rule produces UMAP images for each combination of category and merge key.
rule umap_predictions:
Expand All @@ -156,16 +198,16 @@ rule umap_predictions:
cmmvae workflow umap-predictions --directory {params.predict_dir} {params.categories} {params.merge_keys} --save_dir {params.save_dir}
"""

rule meta_discriminators:
input:
CKPT_PATH
output:
MD_FILES,
params:
log_dir=META_DISC_DIR,
ckpt=CKPT_PATH,
config=TRAIN_CONFIG_FILE
shell:
"""
cmmvae workflow meta-discriminator --log_dir {params.log_dir} --ckpt {params.ckpt} --config {params.config}
"""
# rule meta_discriminators:
# input:
# CKPT_PATH
# output:
# MD_FILES,
# params:
# log_dir=META_DISC_DIR,
# ckpt=CKPT_PATH,
# config=TRAIN_CONFIG_FILE
# shell:
# """
# cmmvae workflow meta-discriminator --log_dir {params.log_dir} --ckpt {params.ckpt} --config {params.config}
# """
121 changes: 121 additions & 0 deletions scripts/data-preprocessing/census_data_filtering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import os
import re
import glob
import argparse as ap
import multiprocessing as mp
import numpy as np
import pandas as pd
import scipy.sparse as sp
import data_filtering_functions as ff
from collections import defaultdict
from data_processing_functions import extract_file_number

def main(
directory: str,
species: list[str],
train_dir: str,
save_dir: str,
seed: int,
skip_metadata: bool,
):
if not skip_metadata:
np.random.seed(seed)
dataframes = {}
for specie in species:
# Process file paths for full and training data
train_files = tuple(glob.glob(os.path.join(train_dir, f"{specie}*.pkl")))
train_files = ff.filter_and_sort_train_files(train_files)

# Get somaIDs of training samples
train_ids = ff.get_train_data_ids(train_files)

# Prepare main df that is to be filtered
metadata_files = glob.glob(os.path.join(directory, f"{specie}*.pkl"))
metadata_files.sort(key=extract_file_number)
meta_dataframe = ff.load_and_merge_metadata(tuple(metadata_files))

# Filter out training samples from the full data
dataframes[specie] = ff.filter_train_ids(meta_dataframe, train_ids)

# Process groups and subset size if needed
grouped_data = ff.filter_into_groups(dataframes)
valid_groups = ff.validate_and_sample_groups(grouped_data, species[0])

# Save filtered metadata
ff.save_grouped_data(valid_groups, dataframes, save_dir)

# Load data and slice by metadata index
for specie in species:
data_files = glob.glob(os.path.join(directory, f"{specie}*.npz"))
data_files.sort(key=extract_file_number)
metadata_files = glob.glob(os.path.join(save_dir, f"{specie}*.pkl"))
metadata_files.sort(key=extract_file_number)
filtered_data = defaultdict(list)
for chunk_n, data_file in enumerate(data_files, start=1):
current_chunk = sp.load_npz(data_file)
for gid, metadata_file in enumerate(metadata_files, start=1):
current_df = pd.read_pickle(metadata_file)
idxes = current_df[current_df["chunk_source"] == chunk_n]["data_index"]
sliced_chunk = current_chunk[idxes, :]
filtered_data[gid].append(sliced_chunk)

# Save filtered data
for gid, data in filtered_data.items():
chunk_data = sp.vstack(data)
sp.save_npz(os.path.join(save_dir, f"{specie}_filtered_{gid}.npz"), chunk_data)

if __name__ == "__main__":
parser = ap.ArgumentParser()
parser.add_argument(
"--directory",
type=str,
required=True,
help="Directory to load data from."
)
parser.add_argument(
"--species",
type=str,
nargs="+",
required=True,
help="Species to load data for."
)
parser.add_argument(
"--train_data",
type=str,
required=False,
default=None,
help="Directory where the training data is stored. Defaults to '--directory'"
)
parser.add_argument(
"--save_directory",
type=str,
required=False,
default=None,
help="Directory to save filtered data in. Defaults to '--directory'"
)
parser.add_argument(
"--seed",
type=int,
required=False,
default=42,
help="Seed for the random module."
)
parser.add_argument(
"--skip_metadata",
action="store_true",
help="Whether to skip the metadata filtering and just slice the count data."
)

args = parser.parse_args()

train_dir = args.directory if args.train_data is None else args.train_data
save_dir = args.directory if args.save_directory is None else args.save_directory

main(
directory= args.directory,
species= args.species,
train_dir= train_dir,
save_dir= save_dir,
seed= args.seed,
skip_metadata= args.skip_metadata
)
99 changes: 99 additions & 0 deletions scripts/data-preprocessing/data_filtering_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import os
import re
import csv
import numpy as np
import pandas as pd
import pandas.core.groupby.generic as gb
from collections import defaultdict
from data_processing_functions import extract_file_number

TRAIN_DATA_FILES = set(range(1, 14))
GROUPING_COLUMNS = ["sex", "tissue", "cell_type", "assay"]
MIN_SAMPLE_SIZE = 100
MAX_SAMPLE_SIZE = 10000
SPECIES_SAMPLE_SIZE = {"human": 60530, "mouse": 52437}

def get_train_data_ids(files: tuple[str]) -> set[int]:

combined_df = pd.DataFrame()

for file in files:
df = pd.read_pickle(file)
combined_df = pd.concat([combined_df, df], ignore_index=True)

return set(combined_df["soma_joinid"])

def filter_train_ids(df: pd.DataFrame, ids: set[int]) -> pd.DataFrame:
filtered_df = df[~(df["soma_joinid"].isin(ids))]
filtered_df = filtered_df.reset_index(drop=True)
return filtered_df

def filter_and_sort_train_files(unfiltered_files: tuple[str]) -> tuple[str]:

filtered_files = [
file for file in unfiltered_files
if (match := re.search(r'(\d+)\.pkl$', file)) and int(match.group(1)) in TRAIN_DATA_FILES
]
filtered_files.sort(key=extract_file_number)
return tuple(filtered_files)

def load_and_merge_metadata(files: tuple[str]) -> pd.DataFrame:

merged_df = pd.DataFrame()
for file in files:
df = pd.read_pickle(file)
df["chunk_source"] = extract_file_number(file)
df = df.reset_index(drop=False)
df.rename(columns={"index": "data_index"}, inplace=True)
merged_df = pd.concat([merged_df, df], ignore_index=True)

return merged_df

def filter_into_groups(dfs: dict[str, pd.DataFrame]):

grouped = {}
for specie, data in dfs.items():
grouped[specie] = data.groupby(GROUPING_COLUMNS)

return grouped

def validate_and_sample_groups(data_groups: dict[str, gb.DataFrameGroupBy], primary_species: str = None):

valid_groups = defaultdict(dict)
if primary_species is not None:
main_df = data_groups.pop(primary_species)
else:
primary_species, main_df = data_groups.popitem()

for gid, idxes in main_df.groups.items():
if len(idxes) < MIN_SAMPLE_SIZE:
continue
elif all(
gid in group.groups.keys() and len(group.groups[gid]) >= MIN_SAMPLE_SIZE
for group in data_groups.values()
):
sample_size = min(
[len(idxes), MAX_SAMPLE_SIZE] + [len(group.groups[gid]) for group in data_groups.values()]
)

valid_groups[gid][primary_species] = np.random.choice(idxes, sample_size, replace= False)
for specie, group in data_groups.items():
valid_groups[gid][specie] = np.random.choice(group.groups[gid], sample_size, replace= False)

return valid_groups

def save_grouped_data(groups: dict[tuple[str], dict[str, np.ndarray]], dfs: dict[str, pd.DataFrame], save_dir: str):

with open(os.path.join(save_dir, "group_references.csv"), "w") as file:
writer = csv.writer(file)
writer.writerow(["group_id", "num_samples"] + GROUPING_COLUMNS)
for i, gid in enumerate(groups.keys(), start=1):
for specie, idx in groups[gid].items():
df = dfs[specie].iloc[idx]
df["group_id"] = i
df["num_samples"] = len(idx)
df = df.sort_values("chunk_source")
df.to_pickle(os.path.join(save_dir, f"{specie}_filtered_{i}.pkl"))
writer.writerow([i, len(idx)] + list(gid))
file = pd.read_csv(os.path.join(save_dir, "group_references.csv"))
file.to_pickle(os.path.join(save_dir, "group_references.pkl"))
2 changes: 2 additions & 0 deletions src/cmmvae/data/local/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

from cmmvae.data.local.cellxgene_datamodule import SpeciesDataModule
from cmmvae.data.local.cellxgene_manager import SpeciesManager
from cmmvae.data.local.cellxgene_datapipe import SpeciesDataPipe


__all__ = [
"SpeciesManager",
"SpeciesDataModule",
"SpeciesDataPipe",
]
6 changes: 4 additions & 2 deletions src/cmmvae/data/local/cellxgene_datapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def __iter__(self):
for i in range(0, n_samples, self.batch_size):
data_batch = sparse_matrix[i : i + self.batch_size]

if self.allow_partials and not data_batch.shape[0] == self.batch_size:
if data_batch.shape[0] != self.batch_size and not self.allow_partials:
continue

tensor = torch.sparse_csr_tensor(
Expand Down Expand Up @@ -252,6 +252,7 @@ def __init__(
npz_masks: Union[str, list[str]],
metadata_masks: Union[str, list[str]],
batch_size: int,
allow_partials = False,
shuffle: bool = True,
return_dense: bool = False,
verbose: bool = False,
Expand Down Expand Up @@ -305,6 +306,7 @@ def __init__(
print(path)

self.batch_size = batch_size
self.allow_partials = allow_partials
self.return_dense = return_dense
self.verbose = verbose
self._shuffle = shuffle
Expand All @@ -324,7 +326,7 @@ def __init__(
dp = dp.shuffle_matrix_and_dataframe()

dp = dp.batch_csr_matrix_and_dataframe(
self.batch_size, return_dense=self.return_dense
self.batch_size, return_dense=self.return_dense, allow_partials=self.allow_partials
)

# thought process on removal
Expand Down
Loading

0 comments on commit 279e396

Please sign in to comment.