This is an automated email from the ASF dual-hosted git repository.
anirudh2290 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 40593c6 Fix ConcatType backward type inference (#15829)
40593c6 is described below
commit 40593c6f6c20baed98a914d14987db5438c0a5a5
Author: Anirudh Subramanian <[email protected]>
AuthorDate: Wed Aug 14 21:56:44 2019 -0700
Fix ConcatType backward type inference (#15829)
* Fix ConcatType and add test
* Remove return false
* Change error message
* Run RNN test only when CUDNN enabled
* set default context for test_contrib_amp
---
src/operator/nn/concat.cc | 32 ++++++++++++++++++++++----------
tests/python/gpu/test_contrib_amp.py | 19 +++++++++++++++++--
2 files changed, 39 insertions(+), 12 deletions(-)
diff --git a/src/operator/nn/concat.cc b/src/operator/nn/concat.cc
index 80469b5..9e016bf 100644
--- a/src/operator/nn/concat.cc
+++ b/src/operator/nn/concat.cc
@@ -144,6 +144,7 @@ bool ConcatType(const nnvm::NodeAttrs& attrs,
const ConcatParam& param_ = nnvm::get<ConcatParam>(attrs.parsed);
int dtype = -1;
+ // checks uniformity of input
for (int i : *in_type) {
if (dtype == -1) {
dtype = i;
@@ -154,18 +155,29 @@ bool ConcatType(const nnvm::NodeAttrs& attrs,
}
}
- if (dtype == -1) {
- LOG(FATAL) << "Not enough information to infer type in Concat.";
- return false;
- }
-
size_t nin = param_.num_args;
- in_type->clear();
- for (size_t i = 0; i < nin; ++i) in_type->push_back(dtype);
-
- out_type->clear();
- out_type->push_back(dtype);
+ // if in types are known out types are unknown
+ if (dtype != -1 && (*out_type)[0] == -1) {
+ (*out_type)[0] = dtype;
+ in_type->clear();
+ for (size_t i = 0; i < nin; ++i) {
+ in_type->push_back(dtype);
+ }
+ // if out types are known in types are unknown
+ } else if ((*out_type)[0] != -1 && dtype == -1) {
+ in_type->clear();
+ for (size_t i = 0; i < nin; ++i) {
+ in_type->push_back((*out_type)[0]);
+ }
+ // if both out_types and in_types are known, and different
+ } else if ((*out_type)[0] != -1 && dtype != -1 && ((*out_type)[0] != dtype))
{
+ std::ostringstream os;
+ os << "Type inconsistent, Provided output type = "
+ << mxnet::op::type_string((*out_type)[0]) << ','
+ << " inferred type = " << mxnet::op::type_string(dtype);
+ throw mxnet::op::InferTypeError(os.str(), 0);
+ }
return true;
}
diff --git a/tests/python/gpu/test_contrib_amp.py
b/tests/python/gpu/test_contrib_amp.py
index 7927cc9..3daab0f 100644
--- a/tests/python/gpu/test_contrib_amp.py
+++ b/tests/python/gpu/test_contrib_amp.py
@@ -26,11 +26,12 @@ import mxnet.contrib.amp as amp
from nose.tools import assert_raises
from mxnet.test_utils import set_default_context, download_model,
same_symbol_structure
from mxnet.gluon.model_zoo.vision import get_model
-from mxnet.gluon import SymbolBlock
+from mxnet.gluon import SymbolBlock, nn, rnn
from mxnet.contrib.amp import amp
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.insert(0, os.path.join(curr_path, '../unittest'))
-from common import with_seed, teardown
+from common import with_seed, teardown, assert_raises_cudnn_not_satisfied
+set_default_context(mx.gpu(0))
def test_amp_coverage():
conditional = [item[0] for item in amp.lists.symbol.CONDITIONAL_FP32_FUNCS]
@@ -305,6 +306,20 @@ def test_amp_conversion():
check_amp_convert_model()
check_amp_convert_hybrid_block()
+@with_seed()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
+def test_amp_conversion_rnn():
+ with mx.Context(mx.gpu(0)):
+ model = nn.HybridSequential()
+ model.add(rnn.LSTM(hidden_size=10, num_layers=2, bidirectional=True))
+ model.add(nn.Dense(2))
+ model.initialize()
+ model.hybridize()
+ out = model(mx.nd.ones((2, 3, 4)))
+ new_model = amp.convert_hybrid_block(model)
+ out2 = new_model(mx.nd.ones((2, 3, 4)))
+ mx.test_utils.assert_almost_equal(out.asnumpy(), out2.asnumpy(),
atol=1e-2, rtol=1e-2)
+
@with_seed()
def test_module_backward_compatibility():