Skip to content

Commit

Permalink
Batchnorm Version Conversion Adapters (onnx#1288)
Browse files Browse the repository at this point in the history
* Squashing into 1 commit

* New BatchNormalization_7_6 adapter
  • Loading branch information
Ac2zoom authored and houseroad committed Aug 25, 2018
1 parent 46a00cc commit b15a99e
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 0 deletions.
38 changes: 38 additions & 0 deletions onnx/test/version_converter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,44 @@ def test_relu_7_5(self): # type: () -> None
assert converted_model.graph.node[0].op_type == "Relu"
assert converted_model.opset_import[0].version == 5

# Test BatchNormalization Adapter: 8 -> 5
def test_batch_normalization_8_5(self): # type: () -> None
nodes = [helper.make_node('BatchNormalization', ["X", "scale", "B",
"mean", "var"], ["Y"])]
graph = helper.make_graph(
nodes,
"test",
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (5,)),
helper.make_tensor_value_info("scale", TensorProto.FLOAT, (1,)),
helper.make_tensor_value_info("B", TensorProto.FLOAT, (1,)),
helper.make_tensor_value_info("mean", TensorProto.FLOAT, (1,)),
helper.make_tensor_value_info("var", TensorProto.FLOAT, (1,))],
[helper.make_tensor_value_info("Y", TensorProto.FLOAT, (5,))])
converted_model = self._converted(graph, helper.make_operatorsetid(
"", 8), 5)
# Assert equality of graph and converted_model
assert converted_model.graph.node[0].op_type == "BatchNormalization"
assert converted_model.opset_import[0].version == 5

# Test BatchNormalization Adapter: 5 -> 8
def test_batch_normalization_5_8(self): # type: () -> None
nodes = [helper.make_node('BatchNormalization', ["X", "scale", "B",
"mean", "var"], ["Y"])]
graph = helper.make_graph(
nodes,
"test",
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (5,)),
helper.make_tensor_value_info("scale", TensorProto.FLOAT, (1,)),
helper.make_tensor_value_info("B", TensorProto.FLOAT, (1,)),
helper.make_tensor_value_info("mean", TensorProto.FLOAT, (1,)),
helper.make_tensor_value_info("var", TensorProto.FLOAT, (1,))],
[helper.make_tensor_value_info("Y", TensorProto.FLOAT, (5,))])
converted_model = self._converted(graph, helper.make_operatorsetid(
"", 5), 8)
# Assert equality of graph and converted_model
assert converted_model.graph.node[0].op_type == "BatchNormalization"
assert converted_model.opset_import[0].version == 8


if __name__ == '__main__':
unittest.main()
2 changes: 2 additions & 0 deletions onnx/version_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
--BatchNorm from Opset 6 to Opset 7
--BatchNorm from Opset 6 to Opset 5
--BatchNorm from Opset 5 to Opset 6
Unsupported adapters:
--Concat from Opset 4 to Opset 3
--Concat from Opset 3 to Opset 4
--MaxPool from Opset 8 to Opset 7
Expand Down
24 changes: 24 additions & 0 deletions onnx/version_converter/adapters/batch_normalization_6_5.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Adapter for BatchNormalization in default domain from version 6 to 5

#pragma once

#include "onnx/version_converter/adapters/adapter.h"

namespace ONNX_NAMESPACE { namespace version_conversion {

class BatchNormalization_6_5 final : public Adapter {
public:
explicit BatchNormalization_6_5()
: Adapter("BatchNormalization", OpSetID(6), OpSetID(5)) {
}

void adapt_batch_normalization_6_5(std::shared_ptr<Graph> graph, Node* node) const {
node->is_(kconsumed_inputs, {0, 0});
}

void adapt(std::shared_ptr<Graph> graph, Node* node) const override {
adapt_batch_normalization_6_5(graph, node);
}
};

}} // namespace ONNX_NAMESPACE::version_conversion
27 changes: 27 additions & 0 deletions onnx/version_converter/adapters/batch_normalization_6_7.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Adapter for BatchNormalization in default domain from version 6 to 7

#pragma once

#include "onnx/version_converter/adapters/adapter.h"

namespace ONNX_NAMESPACE { namespace version_conversion {

struct BatchNormalization_6_7 final : public Adapter {
explicit BatchNormalization_6_7()
: Adapter("BatchNormalization", OpSetID(6), OpSetID(7)) {
}

void adapt_batch_normalization_6_7(std::shared_ptr<Graph> graph, Node* node) const {
if (node->hasAttribute(kis_test)) {
ONNX_ASSERTM(node->i(kis_test) != 0,
"ONNX currently only supports inference, not training.");
node->removeAttribute(kis_test);
}
}

void adapt(std::shared_ptr<Graph> graph, Node* node) const override {
adapt_batch_normalization_6_7(graph, node);
}
};

}} // namespace ONNX_NAMESPACE::version_conversion
23 changes: 23 additions & 0 deletions onnx/version_converter/adapters/batch_normalization_7_6.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Adapter for BatchNormalization in default domain from version 7 to 6

#pragma once

#include "onnx/version_converter/adapters/adapter.h"

namespace ONNX_NAMESPACE { namespace version_conversion {

struct BatchNormalization_7_6 final : public Adapter {
explicit BatchNormalization_7_6()
: Adapter("BatchNormalization", OpSetID(7), OpSetID(6)) {
}

void adapt_batch_normalization_7_6(std::shared_ptr<Graph> graph, Node* node) const {
node->i_(kis_test, 1);
}

void adapt(std::shared_ptr<Graph> graph, Node* node) const override {
adapt_batch_normalization_7_6(graph, node);
}
};

}} // namespace ONNX_NAMESPACE::version_conversion
8 changes: 8 additions & 0 deletions onnx/version_converter/convert.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
#include "onnx/version_converter/adapters/remove_consumed_inputs.h"
#include "onnx/version_converter/adapters/gemm_7_6.h"
#include "onnx/version_converter/adapters/gemm_6_7.h"
#include "onnx/version_converter/adapters/batch_normalization_6_5.h"
#include "onnx/version_converter/adapters/batch_normalization_6_7.h"
#include "onnx/version_converter/adapters/batch_normalization_7_6.h"

namespace ONNX_NAMESPACE { namespace version_conversion {

Expand Down Expand Up @@ -114,6 +117,11 @@ class DefaultVersionConverter : public BaseVersionConverter {
OpSetID(5), OpSetID(6)));
registerAdapter(make_unique<CompatibleAdapter>("Relu",
OpSetID(6), OpSetID(5)));
registerAdapter(make_unique<BatchNormalization_7_6>());
registerAdapter(make_unique<BatchNormalization_6_7>());
registerAdapter(make_unique<BatchNormalization_6_5>());
registerAdapter(make_unique<RemoveConsumedInputs>("BatchNormalization",
OpSetID(5), OpSetID(6)));
}

ModelProto convert_version(
Expand Down

0 comments on commit b15a99e

Please sign in to comment.