junrushao commented on code in PR #230:
URL: https://github.com/apache/tvm-ffi/pull/230#discussion_r2535421726


##########
tests/python/test_build.py:
##########
@@ -31,11 +34,293 @@ def test_build_cpp() -> None:
 
     mod: Module = tvm_ffi.load_module(output_lib_path)
 
+    metadata = mod.get_function_metadata("add_one_cpu")
+    assert metadata is not None, "add_one_cpu should have metadata"
+    assert "type_schema" in metadata, f"{'add_one_cpu'}: {metadata}"
+    schema = TypeSchema.from_json_str(metadata["type_schema"])
+    assert str(schema) == "Callable[[Tensor, Tensor], None]", 
f"{'add_one_cpu'}: {schema}"
+    assert "arg_const" in metadata
+    arg_const = metadata["arg_const"]
+    assert len(arg_const) == 2, "Should have 2 arguments"
+    assert arg_const[0] is False and arg_const[1] is False, f"{'add_one_cpu'}: 
{arg_const}"
+    doc = mod.get_function_doc("add_one_cpu")
+    assert doc is None
+
     x = numpy.array([1, 2, 3, 4, 5], dtype=numpy.float32)
     y = numpy.empty_like(x)
     mod.add_one_cpu(x, y)
     numpy.testing.assert_equal(x + 1, y)
 
 
+def test_build_inline_with_metadata() -> None:  # noqa: PLR0915
+    """Test functions with various input and output types."""
+    # Keep module alive until all returned objects are destroyed
+    output_lib_path = tvm_ffi.cpp.build_inline(
+        name="test_io_types",
+        cpp_sources=r"""
+            // int input -> int output
+            int square(int x) {
+                return x * x;
+            }
+
+            // float input -> float output
+            float reciprocal(float x) {
+                return 1.0f / x;
+            }
+
+            // bool input -> bool output
+            bool negate(bool x) {
+                return !x;
+            }
+
+            // String input -> String output
+            tvm::ffi::String uppercase_first(tvm::ffi::String s) {
+                std::string result(s.c_str());
+                if (!result.empty()) {
+                    result[0] = std::toupper(result[0]);
+                }
+                return tvm::ffi::String(result);
+            }
+
+            // Multiple inputs: int, float -> float
+            float weighted_sum(int count, float weight) {
+                return static_cast<float>(count) * weight;
+            }
+
+            // Multiple inputs: String, int -> String
+            tvm::ffi::String repeat_string(tvm::ffi::String s, int times) {
+                std::string result;
+                for (int i = 0; i < times; ++i) {
+                    result += s.c_str();
+                }
+                return tvm::ffi::String(result);
+            }
+
+            // Mixed types: bool, int, float, String -> String
+            tvm::ffi::String format_data(bool flag, int count, float value, 
tvm::ffi::String label) {
+                std::ostringstream oss;
+                oss << label.c_str() << ": flag=" << (flag ? "true" : "false")
+                    << ", count=" << count << ", value=" << value;
+                return tvm::ffi::String(oss.str());
+            }
+
+            // Tensor input/output
+            void double_tensor(tvm::ffi::TensorView input, 
tvm::ffi::TensorView output) {
+                TVM_FFI_ICHECK(input.ndim() == 1);
+                TVM_FFI_ICHECK(output.ndim() == 1);
+                TVM_FFI_ICHECK(input.size(0) == output.size(0));
+                DLDataType f32_dtype{kDLFloat, 32, 1};
+                TVM_FFI_ICHECK(input.dtype() == f32_dtype);
+                TVM_FFI_ICHECK(output.dtype() == f32_dtype);
+
+                for (int i = 0; i < input.size(0); ++i) {
+                    static_cast<float*>(output.data_ptr())[i] =
+                        static_cast<const float*>(input.data_ptr())[i] * 2.0f;
+                }
+            }
+        """,
+        functions=[
+            "square",
+            "reciprocal",
+            "negate",
+            "uppercase_first",
+            "weighted_sum",
+            "repeat_string",
+            "format_data",
+            "double_tensor",
+        ],
+        extra_cflags=["-DTVM_FFI_DLL_EXPORT_TYPED_FUNC_METADATA"],
+    )
+
+    mod: Module = tvm_ffi.load_module(output_lib_path)
+
+    # Test square: int -> int
+    assert mod.square(5) == 25
+    metadata = mod.get_function_metadata("square")
+    assert metadata is not None
+    schema = TypeSchema.from_json_str(metadata["type_schema"])
+    assert str(schema) == "Callable[[int], int]"
+
+    # Test reciprocal: float -> float
+    result = mod.reciprocal(2.0)
+    assert abs(result - 0.5) < 0.001
+    metadata = mod.get_function_metadata("reciprocal")
+    assert metadata is not None
+    schema = TypeSchema.from_json_str(metadata["type_schema"])
+    assert str(schema) == "Callable[[float], float]"
+
+    # Test negate: bool -> bool
+    assert mod.negate(True) is False
+    assert mod.negate(False) is True
+    metadata = mod.get_function_metadata("negate")
+    assert metadata is not None
+    schema = TypeSchema.from_json_str(metadata["type_schema"])
+    assert str(schema) == "Callable[[bool], bool]"
+
+    # Test uppercase_first: String -> String
+    result = mod.uppercase_first("hello")
+    assert result == "Hello"
+    metadata = mod.get_function_metadata("uppercase_first")
+    assert metadata is not None
+    schema = TypeSchema.from_json_str(metadata["type_schema"])
+    assert str(schema) == "Callable[[str], str]"
+
+    # Test weighted_sum: int, float -> float
+    result = mod.weighted_sum(10, 2.5)
+    assert abs(result - 25.0) < 0.001
+    metadata = mod.get_function_metadata("weighted_sum")
+    assert metadata is not None
+    schema = TypeSchema.from_json_str(metadata["type_schema"])
+    assert str(schema) == "Callable[[int, float], float]"
+
+    # Test repeat_string: String, int -> String
+    result = mod.repeat_string("ab", 3)
+    assert result == "ababab"
+    metadata = mod.get_function_metadata("repeat_string")
+    assert metadata is not None
+    schema = TypeSchema.from_json_str(metadata["type_schema"])
+    assert str(schema) == "Callable[[str, int], str]"
+
+    # Test format_data: bool, int, float, String -> String
+    result = mod.format_data(True, 42, 3.14, "test")
+    assert "test:" in result
+    assert "flag=true" in result
+    assert "count=42" in result
+    assert "value=3.14" in result
+    metadata = mod.get_function_metadata("format_data")
+    assert metadata is not None
+    schema = TypeSchema.from_json_str(metadata["type_schema"])
+    assert str(schema) == "Callable[[bool, int, float, str], str]"
+
+    # Test double_tensor: Tensor, Tensor -> None
+    x = numpy.array([1.0, 2.0, 3.0], dtype=numpy.float32)
+    y = numpy.empty_like(x)
+    mod.double_tensor(x, y)
+    numpy.testing.assert_allclose(y, x * 2.0)
+    metadata = mod.get_function_metadata("double_tensor")
+    assert metadata is not None
+    schema = TypeSchema.from_json_str(metadata["type_schema"])
+    assert str(schema) == "Callable[[Tensor, Tensor], None]"
+
+    # Explicitly cleanup all objects before module unload to avoid 
use-after-free
+    del metadata, schema, result, x, y, mod
+    gc.collect()

Review Comment:
   you are absolutely right kathy!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to