This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new e1964eceb5 [Unity] Add runtime debugging method to RelaxVM (#16238)
e1964eceb5 is described below
commit e1964eceb59431147af43979393ecb1217be9fb7
Author: Junru Shao <[email protected]>
AuthorDate: Thu Dec 14 14:59:29 2023 -0800
[Unity] Add runtime debugging method to RelaxVM (#16238)
---
python/tvm/relax/frontend/nn/modules.py | 19 -------
python/tvm/relax/frontend/nn/op.py | 73 +++++++++++++++++++++++-
python/tvm/runtime/relax_vm.py | 12 +++-
src/runtime/relax_vm/builtin.cc | 29 ++++++++++
tests/python/relax/test_frontend_nn_debug.py | 83 ++++++++++++++++++++++++++++
tests/python/relax/test_frontend_nn_op.py | 36 ------------
6 files changed, 191 insertions(+), 61 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/modules.py
b/python/tvm/relax/frontend/nn/modules.py
index b73b37d4a8..68b719d9ad 100644
--- a/python/tvm/relax/frontend/nn/modules.py
+++ b/python/tvm/relax/frontend/nn/modules.py
@@ -22,8 +22,6 @@ import numpy as np
from tvm import relax as rx
from tvm import tir
-from tvm._ffi import register_func
-from tvm.runtime import NDArray
from . import op
from .core import Effect, Module, ModuleList, Parameter, Tensor,
get_default_dtype
@@ -56,23 +54,6 @@ class IOEffect(Effect):
self.effect = None
return [result]
- def print_(self, tensor: Tensor) -> None:
- """Encloses the side effect of NDArray printing"""
- self.effect = rx.BlockBuilder.current().emit(
- rx.call_pure_packed(
- rx.extern("effect.print"),
- self.effect,
- tensor._expr, # pylint: disable=protected-access
- sinfo_args=[rx.ObjectStructInfo()],
- ),
- name_hint=self.effect.name_hint,
- )
-
-
-@register_func("effect.print")
-def _print(_, array: NDArray) -> None:
- print(f"effect.print: shape = {array.shape}, dtype = {array.dtype}, data
=\n{array}")
-
class ReLU(Module):
"""Module for ReLU activation layer."""
diff --git a/python/tvm/relax/frontend/nn/op.py
b/python/tvm/relax/frontend/nn/op.py
index d315331862..061465d085 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -17,6 +17,7 @@
# pylint:
disable=too-many-lines,invalid-name,protected-access,redefined-outer-name
# pylint: disable=redefined-builtin
"""nn.Tensor operators."""
+import inspect
import math
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
@@ -1491,7 +1492,73 @@ def tensor_expr_op(
)
-def print_(array: Tensor):
+def debug_func(
+ name: str,
+ *args: Union[Tensor, _tir.PrimExpr, int, float, str],
+ _line_info: Optional[str] = None,
+):
+ """Call a debug function during runtime. The debug function must be
registered with the
+ following type signature:
+
+ .. code-block:: python
+
+ @tvm.register_func(name_of_debug_func)
+ def debug_func(lineno: str, arg_0, arg_1, ...) -> None:
+ ...
+
+ Parameters
+ ----------
+ name : str
+ The name of the debug function to call.
+
+ *args : Union[Tensor, _tir.PrimExpr, int, float, str]
+ The arguments to pass to the debug function.
+ """
+ # pylint: disable=import-outside-toplevel
+ from tvm import relax as rx
+
+ from .modules import IOEffect
+
+ # pylint: enable=import-outside-toplevel
+
if SpecBuilder.current().io_effect is None:
- raise RuntimeError("Printing is only supported when debug mode is on.")
- SpecBuilder.current().io_effect.print_(array)
+ raise RuntimeError("Debugging is only supported when debug mode is
on.")
+ io: IOEffect = SpecBuilder.current().io_effect # type: ignore
+
+ if _line_info is None:
+ filename, line_number =
inspect.getframeinfo(inspect.currentframe().f_back)[:2]
+ _line_info = f"{filename}:{line_number}"
+
+ converted_args = []
+ for arg in args:
+ if isinstance(arg, Tensor):
+ converted_args.append(arg._expr) # pylint:
disable=protected-access
+ elif isinstance(arg, int):
+ converted_args.append(rx.PrimValue(_tir.IntImm("int64", arg)))
+ elif isinstance(arg, float):
+ converted_args.append(rx.PrimValue(_tir.FloatImm("float32", arg)))
+ elif isinstance(arg, _tir.PrimExpr):
+ converted_args.append(rx.PrimValue(arg))
+ elif isinstance(arg, str):
+ converted_args.append(rx.StringImm(arg))
+ else:
+ raise TypeError(f"Unsupported type {type(arg)}")
+
+ io.effect = BlockBuilder.current().emit(
+ rx.call_pure_packed(
+ "vm.builtin.invoke_debug_func",
+ io.effect,
+ rx.StringImm(name),
+ rx.StringImm(_line_info),
+ *converted_args,
+ sinfo_args=[rx.ObjectStructInfo()],
+ ),
+ name_hint=io.effect.name_hint,
+ )
+
+
+def print_(tensor: Tensor):
+ """Debug printing a Tensor during runtime."""
+ filename, line_number =
inspect.getframeinfo(inspect.currentframe().f_back)[:2]
+ line_info = f"{filename}:{line_number}"
+ debug_func("vm.builtin.debug_print", tensor, _line_info=line_info)
diff --git a/python/tvm/runtime/relax_vm.py b/python/tvm/runtime/relax_vm.py
index b9d9a38fcf..a925e048b2 100644
--- a/python/tvm/runtime/relax_vm.py
+++ b/python/tvm/runtime/relax_vm.py
@@ -16,14 +16,15 @@
# under the License.
# pylint: disable=invalid-name, redefined-builtin, no-else-return,
consider-using-dict-items
"""The Relax virtual machine."""
-from typing import Callable, List, Optional, Union, Dict, Tuple, Any
from enum import IntEnum
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
import numpy as np # type: ignore
import tvm
from tvm._ffi import base as _base
-
-from tvm.runtime import Device, PackedFunc, Object
+from tvm._ffi import register_func
+from tvm.runtime import Device, Object, PackedFunc
from tvm.runtime.profiling import Report
from ..rpc.base import RPC_SESS_MASK
@@ -510,3 +511,8 @@ class VirtualMachine(object):
report_json = self.module["profile"](func_name, *cargs)
return Report.from_json(report_json)
+
+
+@register_func("vm.builtin.debug_print")
+def _print(lineo: str, array) -> None:
+ print(f"{lineo}: shape = {array.shape}, dtype = {array.dtype}, data
=\n{array}")
diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc
index d6b086f201..fb24a3699d 100644
--- a/src/runtime/relax_vm/builtin.cc
+++ b/src/runtime/relax_vm/builtin.cc
@@ -464,6 +464,35 @@ bool ReadIfCond(TVMArgValue cond) {
TVM_REGISTER_GLOBAL("vm.builtin.read_if_cond").set_body_typed(ReadIfCond);
+//-------------------------------------
+// Debugging API
+//-------------------------------------
+
+TVM_REGISTER_GLOBAL("vm.builtin.invoke_debug_func")
+ .set_body([](TVMArgs args, TVMRetValue* rv) -> void {
+ ICHECK_GE(args.size(), 3);
+ int num_args = args.size() - 3;
+ ObjectRef io_effect = args[0];
+ ICHECK(!io_effect.defined()) << "ValueError: IOEffect is expected to be
lowered to None.";
+ String debug_func_name = args[1];
+ const PackedFunc* debug_func = runtime::Registry::Get(debug_func_name);
+ CHECK(debug_func) << "ValueError: " << debug_func_name << " is not
found. "
+ << "Use the decorator `@tvm.register_func(\"" <<
debug_func_name
+ << "\")` to register it.";
+ String line_info = args[2];
+ std::vector<TVMValue> call_args(num_args + 1);
+ std::vector<int> call_type_codes(num_args + 1);
+ {
+ TVMArgsSetter setter(call_args.data(), call_type_codes.data());
+ setter(0, line_info);
+ for (int i = 0; i < num_args; ++i) {
+ setter(i + 1, args[i + 3]);
+ }
+ }
+ debug_func->CallPacked(TVMArgs(call_args.data(), call_type_codes.data(),
num_args + 1), rv);
+ *rv = io_effect;
+ });
+
//-------------------------------------
// Data structure API
//-------------------------------------
diff --git a/tests/python/relax/test_frontend_nn_debug.py
b/tests/python/relax/test_frontend_nn_debug.py
new file mode 100644
index 0000000000..a055631a4d
--- /dev/null
+++ b/tests/python/relax/test_frontend_nn_debug.py
@@ -0,0 +1,83 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import torch
+
+import tvm
+import tvm.testing
+from tvm import tir
+from tvm.relax.frontend import nn
+from tvm.relax.frontend.nn import op, spec
+from tvm.runtime import NDArray
+
+
+def test_debug_print():
+ class Layer(nn.Module):
+ def forward(self, x: nn.Tensor): # pylint: disable=invalid-name
+ op.print_(x)
+ return x
+
+ model = Layer().jit(
+ spec={
+ "forward": {"x": spec.Tensor([10, 5], dtype="float32")},
+ },
+ debug=True,
+ )
+ x = torch.rand((10, 5), dtype=torch.float32) # pylint:
disable=invalid-name
+ y = model["forward"](x) # pylint: disable=invalid-name
+ assert isinstance(y, torch.Tensor)
+
+
+def test_debug_func():
+ @tvm.register_func("testing.relax.frontend.nn.test_debug_func")
+ def _debug( # pylint: disable=too-many-arguments
+ lineno: str,
+ tensor: NDArray,
+ const_int: int,
+ const_float: float,
+ const_str: str,
+ var_int: int,
+ ) -> None:
+ assert "test_frontend_nn_debug.py" in lineno
+ assert tensor.shape == (10, 5)
+ assert const_int == 1
+ assert const_float == 2.0
+ assert const_str == "test"
+ assert var_int == 8
+
+ class Layer(nn.Module):
+ def forward(self, x: nn.Tensor, v: tir.Var): # pylint:
disable=invalid-name
+ op.debug_func("testing.relax.frontend.nn.test_debug_func", x, 1,
2.0, "test", v)
+ return x
+
+ model = Layer().jit(
+ spec={
+ "forward": {
+ "x": spec.Tensor([10, 5], dtype="float32"),
+ "v": "int",
+ },
+ },
+ debug=True,
+ )
+ x = torch.rand((10, 5), dtype=torch.float32) # pylint:
disable=invalid-name
+ y = model["forward"](x, 8) # pylint: disable=invalid-name
+ assert isinstance(y, torch.Tensor)
+
+
+if __name__ == "__main__":
+ test_debug_print()
+ test_debug_func()
diff --git a/tests/python/relax/test_frontend_nn_op.py
b/tests/python/relax/test_frontend_nn_op.py
index e952316376..ddaec7234b 100644
--- a/tests/python/relax/test_frontend_nn_op.py
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -508,41 +508,5 @@ def test_tensor_expr_op():
tvm.ir.assert_structural_equal(irmodule, Expected)
-def test_print():
- class Model(Module):
- def test(self, x: Tensor):
- z = op.add(x, x)
- op.print_(z)
- return x
-
- # fmt: off
- @I.ir_module
- class Expected:
- @R.function
- def _initialize_effect() -> R.Tuple(R.Object):
- with R.dataflow():
- _io: R.Object = R.null_value()
- lv: R.Tuple(R.Object) = (_io,)
- gv: R.Tuple(R.Object) = lv
- R.output(gv)
- return gv
-
- @R.function
- def test(x: R.Tensor((10, 10), dtype="float32"), _io: R.Object) ->
R.Tuple(R.Tensor((10, 10), dtype="float32"), R.Tuple(R.Object)):
- R.func_attr({"num_input": 2})
- with R.dataflow():
- add: R.Tensor((10, 10), dtype="float32") = R.add(x, x)
- _io1: R.Object = R.call_pure_packed("effect.print", _io, add,
sinfo_args=(R.Object(),))
- gv1: R.Tuple(R.Tensor((10, 10), dtype="float32"),
R.Tuple(R.Object)) = x, (_io1,)
- R.output(gv1)
- return gv1
- # fmt: on
-
- m = Model()
- irmodule, _ = m.export_tvm(spec={"test": {"x": spec.Tensor([10, 10],
"float32")}}, debug=True)
-
- tvm.ir.assert_structural_equal(irmodule["test"], Expected["test"])
-
-
if __name__ == "__main__":
tvm.testing.main()