RuRo commented on pull request #17827:
URL: https://github.com/apache/incubator-mxnet/pull/17827#issuecomment-642109768


   @QueensGambit I am interested in getting this fix accepted.
   
   Could you solve the merge conflict and rebase on top of master? The conflict 
is in the unittest, but it's pretty easy to fix. Just add a `@with_seed` 
decorator to the test function and update the diff context.
   
   A patch like this should work:
   
   <details>
   
   ```diff
   diff --git a/tests/python/unittest/onnx/mxnet_export_test.py 
b/tests/python/unittest/onnx/mxnet_export_test.py
   index 40d7d4e3e..3b039b426 100644
   --- a/tests/python/unittest/onnx/mxnet_export_test.py
   +++ b/tests/python/unittest/onnx/mxnet_export_test.py
   @@ -28,6 +28,7 @@ from common import setup_module, teardown_module, with_seed
    from mxnet import nd, sym
    from mxnet.test_utils import set_default_context
    from mxnet.gluon import nn
   +from mxnet.gluon import HybridBlock
    from mxnet.contrib import onnx as onnx_mxnet
    import mxnet as mx
    
   @@ -80,6 +81,16 @@ def _check_onnx_export(net, group_outputs=False, 
shape_type=tuple, extra_params=
                mx.test_utils.assert_almost_equal(out, imp_out, atol=1e-5, 
rtol=1e-5)
    
    
   +class SplitConcatBlock(HybridBlock):
   +    """Block which creates two splits and later concatenates them"""
   +    def __init__(self, name):
   +        super(SplitConcatBlock, self).__init__(name)
   +
   +    def hybrid_forward(self, F, x):
   +        splits = F.split(x, axis=1, num_outputs=2)
   +        return F.concat(*splits)
   +
   +
    class TestExport(unittest.TestCase):
        """ Tests ONNX export.
        """
   @@ -126,3 +137,10 @@ class TestExport(unittest.TestCase):
                net.add(nn.Dense(100, activation='relu'), nn.Dense(10))
            _check_onnx_export(net, extra_params={'extra_param': nd.array([1, 
2])})
    
   +    @with_seed()
   +    def test_onnx_export_slice(self):
   +        net = nn.HybridSequential(prefix='slice_net')
   +        with net.name_scope():
   +            net.add(nn.Dense(100, activation='relu'), 
SplitConcatBlock("splitConcat"), nn.Dense(10))
   +        _check_onnx_export(net)
   +
   ```
   
   </details>


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to