From 81af7d320083fe159093678cf209770810d78431 Mon Sep 17 00:00:00 2001 From: carlosuc3m <100329787@alumnos.uc3m.es> Date: Wed, 25 Sep 2024 18:58:06 +0200 Subject: [PATCH] keep improving the native support for transformations --- .../model/processing/Processing.java | 236 ++---------------- .../processing/TransformationInstance.java | 150 +++++++++++ 2 files changed, 177 insertions(+), 209 deletions(-) create mode 100644 src/main/java/io/bioimage/modelrunner/model/processing/TransformationInstance.java diff --git a/src/main/java/io/bioimage/modelrunner/model/processing/Processing.java b/src/main/java/io/bioimage/modelrunner/model/processing/Processing.java index 077494de..5cff57b8 100644 --- a/src/main/java/io/bioimage/modelrunner/model/processing/Processing.java +++ b/src/main/java/io/bioimage/modelrunner/model/processing/Processing.java @@ -9,6 +9,7 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor; import io.bioimage.modelrunner.bioimageio.description.TensorSpec; @@ -38,21 +39,14 @@ public class Processing { * It has to contain the tensor of interest. */ private LinkedHashMap inputsMap; - /** - * List containing the names of the processings that need to be applied - * to the tensor image - */ - private List processing; + private Map> preMap; + private Map> postMap; // TODO when adding python //private static BioImageIoPython interp; private static String BIOIMAGEIO_PYTHON_TRANSFORMATIONS_WEB = "https://github.com/bioimage-io/core-bioimage-io-python/blob/b0cea" + "c8fa5b412b1ea811c442697de2150fa1b90/bioimageio/core/prediction_pipeline" + "/_processing.py#L105"; - /** - * Package where the BioImage.io transformations are. - */ - private final static String TRANSFORMATIONS_PACKAGE = BinarizeTransformation.class.getPackage().getName(); /** * The object that is going to execute processing on the given image @@ -63,225 +57,49 @@ public class Processing { */ private Processing(ModelDescriptor descriptor) { this.descriptor = descriptor; + buildPreprocessing(); + buildPostprocessing(); } private void buildPreprocessing() throws ClassNotFoundException { - Map>> preMap = new HashMap>>(); + preMap = new HashMap>(); for (TensorSpec tt : this.descriptor.getInputTensors()) { List preprocessing = tt.getPreprocessing(); - List> list = new ArrayList>(); + List list = new ArrayList(); for (TransformSpec transformation : preprocessing) { - Map map = new HashMap(); - String clsName = findMethodInBioImageIo(transformation.getName()); + list.add(TransformationInstance.create(transformation)); } } } - public static Processing init(ModelDescriptor descriptor) { - return new Processing(descriptor); - } - - public & NativeType, R extends RealType & NativeType> - List> process(List> tensorList){ - return process(tensorList, false); - } - - public & NativeType, R extends RealType & NativeType> - List> process(List> tensorList, boolean inplace) { - - return null; - } - - /** - * Execute processing defined with Java. - * @param transformation - * the name of the class that is going to be executed - * @param kwargs - * the args of the transformation to be executed - * @throws ClassNotFoundException if the Java processing class is not found in the loaded classes - * @throws IOException if there is any issue initializing the Python interpreter - * @throws InterruptedException - * @throws IllegalArgumentException - */ - public < T extends RealType< T > & NativeType< T > > void executeJavaProcessing(String transformation, Map kwargs) throws ClassNotFoundException, IOException, IllegalArgumentException, InterruptedException { - try { - JavaProcessing preproc = JavaProcessing.definePreprocessing(transformation, kwargs); - inputsMap = preproc.execute(tensorSpec, inputsMap); - return; - } catch (ClassNotFoundException ex) { - // TODO when adding python - //System.out.println("Executing processing transformation '" + transformation + "' with Python."); - throw new IOException("Error running processing transformation: " + transformation); - } - // If class not found, execute in Python - // TODO when adding python - //executePythonProcessing(transformation, kwargs); - } - - /** - * Method used to convert Strings in using snake case (snake_case) into camel - * case with the first letter as upper case (CamelCase) - * @param str - * the String to be converted - * @return String converted into camel case with first upper - */ - public static String snakeCaseToCamelCaseFirstCap(String str) { - while(str.contains("_")) { - str = str.replaceFirst("_[a-z]", String.valueOf(Character.toUpperCase(str.charAt(str.indexOf("_") + 1)))); - } - str = str.substring(0, 1).toUpperCase() + str.substring(1); - return str; - } - - /** - * Tries to find a given class in the classpath - * @throws ClassNotFoundException if the class does not exist in the classpath - */ - private void findClassInClassPath(String clsName) throws ClassNotFoundException { - Class.forName(clsName, false, JavaProcessing.class.getClassLoader()); - } - - /** - * Find of the transformation exists in the BioImage.io Java Core - * @throws ClassNotFoundException if the BioImage.io transformation does not exist - */ - private String findMethodInBioImageIo(String methodName) throws ClassNotFoundException { - String javaMethodName = snakeCaseToCamelCaseFirstCap(methodName) + "Transformation"; - String clsName = TRANSFORMATIONS_PACKAGE + "." + javaMethodName; - findClassInClassPath(clsName); - return clsName; - } - private LinkedHashMap runJavaTransformationWithArgs(String clsName, Map args) throws InstantiationException, IllegalAccessException, IllegalArgumentException, InvocationTargetException, NoSuchMethodException, SecurityException, ClassNotFoundException { - Class transformationClass = getClass().getClassLoader().loadClass(clsName); - Object transformationObject = transformationClass.getConstructor().newInstance(); - - for (String arg : args.keySet()) { - setArg(transformationObject, arg); - } - Method[] publicMethods = transformationClass.getMethods(); - Method transformationMethod = null; - for (Method mm : publicMethods) { - if (mm.getName().equals(this.javaMethodName)) { - transformationMethod = mm; - break; + private void buildPostprocessing() throws ClassNotFoundException { + postMap = new HashMap>(); + for (TensorSpec tt : this.descriptor.getInputTensors()) { + List preprocessing = tt.getPreprocessing(); + List list = new ArrayList(); + for (TransformSpec transformation : preprocessing) { + list.add(TransformationInstance.create(transformation)); } } - if (transformationMethod == null) - throw new IllegalArgumentException("The pre-processing transformation class does not contain" - + "the method '" + this.javaMethodName + "' needed to call the transformation."); - // Check that the arguments specified in the rdf.yaml are of the corect type - return null; } - /** - * Set the argument in the processing trasnformation instance - * @param instance - * instance of the processing trasnformation - * @param argName - * name of the argument - * @throws IllegalArgumentException if no method is found for the given argument - * @throws InvocationTargetExceptionif there is any error invoking the method - * @throws IllegalAccessException if it is illegal to access the method - */ - public void setArg(Object instance, String argName) throws IllegalArgumentException, IllegalAccessException, InvocationTargetException { - String mName = getArgumentSetterName(argName); - Method mm = checkArgType(argName, mName); - mm.invoke(instance, args.get(argName)); + public static Processing init(ModelDescriptor descriptor) { + return new Processing(descriptor); } - /** - * Get the setter that the Java transformation class uses to set the argument of the - * pre-processing. The setter has to be named as the argument but in CamelCase with the - * first letter in upper case and preceded by set. For example: min_distance -> setMinDistance - * @param argName - * the name of the argument - * @return the method name - * @throws IllegalArgumentException if no method is found for the given argument - */ - public String getArgumentSetterName(String argName) throws IllegalArgumentException { - String mName = "set" + snakeCaseToCamelCaseFirstCap(argName); - // Check that the method exists - Method[] methods = transformationClass.getMethods(); - for (Method mm : methods) { - if (mm.getName().equals(mName)) - return mName; - } - throw new IllegalArgumentException("Setter for argument '" + argName + "' of the processing " - + "transformation '" + rdfSpec + "' of tensor '" + tensorName - + "' not found in the Java transformation class '" + this.javaClassName + "'. " - + "A method called '" + mName + "' should be present."); + public & NativeType, R extends RealType & NativeType> + List> preprocess(List> tensorList){ + return preprocess(tensorList, false); } - /** - * Method that checks that the type of the arguments provided in the rdf.yaml is correct. - * It also returns the setter method to set the argument - * - * @param mm - * the method that executes the pre-processing transformation - * @return the method used to provide the argument to the instance - * @throws IllegalArgumentException if any of the arguments' type is not correct - */ - private Method checkArgType(String argName, String mName) throws IllegalArgumentException { - Object arg = this.args.get(argName); - Method[] methods = this.transformationClass.getMethods(); - List possibleMethods = new ArrayList(); - for (Method mm : methods) { - if (mm.getName().equals(mName)) - possibleMethods.add(mm); - } - if (possibleMethods.size() == 0) - getArgumentSetterName(argName); - for (Method mm : possibleMethods) { - Parameter[] pps = mm.getParameters(); - if (pps.length != 1) { + public & NativeType, R extends RealType & NativeType> + List> preprocess(List> tensorList, boolean inplace) { + for (Entry> ee : this.preMap.entrySet()) { + Tensor tt = tensorList.stream().filter(t -> t.getName().equals(ee.getKey())).findFirst().orElse(null); + if (tt == null) continue; - } - if (pps[0].getType() == Object.class) - return mm; + ee.getValue().forEach(trans -> trans.run(tt)); } - throw new IllegalArgumentException("Setter '" + mName + "' should have only one input parameter with type Object.class."); - } -} - - // TODO when adding python - /** - public < T extends RealType< T > & NativeType< T > > void executePythonProcessing(String transformation, Map kwargs) throws IOException, IllegalArgumentException, InterruptedException { - Objects.requireNonNull(transformation, "The Python transformation needs to be a 'bioimageio.core' transformatio at " + BIOIMAGEIO_PYTHON_TRANSFORMATIONS_WEB); - Objects.requireNonNull(kwargs); - PythonUtils pUtils = getPythonConfiguration(transformation); - try (BioImageIoPython python = BioImageIoPython.activate(JepUtils.createNewPythonInstance(pUtils))){ - Tensor javaTensor; - if (inputsMap.get(tensorSpec.getName()) instanceof Tensor) { - javaTensor = (Tensor) inputsMap.get(tensorSpec.getName()); - } else if (inputsMap.get(tensorSpec.getName()) instanceof Sequence) { - Sequence seq = (Sequence) inputsMap.get(tensorSpec.getName()); - javaTensor = - Tensor.build(tensorSpec.getName(), tensorSpec.getAxesOrder(), - (RandomAccessibleInterval) SequenceToImgLib2.build(seq, tensorSpec.getAxesOrder())); - } else { - throw new IllegalArgumentException("Every BioImage.io core transformation requires a Tensor, or at least a " - + "Sequence as input."); - } - HashMap pythonKwargs = new HashMap(); - pythonKwargs.put("tensor_name", tensorSpec.getName()); - pythonKwargs.putAll(kwargs); - Map trans = new HashMap(); - trans.put(TransformSpec.getTransformationNameKey(), transformation); - trans.put(TransformSpec.getKwargsKey(), pythonKwargs); - inputsMap.put(tensorSpec.getName(), - python.applyTransformationToTensorInPython(trans, javaTensor)); - } - } - - private static PythonUtils getPythonConfiguration(String transformation) throws IOException { - PythonUtils pythonUtils = JepUtils.getPythonJepConfiguration(); - if (pythonUtils == null) { - JepUtils.openPythonConfigurationIfPythonNotInstalled(); - throw new IOException("Transformation '" + transformation.toUpperCase() + "' seems to be only " - + "avaialble in Python. And Python is not configured in your Icy installation. In order " - + "to use it please configure Python using the Jep Plugin."); - } - return pythonUtils; + return null; } - */ } diff --git a/src/main/java/io/bioimage/modelrunner/model/processing/TransformationInstance.java b/src/main/java/io/bioimage/modelrunner/model/processing/TransformationInstance.java new file mode 100644 index 00000000..a3fe84de --- /dev/null +++ b/src/main/java/io/bioimage/modelrunner/model/processing/TransformationInstance.java @@ -0,0 +1,150 @@ +package io.bioimage.modelrunner.model.processing; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Parameter; +import java.util.List; +import java.util.Map; + +import io.bioimage.modelrunner.bioimageio.description.TransformSpec; +import io.bioimage.modelrunner.tensor.Tensor; +import io.bioimage.modelrunner.transformations.BinarizeTransformation; +import net.imglib2.type.NativeType; +import net.imglib2.type.numeric.RealType; + +public class TransformationInstance { + private final String name; + private final Map args; + private Class cls; + private Object instance; + /** + * Package where the BioImage.io transformations are. + */ + private final static String TRANSFORMATIONS_PACKAGE = BinarizeTransformation.class.getPackage().getName(); + + private final static String RUN_NAME = "apply"; + + protected TransformationInstance(TransformSpec transform) { + this.name = transform.getName(); + this.args = transform.getKwargs(); + this.build(); + } + + public static TransformationInstance create(TransformSpec transform) { + return new TransformationInstance(transform); + } + + public & NativeType, R extends RealType & NativeType> + List> run(Tensor tensor){ + return run(tensor, false); + } + + public & NativeType, R extends RealType & NativeType> + List> run(Tensor tensor, boolean inplace) { + Method m = cls.getMethod(RUN_NAME, List.class); + m.invoke(this.instance, tensor); + return null; + } + + private void build() { + getTransformationClass(); + createInstanceWithArgs(); + } + + /** + * Find of the transformation exists in the BioImage.io Java Core + * @throws ClassNotFoundException if the BioImage.io transformation does not exist + */ + private void getTransformationClass() throws ClassNotFoundException { + String javaMethodName = snakeCaseToCamelCaseFirstCap(this.name) + "Transformation"; + String clsName = TRANSFORMATIONS_PACKAGE + "." + javaMethodName; + findClassInClassPath(clsName); + this.cls = getClass().getClassLoader().loadClass(clsName); + } + + /** + * Tries to find a given class in the classpath + * @throws ClassNotFoundException if the class does not exist in the classpath + */ + private void findClassInClassPath(String clsName) throws ClassNotFoundException { + Class.forName(clsName, false, JavaProcessing.class.getClassLoader()); + } + + /** + * Method used to convert Strings in using snake case (snake_case) into camel + * case with the first letter as upper case (CamelCase) + * @param str + * the String to be converted + * @return String converted into camel case with first upper + */ + public static String snakeCaseToCamelCaseFirstCap(String str) { + while(str.contains("_")) { + str = str.replaceFirst("_[a-z]", String.valueOf(Character.toUpperCase(str.charAt(str.indexOf("_") + 1)))); + } + str = str.substring(0, 1).toUpperCase() + str.substring(1); + return str; + } + + private void createInstanceWithArgs() throws InstantiationException, IllegalAccessException, IllegalArgumentException, InvocationTargetException, NoSuchMethodException, SecurityException, ClassNotFoundException { + this.instance = this.cls.getConstructor().newInstance(); + + for (String kk : args.keySet()) { + setArg(kk); + } + } + + /** + * Set the argument in the processing trasnformation instance + * @param instance + * instance of the processing trasnformation + * @param argName + * name of the argument + * @throws IllegalArgumentException if no method is found for the given argument + * @throws InvocationTargetExceptionif there is any error invoking the method + * @throws IllegalAccessException if it is illegal to access the method + */ + public void setArg(String argName) throws IllegalArgumentException, IllegalAccessException, InvocationTargetException { + Method mm = getMethodForArgument(argName); + checkArgType(mm); + mm.invoke(instance, this.args.get(argName)); + } + + /** + * Get the setter that the Java transformation class uses to set the argument of the + * pre-processing. The setter has to be named as the argument but in CamelCase with the + * first letter in upper case and preceded by set. For example: min_distance -> setMinDistance + * @param argName + * the name of the argument + * @return the method name + * @throws IllegalArgumentException if no method is found for the given argument + */ + public Method getMethodForArgument(String argName) throws IllegalArgumentException { + String mName = "set" + snakeCaseToCamelCaseFirstCap(argName); + // Check that the method exists + Method[] methods = this.cls.getMethods(); + for (Method mm : methods) { + if (mm.getName().equals(mName)) + return mm; + } + throw new IllegalArgumentException("Setter for argument '" + argName + "' of the processing " + + "transformation '" + name + "' not found in the Java transformation class '" + this.cls + "'. " + + "A method called '" + mName + "' should be present."); + } + + /** + * Method that checks that the type of the arguments provided in the rdf.yaml is correct. + * It also returns the setter method to set the argument + * + * @param mm + * the method that executes the pre-processing transformation + * @return the method used to provide the argument to the instance + * @throws IllegalArgumentException if any of the arguments' type is not correct + */ + private void checkArgType(Method mm) throws IllegalArgumentException { + Parameter[] pps = mm.getParameters(); + if (pps.length == 1 && pps[0].getType() == Object.class) + return; + throw new IllegalArgumentException("Setter '" + mm.getName() + "' should have only one input parameter with type Object.class."); + } + +}