Skip to content

Commit

Permalink
fix the onnxprocess for the empty input and name (#104)
Browse files Browse the repository at this point in the history
* fixing the onnxprocess for the empty input and name

* fix the crash on onnxruntime 1.8
  • Loading branch information
wenbingl authored Jun 4, 2021
1 parent 0851eac commit 88a3c0e
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 19 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ endif()
set(CPACK_PACKAGE_NAME "onnxruntime_extensions")
set(CPACK_PACKAGE_VERSION_MAJOR "0")
set(CPACK_PACKAGE_VERSION_MINOR "3")
set(CPACK_PACKAGE_VERSION_PATCH "0")
set(CPACK_PACKAGE_VERSION_PATCH "1")
set(VERSION ${CPACK_PACKAGE_VERSION_MAJOR}.${CPACK_PACKAGE_VERSION_MINOR}.${CPACK_PACKAGE_VERSION_PATCH})


Expand Down
6 changes: 3 additions & 3 deletions includes/onnxruntime/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <string.h>

// This value is used in structures passed to ORT so that a newer version of ORT will still work with them
#define ORT_API_VERSION 5
#define ORT_API_VERSION 6

#ifdef __cplusplus
extern "C" {
Expand Down Expand Up @@ -1120,7 +1120,7 @@ struct OrtApi {
ORT_API2_STATUS(SetGlobalDenormalAsZero, _Inout_ OrtThreadingOptions* tp_options);

/**
* Use this API to create the configuration of an arena that can eventually be used to define
* Use this API to create the configuration of an arena that can eventually be used to define
* an arena based allocator's behavior
* \param max_mem - use 0 to allow ORT to choose the default
* \param arena_extend_strategy - use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested
Expand Down Expand Up @@ -1178,4 +1178,4 @@ struct OrtCustomOp {

#ifdef __cplusplus
}
#endif
#endif
2 changes: 1 addition & 1 deletion onnxruntime_extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
The entry point to onnxruntime custom op library
"""

__version__ = "0.3.0"
__version__ = "0.3.1"
__author__ = "Microsoft"


Expand Down
32 changes: 21 additions & 11 deletions onnxruntime_extensions/onnxprocess/_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,21 @@ def _rename_iter(iterables, prefix_name, inplace=False):

@classmethod
def _rename_graph(cls, graph, prefix, graph_or_container):
def io_rename(node, prefix_name):
def io_rename(node, prefix_name, idx):
new_node = copy.deepcopy(node)
if not node.name:
new_node.name = "{}_op{}".format(prefix_name, idx)

del new_node.input[:]
new_node.input.extend("{}_{}".format(prefix_name, nm_) for nm_ in node.input)
new_node.input.extend("{}_{}".format(prefix_name, nm_) if nm_ else '' for nm_ in node.input)
del new_node.output[:]
new_node.output.extend("{}_{}".format(prefix_name, nm_) for nm_ in node.output)
new_node.output.extend("{}_{}".format(prefix_name, nm_) if nm_ else '' for nm_ in node.output)
return new_node

assert prefix is not None, 'The graph prefix could not be None'
graph_or_container.initializer.extend(cls._rename_iter(graph.initializer, prefix))
graph_or_container.value_info.extend(cls._rename_iter(graph.value_info, prefix))
return list(io_rename(nd_, prefix) for nd_ in graph.node)
return list(io_rename(nd_, prefix, idx_) for idx_, nd_ in enumerate(graph.node))

@classmethod
def _process_node_body(cls, node, prefix):
Expand Down Expand Up @@ -97,6 +100,8 @@ def topological_sort(cls, container, nodes, inputs, outputs):
edges = {}
for op in nodes:
for x in op.input:
if x == '':
continue
try:
predecessor = op_output_map[x]
except KeyError:
Expand Down Expand Up @@ -125,6 +130,7 @@ def recursive_helper(node):

unfinished_nodes.add(node.name)
if node.name in edges: # if the node's output is not in the Graph output.
assert node.name != '', 'this topological-sort depends on the unique node name.'
for successor in edges[node.name]:
recursive_helper(successor)

Expand Down Expand Up @@ -152,7 +158,7 @@ def model_from_ops(container, ops, ts_from, ts_to):
iz_set = set(iz_.name for iz_ in container.initializer)
for op in ops:
iz_needed.update(it_ for it_ in op.input if it_ in iz_set)
all_inputs.extend(it_ for it_ in op.input if it_ not in iz_set)
all_inputs.extend(it_ for it_ in op.input if (it_ != '') and it_ not in iz_set)
all_outputs.extend(ot_ for ot_ in op.output)

intersections = set(all_inputs).intersection(set(all_outputs))
Expand Down Expand Up @@ -205,12 +211,16 @@ def trace_for_onnx(cls, *inputs, names=None, target_opset=11) -> 'ONNXTraceSessi
else torch.tensor(x) for x in np_inputs]
itensors = [tensor_from_torch(i_, None) if isinstance(i_, torch.Tensor)
else tensor_from_onnx(i_, None, None) for i_ in np_inputs]
if names is not None:
if len(inputs) != len(names):
warnings.warn("the name number doesn't match the inputs', assign to the ones in the front.")
num = min(len(itensors), len(names))
for idx_ in range(num):
itensors[idx_].name = names[idx_]
if names is None:
names = []
if len(inputs) != len(names):
warnings.warn("the name number doesn't match the inputs', assign to the ones in the front.")
names.extend([''] * (len(inputs) - len(names)))
for idx_ in range(len(inputs)):
names[idx_] = names[idx_] if names[idx_] else "input{}".format(idx_)
num = min(len(itensors), len(names))
for idx_ in range(num):
itensors[idx_].name = names[idx_]
self.inputs = itensors
return self

Expand Down
10 changes: 7 additions & 3 deletions pyop/pykernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,16 @@ struct PyCustomOpKernel {
for (std::vector<std::string>::const_iterator it = attrs.begin(); it != attrs.end(); ++it) {
size = 0;
OrtStatus* status = api_.KernelInfoGetAttribute_string(info, it->c_str(), nullptr, &size);
if (api_.GetErrorCode(status) != ORT_INVALID_ARGUMENT) {
if ((status != nullptr) && api_.GetErrorCode(status) != ORT_INVALID_ARGUMENT) {
std::string error_message(api_.GetErrorMessage(status));
api_.ReleaseStatus(status);
throw std::runtime_error(MakeString(
"Unable to find attribute '", *it, "' due to '",
error_message, "'."));
}
api_.ReleaseStatus(status);
if (status != nullptr) {
api_.ReleaseStatus(status);
}
attrs_values_[*it] = "";
attrs_values_[*it].resize(size);
status = api_.KernelInfoGetAttribute_string(info, it->c_str(), &(attrs_values_[*it][0]), &size);
Expand All @@ -63,7 +65,9 @@ struct PyCustomOpKernel {
api_.GetErrorMessage(status), "'."));
}
attrs_values_[*it].resize(size - 1);
api_.ReleaseStatus(status);
if (status != nullptr) {
api_.ReleaseStatus(status);
}
}
}

Expand Down

0 comments on commit 88a3c0e

Please sign in to comment.