diff --git a/README.md b/README.md new file mode 100644 index 0000000..7b52553 --- /dev/null +++ b/README.md @@ -0,0 +1,88 @@ +This is an example Maven project implementing an ImageJ 1.x plugin. + +It is intended as an ideal starting point to develop new ImageJ 1.x plugins +in an IDE of your choice. You can even collaborate with developers using a +different IDE than you. + +* In [Eclipse](http://eclipse.org), for example, it is as simple as + _File>Import...>Existing Maven Project_. + +* In [NetBeans](http://netbeans.org), it is even simpler: + _File>Open Project_. + +* The same works in [IntelliJ](http://jetbrains.net). + +* If [jEdit](http://jedit.org) is your preferred IDE, you will need the + [Maven Plugin](http://plugins.jedit.org/plugins/?MavenPlugin). + +Die-hard command-line developers can use Maven directly by calling `mvn` +in the project root. + +However you build the project, in the end you will have the `.jar` file +(called *artifact* in Maven speak) in the `target/` subdirectory. + +To copy the artifact into the correct place, you can call +`mvn -Dimagej.app.directory=/path/to/ImageJ.app/`. +This will not only copy your artifact, but also all the dependencies. Restart +your ImageJ or call *Help>Refresh Menus* to see your plugin in the menus. + +Developing plugins in an IDE is convenient, especially for debugging. To +that end, the plugin contains a `main` method which sets the `plugins.dir` +system property (so that the plugin is added to the Plugins menu), starts +ImageJ, loads an image and runs the plugin. See also +[this page](https://imagej.net/Debugging#Debugging_plugins_in_an_IDE_.28Netbeans.2C_IntelliJ.2C_Eclipse.2C_etc.29) +for information how ImageJ makes it easier to debug in IDEs. + +Since this project is intended as a starting point for your own +developments, it is in the public domain. + +How to use this project as a starting point +=========================================== + +Either + +* `git clone git://github.com/imagej/example-legacy-plugin`, or +* unpack https://github.com/imagej/example-legacy-plugin/archive/master.zip + +Then: + +1. Edit the `pom.xml` file. Every entry should be pretty self-explanatory. + In particular, change + 1. the *artifactId* (**NOTE**: should contain a '_' character) + 2. the *groupId*, ideally to a reverse domain name your organization owns + 3. the *version* (note that you typically want to use a version number + ending in *-SNAPSHOT* to mark it as a work in progress rather than a + final version) + 4. the *dependencies* (read how to specify the correct + *groupId/artifactId/version* triplet + [here](https://imagej.net/Maven#How_to_find_a_dependency.27s_groupId.2FartifactId.2Fversion_.28GAV.29.3F)) + 5. the *developer* information + 6. the *scm* information +2. Remove the `Process_Pixels.java` file and add your own `.java` files + to `src/main/java//` (if you need supporting files -- like icons + -- in the resulting `.jar` file, put them into `src/main/resources/`) +3. Edit `src/main/resources/plugins.config` +4. Replace the contents of `README.md` with information about your project. + +If you cloned the `example-legacy-plugin` repository, you probably want to +publish the result in your own repository: + +1. Call `git status` to verify .gitignore lists all the files (or file + patterns) that should be ignored +2. Call `git add .` and `git add -u` to stage the current files for + commit +3. Call `git commit` or `git gui` to commit the changes +4. [Create a new GitHub repository](https://github.com/new) +5. `git remote set-url origin git@github.com:/` +6. `git push origin HEAD` + +### Eclipse: To ensure that Maven copies the plugin to your ImageJ folder + +1. Go to _Run Configurations..._ +2. Choose _Maven Build_ +3. Add the following parameter: + - name: `imagej.app.directory` + - value: `/path/to/ImageJ.app/` + +This ensures that the final `.jar` file will also be copied to your ImageJ +plugins folder everytime you run the Maven Build diff --git a/compile.sh b/compile.sh new file mode 100644 index 0000000..276f251 --- /dev/null +++ b/compile.sh @@ -0,0 +1 @@ +mvn -Dimagej.app.directory=/Applications/Fiji.app/ diff --git a/config.json b/config.json new file mode 100644 index 0000000..8a28308 --- /dev/null +++ b/config.json @@ -0,0 +1,54 @@ +{ + "model_name": "ANNA-PALM_npc_tubulin_946b", + "label":"ANNA-PALM(v1)", + "url":"https://s3.eu-west-2.amazonaws.com/anna-palm-model/anet_npc_tubulin_946b_tensorflow_model.pb", + "inputs": [{ + "name": "input", + "key": "input", + "type": "image", + "channels": ["SR", "LR"], + "size": 512, + "shape": [1, 512, 512, 2], + "default": 0.0, + "required": true + }, + { + "name": "mode", + "key": "control", + "type": "choice", + "options": { + "tubulin": 0.0, + "nuclear_pore": 1.0, + "actin": 2.0 + }, + "shape": [1, 1, 1, 1], + "default": 0.0, + "required": true + }, + { + "name": "channel mask", + "key": "channel_mask", + "type": "check_list", + "length": 2, + "shape": [1, 1, 1, 2], + "default": 1.0, + "required": false + }, + { + "name": "dropout probability", + "key": "dropout_prob", + "type": "float", + "shape": [], + "default": 0.0, + "required": false + } + ], + "outputs": [{ + "name": "output", + "key": "output", + "type": "image", + "channels": ["SR"], + "size": 512, + "shape": [1, 512, 512, 1] + }] +} diff --git a/make_package.sh b/make_package.sh new file mode 100644 index 0000000..1a9db02 --- /dev/null +++ b/make_package.sh @@ -0,0 +1,2 @@ +mvn clean compile assembly:single +cp ./target/ANNA_PALM_Process-0.2.0-SNAPSHOT-jar-with-dependencies.jar /Applications/Fiji.app/plugins/ diff --git a/pom.xml b/pom.xml new file mode 100644 index 0000000..46f95d7 --- /dev/null +++ b/pom.xml @@ -0,0 +1,120 @@ + + + 4.0.0 + + + + maven-assembly-plugin + + + + org.imod.anet.AnetPlugin + + + + jar-with-dependencies + + + + + + + org.scijava + pom-scijava + 14.0.0 + + + + org.imod + ANNA_PALM_Process + 0.2.0-SNAPSHOT + + ANNA-PALM Process + ANNA-PALM imagej client. + https://github.com/impdpasteur/ANNA-PALM + 2012 + + Imod Pasteur + http://www.pasteur.fr/ + + + + Pasteur License + repo + + + + + Wei OUYANG + http://imagej.net/User:oeway + + lead + developer + debugger + reviewer + support + maintainer + + + + + + Wei OUYANG + http://imagej.net/User:oeway + + + + + + ImageJ Forum + http://forum.imagej.net/ + + + + scm:git:git://github.com/impdpasteur/ANNA-PALM + scm:git:git@github.com:impdpasteur/ANNA-PALM + HEAD + https://github.com/impdpasteur/ANNA-PALM + + + GitHub Issues + https://github.com/impdpasteur/ANNA-PALM + + + None + + + + org.imod.anet + org.imod.anet.ANNA_PALM_Process + Pasteur license + Institut Pasteur + + + + + net.imagej + ij + + + org.tensorflow + tensorflow + 1.4.0 + + + + com.googlecode.json-simple + json-simple + 1.1.1 + + + + io.crossbar.autobahn + autobahn-java + 17.10.5 + + + diff --git a/src/main/java/org/imod/anet/AnetPlugin.java b/src/main/java/org/imod/anet/AnetPlugin.java new file mode 100644 index 0000000..deb1ff8 --- /dev/null +++ b/src/main/java/org/imod/anet/AnetPlugin.java @@ -0,0 +1,695 @@ +/* + * To the extent possible under law, the ImageJ developers have waived + * all copyright and related or neighboring rights to this tutorial code. + * + * See the CC0 1.0 Universal license for details: + * http://creativecommons.org/publicdomain/zero/1.0/ + */ + +package org.imod.anet; + +import ij.IJ; +import ij.Prefs; +import ij.ImageJ; +import ij.ImagePlus; +import ij.gui.GenericDialog; +import ij.plugin.PlugIn; +import ij.process.ImageProcessor; +import ij.process.FloatProcessor; +import ij.WindowManager; +import org.imod.anet.AnetPredict; + +import java.awt.Button; +import java.awt.event.ActionListener; +import java.awt.event.ActionEvent; + +import java.io.File; +import java.io.FileWriter; +import java.io.FileOutputStream; +import java.io.IOException; +import java.net.URL; + +import java.nio.channels.Channels; +import java.nio.channels.ReadableByteChannel; +import java.nio.file.Paths; + +import java.util.Arrays; +import java.util.List; +import java.util.ArrayList; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.Map; +import java.util.HashMap; + +import io.crossbar.autobahn.wamp.Client; +import io.crossbar.autobahn.wamp.Session; +import io.crossbar.autobahn.wamp.types.CallResult; +import io.crossbar.autobahn.wamp.types.CloseDetails; +import io.crossbar.autobahn.wamp.types.ExitInfo; +import io.crossbar.autobahn.wamp.types.InvocationDetails; +import io.crossbar.autobahn.wamp.types.Publication; +import io.crossbar.autobahn.wamp.types.PublishOptions; +import io.crossbar.autobahn.wamp.types.Registration; +import io.crossbar.autobahn.wamp.types.SessionDetails; +import io.crossbar.autobahn.wamp.types.Subscription; + +import org.json.simple.JSONArray; +import org.json.simple.JSONObject; +import org.json.simple.parser.JSONParser; +import java.util.Iterator; +import java.io.FileReader; +/** + * A template for processing each pixel of either + * GRAY8, GRAY16, GRAY32 or COLOR_RGB images. + * + * @author Johannes Schindelin + */ +public class AnetPlugin implements PlugIn { + protected ImagePlus image; + + // image property members + private int width; + private int height; + + // plugin parameters + public int input_size; + public String mode; + private AnetPredict ap; + private Session session; + public String model_path; + public String model_name; + Map inputs; + Map outputs; + + String[] image_window_titles; + Map image_window_map; + + @Override + public void run(String arg) { + // get width and height + if (arg.equals("about")) { + showAbout(); + return; + } + ap = new AnetPredict(); + if(arg.equals("setup")){ + showMainDialog(false); + } + else if(arg.equals("download")){ + session = connect("wss://dai.pasteur.fr/ws", "realm1"); + } + else if(arg == null || arg.equals("run")){ + String _model = Prefs.get("ANET.default_model", "none"); + if(_model.equals("none")){ + showMainDialog(true); + } + else{ + showModelDialog(_model); + } + } + else{ + showModelDialog(arg); + } + } + + private boolean showMainDialog(boolean start) { + GenericDialog gd = new GenericDialog("A-Net Process"); + // Create download button + Button btDownload = new Button("download model"); + btDownload.addActionListener(new ActionListener() { + public void actionPerformed(ActionEvent e) + { + session = connect("wss://dai.pasteur.fr/ws", "realm1"); + } + }); + // Add and show button + gd.add(btDownload); + + List modelList = new ArrayList(); + File folder = new File("models"); + File[] listOfFiles = folder.listFiles(); + if(listOfFiles == null || listOfFiles.length == 0){ + modelList.add("no model available."); + } + else{ + for (int i = 0; i < listOfFiles.length; i++) { + if (listOfFiles[i].isFile()) { + + } else if (listOfFiles[i].isDirectory()) { + modelList.add(listOfFiles[i].getName()); + } + } + } + + String[] items = (String[]) modelList.toArray(new String[modelList.size()]); + gd.addRadioButtonGroup("models", items, items.length, 1, items[0]); + gd.showDialog(); + if (gd.wasCanceled()) + return false; + // get entered values + model_name = gd.getNextRadioButton(); + if(!model_name.equals("no model available.")){ + Prefs.set("ANET.default_model", model_name); + IJ.log("selected model:" + model_name); + } + if(start) + return showModelDialog(model_name); + else + return true; + } + + private boolean showModelDialog(String model_name){ + model_path = Paths.get(IJ.getDirectory("imagej"), "models/" + model_name).toString(); + File fm = new File(model_path+"/tensorflow_model.pb"); + File fc = new File(model_path+"/config.json"); + if(fm.exists() && !fm.isDirectory() && fc.exists() && !fc.isDirectory()) { + IJ.log("loading model from "+model_path); + ap.loadModel(model_path); + } + else{ + IJ.showMessage("A-Net", + "invalid model file." + ); + } + + GenericDialog gd = new GenericDialog(String.format("A-Net (%s)", model_name)); + // gd.hideCancelButton(); + inputs = new HashMap(); + outputs = new HashMap(); + model_path = Paths.get(IJ.getDirectory("imagej"), "models/" + model_name).toString(); + // build image window map + int[] list = WindowManager.getIDList(); + if(list != null){ + image_window_titles = new String[list.length+1]; + image_window_map = new HashMap(); + for (int i=0; i(); + } + image_window_titles[image_window_titles.length-1] = ""; + + if(parseConfig("build_gui", inputs, outputs, gd)){ + gd.showDialog(); + if (gd.wasCanceled()) + return false; + else{ + if(parseConfig("fetch_inputs", inputs, outputs, gd)) + process(); + } + return true; + } + else + return false; + } + + private boolean showDialog2() { + GenericDialog gd = new GenericDialog("A-Net Process"); + + List modelList = new ArrayList(); + File folder = new File("models"); + File[] listOfFiles = folder.listFiles(); + for (int i = 0; i < listOfFiles.length; i++) { + if (listOfFiles[i].isFile()) { + + } else if (listOfFiles[i].isDirectory()) { + IJ.log(listOfFiles[i].getName()); + modelList.add(listOfFiles[i].getName()); + } + } + if(modelList.size()<=0){ + modelList.add("no model available."); + } + String[] items = (String[]) modelList.toArray(new String[modelList.size()]); + gd.addRadioButtonGroup("models", items, items.length, 1, items[0]); + + String[] structure_types = {"tubulin", "nuclear_pore", "actin", "mitochondria","unknown"}; + // gd.addStringField("mode", "tubulin"); + gd.addChoice("mode:", structure_types, "tubulin"); + + // default value is 0.00, 2 digits right of the decimal point + gd.addNumericField("input_size", 512, 0); + + // Create custom button + Button bt = new Button("download model"); + bt.addActionListener(new ActionListener() { + public void actionPerformed(ActionEvent e) + { + session = connect("wss://dai.pasteur.fr/ws", "realm1"); + } + }); + // Add and show button + gd.add(bt); + + gd.showDialog(); + if (gd.wasCanceled()) + return false; + + if (modelList.size()<=0){ + // check if there is any new file + for (int i = 0; i < listOfFiles.length; i++) { + if (listOfFiles[i].isFile()) { + } else if (listOfFiles[i].isDirectory()) { + modelList.add(listOfFiles[i].getName()); + } + } + if(modelList.size()>0){ + IJ.showMessage("A-Net", + "Please try again." + ); + } + return false; + } + // get entered values + model_name = gd.getNextRadioButton(); + mode = gd.getNextChoice();//gd.getNextString(); + input_size = (int) gd.getNextNumber(); + IJ.log("using model " + model_name); + IJ.log("using mode " + mode); + model_path = Paths.get(IJ.getDirectory("imagej"), "models/" + model_name).toString(); + File f = new File(model_path+"/tensorflow_model.pb"); + if(f.exists() && !f.isDirectory()) { + IJ.log(model_path); + // modelPath = "/Users/weiouyang/workspace/tensorflow-java/A-NET-npc-tubulin-946b/"; + ap.loadModel(model_path); + return true; // + } + else{ + IJ.showMessage("A-Net", + "invalid model file." + ); + return false; + } + + } + + private boolean showDownloadDialog(List models) { + GenericDialog gd = new GenericDialog("A-net model download"); + + + for (String model_name : models) { + gd.addCheckbox(model_name, false); + } + gd.showDialog(); + if (gd.wasCanceled()) + return false; + + String selected_models_str = ""; + for (String model_name : models) { + if(gd.getNextBoolean()) + selected_models_str = model_name + "," + selected_models_str; + } + List selected_model_names = Arrays.asList(selected_models_str.split(",")); + for(String model_name: selected_model_names){ + IJ.log("download " + model_name); + // here we CALL every second + CompletableFuture fGet = session.call("org.imod.public.models.get_download_url", model_name); + fGet.whenComplete((callResultGet, throwableGet) -> { + if (throwableGet == null) { + IJ.log(String.format("model url: %s, ", callResultGet.results.get(0))); + Map model_dict = (Map) callResultGet.results.get(0); + String model_url = (String) model_dict.get("url"); + + + IJ.log("dowloading model from: " + model_url); + String fileRoot = Paths.get(IJ.getDirectory("imagej"), "models", model_name).toString(); + try { + // returns pathnames for files and directory + File folder = new File(fileRoot); + // create + folder.mkdirs(); + + } catch(Exception e3) { + // if any error occurs + e3.printStackTrace(); + } + try { + String model_path_str = Paths.get(fileRoot, "tensorflow_model.pb").toString(); + // save config.json + JSONObject obj = new JSONObject(); + obj.putAll(model_dict); + try (FileWriter file = new FileWriter(Paths.get(fileRoot, "config.json").toString())) { + file.write(obj.toJSONString()); + } + // save model file + saveFile(model_url, model_path_str); + + IJ.log("model and config has been saved to " + fileRoot); + } + catch (IOException e2) { + e2.printStackTrace(); + } + } else { + IJ.log(String.format("ERROR - call failed: %s", throwableGet.getMessage())); + } + }); + } + return true; + } + + public boolean parseConfig(String stage, Map inputs, Map outputs, GenericDialog gd) { + JSONParser jsonParser = new JSONParser(); + try{ + Object obj = jsonParser.parse(new FileReader(Paths.get(model_path, "config.json").toString())); + JSONObject config = (JSONObject) obj; + String model_label = (String) config.get("label"); + if(stage.equals("show_outputs")){ + // fetch outputs dictionary + JSONArray outputs_c = (JSONArray) config.get("outputs"); + Iterator iterator = outputs_c.iterator(); + while (iterator.hasNext()) { + JSONObject output = iterator.next(); + JSONArray shape_o = (JSONArray) output.get("shape"); + int s = shape_o.size(); + int[] shape = null; + if(s>0){ + shape = new int[s]; + for(int i=0;i iterator = inputs_c.iterator(); + while (iterator.hasNext()) { + JSONObject input = iterator.next(); + JSONArray shape_i = (JSONArray) input.get("shape"); + int s = shape_i.size(); + int[] shape = null; + if(s>0){ + shape = new int[s]; + for(int i=0;i value_map = new HashMap(); + int c = 0; + for(Object k: opts.keySet()){ + double v = (double) opts.get((String)k); + options[c] = (String)k; + value_map.put((String)k, (float)v); + c++; + } + + if(stage.equals("build_gui")) + gd.addChoice((String)input.get("name"), options, options[0]); + else if(stage.equals("fetch_inputs")){ + String chv = gd.getNextChoice(); + float v = value_map.get(chv); + if(shape.length == 4){ + float[][][][] _input_ = (float[][][][])inputs.get(key); + _input_[0][0][0][0] = v; + } + else if(shape.length == 4){ + float[][][] _input_ = (float[][][])inputs.get(key); + _input_[0][0][0] = v; + } + else if(shape.length == 4){ + float[][] _input_ = (float[][])inputs.get(key); + _input_[0][0] = v; + } + else if(shape.length == 4){ + float[] _input_ = (float[])inputs.get(key); + _input_[0] = v; + } + else{ + float _input_ = (float)inputs.get(key); + _input_ = v; + } + } + } + + + } + + // construct outputs dictionary + JSONArray outputs_c = (JSONArray) config.get("outputs"); + iterator = outputs_c.iterator(); + while (iterator.hasNext()) { + JSONObject output = iterator.next(); + JSONArray shape_o = (JSONArray) output.get("shape"); + int s = shape_o.size(); + int[] shape = null; + if(s>0){ + shape = new int[s]; + for(int i=0;i clazz = AnetPlugin.class; + String url = clazz.getResource("/" + clazz.getName().replace('.', '/') + ".class").toString(); + String pluginsDir = url.substring("file:".length(), url.length() - clazz.getName().length() - ".class".length()); + System.setProperty("plugins.dir", pluginsDir); + + // start ImageJ + new ImageJ(); + + // open the Clown sample + ImagePlus image = IJ.openImage("http://imagej.net/images/clown.jpg"); + image.show(); + + // run the plugin + IJ.runPlugIn(clazz.getName(), ""); + } + + public Session connect(String websocketURL, String realm) { + Session session = new Session(); + session.addOnConnectListener(this::onConnectCallback); + session.addOnJoinListener(this::onJoinCallback); + session.addOnLeaveListener(this::onLeaveCallback); + session.addOnDisconnectListener(this::onDisconnectCallback); + + // finally, provide everything to a Client instance and connect + Client client = new Client(session, websocketURL, realm); + client.connect(); + return session; + } + + private void onConnectCallback(Session session) { + IJ.log("Session connected, ID=" + session.getID()); + } + + private void onJoinCallback(Session session, SessionDetails details) { + IJ.log("Joined, ID=" + session.getID()); + CompletableFuture fList = session.call("org.imod.public.models.list", "*"); + fList.whenComplete((callResultList, throwableList) -> { + if (throwableList == null) { + String modelstr = (String) callResultList.results.get(0); + IJ.log("models: " + modelstr); + IJ.log(modelstr); + List model_names = Arrays.asList(modelstr.split(",")); + showDownloadDialog(model_names); + } + else{ + IJ.showMessage("A-Net error", + "can't get model list" + ); + } + }); + + + } + + private void onLeaveCallback(Session session, CloseDetails detail) { + IJ.log(String.format("Left reason=%s, message=%s", detail.reason, detail.message)); + } + + private void onDisconnectCallback(Session session, boolean wasClean) { + IJ.log(String.format("Session with ID=%s, disconnected.", session.getID())); + } + + public static void saveFile (String url, String file) throws IOException{ + URL website = new URL(url); + ReadableByteChannel rbc = Channels.newChannel(website.openStream()); + FileOutputStream fos = new FileOutputStream(file); + fos.getChannel().transferFrom(rbc, 0, Long.MAX_VALUE); + } + + +} diff --git a/src/main/java/org/imod/anet/AnetPredict.java b/src/main/java/org/imod/anet/AnetPredict.java new file mode 100644 index 0000000..d4b8c61 --- /dev/null +++ b/src/main/java/org/imod/anet/AnetPredict.java @@ -0,0 +1,265 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.imod.anet; + +import java.io.IOException; +import java.io.PrintStream; +import java.nio.charset.Charset; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Arrays; +import java.util.List; +import org.tensorflow.DataType; +import org.tensorflow.Graph; +import org.tensorflow.Output; +import org.tensorflow.Session; +import org.tensorflow.Session.Runner; +import org.tensorflow.Tensor; +import org.tensorflow.TensorFlow; +import org.tensorflow.types.UInt8; + +import java.util.Map; +import java.util.HashMap; + +import ij.IJ; +/** Sample use of the TensorFlow Java API to label images using a pre-trained model. */ +public class AnetPredict { + private Graph graph=null; + public boolean loadModel(String modelDir){ + byte[] graphDef = readAllBytesOrExit(Paths.get(modelDir, "tensorflow_model.pb")); + this.graph = new Graph(); + this.graph.importGraphDef(graphDef); + // this.session = new Session(g); + return true; + } + public float[] predict(float[] sr_input, float[] wf_input, int size, String type) { + float[][][][] _input_ = new float[1][size][size][2]; + float[][][][] _control_ = new float[1][1][1][1]; + float[][][][] _channel_mask_ = new float[1][1][1][2]; + + if(type.equals("tubulin")) _control_[0][0][0][0] = 0.0f; + else if(type.equals("nuclear_pore")) _control_[0][0][0][0] = 1.0f; + else if(type.equals("actin")) _control_[0][0][0][0] = 2.0f; + else if(type.equals("mitochondria")) _control_[0][0][0][0] = 3.0f; + else _control_[0][0][0][0] = -1.0f; + + _channel_mask_[0][0][0][0] = 1.0f; + _channel_mask_[0][0][0][1] = 1.0f; + + float _dropout_prob_ = 0.0f; + for(int i=0;i _dropout_prob = Tensor.create(dropout_prob, Float.class); + Tensor _input = Tensor.create(input, Float.class); + Tensor _control = Tensor.create(control, Float.class); + Tensor _channel_mask = Tensor.create(channel_mask, Float.class); + int size = input[0].length; + System.out.println(size); + float[][][][] output = executeGraph(_input, _channel_mask, _control, _dropout_prob, size); + return output; + } + + private float[][][][] executeGraph(Tensor input, Tensor channel_mask, Tensor control, Tensor dropout_prob, int size) { + + try ( + Session s = new Session(this.graph); + Tensor result = s.runner().feed("input", input).feed("control", control) + .feed("channel_mask", channel_mask).feed("dropout_prob", dropout_prob) + .fetch("output").run().get(0).expect(Float.class)) { + final long[] rshape = result.shape(); + if (result.numDimensions() != 4) { + throw new RuntimeException( + String.format( + "Expected model to produce a [1 H W C] shaped tensor, instead it produced one with shape %s", + Arrays.toString(rshape))); + } + return result.copyTo(new float[1][size][size][1]); + } + + } + + public void predict(Map inputs, Map outputs) { + try ( + Session s = new Session(this.graph); + ) { + Runner runner = s.runner(); + for(String key: inputs.keySet()){ + String [] tmp = key.split(":"); + String pkey = tmp[0]; + int tn = Integer.parseInt(tmp[1]); + Tensor t; + if(tn == 4) t = Tensor.create((float[][][][]) inputs.get(key), Float.class); + else if(tn == 3) t = Tensor.create((float[][][]) inputs.get(key), Float.class); + else if(tn == 2) t = Tensor.create((float[][]) inputs.get(key), Float.class); + else if(tn == 1) t = Tensor.create((float[]) inputs.get(key), Float.class); + else if(tn == 0) t = Tensor.create((float) inputs.get(key), Float.class); + else t = Tensor.create((float[][][][]) inputs.get(key), Float.class); + + runner.feed(pkey, t); + } + for(String key: outputs.keySet()){ + String [] tmp = key.split(":"); + String pkey = tmp[0]; + int tn = Integer.parseInt(tmp[1]); + Tensor result = runner.fetch(pkey).run().get(0).expect(Float.class); + final long[] rshape = result.shape(); + if (result.numDimensions() != tn) { + throw new RuntimeException( + String.format( + "Expected model to produce a tensor with %d dimension(s), instead it produced one with shape %s", + tn, Arrays.toString(rshape))); + } + result.copyTo(outputs.get(key)); + } + } + + } + + private static int maxIndex(float[] probabilities) { + int best = 0; + for (int i = 1; i < probabilities.length; ++i) { + if (probabilities[i] > probabilities[best]) { + best = i; + } + } + return best; + } + + private static byte[] readAllBytesOrExit(Path path) { + try { + return Files.readAllBytes(path); + } catch (IOException e) { + System.err.println("Failed to read [" + path + "]: " + e.getMessage()); + System.exit(1); + } + return null; + } + + private static List readAllLinesOrExit(Path path) { + try { + return Files.readAllLines(path, Charset.forName("UTF-8")); + } catch (IOException e) { + System.err.println("Failed to read [" + path + "]: " + e.getMessage()); + System.exit(0); + } + return null; + } + + // In the fullness of time, equivalents of the methods of this class should be auto-generated from + // the OpDefs linked into libtensorflow_jni.so. That would match what is done in other languages + // like Python, C++ and Go. + static class GraphBuilder { + GraphBuilder(Graph g) { + this.g = g; + } + + Output div(Output x, Output y) { + return binaryOp("Div", x, y); + } + + Output sub(Output x, Output y) { + return binaryOp("Sub", x, y); + } + + Output resizeBilinear(Output images, Output size) { + return binaryOp3("ResizeBilinear", images, size); + } + + Output expandDims(Output input, Output dim) { + return binaryOp3("ExpandDims", input, dim); + } + + Output cast(Output value, Class type) { + DataType dtype = DataType.fromClass(type); + return g.opBuilder("Cast", "Cast") + .addInput(value) + .setAttr("DstT", dtype) + .build() + .output(0); + } + + Output decodePng(Output contents, long channels) { + return g.opBuilder("DecodePng", "DecodePng") + .addInput(contents) + .setAttr("channels", channels) + .build() + .output(0); + } + + Output decodeJpeg(Output contents, long channels) { + return g.opBuilder("DecodeJpeg", "DecodeJpeg") + .addInput(contents) + .setAttr("channels", channels) + .build() + .output(0); + } + + Output constant(String name, Object value, Class type) { + try (Tensor t = Tensor.create(value, type)) { + return g.opBuilder("Const", name) + .setAttr("dtype", DataType.fromClass(type)) + .setAttr("value", t) + .build() + .output(0); + } + } + Output constant(String name, byte[] value) { + return this.constant(name, value, String.class); + } + + Output constant(String name, int value) { + return this.constant(name, value, Integer.class); + } + + Output constant(String name, int[] value) { + return this.constant(name, value, Integer.class); + } + + Output constant(String name, float value) { + return this.constant(name, value, Float.class); + } + + private Output binaryOp(String type, Output in1, Output in2) { + return g.opBuilder(type, type).addInput(in1).addInput(in2).build().output(0); + } + + private Output binaryOp3(String type, Output in1, Output in2) { + return g.opBuilder(type, type).addInput(in1).addInput(in2).build().output(0); + } + private Graph g; + } +} diff --git a/src/main/resources/plugins.config b/src/main/resources/plugins.config new file mode 100644 index 0000000..40049bc --- /dev/null +++ b/src/main/resources/plugins.config @@ -0,0 +1,11 @@ +# A single .jar file can contain multiple plugins, specified in separate lines. +# +# The format is: , "", +# +# If something like ("") is appended to the class name, the setup() method +# will get that as arg parameter; otherwise arg is simply the empty string. + +Process, "A-Net Process", org.imod.anet.AnetPlugin("start") +Plugins>A-Net, "Run A-net", org.imod.anet.AnetPlugin("run") +Plugins>A-Net, "Setup A-net", org.imod.anet.AnetPlugin("setup") +Plugins>A-Net, "About A-net", org.imod.anet.AnetPlugin("about")