Roshrini closed pull request #13500: [MXNET-898] ONNX import/export: 
Sample_multinomial, ONNX export: GlobalLpPool, LpPool
URL: https://github.com/apache/incubator-mxnet/pull/13500
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py 
b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index 3baf10a10d3..d24865d9dcb 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -586,6 +586,7 @@ def convert_pooling(node, **kwargs):
     pool_type = attrs["pool_type"]
     stride = eval(attrs["stride"]) if attrs.get("stride") else None
     global_pool = get_boolean_attribute_value(attrs, "global_pool")
+    p_value = attrs.get('p_value', 'None')
 
     pooling_convention = attrs.get('pooling_convention', 'valid')
 
@@ -598,26 +599,51 @@ def convert_pooling(node, **kwargs):
 
     pad_dims = list(parse_helper(attrs, "pad", [0, 0]))
     pad_dims = pad_dims + pad_dims
-    pool_types = {"max": "MaxPool", "avg": "AveragePool"}
-    global_pool_types = {"max": "GlobalMaxPool", "avg": "GlobalAveragePool"}
+    pool_types = {"max": "MaxPool", "avg": "AveragePool", "lp": "LpPool"}
+    global_pool_types = {"max": "GlobalMaxPool", "avg": "GlobalAveragePool",
+                         "lp": "GlobalLpPool"}
+
+    if pool_type == 'lp' and p_value == 'None':
+        raise AttributeError('ONNX requires a p value for LpPool and 
GlobalLpPool')
 
     if global_pool:
-        node = onnx.helper.make_node(
-            global_pool_types[pool_type],
-            input_nodes,  # input
-            [name],
-            name=name
-        )
+        if pool_type == 'lp':
+            node = onnx.helper.make_node(
+                global_pool_types[pool_type],
+                input_nodes,  # input
+                [name],
+                p=int(p_value),
+                name=name
+            )
+        else:
+            node = onnx.helper.make_node(
+                global_pool_types[pool_type],
+                input_nodes,  # input
+                [name],
+                name=name
+            )
     else:
-        node = onnx.helper.make_node(
-            pool_types[pool_type],
-            input_nodes,  # input
-            [name],
-            kernel_shape=kernel,
-            pads=pad_dims,
-            strides=stride,
-            name=name
-        )
+        if pool_type == 'lp':
+            node = onnx.helper.make_node(
+                pool_types[pool_type],
+                input_nodes,  # input
+                [name],
+                p=int(p_value),
+                kernel_shape=kernel,
+                pads=pad_dims,
+                strides=stride,
+                name=name
+            )
+        else:
+            node = onnx.helper.make_node(
+                pool_types[pool_type],
+                input_nodes,  # input
+                [name],
+                kernel_shape=kernel,
+                pads=pad_dims,
+                strides=stride,
+                name=name
+            )
 
     return [node]
 
@@ -1689,3 +1715,26 @@ def convert_logsoftmax(node, **kwargs):
         name=name
     )
     return [node]
