Skip to content

Commit

Permalink
Implement inference test
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Dec 17, 2024
1 parent fba6486 commit f6853d5
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 11 deletions.
16 changes: 5 additions & 11 deletions test/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from synapse_net.sample_data import get_sample_data


@unittest.skipIf(platform.system() == "Windows", "CLI does not work on Windows")
class TestCLI(unittest.TestCase):
tmp_dir = "tmp"

Expand Down Expand Up @@ -72,17 +73,10 @@ def test_segmentation_cli_with_scale(self):
def test_segmentation_cli_with_checkpoint(self):
cache_dir = os.path.expanduser(pooch.os_cache("synapse-net"))
model_path = os.path.join(cache_dir, "models", "vesicles_2d")
if platform.system() == "Windows":
cmd = [
sys.executable, "-m", "synapse_net.run_segmentation",
"-i", self.data_path, "-o", self.tmp_dir, "-m", "vesicles_2d",
"-c", model_path,
]
else:
cmd = [
"synapse_net.run_segmentation", "-i", self.data_path, "-o", self.tmp_dir, "-m", "vesicles_2d",
"-c", model_path,
]
cmd = [
"synapse_net.run_segmentation", "-i", self.data_path, "-o", self.tmp_dir, "-m", "vesicles_2d",
"-c", model_path,
]
run(cmd)
self.check_segmentation_result()

Expand Down
51 changes: 51 additions & 0 deletions test/test_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import os
import unittest
from functools import partial
from shutil import rmtree

import imageio.v3 as imageio
from synapse_net.file_utils import read_mrc
from synapse_net.sample_data import get_sample_data


class TestInference(unittest.TestCase):
tmp_dir = "tmp"
model_type = "vesicles_2d"
tiling = {"tile": {"z": 1, "y": 512, "x": 512}, "halo": {"z": 0, "y": 32, "x": 32}}

def setUp(self):
self.data_path = get_sample_data("tem_2d")
os.makedirs(self.tmp_dir, exist_ok=True)

def tearDown(self):
try:
rmtree(self.tmp_dir)
except OSError:
pass

def test_run_segmentation(self):
from synapse_net.inference import run_segmentation, get_model

image, _ = read_mrc(self.data_path)
model = get_model(self.model_type)
seg = run_segmentation(image, model, model_type=self.model_type, tiling=self.tiling)
self.assertEqual(image.shape, seg.shape)

def test_segmentation_with_inference_helper(self):
from synapse_net.inference import run_segmentation, get_model
from synapse_net.inference.util import inference_helper

model = get_model(self.model_type)
segmentation_function = partial(
run_segmentation, model=model, model_type=self.model_type, verbose=False, tiling=self.tiling,
)
inference_helper(self.data_path, self.tmp_dir, segmentation_function, data_ext=".mrc")
expected_output_path = os.path.join(self.tmp_dir, "tem_2d_prediction.tif")
self.assertTrue(os.path.exists(expected_output_path))
seg = imageio.imread(expected_output_path)
image, _ = read_mrc(self.data_path)
self.assertEqual(image.shape, seg.shape)


if __name__ == "__main__":
unittest.main()

0 comments on commit f6853d5

Please sign in to comment.