This is an automated email from the ASF dual-hosted git repository.
mehrdadh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ca8153d502 [AOT]Raise error when input name is not valid (#14322)
ca8153d502 is described below
commit ca8153d502b93e653f7a961309110455b90a0e22
Author: Mehrdad Hessar <[email protected]>
AuthorDate: Fri Mar 17 10:29:06 2023 -0700
[AOT]Raise error when input name is not valid (#14322)
This PR fixes #13013.
---
src/runtime/aot_executor/aot_executor.cc | 2 +-
tests/python/relay/aot/test_cpp_aot.py | 38 ++++++++++++++++++++++++++++++++
2 files changed, 39 insertions(+), 1 deletion(-)
diff --git a/src/runtime/aot_executor/aot_executor.cc
b/src/runtime/aot_executor/aot_executor.cc
index 39d5570030..1fed42bf04 100644
--- a/src/runtime/aot_executor/aot_executor.cc
+++ b/src/runtime/aot_executor/aot_executor.cc
@@ -191,7 +191,7 @@ int AotExecutor::GetInputIndex(const std::string& name) {
return i;
}
}
- return -1;
+ ICHECK(false) << "Invalid input name.";
}
std::string AotExecutor::GetInputName(int index) {
diff --git a/tests/python/relay/aot/test_cpp_aot.py
b/tests/python/relay/aot/test_cpp_aot.py
index 3c7a3bc0ca..c1b4fd817a 100644
--- a/tests/python/relay/aot/test_cpp_aot.py
+++ b/tests/python/relay/aot/test_cpp_aot.py
@@ -248,5 +248,43 @@ def
test_aot_input_name_with_special_character(target_kind: str, input_name: str
assert (runner.get_output(0).asnumpy() == expected_output).all()
[email protected]("target_kind", ["c", "llvm"])
+def test_aot_incorrect_input_name(target_kind: str):
+ """Test passing incorrect input name."""
+ dtype = "float32"
+ correct_input_name = "input"
+ incorrect_input_name = "input1"
+ input1 = relay.var(correct_input_name, shape=(10, 5), dtype=dtype)
+ weight = relay.var("weight", shape=(1, 5), dtype=dtype)
+ output = relay.add(input1, weight)
+ func = relay.Function([input1, weight], output)
+
+ input_data = np.random.rand(10, 5).astype(dtype)
+ weight_data = np.random.rand(1, 5).astype(dtype)
+ params = {"weight": weight_data}
+
+ with tvm.transform.PassContext(opt_level=3,
config={"tir.disable_vectorize": True}):
+ mod = tvm.relay.build(
+ tvm.IRModule.from_expr(func),
+ target=target_kind,
+ params=params,
+ executor=tvm.relay.backend.Executor("aot", {"interface-api":
"packed"}),
+ )
+ temp_dir = tvm.contrib.utils.TempDirectory()
+ test_so_path = temp_dir / "test.so"
+ mod.export_library(test_so_path, cc="c++", options=["-std=gnu++17", "-g3",
"-O0"])
+
+ loaded_mod = tvm.runtime.load_module(test_so_path)
+ runner = tvm.runtime.executor.AotModule(loaded_mod["default"](tvm.cpu(0)))
+ inputs = {incorrect_input_name: input_data}
+
+ error_regex = r"Invalid input name."
+ with pytest.raises(tvm.TVMError, match=error_regex):
+ runner.set_input(**inputs)
+
+ with pytest.raises(tvm.TVMError, match=error_regex):
+ runner.get_input_index(incorrect_input_name)
+
+
if __name__ == "__main__":
tvm.testing.main()