+
+
+@mx_op.register("_sample_multinomial")
+def convert_multinomial(node, **kwargs):
+    """Map MXNet's multinomial operator attributes to onnx's
+    Multinomial operator and return the created node.
+    """
+    name, input_nodes, attrs = get_inputs(node, kwargs)
+    dtype = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(attrs.get("dtype", 
'int32'))]
+    sample_size = convert_string_to_list(attrs.get("shape", '1'))
+    if len(sample_size) < 2:
+        sample_size = sample_size[-1]
+    else:
+        raise AttributeError("ONNX currently supports integer sample_size 
only")
+    node = onnx.helper.make_node(
+        "Multinomial",
+        input_nodes,
+        [name],
+        dtype=dtype,
+        sample_size=sample_size,
+        name=name,
+    )
+    return [node]
diff --git a/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py 
b/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py
index 5b33f9faac1..2a668dc84be 100644
--- a/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py
+++ b/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py
@@ -18,7 +18,7 @@
 # coding: utf-8_
 # pylint: disable=invalid-name
 """Operator attributes conversion"""
-from ._op_translations import identity, random_uniform, random_normal
+from ._op_translations import identity, random_uniform, random_normal, 
sample_multinomial
 from ._op_translations import add, subtract, multiply, divide, absolute, 
negative, add_n
 from ._op_translations import tanh, arccos, arcsin, arctan, _cos, _sin, _tan
 from ._op_translations import softplus, shape, gather, lp_pooling, size
@@ -48,6 +48,7 @@
     'RandomNormal'      : random_normal,
     'RandomUniformLike' : random_uniform,
     'RandomNormalLike'  : random_normal,
+    'Multinomial'       : sample_multinomial,
     # Arithmetic Operators
     'Add'               : add,
     'Sub'               : subtract,
diff --git a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py 
b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
index ce0e0e51ef7..a061a7ef002 100644
--- a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
@@ -38,6 +38,19 @@ def random_normal(attrs, inputs, proto_obj):
     new_attr = translation_utils._fix_attribute_names(new_attr, {'mean' : 
'loc'})
     return 'random_uniform', new_attr, inputs
 
+def sample_multinomial(attrs, inputs, proto_obj):
+    """Draw random samples from a multinomial distribution."""
+    try:
+        from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
+    except ImportError:
+        raise ImportError("Onnx and protobuf need to be installed. "
+                          + "Instructions to install - 
https://github.com/onnx/onnx";)
+    new_attrs = translation_utils._remove_attributes(attrs, ['seed'])
+    new_attrs = translation_utils._fix_attribute_names(new_attrs, 
{'sample_size': 'shape'})
+    new_attrs['dtype'] = TENSOR_TYPE_TO_NP_TYPE[int(attrs.get('dtype', 6))]
+    return 'sample_multinomial', new_attrs, inputs
+
+
 # Arithmetic Operations
 def add(attrs, inputs, proto_obj):
     """Adding two tensors"""
@@ -382,6 +395,7 @@ def global_lppooling(attrs, inputs, proto_obj):
                                                                 'kernel': (1, 
1),
                                                                 'pool_type': 
'lp',
                                                                 'p_value': 
p_value})
+    new_attrs = translation_utils._remove_attributes(new_attrs, ['p'])
     return 'Pooling', new_attrs, inputs
 
 def linalg_gemm(attrs, inputs, proto_obj):
@@ -671,11 +685,12 @@ def lp_pooling(attrs, inputs, proto_obj):
     new_attrs = translation_utils._fix_attribute_names(attrs,
                                                        {'kernel_shape': 
'kernel',
                                                         'strides': 'stride',
-                                                        'pads': 'pad',
-                                                        'p_value': p_value
+                                                        'pads': 'pad'
                                                        })
+    new_attrs = translation_utils._remove_attributes(new_attrs, ['p'])
     new_attrs = translation_utils._add_extra_attributes(new_attrs,
-                                                        {'pooling_convention': 
'valid'
+                                                        {'pooling_convention': 
'valid',
+                                                         'p_value': p_value
                                                         })
     new_op = translation_utils._fix_pooling('lp', inputs, new_attrs)
     return new_op, new_attrs, inputs
diff --git a/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py 
b/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py
index f63c1e9e8e6..6fd52665ca3 100644
--- a/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py
+++ b/python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py
@@ -94,6 +94,7 @@ def _fix_pooling(pool_type, inputs, new_attr):
     stride = new_attr.get('stride')
     kernel = new_attr.get('kernel')
     padding = new_attr.get('pad')
+    p_value = new_attr.get('p_value')
 
     # Adding default stride.
     if stride is None:
@@ -138,7 +139,10 @@ def _fix_pooling(pool_type, inputs, new_attr):
             new_pad_op = symbol.pad(curr_sym, mode='constant', 
pad_width=pad_width)
 
     # Apply pooling without pads.
-    new_pooling_op = symbol.Pooling(new_pad_op, pool_type=pool_type, 
stride=stride, kernel=kernel)
+    if pool_type == 'lp':
+        new_pooling_op = symbol.Pooling(new_pad_op, pool_type=pool_type, 
stride=stride, kernel=kernel, p_value=p_value)
+    else:
+        new_pooling_op = symbol.Pooling(new_pad_op, pool_type=pool_type, 
stride=stride, kernel=kernel)
     return new_pooling_op
 
 def _fix_bias(op_name, attrs, num_inputs):
diff --git a/tests/python-pytest/onnx/test_cases.py 
b/tests/python-pytest/onnx/test_cases.py
index 6a189b62492..64aaab0f6d4 100644
--- a/tests/python-pytest/onnx/test_cases.py
+++ b/tests/python-pytest/onnx/test_cases.py
@@ -79,7 +79,6 @@
              'test_softplus'
              ],
     'import': ['test_gather',
-               'test_global_lppooling',
                'test_softsign',
                'test_reduce_',
                'test_mean',
@@ -89,7 +88,6 @@
                'test_averagepool_2d_precomputed_strides',
                'test_averagepool_2d_strides',
                'test_averagepool_3d',
-               'test_LpPool_',
                'test_split_equal',
                'test_hardmax'
                ],
