This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 2f41a39 [FRONTEND][MXNET] Use leaky by default for LeakyReLU (#5192)
2f41a39 is described below
commit 2f41a39688bf5fe2f18d8481f9ae012fb6a05614
Author: MORITA Kazutaka <[email protected]>
AuthorDate: Thu Apr 2 07:49:37 2020 +0900
[FRONTEND][MXNET] Use leaky by default for LeakyReLU (#5192)
---
python/tvm/relay/frontend/mxnet.py | 2 +-
tests/python/frontend/mxnet/test_forward.py | 11 ++++++++++-
2 files changed, 11 insertions(+), 2 deletions(-)
diff --git a/python/tvm/relay/frontend/mxnet.py
b/python/tvm/relay/frontend/mxnet.py
index b918f9b..5c8e726 100644
--- a/python/tvm/relay/frontend/mxnet.py
+++ b/python/tvm/relay/frontend/mxnet.py
@@ -510,7 +510,7 @@ def _mx_pad(inputs, attrs):
pad_mode=pad_mode)
def _mx_leaky_relu(inputs, attrs):
- act_type = attrs.get_str("act_type")
+ act_type = attrs.get_str("act_type", "leaky")
if act_type == "leaky":
return _op.nn.leaky_relu(inputs[0], alpha=attrs.get_float("slope",
0.25))
if act_type == "prelu":
diff --git a/tests/python/frontend/mxnet/test_forward.py
b/tests/python/frontend/mxnet/test_forward.py
index 102905a..f015447 100644
--- a/tests/python/frontend/mxnet/test_forward.py
+++ b/tests/python/frontend/mxnet/test_forward.py
@@ -107,6 +107,14 @@ def test_forward_resnet():
mx_sym = model_zoo.mx_resnet(18)
verify_mxnet_frontend_impl(mx_sym)
+def test_forward_leaky_relu():
+ data = mx.sym.var('data')
+ data = mx.sym.concat(data, -data, dim=1) # negative part explicitly
+ mx_sym = mx.sym.LeakyReLU(data)
+ verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
+ mx_sym = mx.sym.LeakyReLU(data, act_type='leaky')
+ verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100))
+
def test_forward_elu():
data = mx.sym.var('data')
data = mx.sym.concat(data, -data, dim=1) # negative part explicitly
@@ -979,6 +987,7 @@ if __name__ == '__main__':
test_forward_mlp()
test_forward_vgg()
test_forward_resnet()
+ test_forward_leaky_relu()
test_forward_elu()
test_forward_rrelu()
test_forward_prelu()
@@ -1030,4 +1039,4 @@ if __name__ == '__main__':
test_forward_deconvolution()
test_forward_cond()
test_forward_make_loss()
- test_forward_unravel_index()
\ No newline at end of file
+ test_forward_unravel_index()