junrushao commented on code in PR #230:
URL: https://github.com/apache/tvm-ffi/pull/230#discussion_r2529520652


##########
python/tvm_ffi/module.py:
##########
@@ -170,6 +180,93 @@ def get_function(self, name: str, query_imports: bool = 
False) -> core.Function:
             raise AttributeError(f"Module has no function '{name}'")
         return func
 
+    def get_function_metadata(
+        self, name: str, query_imports: bool = False
+    ) -> dict[str, Any] | None:
+        """Get metadata for a function exported from the module.
+
+        This retrieves metadata for functions exported via 
TVM_FFI_DLL_EXPORT_TYPED_FUNC
+        and when TVM_FFI_DLL_EXPORT_TYPED_FUNC_METADATA is on, which includes 
type schema
+        and const-ness information.
+
+        Parameters
+        ----------
+        name
+            The name of the function
+
+        query_imports
+            Whether to also query modules imported by this module.
+
+        Returns
+        -------
+        metadata
+            A dictionary containing function metadata. The ``type_schema`` 
field
+            encodes the callable signature.
+
+        Examples
+        --------
+        .. code-block:: python
+
+            import tvm_ffi
+            from tvm_ffi.core import TypeSchema
+            import json
+
+            mod = tvm_ffi.load_module("add_one_cpu.so")
+            metadata = mod.get_function_metadata("add_one_cpu")
+            schema = TypeSchema.from_json_str(metadata["type_schema"])
+            print(schema)  # Shows function signature
+
+        See Also
+        --------
+        :py:func:`tvm_ffi.get_global_func_metadata`
+            Get metadata for global registry functions.
+
+        """
+        metadata_str = _ffi_api.ModuleGetFunctionMetadata(self, name, 
query_imports)
+        if metadata_str is None:
+            return None
+        return json.loads(metadata_str)
+
+    def get_function_doc(self, name: str, query_imports: bool = False) -> str 
| None:
+        """Get documentation string for a function exported from the module.
+
+        This retrieves documentation for functions exported via 
TVM_FFI_DLL_EXPORT_TYPED_FUNC_DOC.
+        If the function was exported with TVM_FFI_DLL_EXPORT_TYPED_FUNC, this 
function will

Review Comment:
   ditto



##########
tests/cpp/test_function.cc:
##########
@@ -16,9 +16,13 @@
  * specific language governing permissions and limitations
  * under the License.
  */
+#define TVM_FFI_DLL_EXPORT_TYPED_FUNC_METADATA

Review Comment:
   I'm thinking if we want to make this 
`TVM_FFI_DLL_EXPORT_TYPED_FUNC_METADATA` a compile-definition in CMake? CC: 
@tqchen



##########
docs/get_started/quickstart.rst:
##########
@@ -76,7 +76,8 @@ Suppose we implement a C++ function ``AddOne`` that performs 
elementwise ``y = x
 
 
 The macro :c:macro:`TVM_FFI_DLL_EXPORT_TYPED_FUNC` exports the C++ function 
``AddOne``
-as a TVM FFI compatible symbol with the name ``__tvm_ffi_add_one_cpu/cuda`` in 
the resulting library.
+as a TVM FFI compatible symbol ``__tvm_ffi_add_one_cpu/cuda``. If 
``TVM_FFI_DLL_EXPORT_TYPED_FUNC_METADATA`` is on,

Review Comment:
   IIRC this way will get us hyperlink to the C macro doc
   
   ```suggestion
   as a TVM FFI compatible symbol ``__tvm_ffi_add_one_cpu/cuda``. If 
c:macro:`TVM_FFI_DLL_EXPORT_TYPED_FUNC_METADATA` is on,
   ```



##########
tests/python/test_build.py:
##########
@@ -31,11 +34,293 @@ def test_build_cpp() -> None:
 
     mod: Module = tvm_ffi.load_module(output_lib_path)
 
+    metadata = mod.get_function_metadata("add_one_cpu")
+    assert metadata is not None, "add_one_cpu should have metadata"
+    assert "type_schema" in metadata, f"{'add_one_cpu'}: {metadata}"
+    schema = TypeSchema.from_json_str(metadata["type_schema"])
+    assert str(schema) == "Callable[[Tensor, Tensor], None]", 
f"{'add_one_cpu'}: {schema}"
+    assert "arg_const" in metadata
+    arg_const = metadata["arg_const"]
+    assert len(arg_const) == 2, "Should have 2 arguments"
+    assert arg_const[0] is False and arg_const[1] is False, f"{'add_one_cpu'}: 
{arg_const}"
+    doc = mod.get_function_doc("add_one_cpu")
+    assert doc is None
+
     x = numpy.array([1, 2, 3, 4, 5], dtype=numpy.float32)
     y = numpy.empty_like(x)
     mod.add_one_cpu(x, y)
     numpy.testing.assert_equal(x + 1, y)
 
 
+def test_build_inline_with_metadata() -> None:  # noqa: PLR0915
+    """Test functions with various input and output types."""
+    # Keep module alive until all returned objects are destroyed
+    output_lib_path = tvm_ffi.cpp.build_inline(
+        name="test_io_types",
+        cpp_sources=r"""
+            // int input -> int output
+            int square(int x) {
+                return x * x;
+            }
+
+            // float input -> float output
+            float reciprocal(float x) {
+                return 1.0f / x;
+            }
+
+            // bool input -> bool output
+            bool negate(bool x) {
+                return !x;
+            }
+
+            // String input -> String output
+            tvm::ffi::String uppercase_first(tvm::ffi::String s) {
+                std::string result(s.c_str());
+                if (!result.empty()) {
+                    result[0] = std::toupper(result[0]);
+                }
+                return tvm::ffi::String(result);
+            }
+
+            // Multiple inputs: int, float -> float
+            float weighted_sum(int count, float weight) {
+                return static_cast<float>(count) * weight;
+            }
+
+            // Multiple inputs: String, int -> String
+            tvm::ffi::String repeat_string(tvm::ffi::String s, int times) {
+                std::string result;
+                for (int i = 0; i < times; ++i) {
+                    result += s.c_str();
+                }
+                return tvm::ffi::String(result);
+            }
+
+            // Mixed types: bool, int, float, String -> String
+            tvm::ffi::String format_data(bool flag, int count, float value, 
tvm::ffi::String label) {
+                std::ostringstream oss;
+                oss << label.c_str() << ": flag=" << (flag ? "true" : "false")
+                    << ", count=" << count << ", value=" << value;
+                return tvm::ffi::String(oss.str());
+            }
+
+            // Tensor input/output
+            void double_tensor(tvm::ffi::TensorView input, 
tvm::ffi::TensorView output) {
+                TVM_FFI_ICHECK(input.ndim() == 1);
+                TVM_FFI_ICHECK(output.ndim() == 1);
+                TVM_FFI_ICHECK(input.size(0) == output.size(0));
+                DLDataType f32_dtype{kDLFloat, 32, 1};
+                TVM_FFI_ICHECK(input.dtype() == f32_dtype);
+                TVM_FFI_ICHECK(output.dtype() == f32_dtype);
+
+                for (int i = 0; i < input.size(0); ++i) {
+                    static_cast<float*>(output.data_ptr())[i] =
+                        static_cast<const float*>(input.data_ptr())[i] * 2.0f;
+                }
+            }
+        """,
+        functions=[
+            "square",
+            "reciprocal",
+            "negate",
+            "uppercase_first",
+            "weighted_sum",
+            "repeat_string",
+            "format_data",
+            "double_tensor",
+        ],
+        extra_cflags=["-DTVM_FFI_DLL_EXPORT_TYPED_FUNC_METADATA"],
+    )
+
+    mod: Module = tvm_ffi.load_module(output_lib_path)
+
+    # Test square: int -> int
+    assert mod.square(5) == 25
+    metadata = mod.get_function_metadata("square")
+    assert metadata is not None
+    schema = TypeSchema.from_json_str(metadata["type_schema"])
+    assert str(schema) == "Callable[[int], int]"
+
+    # Test reciprocal: float -> float
+    result = mod.reciprocal(2.0)
+    assert abs(result - 0.5) < 0.001
+    metadata = mod.get_function_metadata("reciprocal")
+    assert metadata is not None
+    schema = TypeSchema.from_json_str(metadata["type_schema"])
+    assert str(schema) == "Callable[[float], float]"
+
+    # Test negate: bool -> bool
+    assert mod.negate(True) is False
+    assert mod.negate(False) is True
+    metadata = mod.get_function_metadata("negate")
+    assert metadata is not None
+    schema = TypeSchema.from_json_str(metadata["type_schema"])
+    assert str(schema) == "Callable[[bool], bool]"
+
+    # Test uppercase_first: String -> String
+    result = mod.uppercase_first("hello")
+    assert result == "Hello"
+    metadata = mod.get_function_metadata("uppercase_first")
+    assert metadata is not None
+    schema = TypeSchema.from_json_str(metadata["type_schema"])
+    assert str(schema) == "Callable[[str], str]"
+
+    # Test weighted_sum: int, float -> float
+    result = mod.weighted_sum(10, 2.5)
+    assert abs(result - 25.0) < 0.001
+    metadata = mod.get_function_metadata("weighted_sum")
+    assert metadata is not None
+    schema = TypeSchema.from_json_str(metadata["type_schema"])
+    assert str(schema) == "Callable[[int, float], float]"
+
+    # Test repeat_string: String, int -> String
+    result = mod.repeat_string("ab", 3)
+    assert result == "ababab"
+    metadata = mod.get_function_metadata("repeat_string")
+    assert metadata is not None
+    schema = TypeSchema.from_json_str(metadata["type_schema"])
+    assert str(schema) == "Callable[[str, int], str]"
+
+    # Test format_data: bool, int, float, String -> String
+    result = mod.format_data(True, 42, 3.14, "test")
+    assert "test:" in result
+    assert "flag=true" in result
+    assert "count=42" in result
+    assert "value=3.14" in result
+    metadata = mod.get_function_metadata("format_data")
+    assert metadata is not None
+    schema = TypeSchema.from_json_str(metadata["type_schema"])
+    assert str(schema) == "Callable[[bool, int, float, str], str]"
+
+    # Test double_tensor: Tensor, Tensor -> None
+    x = numpy.array([1.0, 2.0, 3.0], dtype=numpy.float32)
+    y = numpy.empty_like(x)
+    mod.double_tensor(x, y)
+    numpy.testing.assert_allclose(y, x * 2.0)
+    metadata = mod.get_function_metadata("double_tensor")
+    assert metadata is not None
+    schema = TypeSchema.from_json_str(metadata["type_schema"])
+    assert str(schema) == "Callable[[Tensor, Tensor], None]"
+
+    # Explicitly cleanup all objects before module unload to avoid 
use-after-free
+    del metadata, schema, result, x, y, mod
+    gc.collect()

Review Comment:
   I'm thinking if `gc` is not actually useful as TVM-FFI will cache the loaded 
DSOs IIRC



##########
python/tvm_ffi/module.py:
##########
@@ -170,6 +180,93 @@ def get_function(self, name: str, query_imports: bool = 
False) -> core.Function:
             raise AttributeError(f"Module has no function '{name}'")
         return func
 
+    def get_function_metadata(
+        self, name: str, query_imports: bool = False
+    ) -> dict[str, Any] | None:
+        """Get metadata for a function exported from the module.
+
+        This retrieves metadata for functions exported via 
TVM_FFI_DLL_EXPORT_TYPED_FUNC
+        and when TVM_FFI_DLL_EXPORT_TYPED_FUNC_METADATA is on, which includes 
type schema

