This is an automated email from the ASF dual-hosted git repository.

mehrdadh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ca8153d502 [AOT]Raise error when input name is not valid (#14322)
ca8153d502 is described below

commit ca8153d502b93e653f7a961309110455b90a0e22
Author: Mehrdad Hessar <[email protected]>
AuthorDate: Fri Mar 17 10:29:06 2023 -0700

    [AOT]Raise error when input name is not valid (#14322)
    
    This PR fixes #13013.
---
 src/runtime/aot_executor/aot_executor.cc |  2 +-
 tests/python/relay/aot/test_cpp_aot.py   | 38 ++++++++++++++++++++++++++++++++
 2 files changed, 39 insertions(+), 1 deletion(-)

diff --git a/src/runtime/aot_executor/aot_executor.cc 
b/src/runtime/aot_executor/aot_executor.cc
index 39d5570030..1fed42bf04 100644
--- a/src/runtime/aot_executor/aot_executor.cc
+++ b/src/runtime/aot_executor/aot_executor.cc
@@ -191,7 +191,7 @@ int AotExecutor::GetInputIndex(const std::string& name) {
       return i;
     }
   }
-  return -1;
+  ICHECK(false) << "Invalid input name.";
 }
 
 std::string AotExecutor::GetInputName(int index) {
diff --git a/tests/python/relay/aot/test_cpp_aot.py 
b/tests/python/relay/aot/test_cpp_aot.py
index 3c7a3bc0ca..c1b4fd817a 100644
--- a/tests/python/relay/aot/test_cpp_aot.py
+++ b/tests/python/relay/aot/test_cpp_aot.py
@@ -248,5 +248,43 @@ def 
test_aot_input_name_with_special_character(target_kind: str, input_name: str
         assert (runner.get_output(0).asnumpy() == expected_output).all()
 
 
[email protected]("target_kind", ["c", "llvm"])
+def test_aot_incorrect_input_name(target_kind: str):
+    """Test passing incorrect input name."""
+    dtype = "float32"
+    correct_input_name = "input"
+    incorrect_input_name = "input1"
+    input1 = relay.var(correct_input_name, shape=(10, 5), dtype=dtype)
+    weight = relay.var("weight", shape=(1, 5), dtype=dtype)
+    output = relay.add(input1, weight)
+    func = relay.Function([input1, weight], output)
+
+    input_data = np.random.rand(10, 5).astype(dtype)
+    weight_data = np.random.rand(1, 5).astype(dtype)
+    params = {"weight": weight_data}
+
+    with tvm.transform.PassContext(opt_level=3, 
config={"tir.disable_vectorize": True}):
+        mod = tvm.relay.build(
+            tvm.IRModule.from_expr(func),
+            target=target_kind,
+            params=params,
+            executor=tvm.relay.backend.Executor("aot", {"interface-api": 
"packed"}),
+        )
+    temp_dir = tvm.contrib.utils.TempDirectory()
+    test_so_path = temp_dir / "test.so"
+    mod.export_library(test_so_path, cc="c++", options=["-std=gnu++17", "-g3", 
"-O0"])
+
+    loaded_mod = tvm.runtime.load_module(test_so_path)
+    runner = tvm.runtime.executor.AotModule(loaded_mod["default"](tvm.cpu(0)))
+    inputs = {incorrect_input_name: input_data}
+
+    error_regex = r"Invalid input name."
+    with pytest.raises(tvm.TVMError, match=error_regex):
+        runner.set_input(**inputs)
+
+    with pytest.raises(tvm.TVMError, match=error_regex):
+        runner.get_input_index(incorrect_input_name)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to