Skip to content

Commit

Permalink
DRAFT CFE fix S64 paddings in Pad
Browse files Browse the repository at this point in the history
on-going draft to fix S64 paddings in Pad.

Signed-off-by: SaeHie Park <[email protected]>
  • Loading branch information
seanshpark committed Jul 29, 2024
1 parent 3a8110a commit e89f6bd
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
1 change: 1 addition & 0 deletions compiler/luci-interpreter/src/kernels/Pad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ void Pad::configure()
const int32_t padding_after = paddings_data[i * 2 + 1];
assert(padding_before >= 0 && padding_after >= 0);
output_shape.dim(i) = input_shape.dim(i) + padding_before + padding_after;
printf("!!! Pad %d %d %d\r\n", i, input_shape.dim(i), output_shape.dim(i));
}

output()->resize(output_shape);
Expand Down
21 changes: 18 additions & 3 deletions compiler/luci-pass-value-py-test/test_luci_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ def luci_eval_verify(test_name,
atolint = int(atolf32)

# Build TFLite interpreter.
interpreter = tf.lite.Interpreter(tflite_model)
interpreter = tf.lite.Interpreter(
tflite_model, experimental_preserve_all_tensors=True)
interpreter.allocate_tensors()

# Read SignatureDef and get output tensor id orders for remapping
Expand Down Expand Up @@ -87,16 +88,30 @@ def luci_eval_verify(test_name,
output_shape = [int(i) for i in shape_file.read().split(',')]
luci_output_data = np.reshape(output_data, output_shape)
output_tensor = output_details["index"]
print("!!! output_tensor 1", output_tensor)
if full_signatures_outputs_remap != None:
output_tensor = full_signatures_outputs_remap[idx]
print("!!! output_tensor 2", idx, output_tensor)
intp_output_data = interpreter.get_tensor(output_tensor)

print("!!! ", tflite_model, ":", output_tensor, intp_output_data.shape)
print("!!! ", circle_model, ":", output_shape)

err_msg = "Execution result of " + tflite_model + " does not match with " + circle_model
if output_details["dtype"] == np.uint8:
assert np.allclose(
luci_output_data, intp_output_data, rtol=rtolint, atol=atolint), err_msg
elif output_details["dtype"] == np.float32:
assert np.allclose(
luci_output_data, intp_output_data, rtol=rtolf32, atol=atolf32), err_msg
print("!!! float32")
print(intp_output_data.shape)
print(luci_output_data.shape)
res = np.allclose(
luci_output_data, intp_output_data, rtol=rtolf32, atol=atolf32)
if not res:
diff = np.isclose(
luci_output_data, intp_output_data, rtol=rtolf32, atol=atolf32)
print(diff)
assert res, err_msg
elif output_details["dtype"] == np.int64:
assert np.allclose(
luci_output_data, intp_output_data, rtol=rtolint, atol=atolint), err_msg
Expand Down

0 comments on commit e89f6bd

Please sign in to comment.