Review Comment:
   nit
   
   ```suggestion
           This retrieves metadata for functions exported via 
c:macro:`TVM_FFI_DLL_EXPORT_TYPED_FUNC`
           and when c:macro:`TVM_FFI_DLL_EXPORT_TYPED_FUNC_METADATA` is on, 
which includes type schema
   ```



##########
include/tvm/ffi/function_details.h:
##########
@@ -253,6 +253,41 @@ struct TypeSchemaImpl {
   }
 };
 
+/*!
+ * \brief Helper to detect const-ness of a type parameter.
+ * Used for memory effect annotation in function metadata.
+ */
+template <typename T>
+struct IsConstParam {

Review Comment:
   Agreed that const-ness may be not reflective of the underlying needed.
   
   However, I'd love to also bring up that idiomatic function signatures for ML 
kernels may look like:
   
   ```
   SomeKernel(
     x1: IN, # const: input
     x2: IN, # const: input
     y1: IN-OUT, # mutable inputs
     y2: OUT, # output
   )
   ```
   
   where, having const-ness indicators would indeed help production needs. 
@christopherbate also mentioned this feature as desirable in our discussion.
   
   Another example is `torchgen` package installed along side torch, which 
allows us to parse function schemas provided in 
[`native_functions.yaml`](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml)
 and potentially give const-ness and aliasing info:
   
   ```python
   import torch
   import torchgen
   from torchgen.model import FunctionSchema
   
   op = torch.ops.aten.add.Tensor
   
   FunctionSchema.parse(str(op._schema))
   ```
   
   My recommendation is that we don't have to include this feature in this PR, 
but follow-up discussion may be worthy.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to