junrushao commented on code in PR #230:
URL: https://github.com/apache/tvm-ffi/pull/230#discussion_r2529520652
##########
python/tvm_ffi/module.py:
##########
@@ -170,6 +180,93 @@ def get_function(self, name: str, query_imports: bool =
False) -> core.Function:
raise AttributeError(f"Module has no function '{name}'")
return func
+ def get_function_metadata(
+ self, name: str, query_imports: bool = False
+ ) -> dict[str, Any] | None:
+ """Get metadata for a function exported from the module.
+
+ This retrieves metadata for functions exported via
TVM_FFI_DLL_EXPORT_TYPED_FUNC
+ and when TVM_FFI_DLL_EXPORT_TYPED_FUNC_METADATA is on, which includes
type schema
+ and const-ness information.
+
+ Parameters
+ ----------
+ name
+ The name of the function
+
+ query_imports
+ Whether to also query modules imported by this module.
+
+ Returns
+ -------
+ metadata
+ A dictionary containing function metadata. The ``type_schema``
field
+ encodes the callable signature.
+
+ Examples
+ --------
+ .. code-block:: python
+
+ import tvm_ffi
+ from tvm_ffi.core import TypeSchema
+ import json
+
+ mod = tvm_ffi.load_module("add_one_cpu.so")
+ metadata = mod.get_function_metadata("add_one_cpu")
+ schema = TypeSchema.from_json_str(metadata["type_schema"])
+ print(schema) # Shows function signature
+
+ See Also
+ --------
+ :py:func:`tvm_ffi.get_global_func_metadata`
+ Get metadata for global registry functions.
+
+ """
+ metadata_str = _ffi_api.ModuleGetFunctionMetadata(self, name,
query_imports)
+ if metadata_str is None:
+ return None
+ return json.loads(metadata_str)
+
+ def get_function_doc(self, name: str, query_imports: bool = False) -> str
| None:
+ """Get documentation string for a function exported from the module.
+
+ This retrieves documentation for functions exported via
TVM_FFI_DLL_EXPORT_TYPED_FUNC_DOC.
+ If the function was exported with TVM_FFI_DLL_EXPORT_TYPED_FUNC, this
function will
Review Comment:
ditto
##########
tests/cpp/test_function.cc:
##########
@@ -16,9 +16,13 @@
* specific language governing permissions and limitations
* under the License.
*/
+#define TVM_FFI_DLL_EXPORT_TYPED_FUNC_METADATA
Review Comment:
I'm thinking if we want to make this
`TVM_FFI_DLL_EXPORT_TYPED_FUNC_METADATA` a compile-definition in CMake? CC:
@tqchen
##########
docs/get_started/quickstart.rst:
##########
@@ -76,7 +76,8 @@ Suppose we implement a C++ function ``AddOne`` that performs
elementwise ``y = x
The macro :c:macro:`TVM_FFI_DLL_EXPORT_TYPED_FUNC` exports the C++ function
``AddOne``
-as a TVM FFI compatible symbol with the name ``__tvm_ffi_add_one_cpu/cuda`` in
the resulting library.
+as a TVM FFI compatible symbol ``__tvm_ffi_add_one_cpu/cuda``. If
``TVM_FFI_DLL_EXPORT_TYPED_FUNC_METADATA`` is on,
Review Comment:
IIRC this way will get us hyperlink to the C macro doc
```suggestion
as a TVM FFI compatible symbol ``__tvm_ffi_add_one_cpu/cuda``. If
c:macro:`TVM_FFI_DLL_EXPORT_TYPED_FUNC_METADATA` is on,
```
##########
tests/python/test_build.py:
##########
@@ -31,11 +34,293 @@ def test_build_cpp() -> None:
mod: Module = tvm_ffi.load_module(output_lib_path)
+ metadata = mod.get_function_metadata("add_one_cpu")
+ assert metadata is not None, "add_one_cpu should have metadata"
+ assert "type_schema" in metadata, f"{'add_one_cpu'}: {metadata}"
+ schema = TypeSchema.from_json_str(metadata["type_schema"])
+ assert str(schema) == "Callable[[Tensor, Tensor], None]",
f"{'add_one_cpu'}: {schema}"
+ assert "arg_const" in metadata
+ arg_const = metadata["arg_const"]
+ assert len(arg_const) == 2, "Should have 2 arguments"
+ assert arg_const[0] is False and arg_const[1] is False, f"{'add_one_cpu'}:
{arg_const}"
+ doc = mod.get_function_doc("add_one_cpu")
+ assert doc is None
+
x = numpy.array([1, 2, 3, 4, 5], dtype=numpy.float32)
y = numpy.empty_like(x)
mod.add_one_cpu(x, y)
numpy.testing.assert_equal(x + 1, y)
+def test_build_inline_with_metadata() -> None: # noqa: PLR0915
+ """Test functions with various input and output types."""
+ # Keep module alive until all returned objects are destroyed
+ output_lib_path = tvm_ffi.cpp.build_inline(
+ name="test_io_types",
+ cpp_sources=r"""
+ // int input -> int output
+ int square(int x) {
+ return x * x;
+ }
+
+ // float input -> float output
+ float reciprocal(float x) {
+ return 1.0f / x;
+ }
+
+ // bool input -> bool output
+ bool negate(bool x) {
+ return !x;
+ }
+
+ // String input -> String output
+ tvm::ffi::String uppercase_first(tvm::ffi::String s) {
+ std::string result(s.c_str());
+ if (!result.empty()) {
+ result[0] = std::toupper(result[0]);
+ }
+ return tvm::ffi::String(result);
+ }
+
+ // Multiple inputs: int, float -> float
+ float weighted_sum(int count, float weight) {
+ return static_cast<float>(count) * weight;
+ }
+
+ // Multiple inputs: String, int -> String
+ tvm::ffi::String repeat_string(tvm::ffi::String s, int times) {
+ std::string result;
+ for (int i = 0; i < times; ++i) {
+ result += s.c_str();
+ }
+ return tvm::ffi::String(result);
+ }
+
+ // Mixed types: bool, int, float, String -> String
+ tvm::ffi::String format_data(bool flag, int count, float value,
tvm::ffi::String label) {
+ std::ostringstream oss;
+ oss << label.c_str() << ": flag=" << (flag ? "true" : "false")
+ << ", count=" << count << ", value=" << value;
+ return tvm::ffi::String(oss.str());
+ }
+
+ // Tensor input/output
+ void double_tensor(tvm::ffi::TensorView input,
tvm::ffi::TensorView output) {
+ TVM_FFI_ICHECK(input.ndim() == 1);
+ TVM_FFI_ICHECK(output.ndim() == 1);
+ TVM_FFI_ICHECK(input.size(0) == output.size(0));
+ DLDataType f32_dtype{kDLFloat, 32, 1};
+ TVM_FFI_ICHECK(input.dtype() == f32_dtype);
+ TVM_FFI_ICHECK(output.dtype() == f32_dtype);
+
+ for (int i = 0; i < input.size(0); ++i) {
+ static_cast<float*>(output.data_ptr())[i] =
+ static_cast<const float*>(input.data_ptr())[i] * 2.0f;
+ }
+ }
+ """,
+ functions=[
+ "square",
+ "reciprocal",
+ "negate",
+ "uppercase_first",
+ "weighted_sum",
+ "repeat_string",
+ "format_data",
+ "double_tensor",
+ ],
+ extra_cflags=["-DTVM_FFI_DLL_EXPORT_TYPED_FUNC_METADATA"],
+ )
+
+ mod: Module = tvm_ffi.load_module(output_lib_path)
+
+ # Test square: int -> int
+ assert mod.square(5) == 25
+ metadata = mod.get_function_metadata("square")
+ assert metadata is not None
+ schema = TypeSchema.from_json_str(metadata["type_schema"])
+ assert str(schema) == "Callable[[int], int]"
+
+ # Test reciprocal: float -> float
+ result = mod.reciprocal(2.0)
+ assert abs(result - 0.5) < 0.001
+ metadata = mod.get_function_metadata("reciprocal")
+ assert metadata is not None
+ schema = TypeSchema.from_json_str(metadata["type_schema"])
+ assert str(schema) == "Callable[[float], float]"
+
+ # Test negate: bool -> bool
+ assert mod.negate(True) is False
+ assert mod.negate(False) is True
+ metadata = mod.get_function_metadata("negate")
+ assert metadata is not None
+ schema = TypeSchema.from_json_str(metadata["type_schema"])
+ assert str(schema) == "Callable[[bool], bool]"
+
+ # Test uppercase_first: String -> String
+ result = mod.uppercase_first("hello")
+ assert result == "Hello"
+ metadata = mod.get_function_metadata("uppercase_first")
+ assert metadata is not None
+ schema = TypeSchema.from_json_str(metadata["type_schema"])
+ assert str(schema) == "Callable[[str], str]"
+
+ # Test weighted_sum: int, float -> float
+ result = mod.weighted_sum(10, 2.5)
+ assert abs(result - 25.0) < 0.001
+ metadata = mod.get_function_metadata("weighted_sum")
+ assert metadata is not None
+ schema = TypeSchema.from_json_str(metadata["type_schema"])
+ assert str(schema) == "Callable[[int, float], float]"
+
+ # Test repeat_string: String, int -> String
+ result = mod.repeat_string("ab", 3)
+ assert result == "ababab"
+ metadata = mod.get_function_metadata("repeat_string")
+ assert metadata is not None
+ schema = TypeSchema.from_json_str(metadata["type_schema"])
+ assert str(schema) == "Callable[[str, int], str]"
+
+ # Test format_data: bool, int, float, String -> String
+ result = mod.format_data(True, 42, 3.14, "test")
+ assert "test:" in result
+ assert "flag=true" in result
+ assert "count=42" in result
+ assert "value=3.14" in result
+ metadata = mod.get_function_metadata("format_data")
+ assert metadata is not None
+ schema = TypeSchema.from_json_str(metadata["type_schema"])
+ assert str(schema) == "Callable[[bool, int, float, str], str]"
+
+ # Test double_tensor: Tensor, Tensor -> None
+ x = numpy.array([1.0, 2.0, 3.0], dtype=numpy.float32)
+ y = numpy.empty_like(x)
+ mod.double_tensor(x, y)
+ numpy.testing.assert_allclose(y, x * 2.0)
+ metadata = mod.get_function_metadata("double_tensor")
+ assert metadata is not None
+ schema = TypeSchema.from_json_str(metadata["type_schema"])
+ assert str(schema) == "Callable[[Tensor, Tensor], None]"
+
+ # Explicitly cleanup all objects before module unload to avoid
use-after-free
+ del metadata, schema, result, x, y, mod
+ gc.collect()
Review Comment:
I'm thinking if `gc` is not actually useful as TVM-FFI will cache the loaded
DSOs IIRC
##########
python/tvm_ffi/module.py:
##########
@@ -170,6 +180,93 @@ def get_function(self, name: str, query_imports: bool =
False) -> core.Function:
raise AttributeError(f"Module has no function '{name}'")
return func
+ def get_function_metadata(
+ self, name: str, query_imports: bool = False
+ ) -> dict[str, Any] | None:
+ """Get metadata for a function exported from the module.
+
+ This retrieves metadata for functions exported via
TVM_FFI_DLL_EXPORT_TYPED_FUNC
+ and when TVM_FFI_DLL_EXPORT_TYPED_FUNC_METADATA is on, which includes
type schema
Review Comment:
nit
```suggestion
This retrieves metadata for functions exported via
c:macro:`TVM_FFI_DLL_EXPORT_TYPED_FUNC`
and when c:macro:`TVM_FFI_DLL_EXPORT_TYPED_FUNC_METADATA` is on,
which includes type schema
```
##########
include/tvm/ffi/function_details.h:
##########
@@ -253,6 +253,41 @@ struct TypeSchemaImpl {
}
};
+/*!
+ * \brief Helper to detect const-ness of a type parameter.
+ * Used for memory effect annotation in function metadata.
+ */
+template <typename T>
+struct IsConstParam {
Review Comment:
Agreed that const-ness may be not reflective of the underlying needed.
However, I'd love to also bring up that idiomatic function signatures for ML
kernels may look like:
```
SomeKernel(
x1: IN, # const: input
x2: IN, # const: input
y1: IN-OUT, # mutable inputs
y2: OUT, # output
)
```
where, having const-ness indicators would indeed help production needs.
@christopherbate also mentioned this feature as desirable in our discussion.
Another example is `torchgen` package installed along side torch, which
allows us to parse function schemas provided in
[`native_functions.yaml`](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml)
and potentially give const-ness and aliasing info:
```python
import torch
import torchgen
from torchgen.model import FunctionSchema
op = torch.ops.aten.add.Tensor
FunctionSchema.parse(str(op._schema))
```
My recommendation is that we don't have to include this feature in this PR,
but follow-up discussion may be worthy.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]