Skip to content

Commit

Permalink
support tensors with incorrect name
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Dec 16, 2023
1 parent 00509d1 commit 645bd99
Showing 1 changed file with 15 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -305,15 +305,17 @@ public void run(List<Tensor<?>> inputTensors, List<Tensor<?>> outputTensors)
Session.Runner runner = session.runner();
List<String> inputListNames = new ArrayList<String>();
List<TType> inTensors = new ArrayList<TType>();
int c = 0;
for (Tensor tt : inputTensors) {
inputListNames.add(tt.getName());
TType inT = TensorBuilder.build(tt);
inTensors.add(inT);
runner.feed(getModelInputName(tt.getName()), inT);
String inputName = getModelInputName(tt.getName(), c ++);
runner.feed(inputName, inT);
}

c = 0;
for (Tensor tt : outputTensors)
runner = runner.fetch(getModelOutputName(tt.getName()));
runner = runner.fetch(getModelOutputName(tt.getName(), c ++));
// Run runner
List<org.tensorflow.Tensor> resultPatchTensors = runner.run();

Expand Down Expand Up @@ -420,10 +422,14 @@ public void closeModel() {
* the signature input name.
*
* @param inputName Signature input name.
* @param i position of the input of interest in the list of inputs
* @return The readable input name.
*/
public static String getModelInputName(String inputName) {
public static String getModelInputName(String inputName, int i) {
TensorInfo inputInfo = sig.getInputsMap().getOrDefault(inputName, null);
if (inputInfo == null) {
inputInfo = sig.getInputsMap().values().stream().collect(Collectors.toList()).get(i);
}
if (inputInfo != null) {
String modelInputName = inputInfo.getName();
if (modelInputName != null) {
Expand All @@ -446,10 +452,14 @@ public static String getModelInputName(String inputName) {
* given the signature output name.
*
* @param outputName Signature output name.
* @param i position of the input of interest in the list of inputs
* @return The readable output name.
*/
public static String getModelOutputName(String outputName) {
public static String getModelOutputName(String outputName, int i) {
TensorInfo outputInfo = sig.getOutputsMap().getOrDefault(outputName, null);
if (outputInfo == null) {
outputInfo = sig.getOutputsMap().values().stream().collect(Collectors.toList()).get(i);
}
if (outputInfo != null) {
String modelOutputName = outputInfo.getName();
if (modelOutputName.endsWith(":0")) {
Expand Down

0 comments on commit 645bd99

Please sign in to comment.