Skip to content

Commit

Permalink
create stardist abstract
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Jan 16, 2025
1 parent 042f67a commit 66bdd17
Show file tree
Hide file tree
Showing 2 changed files with 387 additions and 321 deletions.
341 changes: 20 additions & 321 deletions src/main/java/io/bioimage/modelrunner/model/Stardist2D.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,41 +23,24 @@
import java.io.FileNotFoundException;
import java.io.IOException;
import java.net.URISyntaxException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.apache.commons.compress.archivers.ArchiveException;

import io.bioimage.modelrunner.apposed.appose.Environment;
import io.bioimage.modelrunner.apposed.appose.Mamba;
import io.bioimage.modelrunner.apposed.appose.MambaInstallException;
import io.bioimage.modelrunner.apposed.appose.Service;
import io.bioimage.modelrunner.apposed.appose.Service.Task;
import io.bioimage.modelrunner.apposed.appose.Service.TaskStatus;
import io.bioimage.modelrunner.bioimageio.BioimageioRepo;
import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor;
import io.bioimage.modelrunner.bioimageio.description.ModelDescriptorFactory;
import io.bioimage.modelrunner.bioimageio.description.exceptions.ModelSpecsException;
import io.bioimage.modelrunner.engine.installation.EngineInstall;
import io.bioimage.modelrunner.exceptions.LoadEngineException;
import io.bioimage.modelrunner.exceptions.LoadModelException;
import io.bioimage.modelrunner.exceptions.RunModelException;
import io.bioimage.modelrunner.system.PlatformDetection;
import io.bioimage.modelrunner.tensor.Tensor;
import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray;
import io.bioimage.modelrunner.utils.CommonUtils;
import io.bioimage.modelrunner.utils.Constants;
import io.bioimage.modelrunner.utils.JSONUtils;
import io.bioimage.modelrunner.versionmanagement.InstalledEngines;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.img.array.ArrayImgs;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Cast;
import net.imglib2.util.Util;

