Skip to content

Commit

Permalink
Merge pull request #55 from anthonyboos559/correlations_gpu_update
Browse files Browse the repository at this point in the history
Moves model processing to seperate rule from correlation calculations
  • Loading branch information
anthonyboos559 authored Nov 4, 2024
2 parents 715aa5d + d912f1b commit 1527a22
Show file tree
Hide file tree
Showing 5 changed files with 232 additions and 97 deletions.
20 changes: 17 additions & 3 deletions Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -162,17 +162,31 @@ rule predict:
## This rule outputs correlation scores per filtered data group
rule correlations:
input:
PREDICTIONS_PATH,
ckpt_path=CKPT_PATH,
output:
CORRELATION_FILES,
os.path.join(CORRELATION_DIR, "correlations_complete.log")
params:
command=TRAIN_COMMAND.lstrip('fit'),
data=CORRELATION_DATA,
save_dir=CORRELATION_DIR,
shell:
"""
mkdir -p {CORRELATION_DIR}
cmmvae workflow correlations {params.command} --correlation_data {params.data} --save_dir {params.save_dir}
cmmvae workflow correlations {params.command} --ckpt_path {input.ckpt_path} --correlation_data {params.data} --save_dir {params.save_dir}
touch {output}
"""

rule run_correlations:
input:
rules.correlations.output
output:
CORRELATION_FILES,
params:
directory=CORRELATION_DIR,
shell:
"""
mkdir -p {CORRELATION_DIR}
cmmvae workflow run-correlations --directory {params.directory}
"""

## Define the rule for generating UMAP visualizations from the merged predictions.
Expand Down
137 changes: 44 additions & 93 deletions src/cmmvae/runners/correlations.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,12 @@
from cmmvae.runners.cli import CMMVAECli
from cmmvae.data.local import SpeciesDataPipe

FILTERED_BY_CATEGORIES = ["assay", "cell_type", "tissue", "sex"]


def setup_datapipes(data_dir: str):
def setup_datapipes(data_dir: str, human_masks: str, mouse_masks: str):
human_pipe = SpeciesDataPipe(
directory_path=data_dir,
npz_masks="human*.npz",
metadata_masks="human*.pkl",
npz_masks=[f"{human_mask}.npz" for human_mask in human_masks],
metadata_masks=[f"{human_mask}.pkl" for human_mask in human_masks],
batch_size=10000,
allow_partials=True,
shuffle=False,
Expand All @@ -33,8 +31,8 @@ def setup_datapipes(data_dir: str):
)
mouse_pipe = SpeciesDataPipe(
directory_path=data_dir,
npz_masks="mouse*.npz",
metadata_masks="mouse*.pkl",
npz_masks=[f"{mouse_mask}.npz" for mouse_mask in mouse_masks],
metadata_masks=[f"{mouse_mask}.pkl" for mouse_mask in mouse_masks],
batch_size=10000,
allow_partials=True,
shuffle=False,
Expand All @@ -45,23 +43,31 @@ def setup_datapipes(data_dir: str):


def setup_dataloaders(data_dir: str):
human_pipe, mouse_pipe = setup_datapipes(data_dir)

gids = np.random.choice(
np.arange(1, 252), size=25, replace=False
)

human_masks = [f"human*_{gid}" for gid in gids]
mouse_masks = [f"mouse*_{gid}" for gid in gids]

human_pipe, mouse_pipe = setup_datapipes(data_dir, human_masks, mouse_masks)

human_dataloader = DataLoader(
dataset=human_pipe,
batch_size=None,
shuffle=False,
collate_fn=lambda x: x,
persistent_workers=False,
num_workers=6,
num_workers=2,
)
mouse_dataloader = DataLoader(
dataset=mouse_pipe,
batch_size=None,
shuffle=False,
collate_fn=lambda x: x,
persistent_workers=False,
num_workers=6,
num_workers=2,
)

return human_dataloader, mouse_dataloader
Expand All @@ -78,69 +84,15 @@ def convert_to_tensor(batch: sp.csr_matrix):
else torch.device("cpu"),
)


def calc_correlations(human_out: np.ndarray, mouse_out: np.ndarray, n_samples: int):
with np.errstate(divide="ignore", invalid="ignore"):
human_correlations = np.corrcoef(human_out)
mouse_correlations = np.corrcoef(mouse_out)

human_cis = np.round(
np.nan_to_num(human_correlations[:n_samples, :n_samples]).mean(), 3
)
human_cross = np.round(
np.nan_to_num(human_correlations[n_samples:, n_samples:]).mean(), 3
)
human_comb = np.round(
np.nan_to_num(human_correlations[:n_samples, n_samples:]).mean(), 3
)

