Skip to content

Commit

Permalink
Add shape inference function of Upsample operator. (onnx#1320)
Browse files Browse the repository at this point in the history
* Add shape inference function of Upsample operator.

* Fix type conversion warning.

* Fix flake8 format errors.

* Add input number checking.

* * Add {} to body of if.

* Change to use propagateElemTypeFromInputToOutput().
  • Loading branch information
rockindy authored and bddppq committed Aug 29, 2018
1 parent 6f91908 commit 79027b3
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
20 changes: 19 additions & 1 deletion onnx/defs/tensor/defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1085,7 +1085,25 @@ ONNX_OPERATOR_SET_SCHEMA(
"T",
OpSchema::all_tensor_types(),
"Constrain input and output types to all tensor types.")
.SetDoc(Upsample_ver7_doc));
.SetDoc(Upsample_ver7_doc)
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
if (!hasNInputShapes(ctx, 1)) {
return;
}
propagateElemTypeFromInputToOutput(ctx, 0, 0);
auto& input_shape = getInputShape(ctx, 0);
auto* output_shape = getOutputShape(ctx, 0);
auto* scale_attr = ctx.getAttribute("scales");
if (input_shape.dim_size() != scale_attr->floats_size()) {
fail_shape_inference(
"Upsample: Input dims != attribute 'scales' dims");
}
for (int i=0; i<input_shape.dim_size(); ++i) {
float dim_value = static_cast<float>(input_shape.dim(i).dim_value());
output_shape->add_dim()->set_dim_value(
static_cast<int64_t>(std::floor(dim_value * scale_attr->floats(i))));
}
}));

ONNX_OPERATOR_SET_SCHEMA(
Identity,
Expand Down
7 changes: 7 additions & 0 deletions onnx/test/shape_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,13 @@ def test_reshape_static_shape_inferred(self): # type: () -> None
initializer=[make_tensor('shape', TensorProto.INT64, (3,), (0, 3, -1))])
self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.UINT8, (2, 3, 4))])

def test_upsample(self): # type: () -> None
graph = self._make_graph(
[('x', TensorProto.INT32, (2, 4, 3, 5))],
[make_node("Upsample", ['x'], ['y'], scales=[1.0, 1.1, 1.3, 1.9])],
[])
self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.INT32, (2, 4, 3, 9))])

def test_shape(self): # type: () -> None
graph = self._make_graph(
[('x', TensorProto.FLOAT, (2, 4, 3))],
Expand Down

0 comments on commit 79027b3

Please sign in to comment.