This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e516eaaea1 [testing][py_converter] Enhance py_converter to better 
support entire modules (#13769)
e516eaaea1 is described below

commit e516eaaea182f21c202ba6a8d60afe8ec969d1a9
Author: Steven S. Lyubomirsky <[email protected]>
AuthorDate: Tue Feb 14 22:06:08 2023 -0500

    [testing][py_converter] Enhance py_converter to better support entire 
modules (#13769)
    
    This PR makes a few improvements to py_converter to make it more useful for 
fuzz testing, especially when running larger modules.
    
    In particular, these changes are to support returning the definition of a 
global var directly (e.g., if you do run_as_python(main_var, mod=mod), the 
result will be a function corresponding to mod["main"]) and to correct two bugs 
in the previous implementation:
    
    Previously, it was not possible to insert a function into a runtime 
container object like an ADT. This was because the converter was simply 
compiling Relay functions into Python functions. This change solves this 
problem by registering the functions into PackedFuncs. However, another fix was 
also needed: Even though PackedFunc is an ObjectRef in C++, the Python bindings 
do not recognize PackedFuncs as Objects, so the code now calls the FFI API 
tuple constructor directly.
    The implementation relied on IRModule.from_expr to wrap passed in 
expressions in a module. However, from_expr will not overwrite the main 
function if one is passed in via the functions argument. Thus, if the user 
passed in a module that already had a main function defined, the wrapping would 
be done incorrectly and result in the main being copied many times. This PR 
corrects this error by not assuming that the name main will be available and 
instead constructing a new module with a re [...]
    None of the cases above had been tested before (there are now tests 
included)
---
 python/tvm/relay/testing/py_converter.py | 83 +++++++++++++++++++++++++-------
 src/runtime/container.cc                 |  1 +
 tests/python/relay/test_py_converter.py  | 65 ++++++++++++++++++++++++-
 3 files changed, 129 insertions(+), 20 deletions(-)

diff --git a/python/tvm/relay/testing/py_converter.py 
b/python/tvm/relay/testing/py_converter.py
index 1ec85faea6..44489aa9cf 100644
--- a/python/tvm/relay/testing/py_converter.py
+++ b/python/tvm/relay/testing/py_converter.py
@@ -38,7 +38,7 @@ OUTPUT_VAR_NAME = "_py_out"
 #     import tvm
 #     from tvm import relay
 #     from tvm import nd
-#     from tvm.runtime import import container as _container
+#     from tvm.runtime import container as _container
 #     from tvm.relay.backend.interpreter import RefValue, ConstructorValue
 PROLOGUE = [
     ast.Import([alias("numpy", None)]),
@@ -60,7 +60,7 @@ class PythonConverter(ExprFunctor):
     def __init__(self, mod, target) -> None:
         super().__init__()
         self.mod = mod
-        self.tgt = target
+        self.tgt = target if isinstance(target, tvm.target.Target) else 
tvm.target.Target(target)
         self.tec = te_compiler.get()
         self.fun_no = 0
         self.var_no = 0
@@ -98,15 +98,31 @@ class PythonConverter(ExprFunctor):
         # unwrap tuple wrappers (some op calls produce them)
         unwrapped = prog.astuple() if isinstance(prog, relay.TupleWrapper) 
else prog
         assert relay.analysis.well_formed(unwrapped)
-        mod = self.mod.from_expr(unwrapped, self.mod.functions, 
self.mod.type_definitions)
+        # For a lone global var, there is nothing we need to do
+        if isinstance(unwrapped, relay.GlobalVar):
+            return unwrapped
+
+        # main might be in the mod already and from_expr will not override it 
if it's there,
+        # so we need a new name
+        target_name = self.generate_function_name("target")
+
+        wrapped = unwrapped
+        if not isinstance(unwrapped, relay.Function):
+            wrapped = relay.Function(relay.analysis.free_vars(unwrapped), 
unwrapped)
+
+        # easiest way to make a deep copy -- note that main will not be 
overridden if it's present
+        copy_mod = tvm.IRModule.from_expr(
+            relay.Tuple([]), self.mod.functions, self.mod.type_definitions
+        )
+        copy_mod[target_name] = wrapped
 
         # necessary pass: SimplifyInference (otherwise we can't generate code 
for some operators)
         # and fusion (to get primitive functions)
         opts = tvm.transform.Sequential(
             [relay.transform.SimplifyInference(), 
relay.transform.FuseOps(fuse_opt_level=0)]
         )
-        mod = opts(mod)
-        optimized = mod["main"]
+        copy_mod = opts(copy_mod)
+        optimized = copy_mod[target_name]
         return optimized if isinstance(unwrapped, Function) else optimized.body
 
     def sanitize(self, name: str) -> str:
@@ -197,7 +213,7 @@ class PythonConverter(ExprFunctor):
 
         var_names = [self.get_var_name(var) for var in func.params]
         body, defs = self.visit(func.body)
-        ret = self.create_def(func_name, var_names, defs + [Return(body)])
+        ret = self.create_def(func_name, var_names, defs + [Return(body)], 
register_packed=True)
         return (ret, func_name)
 
     def convert_module(self):
@@ -219,10 +235,25 @@ class PythonConverter(ExprFunctor):
         """Creates a simple function call."""
         return ast.Call(self.parse_name(func_name), arguments, [])
 
-    def create_def(self, func_name: str, arguments: [str], body):
-        """Wrapper over function definition AST node, whose constructor is 
inconvenient."""
+    def create_def(self, func_name: str, arguments: [str], body, 
register_packed: bool = False):
+        """
+        Wrapper over function definition AST node, whose constructor is 
inconvenient.
+
+        register_packed includes a tvm.register_func decorator on the 
generated function if true.
+        This option should be used for Relay functions (warning: clobbers 
registry!)
+        """
         inner_args = [ast.arg(argument, None) for argument in arguments]
 
+        # add a decorator to register as a PackedFunc so the function will be 
an ObjectRef
+        # and will allow for putting functions into tuples or refs
+        decorator_list = [
+            ast.Call(
+                self.parse_name("tvm.register_func"),
+                [ast.Constant(value=func_name)],
+                [ast.keyword(arg="override", value=ast.Constant(value=True))],
+            )
+        ]
+
         global __MAJOR__, __MINOR__
         if __MAJOR__ == 3 and __MINOR__ >= 8:
             arguments = ast.arguments([], inner_args, None, [], [], None, [])
@@ -233,10 +264,19 @@ class PythonConverter(ExprFunctor):
             func_name,
             arguments,
             body,
-            [],
+            decorator_list if register_packed else [],
             None,
         )
 
+    def create_tuple(self, fields):
+        """
+        Given the ASTs for tuple fields, produce an AST that creates a
+        tuple value with those fields
+        """
+        # Use the FFI API directly so that PackedFuncs will be correctly 
converted to ObjectRef.
+        # Using tvm.runtime.container.tuple_object fails to convert 
PackedFuncs in Python
+        return self.create_call("_container._ffi_api.Tuple", fields)
+
     def create_op_call(self, op: Function, relay_args, py_args):
         """Lowers the passed primitive function, registers it in TVM's
         global compiler, and produces a call to the lowered function in
@@ -290,8 +330,7 @@ class PythonConverter(ExprFunctor):
                 assignments += inner_assignments
                 extra_args += inner_args
                 fields.append(inner_output)
-            fields = [ast.List(fields, Load())]
-            return (assignments, extra_args, 
self.create_call("_container.tuple_object", fields))
+            return (assignments, extra_args, self.create_tuple(fields))
 
         # create a function to wrap the call of the lowered op and return
         # a call to that function
@@ -418,7 +457,9 @@ class PythonConverter(ExprFunctor):
     def visit_global_var(self, gvar: Expr):
         # we don't need to add numbers to global var names because
         # the *names* are checked for uniqueness in the mod
-        return (Name(str(gvar.name_hint), Load()), [])
+        func_name = str(gvar.name_hint)
+        # load in the packed func
+        return (self.create_call("tvm.get_global_func", 
[ast.Constant(value=func_name)]), [])
 
     def visit_let(self, letexp: Expr):
         # To properly account for scoping and ensure that the entire node 
produces an expression,
@@ -456,8 +497,7 @@ class PythonConverter(ExprFunctor):
 
     def visit_tuple(self, tup: Expr):
         fields, ret_defs = self.convert_fields(tup.fields)
-        fields = [ast.List(fields, Load())]
-        return (self.create_call("_container.tuple_object", fields), ret_defs)
+        return (self.create_tuple(fields), ret_defs)
 
     def visit_tuple_getitem(self, tgi: Expr):
         tup, tup_defs = self.visit(tgi.tuple_value)
@@ -471,7 +511,7 @@ class PythonConverter(ExprFunctor):
 
         # need to get the value out of a NDArray to check the condition
         # equvialent to: val.numpy()
-        cond_check = ast.Call(ast.Attribute(cond_body, "asnumpy", Load()), [], 
[])
+        cond_check = ast.Call(ast.Attribute(cond_body, "numpy", Load()), [], 
[])
         ret = ast.IfExp(cond_check, true_body, false_body)
         return (ret, cond_defs + true_defs + false_defs)
 
@@ -490,7 +530,11 @@ class PythonConverter(ExprFunctor):
     def visit_function(self, func: Expr):
         # Python's lambdas are very restrictive, so we do "name" inline 
functions
         converted_func, func_name = self.convert_func_node(func)
-        return (Name(func_name, Load()), [converted_func])
+        # load in the PackedFunc
+        return (
+            self.create_call("tvm.get_global_func", 
[ast.Constant(value=func_name)]),
+            [converted_func],
+        )
 
     def visit_call(self, call: Expr):
         """For calls, we must distinguish between ordinary functions,
@@ -546,7 +590,7 @@ class PythonConverter(ExprFunctor):
             + val_defs
             + [
                 Assign([ast.Attribute(ref, "value", Store())], val),
-                Return(self.create_call("_container.tuple_object", [])),
+                Return(self.create_tuple([])),
             ],
         )
         return (self.create_call(thunk_name, []), [thunk])
