Skip to content

Commit

Permalink
Updated scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
znavidi committed Feb 28, 2025
1 parent bc54fe9 commit d1a4637
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 204 deletions.
3 changes: 1 addition & 2 deletions morphodiff/evaluation/generate_img.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,7 @@ def load_model_and_generate_images(pipeline, model_checkpoint, prompts_df,

# read perturbations
prompts = []
perturbation_file = '../code/required_file/' +\
args.perturbation_list_address
perturbation_file = args.perturbation_list_address

prompt_df = pd.read_csv(perturbation_file)

Expand Down
5 changes: 3 additions & 2 deletions morphodiff/evaluation/perturbation_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
import numpy as np
import torch
import csv
import os


class PerturbationEncoder:

def __init__(self, dataset_id, model_type, model_name):
self.dataset_id = dataset_id
self.model_type = model_type
self.root = "../required_file/"
self.root = os.path.dirname(os.path.abspath(__file__))+"/../required_file/"

if 'HUVEC' in self.dataset_id:
self.sirna_to_gene_df = pd.read_csv(
Expand Down Expand Up @@ -177,7 +178,7 @@ def __init__(self, dataset_id, model_type, model_name):
self.dataset_id = dataset_id
self.model_type = model_type
self.model_name = model_name
self.root = "../code/required_file/"
self.root = os.path.dirname(os.path.abspath(__file__))+"/../required_file/"

if 'HUVEC' in self.dataset_id:
self.sirna_to_gene_df = pd.read_csv(
Expand Down
18 changes: 9 additions & 9 deletions morphodiff/scripts/generate_img.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@ trap 'handler' SIGUSR1
## Load the environment
# source /home/env/morphodiff/bin/activate

## Define/adjust the parameters
EXPERIMENT="BBBC021-experiment-01-resized"
## Define/adjust the parameters ##
# Set the experiment name
EXPERIMENT="BBBC021_experiment"
# you can download pretrained checkpoints from https://huggingface.co/navidi/MorphoDiff_checkpoints/tree/main
CKPT_PATH="/model/BBBC021-MorphoDiff/checkpoint-0"
# For VAE, you can set path to the downloaded stable-diffusion-v1-4, or one of the pretrained chekcpoints
VAE_PATH="/stable-diffusion-v1-4/"
# Set path to the directory where you want to save the generated images
GEN_IMG_PATH="/datasets/${EXPERIMENT}/generated_imgs/"
# Set the number of images you want to generate
Expand All @@ -35,18 +34,19 @@ OOD=False
MODEL_NAME="SD" # this is fixed
MODEL_TYPE="conditional" # set "conditional" for MorphoDiff, and "naive" for unconditional SD

# this PERTURBATION_LIST_PATH should be address of a .csv file with the following columns: perturbation, ood (including header)
# sample file can be found in morphodiff/required_file/BBBC021_experiment_pert_ood_info.csv
PERTURBATION_LIST_PATH="${EXPERIMENT}_pert_ood_info.csv"
# The PERTURBATION_LIST_PATH variable should be address of a .csv file with the following columns: perturbation, ood (including header)
# sample file can be found in morphodiff/required_file/BBBC021_experiment_pert_ood_info.csv for the BBBC021 experiment sample, and
# morphodiff/required_file/HUVEC_01_pert_ood_info.csv for the HUVEC experiment sample
PERTURBATION_LIST_PATH="../required_file/${EXPERIMENT}_pert_ood_info.csv"


## Generate images
python evaluation/generate_img.py \
python ../evaluation/generate_img.py \
--experiment $EXPERIMENT \
--model_checkpoint $CKPT_PATH \
--model_name $MODEL_NAME \
--model_type $MODEL_TYPE \
--vae_path $VAE_PATH \
--vae_path $CKPT_PATH \
--perturbation_list_address $PERTURBATION_LIST_PATH \
--gen_img_path $GEN_IMG_PATH \
--num_imgs $NUM_GEN_IMG \
Expand Down
266 changes: 75 additions & 191 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,201 +1,85 @@
absl-py==2.1.0
accelerate==0.27.2
aiohttp==3.9.3
aiosignal==1.3.1
annotated-types==0.6.0
appdirs==1.4.4
asttokens==2.4.1
async-timeout==4.0.3
attrs==23.2.0
awscli==1.33.16
botocore==1.34.134
certifi==2024.2.2
charset-normalizer==3.3.2
chex==0.1.85
cjm-pandas-utils==0.0.3
cjm-pil-utils==0.0.9
cjm-psl-utils==0.0.4
cjm-pytorch-utils==0.0.6
cjm-torchvision-tfms==0.0.11
click==8.1.7
colorama==0.4.6
comm==0.2.2
contourpy==1.2.0
cycler==0.12.1
datasets==2.18.0
debugpy==1.8.1
decorator==5.1.1
deepspeed==0.13.5
descriptastorus==2.6.1
diffusers @ file:///fs01/home/znavidi/project/cell_painting/cell_painting_SD/code/diffusers
accelerate==1.4.0
aiohappyeyeballs==2.4.6
aiohttp==3.11.13
aiosignal==1.3.2
annotated-types==0.7.0
async-timeout==5.0.1
attrs==25.1.0
certifi==2025.1.31
charset-normalizer==3.4.1
click==8.1.8
datasets==3.3.2
dill==0.3.8
docker-pycreds==0.4.0
docutils==0.16
dominate==2.9.1
et-xmlfile==1.1.0
etils==1.7.0
exceptiongroup==1.2.0
executing==2.0.1
filelock==3.13.1
flax==0.8.1
fonttools==4.50.0
frozenlist==1.4.1
fsspec==2023.10.0
ftfy==6.1.3
gitdb==4.0.11
GitPython==3.1.42
greenlet==3.0.3
grpcio==1.60.1
h5py==3.10.0
hjson==3.1.0
huggingface-hub==0.20.3
idna==3.6
imageio==2.34.0
imageio-ffmpeg==0.4.3
importlib-metadata==7.0.1
importlib_resources==6.1.3
ipykernel==6.29.4
ipython==8.23.0
ipywidgets==8.1.2
jax==0.4.25
jaxlib==0.4.25
jedi==0.19.1
Jinja2==3.1.3
jmespath==1.0.1
joblib==1.4.0
jsonpatch==1.33
jsonpointer==3.0.0
jupyter_client==8.6.1
jupyter_core==5.7.2
jupyterlab_widgets==3.0.10
kaleido==0.2.1
kiwisolver==1.4.5
lazy_loader==0.3
llvmlite==0.43.0
Markdown==3.5.2
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.8.3
matplotlib-inline==0.1.6
mdurl==0.1.2
ml-dtypes==0.3.2
filelock==3.17.0
frozenlist==1.5.0
fsspec==2024.12.0
ftfy==6.3.1
gitdb==4.0.12
GitPython==3.1.44
grpcio==1.70.0
huggingface-hub==0.25.0
idna==3.10
importlib_metadata==8.6.1
Jinja2==3.1.5
Markdown==3.7
MarkupSafe==3.0.2
mpmath==1.3.0
msgpack==1.0.8
multidict==6.0.5
multipledispatch==1.0.0
multidict==6.1.0
multiprocess==0.70.16
mypy-extensions==1.0.0
natsort==8.4.0
nest-asyncio==1.6.0
networkx==3.2.1
ninja==1.11.1.1
numba==0.60.0
numpy==1.26.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.19.3
nvidia-nvjitlink-cu12==12.3.101
nvidia-nvtx-cu12==12.1.105
opencv-python==4.9.0.80
openpyxl==3.1.5
opt-einsum==3.3.0
optax==0.2.0
orbax-checkpoint==0.5.3
packaging==23.2
pandas==2.2.0
pandas-flavor==0.6.0
parso==0.8.4
networkx==3.4.2
numpy==2.2.3
nvidia-cublas-cu12==12.4.5.8
nvidia-cuda-cupti-cu12==12.4.127
nvidia-cuda-nvrtc-cu12==12.4.127
nvidia-cuda-runtime-cu12==12.4.127
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.2.1.3
nvidia-curand-cu12==10.3.5.147
nvidia-cusolver-cu12==11.6.1.9
nvidia-cusparse-cu12==12.3.1.170
nvidia-cusparselt-cu12==0.6.2
nvidia-nccl-cu12==2.21.5
nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.4.127
packaging==24.2
pandas==2.2.3
peft==0.7.0
pexpect==4.9.0
pillow==10.2.0
platformdirs==4.2.0
plotly==5.24.1
prompt-toolkit==3.0.43
protobuf==4.25.3
psutil==5.9.8
ptyprocess==0.7.0
pure-eval==0.2.2
py-cpuinfo==9.0.0
pyarrow==15.0.0
pyarrow-hotfix==0.6
pyasn1==0.6.0
pybbbc @ git+https://github.com/giacomodeodato/pybbbc.git@27d6289d231f869561e0515ecce4934a8e744f39
pycytominer==1.1.0
pydantic==2.6.3
pydantic_core==2.16.3
Pygments==2.17.2
pyjanitor==0.27.0
pynndescent==0.5.13
pynvml==11.5.0
pyparsing==3.1.2
pyre-extensions==0.0.30
pyspng==0.1.2
python-dateutil==2.8.2
pytz==2024.1
PyYAML==6.0.1
pyzmq==25.1.2
rdkit==2023.9.5
regex==2023.12.25
requests==2.31.0
rich==13.7.1
rsa==4.7.2
s3transfer==0.10.2
safetensors==0.4.2
scikit-image==0.22.0
scikit-learn==1.4.2
scipy==1.11.1
seaborn==0.13.2
sentry-sdk==1.44.0
setproctitle==1.3.3
six==1.16.0
smmap==5.0.1
SQLAlchemy==1.4.53
stack-data==0.6.3
swd==1.0.0
sympy==1.12
tabulate==0.9.0
tenacity==9.0.0
tensorboard==2.16.2
pillow==11.1.0
platformdirs==4.3.6
propcache==0.3.0
protobuf==5.29.3
psutil==7.0.0
pyarrow==19.0.1
pydantic==2.10.6
pydantic_core==2.27.2
python-dateutil==2.9.0.post0
pytz==2025.1
PyYAML==6.0.2
regex==2024.11.6
requests==2.32.3
safetensors==0.5.3
sentry-sdk==2.22.0
setproctitle==1.3.5
six==1.17.0
smmap==5.0.2
sympy==1.13.1
tensorboard==2.19.0
tensorboard-data-server==0.7.2
tensorstore==0.1.54
thop==0.1.1.post2209072238
threadpoolctl==3.5.0
tifffile==2024.2.12
timm==1.0.8
tokenizers==0.15.2
toolz==0.12.1
torch==2.2.1
torch-fidelity==0.3.0
torcheval==0.0.7
torchtnt==0.2.4
torchvision==0.17.1
tornado==6.4
tqdm==4.66.2
traitlets==5.14.2
torch==2.6.0
torchvision==0.21.0
tqdm==4.67.1
transformers==4.38.2
triton==2.2.0
typing-inspect==0.9.0
typing_extensions==4.9.0
tzdata==2024.1
ultralytics==8.1.38
umap-learn==0.5.6
urllib3==2.2.1
visdom==0.2.4
wandb==0.16.5
triton==3.2.0
typing_extensions==4.12.2
tzdata==2025.1
urllib3==2.3.0
wandb==0.19.7
wcwidth==0.2.13
websocket-client==1.8.0
Werkzeug==3.0.1
widgetsnbextension==4.0.10
xarray==2024.3.0
xformers==0.0.24
XlsxWriter==3.2.0
xxhash==3.4.1
yarl==1.9.4
zipp==3.17.0
Werkzeug==3.1.3
xformers @ git+https://github.com/facebookresearch/xformers.git@91b8f8d19d02d4bc643c89031905caf0c3e73382
xxhash==3.5.0
yarl==1.18.3
zipp==3.21.0

0 comments on commit d1a4637

Please sign in to comment.