diff --git a/tests/python-pytest/onnx/test_node.py 
b/tests/python-pytest/onnx/test_node.py
index 07ae866b96c..6a0f8bcd73c 100644
--- a/tests/python-pytest/onnx/test_node.py
+++ b/tests/python-pytest/onnx/test_node.py
@@ -56,6 +56,24 @@ def get_rnd(shape, low=-1.0, high=1.0, dtype=np.float32):
         return np.random.choice(a=[False, True], size=shape).astype(np.float32)
 
 
+def _fix_attributes(attrs, attribute_mapping):
+    new_attrs = attrs
+    attr_modify = attribute_mapping.get('modify', {})
+    for k, v in attr_modify.items():
+        new_attrs[v] = new_attrs.pop(k, None)
+
+    attr_add = attribute_mapping.get('add', {})
+    for k, v in attr_add.items():
+        new_attrs[k] = v
+
+    attr_remove = attribute_mapping.get('remove', [])
+    for k in attr_remove:
+        if k in new_attrs:
+            del new_attrs[k]
+
+    return new_attrs
+
+
 def forward_pass(sym, arg, aux, data_names, input_data):
     """ Perform forward pass on given data
     :param sym: Symbol
@@ -118,7 +136,7 @@ def get_onnx_graph(testname, input_names, inputs, 
output_name, output_shape, att
             return model
 
         for test in test_cases:
-            test_name, mxnet_op, onnx_name, inputs, attrs, mxnet_specific = 
test
+            test_name, mxnet_op, onnx_name, inputs, attrs, mxnet_specific, 
fix_attrs, check_value, check_shape = test
             with self.subTest(test_name):
                 names, input_tensors, inputsym = get_input_tensors(inputs)
                 test_op = mxnet_op(*inputsym, **attrs)
@@ -131,33 +149,66 @@ def get_onnx_graph(testname, input_names, inputs, 
output_name, output_shape, att
                                                             onnx_name + 
".onnx")
                     onnxmodel = load_model(onnxmodelfile)
                 else:
-                    onnxmodel = get_onnx_graph(test_name, names, 
input_tensors, onnx_name, outputshape, attrs)
+                    onnx_attrs = _fix_attributes(attrs, fix_attrs)
+                    onnxmodel = get_onnx_graph(test_name, names, 
input_tensors, onnx_name, outputshape, onnx_attrs)
 
                 bkd_rep = backend.prepare(onnxmodel, operation='export')
                 output = bkd_rep.run(inputs)
 
-                npt.assert_almost_equal(output[0], mxnet_output)
+                if check_value:
+                    npt.assert_almost_equal(output[0], mxnet_output)
+
+                if check_shape:
+                    npt.assert_equal(output[0].shape, outputshape)
 
 
-# test_case = ("test_case_name", mxnet op, "ONNX_op_name", [input_list], 
attribute map, MXNet_specific=True/False)
+# test_case = ("test_case_name", mxnet op, "ONNX_op_name", [input_list], 
attribute map, MXNet_specific=True/False,
+# fix_attributes = {'modify': {mxnet_attr_name: onnx_attr_name},
+#                   'remove': [attr_name],
+#                   'add': {attr_name: value},
+# check_value=True/False, check_shape=True/False)
 test_cases = [
-    ("test_equal", mx.sym.broadcast_equal, "Equal", [get_rnd((1, 3, 4, 5)), 
get_rnd((1, 5))], {}, False),
-    ("test_greater", mx.sym.broadcast_greater, "Greater", [get_rnd((1, 3, 4, 
5)), get_rnd((1, 5))], {}, False),
-    ("test_less", mx.sym.broadcast_lesser, "Less", [get_rnd((1, 3, 4, 5)), 
get_rnd((1, 5))], {}, False),
+    ("test_equal", mx.sym.broadcast_equal, "Equal", [get_rnd((1, 3, 4, 5)), 
get_rnd((1, 5))], {}, False, {}, True,
+     False),
+    ("test_greater", mx.sym.broadcast_greater, "Greater", [get_rnd((1, 3, 4, 
5)), get_rnd((1, 5))], {}, False, {}, True,
+     False),
+    ("test_less", mx.sym.broadcast_lesser, "Less", [get_rnd((1, 3, 4, 5)), 
get_rnd((1, 5))], {}, False, {}, True,
+     False),
     ("test_and", mx.sym.broadcast_logical_and, "And",
-     [get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], 
{}, False),
+     [get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], 
{}, False, {}, True, False),
     ("test_xor", mx.sym.broadcast_logical_xor, "Xor",
-     [get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], 
{}, False),
+     [get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], 
{}, False, {}, True, False),
     ("test_or", mx.sym.broadcast_logical_or, "Or",
-     [get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], 
{}, False),
-    ("test_not", mx.sym.logical_not, "Not", [get_rnd((3, 4, 5), 
dtype=np.bool_)], {}, False),
-    ("test_square", mx.sym.square, "Pow", [get_rnd((2, 3), dtype=np.int32)], 
{}, True),
+     [get_rnd((3, 4, 5), dtype=np.bool_), get_rnd((3, 4, 5), dtype=np.bool_)], 
{}, False, {}, True, False),
+    ("test_not", mx.sym.logical_not, "Not", [get_rnd((3, 4, 5), 
dtype=np.bool_)], {}, False, {}, True, False),
+    ("test_square", mx.sym.square, "Pow", [get_rnd((2, 3), dtype=np.int32)], 
{}, True, {}, True, False),
     ("test_spacetodepth", mx.sym.space_to_depth, "SpaceToDepth", [get_rnd((1, 
1, 4, 6))],
-     {'block_size': 2}, False),
+     {'block_size': 2}, False, {}, True, False),
     ("test_softmax", mx.sym.SoftmaxOutput, "Softmax", [get_rnd((1000, 1000)), 
get_rnd(1000)],
-     {'ignore_label': 0, 'use_ignore': False}, True),
-    ("test_fullyconnected", mx.sym.FullyConnected, "Gemm", [get_rnd((4,3)), 
get_rnd((4, 3)), get_rnd(4)],
-     {'num_hidden': 4, 'name': 'FC'}, True)
+     {'ignore_label': 0, 'use_ignore': False}, True, {}, True, False),
+    ("test_fullyconnected", mx.sym.FullyConnected, "Gemm", [get_rnd((4, 3)), 
get_rnd((4, 3)), get_rnd(4)],
+     {'num_hidden': 4, 'name': 'FC'}, True, {}, True, False),
+    ("test_lppool1", mx.sym.Pooling, "LpPool", [get_rnd((2, 3, 20, 20))],
+     {'kernel': (4, 5), 'pad': (0, 0), 'stride': (1, 1), 'p_value': 1, 
'pool_type': 'lp'}, False,
+     {'modify': {'kernel': 'kernel_shape', 'pad': 'pads', 'stride': 'strides', 
'p_value': 'p'},
+      'remove': ['pool_type']}, True, False),
+    ("test_lppool2", mx.sym.Pooling, "LpPool", [get_rnd((2, 3, 20, 20))],
+     {'kernel': (4, 5), 'pad': (0, 0), 'stride': (1, 1), 'p_value': 2, 
'pool_type': 'lp'}, False,
+     {'modify': {'kernel': 'kernel_shape', 'pad': 'pads', 'stride': 'strides', 
'p_value': 'p'},
+      'remove': ['pool_type']}, True, False),
+    ("test_globallppool1", mx.sym.Pooling, "GlobalLpPool", [get_rnd((2, 3, 20, 
20))],
+     {'kernel': (4, 5), 'pad': (0, 0), 'stride': (1, 1), 'p_value': 1, 
'pool_type': 'lp', 'global_pool': True}, False,
+     {'modify': {'p_value': 'p'},
+      'remove': ['pool_type', 'kernel', 'pad', 'stride', 'global_pool']}, 
True, False),
+    ("test_globallppool2", mx.sym.Pooling, "GlobalLpPool", [get_rnd((2, 3, 20, 
20))],
+     {'kernel': (4, 5), 'pad': (0, 0), 'stride': (1, 1), 'p_value': 2, 
'pool_type': 'lp', 'global_pool': True}, False,
+     {'modify': {'p_value': 'p'},
+      'remove': ['pool_type', 'kernel', 'pad', 'stride', 'global_pool']}, 
True, False),
+
+    # since results would be random, checking for shape alone
+    ("test_multinomial", mx.sym.sample_multinomial, "Multinomial",
+     [np.array([0, 0.1, 0.2, 0.3, 0.4]).astype("float32")],
+     {'shape': (10,)}, False, {'modify': {'shape': 'sample_size'}}, False, 
True)
 ]
 
 if __name__ == '__main__':


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to