Skip to content

Commit

Permalink
correct more errors related to pre and post processing
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Nov 20, 2024
1 parent 49ff319 commit 2c04835
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ private void buildPreprocessing() throws IllegalArgumentException, RuntimeExcept
for (TransformSpec transformation : preprocessing) {
list.add(TransformationInstance.create(transformation));
}
preMap.put(tt.getName(), list);
}
}

Expand All @@ -86,6 +87,7 @@ private void buildPostprocessing() throws IllegalArgumentException, RuntimeExcep
for (TransformSpec transformation : preprocessing) {
list.add(TransformationInstance.create(transformation));
}
postMap.put(tt.getName(), list);
}
}

Expand Down Expand Up @@ -139,6 +141,8 @@ List<Tensor<R>> preprocess(List<Tensor<T>> tensorList, boolean inplace) {
Tensor<T> tt = tensorList.stream().filter(t -> t.getName().equals(ee.getKey())).findFirst().orElse(null);
if (tt == null)
continue;
if (ee.getValue().size() == 0)
outputs.add(Cast.unchecked(tt));
for (TransformationInstance trans : ee.getValue()) {
List<Tensor<R>> outList = trans.run(tt, inplace);
outputs.addAll(outList);
Expand Down Expand Up @@ -184,6 +188,8 @@ List<Tensor<R>> postprocess(List<Tensor<T>> tensorList, boolean inplace) {
Tensor<T> tt = tensorList.stream().filter(t -> t.getName().equals(ee.getKey())).findFirst().orElse(null);
if (tt == null)
continue;
if (ee.getValue().size() == 0)
outputs.add(Cast.unchecked(tt));
for (TransformationInstance trans : ee.getValue()) {
List<Tensor<R>> outList = trans.run(tt, inplace);
outputs.addAll(outList);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;

Expand All @@ -34,6 +36,7 @@
import io.bioimage.modelrunner.transformations.BinarizeTransformation;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.util.Cast;

/**
* Class that creates an instance able to run the corresponding Bioimage.io processing routine
Expand Down Expand Up @@ -116,13 +119,43 @@ List<Tensor<R>> run(Tensor<T> tensor) throws RuntimeException {
List<Tensor<R>> run(Tensor<T> tensor, boolean inplace) throws RuntimeException {
Method m;
try {
if (inplace)
m = cls.getMethod(RUN_INPLACE_NAME, Tensor.class);
else
m = cls.getMethod(RUN_NAME, Tensor.class);

m.invoke(this.instance, tensor);
return null;
if (inplace) {
m = cls.getMethod(RUN_INPLACE_NAME, Tensor.class);
m.invoke(this.instance, tensor);
return Collections.singletonList(Cast.unchecked(tensor));
} else {
m = cls.getMethod(RUN_NAME, Tensor.class);
Object result = m.invoke(this.instance, tensor);

// Handle different possible return types
if (result == null) {
return null;
} else if (result instanceof List<?>) {
// Cast and verify each element is a Tensor<R>
List<?> resultList = (List<?>) result;
List<Tensor<R>> outputList = new ArrayList<>();

for (Object item : resultList) {
if (item instanceof Tensor<?>) {
@SuppressWarnings("unchecked")
Tensor<R> tensorItem = (Tensor<R>) item;
outputList.add(tensorItem);
} else {
throw new RuntimeException("Invalid return type: Expected Tensor but got " +
(item != null ? item.getClass().getName() : "null"));
}
}
return outputList;
} else if (result instanceof Tensor<?>) {
// Single Tensor result
@SuppressWarnings("unchecked")
Tensor<R> tensorResult = (Tensor<R>) result;
return Collections.singletonList(tensorResult);
} else {
throw new RuntimeException("Unexpected return type: " +
(result != null ? result.getClass().getName() : "null"));
}
}
} catch (IllegalAccessException | IllegalArgumentException | InvocationTargetException
| NoSuchMethodException | SecurityException e) {
throw new RuntimeException(Types.stackTrace(e));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ protected < R extends RealType< R > & NativeType< R > > Tensor< FloatType > make
final ImgFactory< FloatType > factory = Util.getArrayOrCellImgFactory( input.getData(), new FloatType() );
final Img< FloatType > outputImg = factory.create( input.getData() );
RealTypeConverters.copyFromTo(input.getData(), outputImg);
final Tensor< FloatType > output = Tensor.build( getName() + '_' + input.getName(), input.getAxesOrderString(), outputImg );
// TODO what name final Tensor< FloatType > output = Tensor.build( getName() + '_' + input.getName(), input.getAxesOrderString(), outputImg );
final Tensor< FloatType > output = Tensor.build( input.getName(), input.getAxesOrderString(), outputImg );
return output;
}
}

0 comments on commit 2c04835

Please sign in to comment.