This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new e1964eceb5 [Unity] Add runtime debugging method to RelaxVM (#16238)
e1964eceb5 is described below

commit e1964eceb59431147af43979393ecb1217be9fb7
Author: Junru Shao <[email protected]>
AuthorDate: Thu Dec 14 14:59:29 2023 -0800

    [Unity] Add runtime debugging method to RelaxVM (#16238)
---
 python/tvm/relax/frontend/nn/modules.py      | 19 -------
 python/tvm/relax/frontend/nn/op.py           | 73 +++++++++++++++++++++++-
 python/tvm/runtime/relax_vm.py               | 12 +++-
 src/runtime/relax_vm/builtin.cc              | 29 ++++++++++
 tests/python/relax/test_frontend_nn_debug.py | 83 ++++++++++++++++++++++++++++
 tests/python/relax/test_frontend_nn_op.py    | 36 ------------
 6 files changed, 191 insertions(+), 61 deletions(-)

diff --git a/python/tvm/relax/frontend/nn/modules.py 
b/python/tvm/relax/frontend/nn/modules.py
index b73b37d4a8..68b719d9ad 100644
--- a/python/tvm/relax/frontend/nn/modules.py
+++ b/python/tvm/relax/frontend/nn/modules.py
@@ -22,8 +22,6 @@ import numpy as np
 
 from tvm import relax as rx
 from tvm import tir
-from tvm._ffi import register_func
-from tvm.runtime import NDArray
 
 from . import op
 from .core import Effect, Module, ModuleList, Parameter, Tensor, 
get_default_dtype
@@ -56,23 +54,6 @@ class IOEffect(Effect):
         self.effect = None
         return [result]
 
-    def print_(self, tensor: Tensor) -> None:
-        """Encloses the side effect of NDArray printing"""
-        self.effect = rx.BlockBuilder.current().emit(
-            rx.call_pure_packed(
-                rx.extern("effect.print"),
-                self.effect,
-                tensor._expr,  # pylint: disable=protected-access
-                sinfo_args=[rx.ObjectStructInfo()],
-            ),
-            name_hint=self.effect.name_hint,
-        )
-
-
-@register_func("effect.print")
-def _print(_, array: NDArray) -> None:
-    print(f"effect.print: shape = {array.shape}, dtype = {array.dtype}, data 
=\n{array}")
-
 
 class ReLU(Module):
     """Module for ReLU activation layer."""
diff --git a/python/tvm/relax/frontend/nn/op.py 
b/python/tvm/relax/frontend/nn/op.py
index d315331862..061465d085 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -17,6 +17,7 @@
 # pylint: 
disable=too-many-lines,invalid-name,protected-access,redefined-outer-name
 # pylint: disable=redefined-builtin
 """nn.Tensor operators."""
+import inspect
 import math
 from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
 
@@ -1491,7 +1492,73 @@ def tensor_expr_op(
     )
 
 
-def print_(array: Tensor):
+def debug_func(
+    name: str,
+    *args: Union[Tensor, _tir.PrimExpr, int, float, str],
+    _line_info: Optional[str] = None,
+):
+    """Call a debug function during runtime. The debug function must be 
registered with the
+    following type signature:
+
+    .. code-block:: python
+
+        @tvm.register_func(name_of_debug_func)
+        def debug_func(lineno: str, arg_0, arg_1, ...) -> None:
+            ...
+
+    Parameters
+    ----------
+    name : str
+        The name of the debug function to call.
+
+    *args : Union[Tensor, _tir.PrimExpr, int, float, str]
+        The arguments to pass to the debug function.
+    """
+    # pylint: disable=import-outside-toplevel
+    from tvm import relax as rx
+
+    from .modules import IOEffect
+
+    # pylint: enable=import-outside-toplevel
+
     if SpecBuilder.current().io_effect is None:
-        raise RuntimeError("Printing is only supported when debug mode is on.")
-    SpecBuilder.current().io_effect.print_(array)
+        raise RuntimeError("Debugging is only supported when debug mode is 
on.")
+    io: IOEffect = SpecBuilder.current().io_effect  # type: ignore
+
+    if _line_info is None:
+        filename, line_number = 
inspect.getframeinfo(inspect.currentframe().f_back)[:2]
+        _line_info = f"{filename}:{line_number}"
+
+    converted_args = []
+    for arg in args:
+        if isinstance(arg, Tensor):
+            converted_args.append(arg._expr)  # pylint: 
disable=protected-access
+        elif isinstance(arg, int):
+            converted_args.append(rx.PrimValue(_tir.IntImm("int64", arg)))
+        elif isinstance(arg, float):
+            converted_args.append(rx.PrimValue(_tir.FloatImm("float32", arg)))
+        elif isinstance(arg, _tir.PrimExpr):
+            converted_args.append(rx.PrimValue(arg))
+        elif isinstance(arg, str):
+            converted_args.append(rx.StringImm(arg))
+        else:
+            raise TypeError(f"Unsupported type {type(arg)}")
+
+    io.effect = BlockBuilder.current().emit(
+        rx.call_pure_packed(
+            "vm.builtin.invoke_debug_func",
+            io.effect,
+            rx.StringImm(name),
+            rx.StringImm(_line_info),
+            *converted_args,
+            sinfo_args=[rx.ObjectStructInfo()],
+        ),
+        name_hint=io.effect.name_hint,
+    )
+
+
+def print_(tensor: Tensor):
+    """Debug printing a Tensor during runtime."""
+    filename, line_number = 
inspect.getframeinfo(inspect.currentframe().f_back)[:2]
+    line_info = f"{filename}:{line_number}"
+    debug_func("vm.builtin.debug_print", tensor, _line_info=line_info)
diff --git a/python/tvm/runtime/relax_vm.py b/python/tvm/runtime/relax_vm.py
index b9d9a38fcf..a925e048b2 100644
--- a/python/tvm/runtime/relax_vm.py
+++ b/python/tvm/runtime/relax_vm.py
@@ -16,14 +16,15 @@
 # under the License.
 # pylint: disable=invalid-name, redefined-builtin, no-else-return, 
consider-using-dict-items
 """The Relax virtual machine."""
-from typing import Callable, List, Optional, Union, Dict, Tuple, Any
 from enum import IntEnum
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
 import numpy as np  # type: ignore
 
 import tvm
 from tvm._ffi import base as _base
-
-from tvm.runtime import Device, PackedFunc, Object
+from tvm._ffi import register_func
+from tvm.runtime import Device, Object, PackedFunc
 from tvm.runtime.profiling import Report
 
 from ..rpc.base import RPC_SESS_MASK
@@ -510,3 +511,8 @@ class VirtualMachine(object):
 
         report_json = self.module["profile"](func_name, *cargs)
         return Report.from_json(report_json)
+
+
+@register_func("vm.builtin.debug_print")
+def _print(lineo: str, array) -> None:
+    print(f"{lineo}: shape = {array.shape}, dtype = {array.dtype}, data 
=\n{array}")
diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc
index d6b086f201..fb24a3699d 100644
--- a/src/runtime/relax_vm/builtin.cc
+++ b/src/runtime/relax_vm/builtin.cc
@@ -464,6 +464,35 @@ bool ReadIfCond(TVMArgValue cond) {
 
 TVM_REGISTER_GLOBAL("vm.builtin.read_if_cond").set_body_typed(ReadIfCond);
 
+//-------------------------------------
+//  Debugging API
+//-------------------------------------
+
+TVM_REGISTER_GLOBAL("vm.builtin.invoke_debug_func")
+    .set_body([](TVMArgs args, TVMRetValue* rv) -> void {
+      ICHECK_GE(args.size(), 3);
+      int num_args = args.size() - 3;
+      ObjectRef io_effect = args[0];
+      ICHECK(!io_effect.defined()) << "ValueError: IOEffect is expected to be 
lowered to None.";
+      String debug_func_name = args[1];
+      const PackedFunc* debug_func = runtime::Registry::Get(debug_func_name);
+      CHECK(debug_func) << "ValueError: " << debug_func_name << " is not 
found. "
+                        << "Use the decorator `@tvm.register_func(\"" << 
debug_func_name
+                        << "\")` to register it.";
+      String line_info = args[2];
+      std::vector<TVMValue> call_args(num_args + 1);
+      std::vector<int> call_type_codes(num_args + 1);
+      {
+        TVMArgsSetter setter(call_args.data(), call_type_codes.data());
+        setter(0, line_info);
+        for (int i = 0; i < num_args; ++i) {
+          setter(i + 1, args[i + 3]);
+        }
+      }
+      debug_func->CallPacked(TVMArgs(call_args.data(), call_type_codes.data(), 
num_args + 1), rv);
+      *rv = io_effect;
+    });
+
 //-------------------------------------
 //  Data structure API
 //-------------------------------------
diff --git a/tests/python/relax/test_frontend_nn_debug.py 
b/tests/python/relax/test_frontend_nn_debug.py
new file mode 100644
index 0000000000..a055631a4d
--- /dev/null
+++ b/tests/python/relax/test_frontend_nn_debug.py
@@ -0,0 +1,83 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import torch
+
+import tvm
+import tvm.testing
+from tvm import tir
+from tvm.relax.frontend import nn
+from tvm.relax.frontend.nn import op, spec
+from tvm.runtime import NDArray
+
+
+def test_debug_print():
+    class Layer(nn.Module):
+        def forward(self, x: nn.Tensor):  # pylint: disable=invalid-name
+            op.print_(x)
+            return x
+
+    model = Layer().jit(
+        spec={
+            "forward": {"x": spec.Tensor([10, 5], dtype="float32")},
+        },
+        debug=True,
+    )
+    x = torch.rand((10, 5), dtype=torch.float32)  # pylint: 
disable=invalid-name
+    y = model["forward"](x)  # pylint: disable=invalid-name
+    assert isinstance(y, torch.Tensor)
+
+
+def test_debug_func():
+    @tvm.register_func("testing.relax.frontend.nn.test_debug_func")
+    def _debug(  # pylint: disable=too-many-arguments
+        lineno: str,
+        tensor: NDArray,
+        const_int: int,
+        const_float: float,
+        const_str: str,
+        var_int: int,
+    ) -> None:
+        assert "test_frontend_nn_debug.py" in lineno
+        assert tensor.shape == (10, 5)
+        assert const_int == 1
+        assert const_float == 2.0
+        assert const_str == "test"
+        assert var_int == 8
+
+    class Layer(nn.Module):
+        def forward(self, x: nn.Tensor, v: tir.Var):  # pylint: 
disable=invalid-name
+            op.debug_func("testing.relax.frontend.nn.test_debug_func", x, 1, 
2.0, "test", v)
+            return x
+
+    model = Layer().jit(
+        spec={
+            "forward": {
+                "x": spec.Tensor([10, 5], dtype="float32"),
+                "v": "int",
+            },
+        },
+        debug=True,
+    )
+    x = torch.rand((10, 5), dtype=torch.float32)  # pylint: 
disable=invalid-name
+    y = model["forward"](x, 8)  # pylint: disable=invalid-name
+    assert isinstance(y, torch.Tensor)
+
+
+if __name__ == "__main__":
+    test_debug_print()
+    test_debug_func()
diff --git a/tests/python/relax/test_frontend_nn_op.py 
b/tests/python/relax/test_frontend_nn_op.py
index e952316376..ddaec7234b 100644
--- a/tests/python/relax/test_frontend_nn_op.py
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -508,41 +508,5 @@ def test_tensor_expr_op():
     tvm.ir.assert_structural_equal(irmodule, Expected)
 
 
-def test_print():
-    class Model(Module):
-        def test(self, x: Tensor):
-            z = op.add(x, x)
-            op.print_(z)
-            return x
-
-    # fmt: off
-    @I.ir_module
-    class Expected:
-        @R.function
-        def _initialize_effect() -> R.Tuple(R.Object):
-            with R.dataflow():
-                _io: R.Object = R.null_value()
-                lv: R.Tuple(R.Object) = (_io,)
-                gv: R.Tuple(R.Object) = lv
-                R.output(gv)
-            return gv
-
-        @R.function
-        def test(x: R.Tensor((10, 10), dtype="float32"), _io: R.Object) -> 
R.Tuple(R.Tensor((10, 10), dtype="float32"), R.Tuple(R.Object)):
-            R.func_attr({"num_input": 2})
-            with R.dataflow():
-                add: R.Tensor((10, 10), dtype="float32") = R.add(x, x)
-                _io1: R.Object = R.call_pure_packed("effect.print", _io, add, 
sinfo_args=(R.Object(),))
-                gv1: R.Tuple(R.Tensor((10, 10), dtype="float32"), 
R.Tuple(R.Object)) = x, (_io1,)
-                R.output(gv1)
-            return gv1
-    # fmt: on
-
-    m = Model()
-    irmodule, _ = m.export_tvm(spec={"test": {"x": spec.Tensor([10, 10], 
"float32")}}, debug=True)
-
-    tvm.ir.assert_structural_equal(irmodule["test"], Expected["test"])
-
-
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to