/**
* Implementation of an API to run Stardist 2D models out of the box with little configuration.
Expand All @@ -66,268 +49,34 @@
*
*@author Carlos Garcia
*/
public class Stardist2D {
public class Stardist2D extends StardistAbstract {

private final String modelDir;

private final String name;

private final String basedir;

private final int nChannels;

private boolean loaded = false;

private SharedMemoryArray shma;

private ModelDescriptor descriptor;

private Service python;

private static final List<String> STARDIST_DEPS = Arrays.asList(new String[] {"python=3.10", "stardist", "numpy", "appose"});

private static final List<String> STARDIST_CHANNELS = Arrays.asList(new String[] {"conda-forge", "default"});


private static final String COORDS_DTYPE_KEY = "coords_dtype";

private static final String COORDS_SHAPE_KEY = "coords_shape";

private static final String POINTS_DTYPE_KEY = "points_dtype";

private static final String POINTS_SHAPE_KEY = "points_shape";

private static final String POINTS_KEY = "points";

private static final String COORDS_KEY = "coords";

private static final String LOAD_MODEL_CODE = ""
+ "if 'StarDist2D' not in globals().keys():" + System.lineSeparator()
+ " from stardist.models import StarDist2D" + System.lineSeparator()
+ " globals()['StarDist2D'] = StarDist2D" + System.lineSeparator()
+ "if 'np' not in globals().keys():" + System.lineSeparator()
+ " import numpy as np" + System.lineSeparator()
+ " globals()['np'] = np" + System.lineSeparator()
+ "if 'os' not in globals().keys():" + System.lineSeparator()
+ " import os" + System.lineSeparator()
+ " globals()['os'] = os" + System.lineSeparator()
+ "if 'shared_memory' not in globals().keys():" + System.lineSeparator()
+ " from multiprocessing import shared_memory" + System.lineSeparator()
+ " globals()['shared_memory'] = shared_memory" + System.lineSeparator()
+ "model = StarDist2D(None, name='%s', basedir='%s')" + System.lineSeparator()
+ "globals()['model'] = model";

private static final String RUN_MODEL_CODE = ""
+ "output = model.predict_instances(im, return_predict=False)" + System.lineSeparator()
+ "im[:] = output[0]" + System.lineSeparator()
+ "if output[1]['points'].nbytes == 0:" + System.lineSeparator()
+ " task.outputs['points_shape'] = None" + System.lineSeparator()
+ "else:" + System.lineSeparator()
+ " task.outputs['points_shape'] = output[1]['points'].shape" + System.lineSeparator()
+ " task.outputs['points_dtype'] = output[1]['points'].dtype" + System.lineSeparator()
+ " points_shm = "
+ " shared_memory.SharedMemory(create=True, name=os.path.basename(shm_points_id), size=output[1]['points'].nbytes)" + System.lineSeparator()
+ " shared_points = np.ndarray(output[1]['points'].shape, dtype=output[1]['points'].dtype, buffer=points_shm.buf)" + System.lineSeparator()
+ " globals()['shared_points'] = shared_points" + System.lineSeparator()
+ "if output[1]['coord'].nbytes == 0:" + System.lineSeparator()
+ " task.outputs['coords_shape'] = None" + System.lineSeparator()
+ "else:" + System.lineSeparator()
+ " task.outputs['coords_shape'] = output[1]['coord'].shape" + System.lineSeparator()
+ " task.outputs['coords_dtype'] = output[1]['coord'].dtype" + System.lineSeparator()
+ " coords_shm = "
+ " shared_memory.SharedMemory(create=True, name=os.path.basename(shm_points_id), size=output[1]['coord'].nbytes)" + System.lineSeparator()
+ " shared_coords = np.ndarray(output[1]['coord'].shape, dtype=output[1]['coord'].dtype, buffer=coords_shm.buf)" + System.lineSeparator()
+ " globals()['shared_coords'] = shared_coords" + System.lineSeparator()
+ "if os.name == 'nt':" + System.lineSeparator()
+ " im_shm.close()" + System.lineSeparator()
+ " im_shm.unlink()" + System.lineSeparator();

private static final String CLOSE_SHM_CODE = ""
+ "if 'points_shm' in globals().keys():" + System.lineSeparator()
+ " points_shm.close()" + System.lineSeparator()
+ " points_shm.unlink()" + System.lineSeparator()
+ "if 'coords_shm' in globals().keys():" + System.lineSeparator()
+ " coords_shm.close()" + System.lineSeparator()
+ " coords_shm.unlink()" + System.lineSeparator();
private static String MODULE_NAME = "Stardist2D";

private Stardist2D(String modelName, String baseDir) throws IOException, ModelSpecsException {
this.name = modelName;
this.basedir = baseDir;
modelDir = new File(baseDir, modelName).getAbsolutePath();
if (new File(modelDir, "config.json").isFile() == false && new File(modelDir, Constants.RDF_FNAME).isFile() == false)
throw new IllegalArgumentException("No 'config.json' file found in the model directory");
else if (new File(modelDir, "config.json").isFile() == false)
createConfigFromBioimageio();
if (new File(modelDir, "thresholds.json").isFile() == false && new File(modelDir, Constants.RDF_FNAME).isFile() == false)
throw new IllegalArgumentException("No 'thresholds.json' file found in the model directory");
else if (new File(modelDir, "thresholds.json").isFile() == false)
createThresholdsFromBioimageio();
this.nChannels = ((Number) JSONUtils.load(new File(modelDir, "config.json").getAbsolutePath()).get("n_channel_in")).intValue();
createPythonService();
super(modelName, baseDir);
}

private Stardist2D(ModelDescriptor descriptor) throws IOException, ModelSpecsException {
this.descriptor = descriptor;
this.name = new File(descriptor.getModelPath()).getName();
this.basedir = new File(descriptor.getModelPath()).getParentFile().getAbsolutePath();
modelDir = descriptor.getModelPath();
if (new File(modelDir, "config.json").isFile() == false)
createConfigFromBioimageio();
if (new File(modelDir, "thresholds.json").isFile() == false)
createThresholdsFromBioimageio();
this.nChannels = ((Number) JSONUtils.load(new File(modelDir, "config.json").getAbsolutePath()).get("n_channel_in")).intValue();
createPythonService();
super(descriptor);
}

private void createConfigFromBioimageio() throws IOException, ModelSpecsException {
if (descriptor == null)
descriptor = ModelDescriptorFactory.readFromLocalFile(modelDir + File.separator + Constants.RDF_FNAME);
Map<String, Object> stardistMap = (Map<String, Object>) descriptor.getConfig().getSpecMap().get("stardist");
Map<String, Object> stardistConfig = (Map<String, Object>) stardistMap.get("config");
JSONUtils.writeJSONFile(new File(modelDir, "config.json").getAbsolutePath(), stardistConfig);
}

private void createThresholdsFromBioimageio() throws IOException, ModelSpecsException {
if (descriptor == null)
descriptor = ModelDescriptorFactory.readFromLocalFile(modelDir + File.separator + Constants.RDF_FNAME);
Map<String, Object> stardistMap = (Map<String, Object>) descriptor.getConfig().getSpecMap().get("stardist");
Map<String, Object> stardistThres = (Map<String, Object>) stardistMap.get("thresholds");
JSONUtils.writeJSONFile(new File(modelDir, "thresholds.json").getAbsolutePath(), stardistThres);
}

private void createPythonService() throws IOException {
Environment env = new Environment() {
@Override public String base() { return new Mamba().getEnvsDir() + File.separator + "stardist"; }
};
python = env.python();
python.debug(System.err::println);
}

protected String createEncodeImageScript() {
String code = "";
// This line wants to recreate the original numpy array. Should look like:
// input0_appose_shm = shared_memory.SharedMemory(name=input0)
// input0 = np.ndarray(size, dtype="float64", buffer=input0_appose_shm.buf).reshape([64, 64])
code += "im_shm = shared_memory.SharedMemory(name='"
+ shma.getNameForPython() + "', size=" + shma.getSize()
+ ")" + System.lineSeparator();
long nElems = 1;
for (long elem : shma.getOriginalShape()) nElems *= elem;
code += "im = np.ndarray(" + nElems + ", dtype='" + CommonUtils.getDataTypeFromRAI(Cast.unchecked(shma.getSharedRAI()))
+ "', buffer=im_shm.buf).reshape([";
for (int i = 0; i < shma.getOriginalShape().length; i ++)
code += shma.getOriginalShape()[i] + ", ";
code += "])" + System.lineSeparator();
return code;
}

public void close() {
if (!loaded)
return;
python.close();
}

public <T extends RealType<T> & NativeType<T>>
Map<String, RandomAccessibleInterval<T>> predict(RandomAccessibleInterval<T> img) throws IOException, InterruptedException {

shma = SharedMemoryArray.createSHMAFromRAI(img, false, false);
String code = "";
if (!loaded) {
code += String.format(LOAD_MODEL_CODE, this.name, this.basedir) + System.lineSeparator();
}

code += createEncodeImageScript() + System.lineSeparator();
code += RUN_MODEL_CODE + System.lineSeparator();


Map<String, Object> inputs = new HashMap<String, Object>();
String shm_coords_id = SharedMemoryArray.createShmName();
String shm_points_id = SharedMemoryArray.createShmName();
inputs.put("shm_coords_id", shm_coords_id);
inputs.put("shm_points_id", shm_points_id);

Task task = python.task(code, inputs);
task.waitFor();
if (task.status == TaskStatus.CANCELED)
throw new RuntimeException("Task canceled");
else if (task.status == TaskStatus.FAILED)
throw new RuntimeException(task.error);
else if (task.status == TaskStatus.CRASHED)
throw new RuntimeException(task.error);
loaded = true;


return reconstructOutputs(task, shm_coords_id, shm_points_id);
}

private <T extends RealType<T> & NativeType<T>>
Map<String, RandomAccessibleInterval<T>> reconstructOutputs(Task task, String shm_coords_id, String shm_points_id)
throws IOException, InterruptedException {

Map<String, RandomAccessibleInterval<T>> outs = new HashMap<String, RandomAccessibleInterval<T>>();
// TODO I do not understand why is complaining when the types align perfectly
RandomAccessibleInterval<T> maskCopy = Tensor.createCopyOfRaiInWantedDataType(Cast.unchecked(shma.getSharedRAI()),
Util.getTypeFromInterval(Cast.unchecked(shma.getSharedRAI())));
outs.put("mask", maskCopy);
outs.put("points", reconstructPoints(task, shm_points_id));
outs.put("coord", reconstructCoord(task, shm_coords_id));

shma.close();

if (PlatformDetection.isWindows()) {
Task closeSHMTask = python.task(CLOSE_SHM_CODE);
closeSHMTask.waitFor();
}
return outs;
@Override
protected String createImportsCode() {
return String.format(LOAD_MODEL_CODE_ABSTRACT, MODULE_NAME, MODULE_NAME,
MODULE_NAME, MODULE_NAME, MODULE_NAME, this.name, this.basedir);
}

private <T extends RealType<T> & NativeType<T>>
RandomAccessibleInterval<T> reconstructCoord(Task task, String shm_coords_id) throws IOException {

String coords_dtype = (String) task.outputs.get("coords_dtype");
List<Number> coords_shape = (List<Number>) task.outputs.get("coords_shape");
if (coords_shape == null)
return null;

long[] coordsSh = new long[coords_shape.size()];
for (int i = 0; i < coordsSh.length; i ++)
coordsSh[i] = coords_shape.get(i).longValue();
SharedMemoryArray shmCoords = SharedMemoryArray.readOrCreate(shm_coords_id, coordsSh,
Cast.unchecked(CommonUtils.getImgLib2DataType(coords_dtype)), false, false);

Map<String, RandomAccessibleInterval<T>> outs = new HashMap<String, RandomAccessibleInterval<T>>();
// TODO I do not understand why is complaining when the types align perfectly
RandomAccessibleInterval<T> coordsRAI = shmCoords.getSharedRAI();
RandomAccessibleInterval<T> coordsCopy = Tensor.createCopyOfRaiInWantedDataType(Cast.unchecked(coordsRAI),
Util.getTypeFromInterval(Cast.unchecked(shmCoords)));
outs.put("coords", coordsCopy);

shmCoords.close();

return coordsCopy;
}

private <T extends RealType<T> & NativeType<T>>
RandomAccessibleInterval<T> reconstructPoints(Task task, String shm_points_id) throws IOException {

String points_dtype = (String) task.outputs.get("points_dtype");
List<Number> points_shape = (List<Number>) task.outputs.get("points_shape");
if (points_shape == null)
return null;


long[] pointsSh = new long[points_shape.size()];
for (int i = 0; i < pointsSh.length; i ++)
pointsSh[i] = points_shape.get(i).longValue();
SharedMemoryArray shmPoints = SharedMemoryArray.readOrCreate(shm_points_id, pointsSh,
Cast.unchecked(CommonUtils.getImgLib2DataType(points_dtype)), false, false);

// TODO I do not understand why is complaining when the types align perfectly
RandomAccessibleInterval<T> pointsRAI = shmPoints.getSharedRAI();
RandomAccessibleInterval<T> pointsCopy = Tensor.createCopyOfRaiInWantedDataType(Cast.unchecked(pointsRAI),
Util.getTypeFromInterval(Cast.unchecked(pointsRAI)));
shmPoints.close();
return pointsCopy;

@Override
protected <T extends RealType<T> & NativeType<T>> void checkInput(RandomAccessibleInterval<T> image) {
if (image.dimensionsAsLongArray().length == 2 && this.nChannels != 1)
throw new IllegalArgumentException("Stardist2D needs an image with three dimensions: XYC");
else if (image.dimensionsAsLongArray().length != 3 && this.nChannels != 1)
throw new IllegalArgumentException("Stardist2D needs an image with three dimensions: XYC");
else if (image.dimensionsAsLongArray().length != 2 && image.dimensionsAsLongArray()[2] != nChannels)
throw new IllegalArgumentException("This Stardist2D model requires " + nChannels + " channels.");
else if (image.dimensionsAsLongArray().length > 3 || image.dimensionsAsLongArray().length < 2)
throw new IllegalArgumentException("Stardist2D model requires an image with dimensions XYC.");
}

/**
Expand Down Expand Up @@ -404,58 +153,7 @@ public static Stardist2D fromPretained(String pretrainedModel, String installDir
}
}

private <T extends RealType<T> & NativeType<T>> void checkInput(RandomAccessibleInterval<T> image) {
if (image.dimensionsAsLongArray().length == 2 && this.nChannels != 1)
throw new IllegalArgumentException("Stardist2D needs an image with three dimensions: XYC");
else if (image.dimensionsAsLongArray().length != 3 && this.nChannels != 1)
throw new IllegalArgumentException("Stardist2D needs an image with three dimensions: XYC");
else if (image.dimensionsAsLongArray().length != 2 && image.dimensionsAsLongArray()[2] != nChannels)
throw new IllegalArgumentException("This Stardist2D model requires " + nChannels + " channels.");
else if (image.dimensionsAsLongArray().length > 3 || image.dimensionsAsLongArray().length < 2)
throw new IllegalArgumentException("Stardist2D model requires an image with dimensions XYC.");
}

/**
* Check whether everything that is needed for Stardist 2D is installed or not
*/
public void checkRequirementsInstalled() {
// TODO
}

/**
* Check whether the requirements needed to run Stardist 2D are satisfied or not.
* First checks if the corresponding Java DL engine is installed or not, then checks
* if the Python environment needed for Stardist2D post processing is fine too.
*
* If anything is not installed, this method also installs it
*
* @throws IOException if there is any error downloading the DL engine or installing the micromamba environment
* @throws InterruptedException if the installation is stopped
* @throws RuntimeException if there is any unexpected error in the micromamba environment installation
* @throws MambaInstallException if there is any error downloading or installing micromamba
* @throws ArchiveException if there is any error decompressing the micromamba installer
* @throws URISyntaxException if the URL to the micromamba installation is not correct
*/
public static void installRequirements() throws IOException, InterruptedException,
RuntimeException, MambaInstallException,
ArchiveException, URISyntaxException {
boolean installed = InstalledEngines.buildEnginesFinder()
.checkEngineWithArgsInstalledForOS("tensorflow", "1.15.0", null, null).size() != 0;
if (!installed)
EngineInstall.installEngineWithArgs("tensorflow", "1.15.0", true, true);

Mamba mamba = new Mamba();
boolean stardistPythonInstalled = false;
try {
stardistPythonInstalled = mamba.checkAllDependenciesInEnv("stardist", STARDIST_DEPS);
} catch (MambaInstallException e) {
mamba.installMicromamba();
}
if (!stardistPythonInstalled) {
// TODO add logging for environment installation
mamba.create("stardist", true, STARDIST_CHANNELS, STARDIST_DEPS);
};
}

/**
* Main method to check functionality
Expand Down Expand Up @@ -483,6 +181,7 @@ public static void main(String[] args) throws IOException, InterruptedException,
RandomAccessibleInterval<FloatType> img = ArrayImgs.floats(new long[] {512, 512});

Map<String, RandomAccessibleInterval<FloatType>> res = model.predict(img);
model.close();
System.out.println(true);
}
}
Loading

0 comments on commit 66bdd17

Please sign in to comment.