@@ -602,7 +646,10 @@ def to_python(expr: Expr, mod=None, 
target=tvm.target.Target("llvm")):
 
 def run_as_python(expr: Expr, mod=None, target=tvm.target.Target("llvm")):
     """Converts the given Relay expression into a Python script and
-    executes it."""
+    executes it.
+
+    Note that closures will be returned as PackedFuncs
+    """
     mod = mod if mod is not None else tvm.IRModule()
     py_ast = to_python(expr, mod, target)
     code = compile(py_ast, "<string>", "exec")
diff --git a/src/runtime/container.cc b/src/runtime/container.cc
index adcaecbc64..7b5105a3fc 100644
--- a/src/runtime/container.cc
+++ b/src/runtime/container.cc
@@ -202,5 +202,6 @@ 
TVM_REGISTER_GLOBAL("runtime.GetShapeTupleElem").set_body_typed([](ShapeTuple sh
   ICHECK_LT(idx, shape.size());
   return shape[idx];
 });
+
 }  // namespace runtime
 }  // namespace tvm
diff --git a/tests/python/relay/test_py_converter.py 
b/tests/python/relay/test_py_converter.py
index bd5635e8cf..d43ec5861b 100644
--- a/tests/python/relay/test_py_converter.py
+++ b/tests/python/relay/test_py_converter.py
@@ -18,7 +18,7 @@ import numpy as np
 import tvm
 from tvm import te
 from tvm import relay
-from tvm.relay.testing import to_python, run_as_python
+from tvm.relay.testing import run_as_python
 from tvm.relay.prelude import Prelude
 from tvm.runtime.container import ADT
 from tvm.relay.backend.interpreter import RefValue, ConstructorValue
@@ -70,7 +70,6 @@ def test_create_empty_tuple():
 def test_create_scalar():
     scalar = relay.const(1)
     tensor_val = run_as_python(scalar)
-    print(type(tensor_val))
     assert_tensor_value(tensor_val, 1)
 
 
@@ -611,3 +610,65 @@ def test_batch_norm():
     verify_batch_norm([(20, 10), (10,), (10,), (10,), (10,)])
     verify_batch_norm([(10, 50), (50,), (50,), (50,), (50,)])
     verify_batch_norm([(30, 40), (40,), (40,), (40,), (40,)])
+
+
+def test_return_global_var():
+    tt = relay.TensorType([1], "float32")
+    x = relay.Var("x", type_annotation=tt)
+    identity = relay.Function([x], x, ret_type=tt)
+    mod = tvm.IRModule()
+    mod["main"] = identity
+    main_var = mod.get_global_var("main")
+    main_func = run_as_python(main_var, mod=mod)
+
+    arg = tvm.nd.array(np.array([0.0], dtype="float32"))
+    res = main_func(arg)
+    assert arg.numpy() == res.numpy()
+
+
+def test_closure_in_tuple():
+    tt = relay.TensorType([1], "float32")
+    x = relay.Var("x", type_annotation=tt)
+    identity = relay.Function([x], x, ret_type=tt)
+    tup = relay.Tuple([identity, identity])
+    index = relay.TupleGetItem(tup, 0)
+
+    func = run_as_python(index)
+    arg = tvm.nd.array(np.array([0.0], dtype="float32"))
+    res = func(arg)
+    assert arg.numpy() == res.numpy()
+
+
+def test_closure_in_ref():
+    tt = relay.TensorType([1], "float32")
+    x = relay.Var("x", type_annotation=tt)
+    identity = relay.Function([x], x, ret_type=tt)
+    gv = relay.GlobalVar("id")
+
+    r = relay.Var("r")
+    seq = relay.Let(
+        r,
+        relay.RefCreate(gv),
+        relay.Call(relay.RefRead(r), [relay.const(np.array([0.0], 
dtype="float32"))]),
+    )
+
+    mod = tvm.IRModule()
+    mod[gv] = identity
+    res = run_as_python(seq, mod=mod)
+    assert res.numpy() == np.array([0.0], dtype="float32")
+
+
+def test_compiling_with_main():
+    unit_type = relay.TupleType([])
+    unit = relay.Function([], relay.Tuple([]), ret_type=unit_type)
+
+    x = relay.Var("x", type_annotation=unit_type)
+    identity = relay.Function([x], x, ret_type=unit_type)
+
+    mod = tvm.IRModule()
+    mod["unit"] = unit
+    mod["main"] = identity
+
+    res = 
run_as_python(mod.get_global_var("main")(mod.get_global_var("unit")()), mod=mod)
+    assert isinstance(res, ADT)
+    assert len(res) == 0

Reply via email to