diff --git a/src/main/java/net/imglib2/labkit/labeling/LabelingSerializer.java b/src/main/java/net/imglib2/labkit/labeling/LabelingSerializer.java index f3107b81..6093516e 100644 --- a/src/main/java/net/imglib2/labkit/labeling/LabelingSerializer.java +++ b/src/main/java/net/imglib2/labkit/labeling/LabelingSerializer.java @@ -153,7 +153,7 @@ private > void saveAsTiff(Labeling labeling, io.save(ds.create(imgPlus), filename); } - private static class LabelsMetaData { + public static class LabelsMetaData { List> labelSets; diff --git a/src/main/java/net/imglib2/labkit/models/DefaultSegmentationModel.java b/src/main/java/net/imglib2/labkit/models/DefaultSegmentationModel.java index 41bdfd22..bfb4c19c 100644 --- a/src/main/java/net/imglib2/labkit/models/DefaultSegmentationModel.java +++ b/src/main/java/net/imglib2/labkit/models/DefaultSegmentationModel.java @@ -127,10 +127,10 @@ public SegmentationItem addSegmenter() { @Override public void train(SegmentationItem item) { - ParallelUtils.runInOtherThread(() -> internTrain(item)); + ParallelUtils.runInOtherThread(() -> trainAndWait(item)); } - private void internTrain(SegmentationItem item) { + public void trainAndWait(SegmentationItem item) { SwingProgressWriter progressWriter = new SwingProgressWriter(null, "Training in Progress"); progressWriter.setVisible(true); diff --git a/src/main/java/net/imglib2/labkit/plugin/SimpleLabkitPlugin.java b/src/main/java/net/imglib2/labkit/plugin/SimpleLabkitPlugin.java new file mode 100644 index 00000000..57dd4404 --- /dev/null +++ b/src/main/java/net/imglib2/labkit/plugin/SimpleLabkitPlugin.java @@ -0,0 +1,83 @@ +package net.imglib2.labkit.plugin; + +import ij.ImagePlus; +import net.imglib2.img.Img; +import net.imglib2.img.display.imagej.ImageJFunctions; +import net.imglib2.labkit.BatchSegmenter; +import net.imglib2.labkit.inputimage.DefaultInputImage; +import net.imglib2.labkit.labeling.Labeling; +import net.imglib2.labkit.labeling.LabelingSerializer; +import net.imglib2.labkit.models.DefaultSegmentationModel; +import net.imglib2.labkit.models.ImageLabelingModel; +import net.imglib2.labkit.utils.progress.StatusServiceProgressWriter; +import net.imglib2.roi.labeling.ImgLabeling; +import net.imglib2.type.numeric.ARGBType; +import net.imglib2.type.numeric.IntegerType; +import net.imglib2.util.Intervals; +import org.scijava.Context; +import org.scijava.ItemIO; +import org.scijava.app.StatusService; +import org.scijava.command.Command; +import org.scijava.plugin.Parameter; +import org.scijava.plugin.Plugin; + +import static net.imglib2.labkit.labeling.LabelingSerializer.fromImageAndLabelSets; + +@Plugin( type = Command.class, menuPath = "Plugins>Segmentation>Label (simple mode)" ) +public class SimpleLabkitPlugin implements Command { + + @Parameter + Img input; + + @Parameter + Img> labeling; + + @Parameter(type = ItemIO.OUTPUT) + Img output; + + @Parameter + Context context; + + @Parameter + StatusService statusService; + + @Override + public void run() { + + // init segmentation model, serializer, labeling model + DefaultSegmentationModel segmentationModel = new DefaultSegmentationModel( new DefaultInputImage( + input ), context ); + final ImageLabelingModel labelingModel = segmentationModel + .imageLabelingModel(); + + // load labeling from labeling img + LabelingSerializer.LabelsMetaData meta = new LabelingSerializer.LabelsMetaData(labeling); + ImgLabeling imgLabeling = fromImageAndLabelSets(labeling, meta.asLabelSets()); + labelingModel.labeling().set( Labeling.fromImgLabeling(imgLabeling) ); + if ( labelingModel.labeling().get().getLabels().size() == 0 ) + { + System.out.println( "no labels" ); + return; + } + + // train + segmentationModel.trainAndWait( segmentationModel + .selectedSegmenter().get() ); + + // run segmentation + final ImagePlus segImgImagePlus = ImageJFunctions.wrap( input, "seginput" ); + final Img segImg = ImageJFunctions.wrap(segImgImagePlus); + try + { + output = BatchSegmenter.segment( segImg, + segmentationModel.selectedSegmenter().get(), + Intervals.dimensionsAsIntArray( segImg ), + new StatusServiceProgressWriter( statusService ) ); + } + catch ( InterruptedException e ) + { + e.printStackTrace(); + } + + } +} diff --git a/src/test/java/net/imglib2/labkit/plugin/SimpleLabkitPluginDemo.java b/src/test/java/net/imglib2/labkit/plugin/SimpleLabkitPluginDemo.java new file mode 100644 index 00000000..0aed3376 --- /dev/null +++ b/src/test/java/net/imglib2/labkit/plugin/SimpleLabkitPluginDemo.java @@ -0,0 +1,74 @@ +package net.imglib2.labkit.plugin; + +import net.imagej.ImageJ; +import net.imglib2.RandomAccess; +import net.imglib2.img.Img; +import net.imglib2.type.numeric.integer.ByteType; +import net.imglib2.type.numeric.integer.UnsignedByteType; +import net.imglib2.type.numeric.real.DoubleType; +import org.junit.Test; +import org.scijava.command.CommandModule; + +import java.util.concurrent.ExecutionException; + +import static org.junit.Assert.assertEquals; + +public class SimpleLabkitPluginDemo { + + @Test + public void run() throws ExecutionException, InterruptedException { + ImageJ ij = new ImageJ(); + + long[] dims = new long[]{10,10}; + + //create input image, set values to zero, left top quarter to one + Img input = ij.op().create().img(dims); + RandomAccess inputRA = input.randomAccess(); + for (int i = 0; i < dims[0]; i++) { + for (int j = 0; j < dims[1]; j++) { + inputRA.setPosition(new int[]{i,j}); + if(i < 5 && j < 5) { + inputRA.get().setOne(); + } else { + inputRA.get().setZero(); + } + } + } + + //create labeling with first pixel labeled one, rest zero + Img labeling = ij.op().convert().int8(ij.op().create().img(dims)); + labeling.forEach(pixel -> pixel.setZero()); + labeling.firstElement().setReal(2.0); + RandomAccess labelingRA = labeling.randomAccess(); + labelingRA.setPosition(new int[]{(int) (dims[0]-1), (int) (dims[1]-1)}); + labelingRA.get().setOne(); + + ij.ui().show(input); + ij.ui().show(labeling); + + // run labkit command + CommandModule result = ij.command().run(SimpleLabkitPlugin.class, true, "input", input, "labeling", labeling).get(); + Img output = (Img) result.getOutput("output"); + + ij.ui().show("output", output); + // result control + assertEquals(dims[0], output.dimension(0)); + assertEquals(dims[1], output.dimension(1)); + RandomAccess outputRA = output.randomAccess(); + for (int i = 0; i < dims[0]; i++) { + for (int j = 0; j < dims[1]; j++) { + outputRA.setPosition(new int[]{i,j}); + if(i < 5 && j < 5) { + assertEquals(1, outputRA.get().get()); + }else { + assertEquals(0, outputRA.get().get()); + } + } + } + } + + public static void main(String... args) throws ExecutionException, InterruptedException { + new SimpleLabkitPluginDemo().run(); + } + +}