This is an automated email from the ASF dual-hosted git repository.
echuraev pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d1e1b4c8b2 [bugfix][relay] Fix alpha attribute with None in ELU
(#14742)
d1e1b4c8b2 is described below
commit d1e1b4c8b2ff1ae7cdfbf8823ecffb3c305f078f
Author: Qingchao Shen <[email protected]>
AuthorDate: Tue May 2 12:54:04 2023 +0800
[bugfix][relay] Fix alpha attribute with None in ELU (#14742)
* fix alpha in elu
* add test case
---
python/tvm/relay/frontend/keras.py | 2 ++
tests/python/frontend/keras/test_forward.py | 19 +++++++++++++------
2 files changed, 15 insertions(+), 6 deletions(-)
diff --git a/python/tvm/relay/frontend/keras.py
b/python/tvm/relay/frontend/keras.py
index 4539c221c9..ef94c74e03 100644
--- a/python/tvm/relay/frontend/keras.py
+++ b/python/tvm/relay/frontend/keras.py
@@ -160,6 +160,8 @@ def _convert_advanced_activation(inexpr, keras_layer, etab,
data_layout, input_s
raise tvm.error.OpAttributeInvalid("The alpha value of a LeakyReLU
cannot be None.")
return _op.nn.leaky_relu(inexpr, alpha=float(keras_layer.alpha))
if act_type == "ELU":
+ if np.isnan(keras_layer.alpha).any():
+ raise tvm.error.OpAttributeInvalid("The alpha value of a ELU
cannot be None.")
alpha = keras_layer.alpha if hasattr(keras_layer, "alpha") else 1.0
alpha = _expr.const(alpha, dtype="float32")
return _get_elu(inexpr, alpha)
diff --git a/tests/python/frontend/keras/test_forward.py
b/tests/python/frontend/keras/test_forward.py
index 1377c180ae..86e88d0764 100644
--- a/tests/python/frontend/keras/test_forward.py
+++ b/tests/python/frontend/keras/test_forward.py
@@ -214,19 +214,26 @@ class TestKeras:
def test_forward_activations_except(self, keras_mod):
"""
- test invalid attribute alpha=None for LeakyReLU.
+ test invalid attribute alpha=None for LeakyReLU and ELU.
after version 2.3.1 in keras, checking was added to reject the invalid
api call:
- LeakyReLU(alpha=None), (issue:
https://github.com/tensorflow/tensorflow/pull/47017)
+ LeakyReLU(alpha=None) and ELU(alpha=None),
+ (see issue: https://github.com/tensorflow/tensorflow/pull/47017)
Thus, it's necessary to check the keras version to avoid crash at
LeakyReLU(alpha=None)
+ and ELU(alpha=None)
"""
if package_version.parse(keras_mod.__version__.split("-tf")[0]) <=
package_version.parse(
"2.3.1"
):
+ act_funcs = [
+ keras_mod.layers.LeakyReLU(alpha=None),
+ keras_mod.layers.LEU(2, 3, 4),
+ ]
data = keras_mod.layers.Input(shape=(2, 3, 4))
- layer = keras_mod.layers.LeakyReLU(alpha=None)(data)
- keras_model = keras_mod.models.Model(data, layer)
- with pytest.raises(tvm.error.OpAttributeInvalid):
- verify_keras_frontend(keras_model)
+ for act_func in act_funcs:
+ layer = act_func(data)
+ keras_model = keras_mod.models.Model(data, layer)
+ with pytest.raises(tvm.error.OpAttributeInvalid):
+ verify_keras_frontend(keras_model)
def test_forward_dense(self, keras_mod):
"""test_forward_dense"""