This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 03fecba6c8 [Bugfix][Frontend][Keras] Add a check to reject the invalid 
input shape (#15335)
03fecba6c8 is described below

commit 03fecba6c8b55c215dc09631cf8b55db88041691
Author: Qingchao Shen <[email protected]>
AuthorDate: Fri Jul 21 06:17:44 2023 +0800

    [Bugfix][Frontend][Keras] Add a check to reject the invalid input shape 
(#15335)
    
    * reject invalid input_shape
    
    * Update test_forward.py
    
    * Update keras.py
    
    * Update keras.py
    
    * Update test_forward.py
    
    * Update test_forward.py
    
    * Update test_forward.py
    
    * Update test_forward.py
    
    * Update keras.py
    
    * Update test_forward.py
    
    * Update test_forward.py
    
    * Update keras.py
    
    * Update test_forward.py
---
 python/tvm/relay/frontend/keras.py          | 6 ++++++
 tests/python/frontend/keras/test_forward.py | 8 ++++++++
 2 files changed, 14 insertions(+)

diff --git a/python/tvm/relay/frontend/keras.py 
b/python/tvm/relay/frontend/keras.py
index 63938c9e42..16764cd581 100644
--- a/python/tvm/relay/frontend/keras.py
+++ b/python/tvm/relay/frontend/keras.py
@@ -1436,6 +1436,12 @@ def from_keras(model, shape=None, layout="NCHW"):
     def _convert_input_layer(keras_layer):
         input_name = keras_layer.name
         input_shape = shape[input_name] if shape is not None and input_name in 
shape else None
+        if input_shape and len(input_shape) > 1 and any(dim <= 0 for dim in 
input_shape[1:]):
+            msg = (
+                "Expected input's non-batch dimensions to have positive 
length, "
+                f"but the input has a shape of {input_shape}"
+            )
+            raise ValueError(msg)
         etab.set_expr(input_name, new_var(input_name, shape=input_shape))
 
     def _convert_layer(keras_layer, etab, scope=""):
diff --git a/tests/python/frontend/keras/test_forward.py 
b/tests/python/frontend/keras/test_forward.py
index 44767712d0..cdc253bc5d 100644
--- a/tests/python/frontend/keras/test_forward.py
+++ b/tests/python/frontend/keras/test_forward.py
@@ -32,6 +32,7 @@ import tvm
 from tvm import relay
 from tvm.contrib import graph_executor
 import tvm.testing
+import pytest
 
 if tf.executing_eagerly():
     GPUS = tf.config.experimental.list_physical_devices("GPU")
@@ -295,6 +296,7 @@ class TestKeras:
         verify_keras_frontend(keras_model)
 
     def test_forward_pool(self, keras_mod):
+        """test_forward_pool"""
         data = keras_mod.layers.Input(shape=(32, 32, 1))
         # maxpool
         x = keras_mod.layers.MaxPooling2D((3, 3), strides=(1, 1), 
padding="same")(data)
@@ -304,6 +306,12 @@ class TestKeras:
         y = keras_mod.layers.AveragePooling2D((3, 3), strides=(1, 1), 
padding="same")(data)
         keras_model = keras_mod.models.Model(data, y)
         verify_keras_frontend(keras_model)
+        # reject the invalid input shape
+        data = keras_mod.layers.Input(shape=(0, 3, 6, 4))
+        x = keras_mod.layers.GlobalAveragePooling3D()(data)
+        keras_model = keras_mod.models.Model(data, x)
+        with pytest.raises(ValueError):
+            verify_keras_frontend(keras_model)
 
     def test_forward_conv1d(self, keras_mod):
         """test_forward_conv1d"""

Reply via email to