This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch v1.2.0
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.2.0 by this push:
new 5ab574f Merge asymmetric padding(#10676)
5ab574f is described below
commit 5ab574f853e4d6c54a1d617065cb13d73907ab1b
Author: Anirudh <[email protected]>
AuthorDate: Wed Apr 25 13:43:48 2018 -0700
Merge asymmetric padding(#10676)
---
.../mxnet/contrib/onnx/_import/op_translations.py | 46 ++++++++++++++++++----
tests/python-pytest/onnx/onnx_test.py | 5 ++-
2 files changed, 42 insertions(+), 9 deletions(-)
diff --git a/python/mxnet/contrib/onnx/_import/op_translations.py
b/python/mxnet/contrib/onnx/_import/op_translations.py
index de34132..2fa517a 100644
--- a/python/mxnet/contrib/onnx/_import/op_translations.py
+++ b/python/mxnet/contrib/onnx/_import/op_translations.py
@@ -214,12 +214,28 @@ def conv(attrs, inputs, cls):
new_attrs = translation_utils._fix_bias('Convolution', new_attrs,
len(inputs))
new_attrs = translation_utils._fix_channels('Convolution', new_attrs,
inputs, cls)
-
- return 'Convolution', new_attrs, inputs
-
+ kernel = new_attrs['kernel']
+ stride = new_attrs['stride'] if 'stride' in new_attrs else []
+ padding = new_attrs['pad'] if 'pad' in new_attrs else []
+ dilations = new_attrs['dilate'] if 'dilate' in new_attrs else []
+ num_filter = new_attrs['num_filter']
+ num_group = new_attrs['num_group']
+ no_bias = new_attrs['no_bias'] if 'no_bias' in new_attrs else 0
+ bias = None if no_bias is True else inputs[2]
+
+ # Unlike ONNX, MXNet's convolution operator does not support asymmetric
padding, so we first
+ # use 'Pad' operator, which supports asymmetric padding. Then use the
convolution operator.
+ pad_width = (0, 0, 0, 0) + translation_utils._pad_sequence_fix(padding,
kernel_dim=len(kernel))
+ pad_op = symbol.pad(inputs[0], mode='constant', pad_width=pad_width)
+
+ conv_op = symbol.Convolution(pad_op, inputs[1], bias,
+ kernel=kernel, stride=stride,
dilate=dilations,
+ num_filter=num_filter, num_group=num_group,
no_bias=no_bias)
+
+ return conv_op, new_attrs, inputs
def deconv(attrs, inputs, cls):
- """Compute N-D convolution on (N+2)-D input."""
+ """Computes transposed convolution of the input tensor."""
new_attrs = translation_utils._fix_attribute_names(attrs, {'kernel_shape'
: 'kernel',
'strides' :
'stride',
'pads': 'pad',
@@ -229,9 +245,25 @@ def deconv(attrs, inputs, cls):
new_attrs = translation_utils._fix_bias('Deconvolution', new_attrs,
len(inputs))
new_attrs = translation_utils._fix_channels('Deconvolution', new_attrs,
inputs, cls)
-
- return 'Convolution', new_attrs, inputs
-
+ kernel = new_attrs['kernel']
+ stride = new_attrs['stride'] if 'stride' in new_attrs else []
+ padding = new_attrs['pad'] if 'pad' in new_attrs else []
+ dilations = new_attrs['dilate'] if 'dilate' in new_attrs else []
+ num_filter = new_attrs['num_filter']
+ num_group = new_attrs['num_group']
+ no_bias = new_attrs['no_bias'] if 'no_bias' in new_attrs else False
+ bias = None if no_bias is True else inputs[2]
+
+ # Unlike ONNX, MXNet's deconvolution operator does not support asymmetric
padding, so we first
+ # use 'Pad' operator, which supports asymmetric padding. Then use the
deconvolution operator.
+ pad_width = (0, 0, 0, 0) + translation_utils._pad_sequence_fix(padding,
kernel_dim=len(kernel))
+ pad_op = symbol.pad(inputs[0], mode='constant', pad_width=pad_width)
+
+ deconv_op = symbol.Deconvolution(pad_op, inputs[1], bias,
+ kernel=kernel, stride=stride,
dilate=dilations,
+ num_filter=num_filter,
num_group=num_group, no_bias=no_bias)
+
+ return deconv_op, new_attrs, inputs
def fully_connected(attrs, inputs, cls):
"""Applies a linear transformation: Y=XWT+b."""
diff --git a/tests/python-pytest/onnx/onnx_test.py
b/tests/python-pytest/onnx/onnx_test.py
index 36cb9ab..e75ef69 100644
--- a/tests/python-pytest/onnx/onnx_test.py
+++ b/tests/python-pytest/onnx/onnx_test.py
@@ -124,12 +124,13 @@ def test_super_resolution_example():
assert sym.list_outputs()[0] == 'reshape5_output'
attrs_keys = sym.attr_dict().keys()
- assert len(attrs_keys) == 19
+ assert len(attrs_keys) == 23
for i, key_item in enumerate(['reshape4', 'convolution2', 'convolution0',
'transpose0', '6', 'reshape0', 'reshape2',
'reshape3', '3', 'reshape1', '5', '4', '7',
'convolution1', '9', '2', 'convolution3',
- 'reshape5', '8']):
+ 'reshape5', '8', 'pad1', 'pad0', 'pad3',
+ 'pad2']):
assert key_item in attrs_keys
param_keys = arg_params.keys()
--
To stop receiving notification emails like this one, please contact
[email protected].