This is an automated email from the ASF dual-hosted git repository.

anirudh2290 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 40593c6  Fix ConcatType backward type inference (#15829)
40593c6 is described below

commit 40593c6f6c20baed98a914d14987db5438c0a5a5
Author: Anirudh Subramanian <[email protected]>
AuthorDate: Wed Aug 14 21:56:44 2019 -0700

    Fix ConcatType backward type inference (#15829)
    
    * Fix ConcatType and add test
    
    * Remove return false
    
    * Change error message
    
    * Run RNN test only when CUDNN enabled
    
    * set default context for test_contrib_amp
---
 src/operator/nn/concat.cc            | 32 ++++++++++++++++++++++----------
 tests/python/gpu/test_contrib_amp.py | 19 +++++++++++++++++--
 2 files changed, 39 insertions(+), 12 deletions(-)

diff --git a/src/operator/nn/concat.cc b/src/operator/nn/concat.cc
index 80469b5..9e016bf 100644
--- a/src/operator/nn/concat.cc
+++ b/src/operator/nn/concat.cc
@@ -144,6 +144,7 @@ bool ConcatType(const nnvm::NodeAttrs& attrs,
   const ConcatParam& param_ = nnvm::get<ConcatParam>(attrs.parsed);
   int dtype = -1;
 
+  // checks uniformity of input
   for (int i : *in_type) {
     if (dtype == -1) {
       dtype = i;
@@ -154,18 +155,29 @@ bool ConcatType(const nnvm::NodeAttrs& attrs,
     }
   }
 
-  if (dtype == -1) {
-    LOG(FATAL) << "Not enough information to infer type in Concat.";
-    return false;
-  }
-
   size_t nin = param_.num_args;
-  in_type->clear();
-  for (size_t i = 0; i < nin; ++i) in_type->push_back(dtype);
-
-  out_type->clear();
-  out_type->push_back(dtype);
 
+  // if in types are known out types are unknown
+  if (dtype != -1 && (*out_type)[0] == -1) {
+    (*out_type)[0] = dtype;
+    in_type->clear();
+    for (size_t i = 0; i < nin; ++i) {
+      in_type->push_back(dtype);
+    }
+  // if out types are known in types are unknown
+  } else if ((*out_type)[0] != -1 && dtype == -1) {
+    in_type->clear();
+    for (size_t i = 0; i < nin; ++i) {
+      in_type->push_back((*out_type)[0]);
+    }
+  // if both out_types and in_types are known, and different
+  } else if ((*out_type)[0] != -1 && dtype != -1 && ((*out_type)[0] != dtype)) 
{
+    std::ostringstream os;
+    os << "Type inconsistent, Provided output type = "
+       << mxnet::op::type_string((*out_type)[0]) << ','
+       << " inferred type = " << mxnet::op::type_string(dtype);
+    throw mxnet::op::InferTypeError(os.str(), 0);
+  }
   return true;
 }
 
diff --git a/tests/python/gpu/test_contrib_amp.py 
b/tests/python/gpu/test_contrib_amp.py
index 7927cc9..3daab0f 100644
--- a/tests/python/gpu/test_contrib_amp.py
+++ b/tests/python/gpu/test_contrib_amp.py
@@ -26,11 +26,12 @@ import mxnet.contrib.amp as amp
 from nose.tools import assert_raises
 from mxnet.test_utils import set_default_context, download_model, 
same_symbol_structure
 from mxnet.gluon.model_zoo.vision import get_model
-from mxnet.gluon import SymbolBlock
+from mxnet.gluon import SymbolBlock, nn, rnn
 from mxnet.contrib.amp import amp
 curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
 sys.path.insert(0, os.path.join(curr_path, '../unittest'))
-from common import with_seed, teardown
+from common import with_seed, teardown, assert_raises_cudnn_not_satisfied
+set_default_context(mx.gpu(0))
 
 def test_amp_coverage():
     conditional = [item[0] for item in amp.lists.symbol.CONDITIONAL_FP32_FUNCS]
@@ -305,6 +306,20 @@ def test_amp_conversion():
         check_amp_convert_model()
         check_amp_convert_hybrid_block()
 
+@with_seed()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
+def test_amp_conversion_rnn():
+    with mx.Context(mx.gpu(0)):
+        model = nn.HybridSequential()
+        model.add(rnn.LSTM(hidden_size=10, num_layers=2, bidirectional=True))
+        model.add(nn.Dense(2))
+        model.initialize()
+        model.hybridize()
+        out = model(mx.nd.ones((2, 3, 4)))
+        new_model = amp.convert_hybrid_block(model)
+        out2 = new_model(mx.nd.ones((2, 3, 4)))
+        mx.test_utils.assert_almost_equal(out.asnumpy(), out2.asnumpy(), 
atol=1e-2, rtol=1e-2)
+
 
 @with_seed()
 def test_module_backward_compatibility():

Reply via email to