This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e516eaaea1 [testing][py_converter] Enhance py_converter to better
support entire modules (#13769)
e516eaaea1 is described below
commit e516eaaea182f21c202ba6a8d60afe8ec969d1a9
Author: Steven S. Lyubomirsky <[email protected]>
AuthorDate: Tue Feb 14 22:06:08 2023 -0500
[testing][py_converter] Enhance py_converter to better support entire
modules (#13769)
This PR makes a few improvements to py_converter to make it more useful for
fuzz testing, especially when running larger modules.
In particular, these changes are to support returning the definition of a
global var directly (e.g., if you do run_as_python(main_var, mod=mod), the
result will be a function corresponding to mod["main"]) and to correct two bugs
in the previous implementation:
Previously, it was not possible to insert a function into a runtime
container object like an ADT. This was because the converter was simply
compiling Relay functions into Python functions. This change solves this
problem by registering the functions into PackedFuncs. However, another fix was
also needed: Even though PackedFunc is an ObjectRef in C++, the Python bindings
do not recognize PackedFuncs as Objects, so the code now calls the FFI API
tuple constructor directly.
The implementation relied on IRModule.from_expr to wrap passed in
expressions in a module. However, from_expr will not overwrite the main
function if one is passed in via the functions argument. Thus, if the user
passed in a module that already had a main function defined, the wrapping would
be done incorrectly and result in the main being copied many times. This PR
corrects this error by not assuming that the name main will be available and
instead constructing a new module with a re [...]
None of the cases above had been tested before (there are now tests
included)
---
python/tvm/relay/testing/py_converter.py | 83 +++++++++++++++++++++++++-------
src/runtime/container.cc | 1 +
tests/python/relay/test_py_converter.py | 65 ++++++++++++++++++++++++-
3 files changed, 129 insertions(+), 20 deletions(-)
diff --git a/python/tvm/relay/testing/py_converter.py
b/python/tvm/relay/testing/py_converter.py
index 1ec85faea6..44489aa9cf 100644
--- a/python/tvm/relay/testing/py_converter.py
+++ b/python/tvm/relay/testing/py_converter.py
@@ -38,7 +38,7 @@ OUTPUT_VAR_NAME = "_py_out"
# import tvm
# from tvm import relay
# from tvm import nd
-# from tvm.runtime import import container as _container
+# from tvm.runtime import container as _container
# from tvm.relay.backend.interpreter import RefValue, ConstructorValue
PROLOGUE = [
ast.Import([alias("numpy", None)]),
@@ -60,7 +60,7 @@ class PythonConverter(ExprFunctor):
def __init__(self, mod, target) -> None:
super().__init__()
self.mod = mod
- self.tgt = target
+ self.tgt = target if isinstance(target, tvm.target.Target) else
tvm.target.Target(target)
self.tec = te_compiler.get()
self.fun_no = 0
self.var_no = 0
@@ -98,15 +98,31 @@ class PythonConverter(ExprFunctor):
# unwrap tuple wrappers (some op calls produce them)
unwrapped = prog.astuple() if isinstance(prog, relay.TupleWrapper)
else prog
assert relay.analysis.well_formed(unwrapped)
- mod = self.mod.from_expr(unwrapped, self.mod.functions,
self.mod.type_definitions)
+ # For a lone global var, there is nothing we need to do
+ if isinstance(unwrapped, relay.GlobalVar):
+ return unwrapped
+
+ # main might be in the mod already and from_expr will not override it
if it's there,
+ # so we need a new name
+ target_name = self.generate_function_name("target")
+
+ wrapped = unwrapped
+ if not isinstance(unwrapped, relay.Function):
+ wrapped = relay.Function(relay.analysis.free_vars(unwrapped),
unwrapped)
+
+ # easiest way to make a deep copy -- note that main will not be
overridden if it's present
+ copy_mod = tvm.IRModule.from_expr(
+ relay.Tuple([]), self.mod.functions, self.mod.type_definitions
+ )
+ copy_mod[target_name] = wrapped
# necessary pass: SimplifyInference (otherwise we can't generate code
for some operators)
# and fusion (to get primitive functions)
opts = tvm.transform.Sequential(
[relay.transform.SimplifyInference(),
relay.transform.FuseOps(fuse_opt_level=0)]
)
- mod = opts(mod)
- optimized = mod["main"]
+ copy_mod = opts(copy_mod)
+ optimized = copy_mod[target_name]
return optimized if isinstance(unwrapped, Function) else optimized.body
def sanitize(self, name: str) -> str:
@@ -197,7 +213,7 @@ class PythonConverter(ExprFunctor):
var_names = [self.get_var_name(var) for var in func.params]
body, defs = self.visit(func.body)
- ret = self.create_def(func_name, var_names, defs + [Return(body)])
+ ret = self.create_def(func_name, var_names, defs + [Return(body)],
register_packed=True)
return (ret, func_name)
def convert_module(self):
@@ -219,10 +235,25 @@ class PythonConverter(ExprFunctor):
"""Creates a simple function call."""
return ast.Call(self.parse_name(func_name), arguments, [])
- def create_def(self, func_name: str, arguments: [str], body):
- """Wrapper over function definition AST node, whose constructor is
inconvenient."""
+ def create_def(self, func_name: str, arguments: [str], body,
register_packed: bool = False):
+ """
+ Wrapper over function definition AST node, whose constructor is
inconvenient.
+
+ register_packed includes a tvm.register_func decorator on the
generated function if true.
+ This option should be used for Relay functions (warning: clobbers
registry!)
+ """
inner_args = [ast.arg(argument, None) for argument in arguments]
+ # add a decorator to register as a PackedFunc so the function will be
an ObjectRef
+ # and will allow for putting functions into tuples or refs
+ decorator_list = [
+ ast.Call(
+ self.parse_name("tvm.register_func"),
+ [ast.Constant(value=func_name)],
+ [ast.keyword(arg="override", value=ast.Constant(value=True))],
+ )
+ ]
+
global __MAJOR__, __MINOR__
if __MAJOR__ == 3 and __MINOR__ >= 8:
arguments = ast.arguments([], inner_args, None, [], [], None, [])
@@ -233,10 +264,19 @@ class PythonConverter(ExprFunctor):
func_name,
arguments,
body,
- [],
+ decorator_list if register_packed else [],
None,
)
+ def create_tuple(self, fields):
+ """
+ Given the ASTs for tuple fields, produce an AST that creates a
+ tuple value with those fields
+ """
+ # Use the FFI API directly so that PackedFuncs will be correctly
converted to ObjectRef.
+ # Using tvm.runtime.container.tuple_object fails to convert
PackedFuncs in Python
+ return self.create_call("_container._ffi_api.Tuple", fields)
+
def create_op_call(self, op: Function, relay_args, py_args):
"""Lowers the passed primitive function, registers it in TVM's
global compiler, and produces a call to the lowered function in
@@ -290,8 +330,7 @@ class PythonConverter(ExprFunctor):
assignments += inner_assignments
extra_args += inner_args
fields.append(inner_output)
- fields = [ast.List(fields, Load())]
- return (assignments, extra_args,
self.create_call("_container.tuple_object", fields))
+ return (assignments, extra_args, self.create_tuple(fields))
# create a function to wrap the call of the lowered op and return
# a call to that function
@@ -418,7 +457,9 @@ class PythonConverter(ExprFunctor):
def visit_global_var(self, gvar: Expr):
# we don't need to add numbers to global var names because
# the *names* are checked for uniqueness in the mod
- return (Name(str(gvar.name_hint), Load()), [])
+ func_name = str(gvar.name_hint)
+ # load in the packed func
+ return (self.create_call("tvm.get_global_func",
[ast.Constant(value=func_name)]), [])
def visit_let(self, letexp: Expr):
# To properly account for scoping and ensure that the entire node
produces an expression,
@@ -456,8 +497,7 @@ class PythonConverter(ExprFunctor):
def visit_tuple(self, tup: Expr):
fields, ret_defs = self.convert_fields(tup.fields)
- fields = [ast.List(fields, Load())]
- return (self.create_call("_container.tuple_object", fields), ret_defs)
+ return (self.create_tuple(fields), ret_defs)
def visit_tuple_getitem(self, tgi: Expr):
tup, tup_defs = self.visit(tgi.tuple_value)
@@ -471,7 +511,7 @@ class PythonConverter(ExprFunctor):
# need to get the value out of a NDArray to check the condition
# equvialent to: val.numpy()
- cond_check = ast.Call(ast.Attribute(cond_body, "asnumpy", Load()), [],
[])
+ cond_check = ast.Call(ast.Attribute(cond_body, "numpy", Load()), [],
[])
ret = ast.IfExp(cond_check, true_body, false_body)
return (ret, cond_defs + true_defs + false_defs)
@@ -490,7 +530,11 @@ class PythonConverter(ExprFunctor):
def visit_function(self, func: Expr):
# Python's lambdas are very restrictive, so we do "name" inline
functions
converted_func, func_name = self.convert_func_node(func)
- return (Name(func_name, Load()), [converted_func])
+ # load in the PackedFunc
+ return (
+ self.create_call("tvm.get_global_func",
[ast.Constant(value=func_name)]),
+ [converted_func],
+ )
def visit_call(self, call: Expr):
"""For calls, we must distinguish between ordinary functions,
@@ -546,7 +590,7 @@ class PythonConverter(ExprFunctor):
+ val_defs
+ [
Assign([ast.Attribute(ref, "value", Store())], val),
- Return(self.create_call("_container.tuple_object", [])),
+ Return(self.create_tuple([])),
],
)
return (self.create_call(thunk_name, []), [thunk])
@@ -602,7 +646,10 @@ def to_python(expr: Expr, mod=None,
target=tvm.target.Target("llvm")):
def run_as_python(expr: Expr, mod=None, target=tvm.target.Target("llvm")):
"""Converts the given Relay expression into a Python script and
- executes it."""
+ executes it.
+
+ Note that closures will be returned as PackedFuncs
+ """
mod = mod if mod is not None else tvm.IRModule()
py_ast = to_python(expr, mod, target)
code = compile(py_ast, "<string>", "exec")
diff --git a/src/runtime/container.cc b/src/runtime/container.cc
index adcaecbc64..7b5105a3fc 100644
--- a/src/runtime/container.cc
+++ b/src/runtime/container.cc
@@ -202,5 +202,6 @@
TVM_REGISTER_GLOBAL("runtime.GetShapeTupleElem").set_body_typed([](ShapeTuple sh
ICHECK_LT(idx, shape.size());
return shape[idx];
});
+
} // namespace runtime
} // namespace tvm
diff --git a/tests/python/relay/test_py_converter.py
b/tests/python/relay/test_py_converter.py
index bd5635e8cf..d43ec5861b 100644
--- a/tests/python/relay/test_py_converter.py
+++ b/tests/python/relay/test_py_converter.py
@@ -18,7 +18,7 @@ import numpy as np
import tvm
from tvm import te
from tvm import relay
-from tvm.relay.testing import to_python, run_as_python
+from tvm.relay.testing import run_as_python
from tvm.relay.prelude import Prelude
from tvm.runtime.container import ADT
from tvm.relay.backend.interpreter import RefValue, ConstructorValue
@@ -70,7 +70,6 @@ def test_create_empty_tuple():
def test_create_scalar():
scalar = relay.const(1)
tensor_val = run_as_python(scalar)
- print(type(tensor_val))
assert_tensor_value(tensor_val, 1)
@@ -611,3 +610,65 @@ def test_batch_norm():
verify_batch_norm([(20, 10), (10,), (10,), (10,), (10,)])
verify_batch_norm([(10, 50), (50,), (50,), (50,), (50,)])
verify_batch_norm([(30, 40), (40,), (40,), (40,), (40,)])
+
+
+def test_return_global_var():
+ tt = relay.TensorType([1], "float32")
+ x = relay.Var("x", type_annotation=tt)
+ identity = relay.Function([x], x, ret_type=tt)
+ mod = tvm.IRModule()
+ mod["main"] = identity
+ main_var = mod.get_global_var("main")
+ main_func = run_as_python(main_var, mod=mod)
+
+ arg = tvm.nd.array(np.array([0.0], dtype="float32"))
+ res = main_func(arg)
+ assert arg.numpy() == res.numpy()
+
+
+def test_closure_in_tuple():
+ tt = relay.TensorType([1], "float32")
+ x = relay.Var("x", type_annotation=tt)
+ identity = relay.Function([x], x, ret_type=tt)
+ tup = relay.Tuple([identity, identity])
+ index = relay.TupleGetItem(tup, 0)
+
+ func = run_as_python(index)
+ arg = tvm.nd.array(np.array([0.0], dtype="float32"))
+ res = func(arg)
+ assert arg.numpy() == res.numpy()
+
+
+def test_closure_in_ref():
+ tt = relay.TensorType([1], "float32")
+ x = relay.Var("x", type_annotation=tt)
+ identity = relay.Function([x], x, ret_type=tt)
+ gv = relay.GlobalVar("id")
+
+ r = relay.Var("r")
+ seq = relay.Let(
+ r,
+ relay.RefCreate(gv),
+ relay.Call(relay.RefRead(r), [relay.const(np.array([0.0],
dtype="float32"))]),
+ )
+
+ mod = tvm.IRModule()
+ mod[gv] = identity
+ res = run_as_python(seq, mod=mod)
+ assert res.numpy() == np.array([0.0], dtype="float32")
+
+
+def test_compiling_with_main():
+ unit_type = relay.TupleType([])
+ unit = relay.Function([], relay.Tuple([]), ret_type=unit_type)
+
+ x = relay.Var("x", type_annotation=unit_type)
+ identity = relay.Function([x], x, ret_type=unit_type)
+
+ mod = tvm.IRModule()
+ mod["unit"] = unit
+ mod["main"] = identity
+
+ res =
run_as_python(mod.get_global_var("main")(mod.get_global_var("unit")()), mod=mod)
+ assert isinstance(res, ADT)
+ assert len(res) == 0