mouse_cis = np.round(
np.nan_to_num(mouse_correlations[:n_samples, :n_samples]).mean(), 3
)
mouse_cross = np.round(
np.nan_to_num(mouse_correlations[n_samples:, n_samples:]).mean(), 3
)
mouse_comb = np.round(
np.nan_to_num(mouse_correlations[:n_samples, n_samples:]).mean(), 3
)

return pd.DataFrame(
{
"human_cis": [human_cis],
"human_cross": [human_cross],
"human_comb": [human_comb],
"mouse_cis": [mouse_cis],
"mouse_cross": [mouse_cross],
"mouse_comb": [mouse_comb],
}
)


def get_correlations(model: CMMVAEModel, data_dir: str):
def get_correlations(model: CMMVAEModel, data_dir: str, save_dir: str):
human_dataloader, mouse_dataloader = setup_dataloaders(data_dir)

correlations = pd.DataFrame(
columns=[
"group_id",
"num_samples",
"human_cis",
"human_cross",
"human_comb",
"mouse_cis",
"mouse_cross",
"mouse_comb",
"tag",
]
)

for (human_batch, human_metadata), (mouse_batch, mouse_metadata) in zip(
human_dataloader, mouse_dataloader
):
human_batch = human_batch.cuda()
mouse_batch = mouse_batch.cuda()

n_samples = human_metadata["num_samples"].iloc[0]

model.module.eval()
with torch.no_grad():
_, _, _, human_out, _ = model.module(
Expand All @@ -149,35 +101,35 @@ def get_correlations(model: CMMVAEModel, data_dir: str):
_, _, _, mouse_out, _ = model.module(
mouse_batch, mouse_metadata, RK.MOUSE, cross_generate=True
)

human_stacked_out = np.vstack(
(human_out[RK.HUMAN].cpu().numpy(), mouse_out[RK.HUMAN].cpu().numpy())
save_data = convert_to_csr(human_out, mouse_out)
metadata = {RK.HUMAN: human_metadata, RK.MOUSE: mouse_metadata}
save_correlations(save_data, metadata, save_dir, gid=human_metadata["group_id"].iloc[0])

def convert_to_csr(
human_xhats: dict[str: torch.Tensor],
mouse_xhats: dict[str: torch.Tensor],
):
converted = {}
for output_species, xhat in human_xhats.items():
nparray = xhat.cpu().numpy()
converted[f"human_to_{output_species}"] = sp.csr_matrix(nparray)
for output_species, xhat in mouse_xhats.items():
nparray = xhat.cpu().numpy()
converted[f"mouse_to_{output_species}"] = sp.csr_matrix(nparray)
return converted


def save_correlations(data: dict[str, sp.csr_matrix], metadata: pd.DataFrame, save_dir: str, gid: int):
for file_name, data in data.items():
sp.save_npz(
os.path.join(save_dir, f"{file_name}_{gid}.npz"),
data
)
mouse_stacked_out = np.vstack(
(mouse_out[RK.MOUSE].cpu().numpy(), human_out[RK.MOUSE].cpu().numpy())
)

avg_correlations = calc_correlations(
human_stacked_out, mouse_stacked_out, n_samples
for species, md in metadata.items():
md.to_pickle(
os.path.join(save_dir, f"{species}_metadata_{gid}.pkl")
)

avg_correlations["group_id"] = human_metadata["group_id"].iloc[0]
avg_correlations["num_samples"] = n_samples
avg_correlations["tag"] = " ".join(
[human_metadata[cat].iloc[0] for cat in FILTERED_BY_CATEGORIES]
)

correlations = pd.concat([correlations, avg_correlations], ignore_index=True)

return correlations


def save_correlations(correlations: pd.DataFrame, save_dir: str):
correlations = correlations.sort_values("group_id")
correlations.to_csv(os.path.join(save_dir, "correlations.csv"), index=False)
correlations.to_pickle(os.path.join(save_dir, "correlations.pkl"))


@click.command(
context_settings=dict(
ignore_unknown_options=True,
Expand Down Expand Up @@ -207,8 +159,7 @@ def correlations(ctx: click.Context, correlation_data: str, save_dir: str):
model = type(cli.model).load_from_checkpoint(
cli.config["ckpt_path"], module=cli.model.module
)
correlations = get_correlations(model, correlation_data)
save_correlations(correlations, save_dir)
get_correlations(model, correlation_data, save_dir)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 1527a22

Please sign in to comment.