KellenSunderland closed pull request #11893: WIP: Add deconvolutions to onnx 
exporter
URL: https://github.com/apache/incubator-mxnet/pull/11893
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py 
b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index b2c93670bb4..c920595ae44 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -88,6 +88,37 @@ def parse_helper(attrs, attrs_name, alt_value=None):
             raise AttributeError("Malformed %s dimensions: %s" % (attrs_name, 
str(attrs_str)))
     return alt_value
 
+
+def parse_padding(attrs):
+    tuple_re = re.compile('\([0-9|,| ]+\)')
+
+    def parse_helper(attrs_name, alt_value=None):
+        if attrs is None:
+            return alt_value
+        attrs_str = attrs.get(attrs_name)
+        if attrs_str is None:
+            return alt_value
+        attrs_match = tuple_re.search(attrs_str)
+        if attrs_match is not None:
+            if attrs_match.span() == (0, len(attrs_str)):
+                dims = eval(attrs_str)
+                return dims
+            else:
+                raise AttributeError("Malformed %s dimensions: %s" % 
(attrs_name, str(attrs_str)))
+        return alt_value
+
+    symetric_pads = list(parse_helper("pad", [0, 0]))
+    result = []
+
+    # Each padding in MXNet is assumed to be symmetric in dim1, dim2 ...
+    # In ONNX we need to have a start_dim1, start_dim2, ..., end_dim1, end_dim2
+    for pad in symetric_pads:
+        result.append(pad)
+    for pad in symetric_pads:
+        result.append(pad)
+    return result
+
+
 def transform_padding(pad_width):
     """Helper function to convert padding format for pad operator.
     """
@@ -160,6 +191,44 @@ def convert_weights_and_inputs(node, **kwargs):
         return [tval_node]
 
 
+@mx_op.register("Deconvolution")
+def convert_deconvolution(node, **kwargs):
+    helper, _, _ = import_onnx_modules()
+    name = node["name"]
+    inputs = node["inputs"]
+
+    num_inputs = len(inputs)
+
+    proc_nodes = kwargs["proc_nodes"]
+    input_node = proc_nodes[kwargs["index_lookup"][inputs[0][0]]].name
+    weights_node = proc_nodes[kwargs["index_lookup"][inputs[1][0]]].name
+
+    if num_inputs > 2:
+        bias_node = proc_nodes[kwargs["index_lookup"][inputs[2][0]]].name
+
+    attrs = node.get("attrs")
+    num_filter = int(attrs["num_filter"])
+    kernel_dims = list(parse_helper(attrs, "kernel"))
+    stride_dims = list(parse_helper(attrs, "stride", [1, 1]))
+    pad_dims = parse_padding(attrs)
+    num_group = int(attrs.get("num_group", 1))
+    input_nodes = [input_node, weights_node]
+    if num_inputs > 2:
+        input_nodes.append(bias_node)
+
+    deconv_node = helper.make_node(
+        "ConvTranspose",
+        inputs=input_nodes,
+        outputs=[name],
+        kernel_shape=kernel_dims,
+        strides=stride_dims,
+        pads=pad_dims,
+        group=num_group,
+        name=name
+    )
+
+    return deconv_node
+
 @mx_op.register("Convolution")
 def convert_convolution(node, **kwargs):
     """Map MXNet's convolution operator attributes to onnx's Conv operator


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to