This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity-staging
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit da893fe3b215b79484c785714f8083b94da4dcb6
Author: Yuchen Jin <[email protected]>
AuthorDate: Fri Feb 17 11:37:12 2023 -0800

    [Unity] Relay -> Relax translator  (#14026)
    
    This PR implements a Relay to Relax translator, which allows us to import 
Relay workloads to Relax for benchmarking and development purposes (tests and 
examples are added).
---
 apps/relax_examples/e2e_auto_tir.py          | 253 ++++++++++++++++++++++
 apps/relax_examples/mlp.py                   |  57 +++++
 apps/relax_examples/nn_module.py             |  69 ++++++
 apps/relax_examples/resnet.py                |  53 +++++
 python/tvm/relax/testing/__init__.py         |   1 +
 python/tvm/relax/testing/relay_translator.py | 251 ++++++++++++++++++++++
 python/tvm/relax/testing/transform.py        | 125 +++++++++++
 src/relay/backend/utils.cc                   |   7 +
 tests/python/relax/test_relay_translator.py  | 300 +++++++++++++++++++++++++++
 9 files changed, 1116 insertions(+)

diff --git a/apps/relax_examples/e2e_auto_tir.py 
b/apps/relax_examples/e2e_auto_tir.py
new file mode 100644
index 0000000000..92cda16f79
--- /dev/null
+++ b/apps/relax_examples/e2e_auto_tir.py
@@ -0,0 +1,253 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import datetime
+import os
+import csv
+import json
+import argparse
+import logging
+from typing import Dict
+import numpy as np  # type: ignore
+
+import tvm
+from tvm import relay, relax, runtime, transform
+from tvm.ir.module import IRModule
+from tvm import meta_schedule as ms
+from tvm.meta_schedule.testing.relay_workload import get_network
+from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc
+from tvm.relax.testing import relay_translator
+from tvm.target.target import Target
+
+
+def _parse_args():
+    args = argparse.ArgumentParser()
+    args.add_argument(
+        "--workload",
+        type=str,
+        required=True,
+    )
+    args.add_argument(
+        "--input-shape",
+        type=str,
+        required=True,
+    )
+    args.add_argument(
+        "--target",
+        type=str,
+        required=True,
+    )
+    args.add_argument(
+        "--num-trials",
+        type=int,
+        required=True,
+    )
+    args.add_argument(
+        "--rpc-host",
+        type=str,
+        default=None,
+    )
+    args.add_argument(
+        "--rpc-port",
+        type=int,
+        default=None,
+    )
+    args.add_argument(
+        "--rpc-key",
+        type=str,
+        default=None,
+    )
+    args.add_argument(
+        "--work-dir",
+        type=str,
+        required=True,
+    )
+    args.add_argument(
+        "--cache-dir",
+        type=str,
+        default=None,
+    )
+    args.add_argument(
+        "--rpc-timeout-sec",
+        type=int,
+        default=180,
+    )
+    args.add_argument("--num-measurement-repeats", type=int, default=5)
+    args.add_argument("--num-measurements", type=int, default=10)
+    args.add_argument("--results-file", type=str, required=False, default=None)
+    parsed = args.parse_args()
+    parsed.target = tvm.target.Target(parsed.target)
+    parsed.input_shape = json.loads(parsed.input_shape)
+    if parsed.target.attrs.get("mtriple", None) == "aarch64-linux-gnu":
+        parsed.alloc_repeat = 3
+    else:
+        parsed.alloc_repeat = 1
+    if parsed.rpc_host and parsed.rpc_port and parsed.rpc_key:
+        parsed.rpc_config = ms.runner.RPCConfig(
+            tracker_host=parsed.rpc_host,
+            tracker_port=parsed.rpc_port,
+            tracker_key=parsed.rpc_key,
+            session_timeout_sec=parsed.rpc_timeout_sec,
+        )
+        parsed.workers = 
parsed.rpc_config.count_num_servers(allow_missing=False)
+    else:
+        # check all rpc configs are None
+        assert (
+            (parsed.rpc_host is None) and (parsed.rpc_port is None) and 
(parsed.rpc_key is None)
+        ), "Please set all 'rpc_host', 'rpc_port' and 'rpc_key' to use PRC 
server"
+        parsed.rpc_config = None
+        parsed.workers = 1
+    return parsed
+
+
+logging.basicConfig()
+logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG)
+ARGS = _parse_args()
+
+
+def apply_opt_before_tuning(
+    relay_mod: IRModule, params: Dict[str, runtime.NDArray], target: Target
+):
+    with transform.PassContext(opt_level=3):
+        main_func = relay_mod["main"]
+        bind_main_func = relay.build_module.bind_params_by_name(main_func, 
params)
+        relay_mod = IRModule.from_expr(bind_main_func)
+        relay_mod = relay.transform.SimplifyInference()(relay_mod)
+        relay_mod = relay.transform.FoldConstant()(relay_mod)
+        relay_mod = relay.transform.FoldScaleAxis()(relay_mod)
+        relay_mod = relay.transform.CanonicalizeOps()(relay_mod)
+        relay_mod = relay.transform.AlterOpLayout()(relay_mod)
+        relay_mod = relay.transform.FoldConstant()(relay_mod)
+
+        relax_mod = relay_translator.from_relay(relay_mod["main"], 
target=target)
+        relax_mod = relax.transform.AnnotateTIROpPattern()(relax_mod)
+        relax_mod = relax.transform.FuseOps()(relax_mod)
+        relax_mod = relax.transform.FuseTIR()(relax_mod)
+    return relax_mod
+
+
+def f_measurement(
+    rt_mod: runtime.Module, device: runtime.ndarray.Device, input_data: 
Dict[str, runtime.NDArray]
+):
+    vm = relax.vm.VirtualMachine(exec=rt_mod, device=device)
+    vm.save_function("main", "measure_func", **input_data, 
include_return=False)
+    evaluator = vm.time_evaluator(
+        func_name="measure_func",
+        dev=device,
+        repeat=ARGS.num_measurement_repeats,
+        number=ARGS.num_measurements,
+        min_repeat_ms=500,
+    )
+    return evaluator()
+
+
+def get_runner():
+    runner_config = {
+        "evaluator_config": ms.runner.EvaluatorConfig(
+            number=3,
+            repeat=1,
+            min_repeat_ms=100,
+            enable_cpu_cache_flush=False,
+        ),
+        "alloc_repeat": ARGS.alloc_repeat,
+    }
+    if ARGS.rpc_config:
+        runner = ms.runner.RPCRunner(
+            rpc_config=ARGS.rpc_config, max_workers=ARGS.workers, 
**runner_config
+        )
+    else:
+        runner = ms.runner.LocalRunner(**runner_config)
+
+    return runner
+
+
+def main():
+    relay_mod, params, (input_name, input_shape, input_dtype) = get_network(
+        ARGS.workload,
+        ARGS.input_shape,
+        cache_dir=ARGS.cache_dir,
+    )
+    input_info = {input_name: input_shape}
+    input_data = {}
+    for input_name, input_shape in input_info.items():
+        print(f"  input_name: {input_name}")
+        print(f"  input_shape: {input_shape}")
+        print(f"  input_dtype: {input_dtype}")
+
+    # translate the ResNet model from Relay to Relax
+    relax_mod = apply_opt_before_tuning(relay_mod, params, target=ARGS.target)
+    assert isinstance(relax_mod, tvm.IRModule)
+
+    db = ms.relax_integration.tune_relax(
+        mod=relax_mod,
+        target=ARGS.target,
+        params=params,
+        num_trials_per_iter=64,
+        max_trials_per_task=ARGS.num_trials,
+        max_trials_global=ARGS.num_trials,
+        runner=get_runner(),
+        work_dir=ARGS.work_dir,
+    )
+    executable = ms.relax_integration.compile_relax(
+        db,
+        mod=relax_mod,
+        target=ARGS.target,
+        params=params,
+    )
+
+    for input_name, input_shape in input_info.items():
+        if input_dtype.startswith("float"):
+            input_data[input_name] = 
np.random.uniform(size=input_shape).astype(input_dtype)
+        else:
+            input_data[input_name] = np.random.randint(
+                low=0, high=10000, size=input_shape, dtype=input_dtype
+            )
+
+    # for documentation purposes
+    start_time = datetime.datetime.now()
+
+    if ARGS.rpc_config:
+        result = run_module_via_rpc(
+            rpc_config=ARGS.rpc_config,
+            lib=executable.mod,
+            dev_type=ARGS.target.kind.name,
+            args=input_data,
+            continuation=f_measurement,
+        )
+    else:
+        dev = tvm.device(ARGS.target.kind.name)
+        result = f_measurement(executable.mod, dev, input_data)
+
+    print(result)
+
+    if not ARGS.results_file:
+        return
+
+    out_path = os.path.abspath(os.path.expanduser(ARGS.results_file))
+    with open(out_path, "w") as out_file:
+        writer = csv.writer(out_file)
+        # write experiment parameters at the top as a record
+        writer.writerow(["start", str(start_time)])
+        writer.writerow(["workload", ARGS.workload])
+        writer.writerow(["input_shape", ARGS.input_shape])
+        writer.writerow(["target", ARGS.target])
+        writer.writerow(["num_measurement_repeats", 
ARGS.num_measurement_repeats])
+        for res in result.results:
+            writer.writerow([str(res)])
+
+
+if __name__ == "__main__":
+    main()
diff --git a/apps/relax_examples/mlp.py b/apps/relax_examples/mlp.py
new file mode 100644
index 0000000000..02e17dc304
--- /dev/null
+++ b/apps/relax_examples/mlp.py
@@ -0,0 +1,57 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# Example code on creating, compiling, and running an MLP model in relax
+
+
+import tvm
+from tvm import relax, tir, topi
+import numpy as np
+
+
+def build_mlp(data, weight):
+    bb = relax.BlockBuilder()
+
+    with bb.function("mlp", [data, weight]):
+        gv0 = bb.emit_te(tvm.contrib.cblas.matmul, data, weight, transa=False, 
transb=False)
+        gv1 = bb.emit_te(topi.nn.relu, gv0)
+        bb.emit_func_output(gv1)
+
+    mod = bb.get()
+    return mod
+
+
+if __name__ == "__main__":
+    # symbolic dimensions
+    n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
+    # create data and weight variables
+    data = relax.Var("data", relax.TensorStructInfo([n, m], "float32"))
+    weight = relax.Var("weight", relax.TensorStructInfo([m, n], "float32"))
+
+    # construct a mlp model
+    mod = build_mlp(data, weight)
+
+    # build and create vm executor
+    target = tvm.target.Target("llvm", host="llvm")
+    ex = relax.vm.build(mod, target)
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+
+    # run the mlp model on relax vm
+    data = tvm.nd.array(np.random.rand(16, 32).astype(np.float32))
+    weight = tvm.nd.array(np.random.rand(32, 16).astype(np.float32))
+    res = vm["mlp"](data, weight)
+    print(res)
diff --git a/apps/relax_examples/nn_module.py b/apps/relax_examples/nn_module.py
new file mode 100644
index 0000000000..b57cb00685
--- /dev/null
+++ b/apps/relax_examples/nn_module.py
@@ -0,0 +1,69 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# Example code on creating, compiling, and running a neural network with 
pytorch-like API
+
+
+import tvm
+from tvm.relay import Call
+from tvm import relax, tir
+from tvm.relax.testing import nn
+from tvm.script import relax as R
+import numpy as np
+
+
+if __name__ == "__main__":
+    builder = relax.BlockBuilder()
+
+    # a symbolic variable to represent minibatch size
+    n = tir.Var("n", "int64")
+    input_size = 784
+    hidden_sizes = [128, 32]
+    output_size = 10
+
+    # build a three linear-layer neural network for a classification task
+    with builder.function("main"):
+        model = nn.Sequential(
+            nn.Linear(input_size, hidden_sizes[0]),
+            nn.ReLU(),
+            nn.Linear(hidden_sizes[0], hidden_sizes[1]),
+            nn.ReLU(),
+            nn.Linear(hidden_sizes[1], output_size),
+            nn.LogSoftmax(),
+        )
+        data = nn.Placeholder((n, input_size), name="data")
+        output = model(data)
+        params = [data] + model.parameters()
+        builder.emit_func_output(output, params=params)
+
+    # get and print the IRmodule being built
+    mod = builder.get()
+    mod.show()
+
+    # build the IRModule and create relax vm
+    target = tvm.target.Target("llvm", host="llvm")
+    ex = relax.vm.build(mod, target)
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+
+    # init parameters
+    params = nn.init_params(mod)
+
+    # run the model on relax vm
+    # the input data has a minibatch size of 3
+    data = tvm.nd.array(np.random.rand(3, input_size).astype(np.float32))
+    res = vm["main"](data, *params)
+    print(res)
diff --git a/apps/relax_examples/resnet.py b/apps/relax_examples/resnet.py
new file mode 100644
index 0000000000..df0cab02f1
--- /dev/null
+++ b/apps/relax_examples/resnet.py
@@ -0,0 +1,53 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Example ResNet workload by translating the Relay program to Relax"""
+
+import tvm
+import tvm.testing
+from tvm.relay import testing
+from tvm import relax, relay
+from tvm.relax.testing import relay_translator, nn
+from tvm.runtime import vm as vm_rt
+from tvm.script import relax as R
+import numpy as np
+
+if __name__ == "__main__":
+    relay_mod, _ = testing.resnet.get_workload(num_layers=50, batch_size=1, 
dtype="float32")
+
+    # translate the ResNet model from Relay to Relax
+    target = tvm.target.Target("llvm", host="llvm")
+    relax_mod = relay_translator.from_relay(relay_mod["main"], target)
+
+    # print the ResNet IRmodule got translated
+    relax_mod.show()
+
+    # build the IRModule and create relax vm
+    ex = relax.vm.build(relax_mod, target)
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+
+    # init weights and run the model on relax vm
+    shape = (1, 3, 224, 224)
+    data = tvm.nd.array(np.random.rand(*shape).astype(np.float32))
+    params = nn.init_params(relax_mod)
+    res = vm["main"](data, *params)
+
+    # check correctness by comparing with relay result
+    exe = relay.vm.compile(relay_mod, target)
+    relay_vm = vm_rt.VirtualMachine(exe, tvm.cpu())
+    inputs = [data] + params
+    expected_output = relay_vm.run(*inputs)
+    tvm.testing.assert_allclose(res.numpy(), expected_output.numpy(), 
rtol=1e-4, atol=1e-4)
diff --git a/python/tvm/relax/testing/__init__.py 
b/python/tvm/relax/testing/__init__.py
index ab1dd6f515..7344798f70 100644
--- a/python/tvm/relax/testing/__init__.py
+++ b/python/tvm/relax/testing/__init__.py
@@ -18,3 +18,4 @@
 """The Relax testing namespace containing nn and translator."""
 
 from .nn import *
+from .relay_translator import *
diff --git a/python/tvm/relax/testing/relay_translator.py 
b/python/tvm/relax/testing/relay_translator.py
new file mode 100644
index 0000000000..fd5aab89fa
--- /dev/null
+++ b/python/tvm/relax/testing/relay_translator.py
@@ -0,0 +1,251 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-argument, invalid-name, no-else-return, 
too-many-nested-blocks
+"""Relay to Relax translator."""
+
+from typing import Any, Dict, List, Optional
+
+import tvm
+from tvm import relax, relay
+from tvm.ir.module import IRModule
+from tvm.relax.testing import nn
+from tvm.relay.backend.te_compiler import select_implementation
+from tvm.runtime import NDArray
+from tvm.target import Target
+from tvm.meta_schedule.relay_integration import _autotvm_silencer
+
+
+def from_relay(
+    func: relay.Function,
+    target: Target,
+    relay_params: Optional[Dict[str, NDArray]] = None,
+    *,
+    opt_level: int = 3,
+    pass_config: Optional[Dict[str, Any]] = None,
+    disabled_pass: Optional[List[str]] = None,
+    translate_op_with_tir: Optional[Dict[str, tvm.tir.PrimFunc]] = None,
+) -> IRModule:
+    """Convert a Relay function into a Relax program.
+
+    Parameters
+    ----------
+    func : relay.Function
+        Relay function to be converted.
+
+    target: Target
+        The target to compile the model, used for selecting topi functions.
+
+    relay_params: Optional[Dict[str, NDArray]]
+        Parameters to bind.
+
+    opt_level: int
+        The optimization level.
+
+    pass_config: Optional[Dict[str, Any]]
+        Pass configuration.
+
+    disabled_pass: Optional[List[str]]
+        Passes to disable.
+
+    translate_op_with_tir: Optional[Dict[str, tvm.tir.PrimFunc]]
+        Dict that maps op names to user-defined PrimFuncs.
+        Takes relay operator names and forces them to user-defined PrimFuncs 
during translation.
+
+    Returns
+    -------
+    mod : tvm.IRModule
+        The Relax IRModule for compilation
+    """
+    # A map to store the mapping of Relay Expr to its corresponding Relax var
+    var_map = {}
+    # The output of the function
+    output_var = None
+
+    if not isinstance(target, Target):
+        target = Target(target)
+    if disabled_pass is None:
+        disabled_pass = []
+    if pass_config is None:
+        pass_config = {
+            "relay.FuseOps.max_depth": 1,  # Disable relay fusion
+            "relay.backend.use_meta_schedule": True,
+            "relay.backend.use_meta_schedule_dispatch": True,
+        }
+
+    if relay_params:
+        func = relay.build_module.bind_params_by_name(func, relay_params)
+
+    params = []
+    tir_var_map: Dict[tvm.tir.Var, tvm.tir.PrimExpr] = dict()
+
+    def convert_shape(shape: List[tvm.tir.PrimExpr]) -> List[tvm.tir.PrimExpr]:
+        """Convert the relay shape to relax shape by changing Any dim to 
symbolic dim"""
+        ret = []
+        for dim in shape:
+            if isinstance(dim, tvm.tir.IntImm):
+                ret.append(tvm.tir.IntImm("int64", int(dim)))
+            elif isinstance(dim, tvm.tir.Any):
+                ret.append(tvm.tir.Var("d", "int64"))
+            else:
+                ret.append(dim)
+        return ret
+
+    def _copy_undefined_var_in_shape(sinfo: relax.TensorStructInfo):
+        def _visit_expr(e: tvm.tir.PrimExpr):
+            if isinstance(e, tvm.tir.Var) and e not in tir_var_map:
+                new_var = tvm.tir.Var(e.name, e.dtype)
+                tir_var_map[e] = new_var
+
+        assert isinstance(
+            sinfo.shape, relax.ShapeExpr
+        ), "arg with TensorStructInfo in Relay translator must have ShapeExpr 
shape"
+        for shape_value in sinfo.shape.values:
+            tvm.tir.stmt_functor.post_order_visit(shape_value, _visit_expr)
+
+    def visit_func(node):
+        nonlocal output_var
+        if isinstance(node, relay.Var):
+            if isinstance(node.type_annotation, relay.TensorType):
+                var_map[node] = nn.Placeholder(
+                    tuple(convert_shape(node.type_annotation.shape)),
+                    node.type_annotation.dtype,
+                    node.name_hint,
+                )
+                params.append(var_map[node])
+            else:
+                raise TypeError("The type of relay.Var to be translated must 
be of TensorType.")
+        elif isinstance(node, relay.Call):
+            args = node.args
+            new_args = []
+            te_inputs = []
+            for arg in args:
+                if arg in var_map:
+                    arg_expr = var_map[arg]
+                    if isinstance(arg_expr.struct_info, 
relax.TensorStructInfo):
+                        _copy_undefined_var_in_shape(arg_expr.struct_info)
+                        new_args.append(arg_expr)
+                        te_inputs.append(tvm.relax.expr.te_tensor(arg_expr, 
tir_var_map))
+                    elif isinstance(arg_expr.struct_info, 
relax.TupleStructInfo):
+                        n_tensor = len(arg_expr.struct_info.fields)
+                        bound_tuple = bb.lookup_binding(arg_expr)
+                        if isinstance(bound_tuple, relax.Tuple):
+                            assert len(bound_tuple) == n_tensor
+                        for i in range(n_tensor):
+                            if isinstance(bound_tuple, relax.Tuple):
+                                item = bb.emit(bound_tuple[i])
+                            else:
+                                item = bb.emit(relax.TupleGetItem(arg_expr, i))
+
+                            assert isinstance(item.struct_info, 
relax.TensorStructInfo), (
+                                "Relay translator doesn't support Call "
+                                "argument being nested Tensor tuple."
+                            )
+                            _copy_undefined_var_in_shape(item.struct_info)
+                            new_args.append(item)
+                            te_inputs.append(tvm.relax.expr.te_tensor(item, 
tir_var_map))
+                    else:
+                        raise TypeError(
+                            f"CallTIR argument type being 
{type(arg_expr.checked_type)} is not "
+                            "supported."
+                        )
+
+            op_name = node.op.name
+            attrs = node.attrs
+            out_type = node.checked_type
+
+            if translate_op_with_tir and op_name in translate_op_with_tir:
+                tir_gvar = bb.add_func(translate_op_with_tir[op_name], op_name)
+                call = relax.call_tir(
+                    tir_gvar, new_args, relax.TensorStructInfo(out_type.shape, 
out_type.dtype)
+                )
+                var = bb.emit(call)
+            else:
+                with target:
+                    best_impl, outputs = select_implementation(
+                        node.op,
+                        attrs,
+                        te_inputs,
+                        out_type,
+                        target,
+                        use_autotvm=False,
+                    )
+                    compute_func = best_impl.compute
+                    name_hint = op_name.split(".")[-1]
+                    var = bb.emit_te(
+                        compute_func,
+                        attrs,
+                        new_args,
+                        node.checked_type,
+                        primfunc_name_hint=name_hint,
+                    )
+
+            output_var = var
+            var_map[node] = var
+        elif isinstance(node, relay.Constant):
+            # fill the shape and checked_type fields of the Constant
+            new_constant = relax.Constant(node.data)
+            var_map[node] = new_constant
+        elif isinstance(node, relay.Tuple):
+            new_fields = []
+            for field in node.fields:
+                if field in var_map:
+                    new_fields.append(var_map[field])
+                else:
+                    raise RuntimeError("field is not in var_map.")
+            new_tuple = relax.Tuple(new_fields)
+            new_tuple_var = relax.BlockBuilder.current().emit(new_tuple)
+            var_map[node] = new_tuple_var
+            output_var = new_tuple_var
+        elif isinstance(node, relay.TupleGetItem):
+            if node.tuple_value in var_map:
+                new_tuple = var_map[node.tuple_value]
+                new_tuple_get_item_node = relax.TupleGetItem(new_tuple, 
node.index)
+                new_tuple_get_item_var = 
relax.BlockBuilder.current().emit(new_tuple_get_item_node)
+                var_map[node] = new_tuple_get_item_var
+                output_var = new_tuple_get_item_var
+            else:
+                raise RuntimeError("tuple is not in var_map")
+        elif isinstance(node, relay.Function):
+            cur_bb = relax.BlockBuilder.current()
+            gv = cur_bb.emit_output(output_var)
+            df_block = cur_bb._end_block()
+            cur_bb._blocks.append(df_block)
+            cur_bb.emit_func_output(gv, params)
+        elif isinstance(node, tvm.ir.Op):
+            pass
+        else:
+            raise TypeError("{} is not supported yet.".format(str(type(node))))
+
+    # List of subset of relay->relay optimizations
+    # See src/relay/backend/utils.cc::GetPassPrefix() for full list
+    seq = tvm.get_global_func("relay.backend.GetPassPrefixSeq")(True, True)
+
+    # Since optimization passes and OpStrategy are highly context-dependent,
+    # we match the exact same context with `extract_task_from_relay()` env
+    with _autotvm_silencer(), tvm.transform.PassContext(
+        opt_level=opt_level,
+        config=pass_config,
+        disabled_pass=disabled_pass,
+    ):
+        mod = tvm.IRModule.from_expr(func)
+        mod = seq(mod)
+        bb = relax.BlockBuilder()
+        with bb.function("main"):
+            bb._begin_dataflow_block()
+            relay.analysis.post_order_visit(mod["main"], visit_func)
+
+    return bb.get()
diff --git a/python/tvm/relax/testing/transform.py 
b/python/tvm/relax/testing/transform.py
new file mode 100644
index 0000000000..c8ca618d4c
--- /dev/null
+++ b/python/tvm/relax/testing/transform.py
@@ -0,0 +1,125 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-argument, invalid-name, no-else-return, 
abstract-method, arguments-differ
+"""Relax transformation passes for testing"""
+
+from tvm import ir
+from tvm import relax
+from tvm.ir.module import IRModule
+from tvm.ir.transform import PassContext
+from tvm.target import Target
+from tvm.ir import transform
+from tvm.relax import PyExprMutator
+from tvm.relax.expr import Call
+from tvm.relay.backend.te_compiler import select_implementation
+
+
[email protected]_pass(opt_level=0)
+class LowerWithRelayOpStrategyPass(transform.Pass):
+    """Lower Relax Op into TIR by using Relay OpStrategy.
+
+    Since operators like conv2d, add, matmul are relay-, relax- independent,
+    this pass assumes we can always find relay op equivalent for such relax 
ops,
+    and use Relay Op Strategy (legacy) to perform lowering and find the TOPI 
implementation.
+
+    Parameters
+    ----------
+    target : Target
+        target info
+
+    Returns
+    -------
+    pass : transform.Pass
+        lowering pass
+    """
+
+    def __init__(self, target: Target):
+        self.target = target
+
+    def transform_module(self, mod: IRModule, ctx: PassContext) -> IRModule:
+        """Implement lowering mechanism.
+
+        Parameters
+        ----------
+        mod : IRModule
+            Input IRModule with Relax ops
+
+        ctx: PassContext
+            Pass context
+
+        Returns
+        -------
+        out_mod : IRModule
+            Output IRModule with lowered TIR functions
+        """
+        target = self.target
+
+        @relax.expr_functor.mutator
+        class Lowerer(PyExprMutator):
+            """Mutator that performs lowering."""
+
+            def visit_call_(self, call_node: Call):
+                # Ignore function calls
+                # We only target calls for operators
+                if isinstance(call_node.op, (relax.GlobalVar, 
relax.expr.ExternFunc)):
+                    return call_node
+
+                # Current relax op name simply adds "relax." prefix to relay 
op name.
+                # Thus, remove "relax." prefix to deduce relay op name.
+                relay_op_name = call_node.op.name[6:]
+                # Check if equivalent relay op exists. If not, return the 
original call.
+                if relay_op_name in ir.Op.list_op_names():
+                    relay_op = ir.Op.get(relay_op_name)
+
+                    # Todo(relax-team): to be revisited - support dyn shape or 
deprecate.
+                    tir_var_map = dict()
+                    te_inputs = [relax.expr.te_tensor(arg, tir_var_map) for 
arg in call_node.args]
+                    best_impl_tuple = select_implementation(
+                        relay_op,
+                        call_node.attrs,
+                        te_inputs,
+                        call_node.checked_type,
+                        target,
+                        use_autotvm=False,
+                    )
+                    compute_func = best_impl_tuple[0].compute
+                    # Extract the name of the operator without the prefix
+                    # e.g., for relay op "nn.conv2d", name_hint would be conv2d
+                    name_hint = relay_op_name.split(".")[-1]
+
+                    return self.builder_.call_te(
+                        compute_func,
+                        call_node.attrs,
+                        call_node.args,
+                        call_node.attrs,
+                        primfunc_name_hint=name_hint,
+                    )
+                else:
+                    return call_node
+
+            # TOOD(@team): transform() wapper is necessary to include TIR 
functions.
+            # IMO, this is bit unintuitive. Can we improve this?
+            def transform(self):
+                for gv, func in mod.functions.items():
+                    if isinstance(func, relax.Function):
+                        updated_func = self.visit_expr(func)
+                        self.builder_.update_func(gv, updated_func)
+                new_mod = self.builder_.get()
+                new_mod = new_mod.with_attrs(mod.attrs) if mod.attrs else 
new_mod
+                return new_mod
+
+        return Lowerer().transform()
diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc
index 4ff8a59b34..3fb1c89c28 100644
--- a/src/relay/backend/utils.cc
+++ b/src/relay/backend/utils.cc
@@ -443,6 +443,13 @@ 
TVM_REGISTER_GLOBAL("relay.backend.tir_converter.allow_extern")
       return DefaultTIRConverterImpl(args, constants, true);
     });
 
+TVM_REGISTER_GLOBAL("relay.backend.GetPassPrefixSeq")
+    .set_body_typed([](bool is_homogeneous, bool is_vm) {
+      auto pass_seqs = GetPassPrefix(is_homogeneous, is_vm);
+      transform::Sequential seq(pass_seqs);
+      return seq;
+    });
+
 }  // namespace backend
 }  // namespace relay
 }  // namespace tvm
diff --git a/tests/python/relax/test_relay_translator.py 
b/tests/python/relax/test_relay_translator.py
new file mode 100644
index 0000000000..5f7e05b02d
--- /dev/null
+++ b/tests/python/relax/test_relay_translator.py
@@ -0,0 +1,300 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tempfile
+
+import numpy as np
+import pytest
+import tvm
+import tvm.testing
+from tvm import meta_schedule as ms
+from tvm import relax, relay, tir, topi
+from tvm.ir.base import assert_structural_equal
+from tvm.relax.testing import relay_translator
+from tvm.relay import testing
+from tvm.runtime import vm
+from tvm.script import tir as T
+from tvm.target import Target
+
+
+def get_resnet(batch_size, dtype, layout, image_shape):
+    relay_mod, params = testing.resnet.get_workload(
+        num_layers=18,
+        batch_size=batch_size,
+        dtype=dtype,
+        layout=layout,
+        image_shape=image_shape,
+    )
+
+    return relay_mod, params
+
+
+def relay_build_and_run(mod, target, dev, params, data):
+    with tempfile.TemporaryDirectory() as work_dir:
+        db = ms.relay_integration.tune_relay(
+            mod=mod,
+            params=params,
+            target=target,
+            num_trials_per_iter=32,
+            max_trials_per_task=32,
+            max_trials_global=1024,
+            task_scheduler="round-robin",
+            work_dir=work_dir,
+        )
+        ex = ms.relay_integration.compile_relay(
+            db,
+            mod=mod,
+            target=target,
+            params=params,
+        )
+    rt_mod = tvm.contrib.graph_executor.GraphModule(ex["default"](dev))
+    rt_mod.set_input("data", data)
+    rt_mod.run()
+    out = rt_mod.get_output(0).numpy()
+    return ex, rt_mod, out
+
+
+def relax_build_and_run(mod, target, dev, params, data):
+    mod = relax.transform.BindParams("main", params)(mod)
+    with tempfile.TemporaryDirectory() as work_dir:
+        db = ms.relax_integration.tune_relax(
+            mod=mod,
+            target=target,
+            task_scheduler="round-robin",
+            num_trials_per_iter=32,
+            max_trials_per_task=32,
+            max_trials_global=1024,
+            work_dir=work_dir,
+        )
+        ex = ms.relax_integration.compile_relax(
+            db,
+            mod=mod,
+            target=target,
+            params=params,
+        )
+    vm = relax.VirtualMachine(ex, dev)
+    res = vm["main"](data)
+    out = res.numpy()
+    return ex, vm, out
+
+
+def verify_e2e_translation(target_str, layout, batch_size, image_shape):
+    target = Target(target_str)
+    dev = tvm.device(str(target), dev_id=0)
+    relay_mod, params = get_resnet(batch_size, "float32", layout, image_shape)
+    input_shape = (1, *image_shape)
+    data = tvm.nd.array(np.random.rand(*input_shape).astype(np.float32), dev)
+    relax_mod = relay_translator.from_relay(relay_mod["main"], target, params)
+    assert relax_mod["main"].attrs["global_symbol"] == "main"
+
+    _, _, relay_out = relay_build_and_run(relay_mod, target, dev, params, data)
+    _, _, relax_out = relax_build_and_run(relax_mod, target, dev, params, data)
+    tvm.testing.assert_allclose(relay_out, relax_out, atol=1e-5, rtol=1e-5)
+
+
[email protected](reason="take too much time")
[email protected](
+    "layout, batch_size, image_shape", [("NCHW", 1, (3, 224, 224)), ("NHWC", 
1, (224, 224, 3))]
+)
+def test_verify_e2e_translation_cpu(layout, batch_size, image_shape):
+    verify_e2e_translation("llvm --num-cores=16", layout, batch_size, 
image_shape)
+
+
[email protected](reason="take too much time")
[email protected]_gpu
[email protected](
+    "layout, batch_size, image_shape", [("NCHW", 1, (3, 224, 224)), ("NHWC", 
1, (224, 224, 3))]
+)
+def test_verify_e2e_translation_gpu(layout, batch_size, image_shape):
+    verify_e2e_translation("cuda", layout, batch_size, image_shape)
+
+
+def verify_extracted_tasks(target_str, layout, batch_size, image_shape):
+    target = Target(target_str)
+    relay_mod, params = get_resnet(batch_size, "float32", layout, image_shape)
+    relax_mod = relay_translator.from_relay(
+        relay_mod["main"],
+        target,
+        params,
+        pass_config={
+            "relay.backend.use_meta_schedule": True,
+            "relay.FuseOps.max_depth": 1,  # Disable relay fusion
+        },
+    )
+    relay_tasks = ms.relay_integration.extract_tasks(
+        relay_mod,
+        target=target,
+        params=params,
+        pass_config={
+            "relay.backend.use_meta_schedule": True,
+            "relay.FuseOps.max_depth": 1,  # Disable relay fusion
+        },
+    )
+    relax_tasks = ms.relax_integration.extract_tasks(
+        relax_mod,
+        target=target,
+        params=params,
+    )
+    # TODO (yongwww, yuchen): tophub guides relay passes, which causes 
inconsistent tasks
+    # assert len(relay_tasks) == len(relax_tasks)
+    # TODO: Can we compare extracted tasks as well?
+
+
[email protected](
+    "layout, batch_size, image_shape",
+    [
+        ("NCHW", 1, (3, 224, 224)),
+        ("NHWC", 1, (224, 224, 3)),
+    ],
+)
+def test_verify_extracted_tasks_cpu(layout, batch_size, image_shape):
+    verify_extracted_tasks("llvm --num-cores=16", layout, batch_size, 
image_shape)
+
+
[email protected]_gpu
[email protected](
+    "layout, batch_size, image_shape", [("NCHW", 1, (3, 224, 224)), ("NHWC", 
1, (224, 224, 3))]
+)
+def test_verify_extracted_tasks_gpu(layout, batch_size, image_shape):
+    verify_extracted_tasks("cuda", layout, batch_size, image_shape)
+
+
+def translate_and_build_vms(relay_mod, target_str="llvm", 
translate_op_with_tir=None):
+    target = tvm.target.Target(target_str)
+
+    # build the relay IRModule and create relay vm
+    relay_ex = relay.vm.compile(relay_mod, target)
+    relay_vm = vm.VirtualMachine(relay_ex, tvm.cpu())
+
+    # build the relax IRModule and create relax vm
+    relax_mod = relay_translator.from_relay(
+        relay_mod["main"], target, translate_op_with_tir=translate_op_with_tir
+    )
+    relax_ex = relax.vm.build(relax_mod, target)
+    relax_vm = relax.VirtualMachine(relax_ex, tvm.cpu())
+
+    return relay_vm, relax_vm, relax_mod
+
+
+def verify_vm_outputs(
+    input_shape,
+    relay_vm,
+    relax_vm,
+    extra_args=[],
+):
+    input = tvm.nd.array(np.random.rand(*input_shape).astype(np.float32))
+
+    # check correctness by comparing relax and relay result
+    args = [input] + extra_args
+    relax_output = relax_vm["main"](*args)
+    relay_output = relay_vm.run(*args)
+    tvm.testing.assert_allclose(relay_output.numpy(), relax_output.numpy())
+
+
+def test_single_dynamic_dim():
+    wx, wy = 64, 128
+    # create relay module: y = data * weights + bias with dynamic batch 
dimension
+    data = relay.var("data", shape=(relay.Any(), wx))
+    weights = relay.var("weights", shape=(wx, wy))
+    bias = relay.var("bias", shape=(wy,))
+    y = relay.nn.matmul(data, weights)
+    relay_mod = tvm.IRModule.from_expr(relay.Function([data, weights, bias], y 
+ bias))
+
+    relay_vm, relax_vm, _ = translate_and_build_vms(relay_mod)
+    weights = tvm.nd.array(np.random.rand(wx, wy).astype(np.float32))
+    bias = tvm.nd.array(np.random.rand(wy).astype(np.float32))
+    # verify for different batch sizes
+    verify_vm_outputs([10, wx], relay_vm, relax_vm, [weights, bias])
+    verify_vm_outputs([32, wx], relay_vm, relax_vm, [weights, bias])
+
+
+def test_multiple_dynamic_dims():
+    # create relay module: y = a + a, where a has shape = (?, 5, ?)
+    shape = (relay.Any(), 5, relay.Any())
+    a = relay.var("a", shape=shape)
+
+    relay_mod = tvm.IRModule.from_expr(relay.Function([a], a + a))
+    relay_vm, relax_vm, _ = translate_and_build_vms(relay_mod)
+    # verify for different shapes
+    verify_vm_outputs([2, 5, 10], relay_vm, relax_vm)
+    verify_vm_outputs([12, 5, 24], relay_vm, relax_vm)
+
+
+def test_layout_transform():
+    shape = (1, 3, 224, 224)
+    a = relay.var("a", shape=shape)
+    b = relay.layout_transform(a, "NCHW", "NHWC")
+    relay_mod = tvm.IRModule.from_expr(relay.Function([a], b))
+
+    relay_vm, relax_vm, _ = translate_and_build_vms(relay_mod)
+    verify_vm_outputs([1, 3, 224, 224], relay_vm, relax_vm)
+
+
+def test_translate_op_with_tir():
+    @T.prim_func
+    def tir_matmul(
+        A: T.Buffer((512, 512), "float32"),
+        B: T.Buffer((512, 512), "float32"),
+        C: T.Buffer((512, 512), "float32"),
+    ) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "multiply", "tir.noalias": True})
+        # body
+        # with T.block("root")
+        for i0, i1, i2 in T.grid(512, 512, 512):
+            with T.block("C"):
+                i, j, k = T.axis.remap("SSR", [i0, i1, i2])
+                T.reads(C[i, j], A[i, k], B[k, j])
+                T.writes(C[i, j])
+                with T.init():
+                    C[i, j] = T.float32(0)
+                C[i, j] = C[i, j] + A[i, k] * B[k, j]
+
+    shape = (512, 512)
+    a = relay.var("a", shape=shape)
+
+    relay_mod = tvm.IRModule.from_expr(relay.Function([a], a * a))
+    _, _, relax_mod = translate_and_build_vms(
+        relay_mod, translate_op_with_tir={"multiply": tir_matmul}
+    )
+    assert_structural_equal(relax_mod["multiply"], tir_matmul)
+
+
+def test_translate_tuple_arg():
+    x = relay.var("x", shape=(10, 16))
+    y = relay.var("y", shape=(10, 16))
+    relay_mod = tvm.IRModule.from_expr(relay.Function([x, y], 
relay.concatenate((x, y), axis=-1)))
+    relax_mod = relay_translator.from_relay(relay_mod["main"], target="llvm")
+
+    # Construct the expected module
+    bb = relax.BlockBuilder()
+    x_relax = relax.Var("x", relax.TensorStructInfo([10, 16], "float32"))
+    y_relax = relax.Var("y", relax.TensorStructInfo([10, 16], "float32"))
+    with bb.function("main", [x_relax, y_relax]):
+        with bb.dataflow():
+            _ = bb.emit(relax.Tuple((x_relax, y_relax)))
+            lv1 = bb.emit(x_relax)
+            lv2 = bb.emit(y_relax)
+            lv3 = bb.emit_te(topi.x86.concatenate, (lv1, lv2), axis=-1)
+            gv = bb.emit_output(lv3)
+        bb.emit_func_output(gv)
+
+    assert_structural_equal(relax_mod, bb.get())
+
+
+if __name__ == "__main__":
+    pytest.main([__file__])


Reply via email to