This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new d7a6285f47 [Unity][BYOC] Add transposed matmul support to Relax 
CUTLASS BYOC (#14128)
d7a6285f47 is described below

commit d7a6285f473dad912dd90183248f05b07a18e7e4
Author: Lite Ye <[email protected]>
AuthorDate: Mon Feb 27 03:33:50 2023 -0500

    [Unity][BYOC] Add transposed matmul support to Relax CUTLASS BYOC (#14128)
    
    Add transposed matmul support for Relax CUTLASS
---
 python/tvm/contrib/cutlass/build.py          | 88 ++++++++++++++++++--------
 python/tvm/contrib/cutlass/gemm_operation.py | 11 ++--
 python/tvm/contrib/cutlass/gen_tensor_op.py  | 93 +++++++++++++++++++++-------
 python/tvm/relax/__init__.py                 |  1 +
 python/tvm/relax/backend/contrib/cutlass.py  | 30 +++++++++
 python/tvm/relax/dpl/pattern.py              | 25 ++++++++
 src/relax/ir/dataflow_matcher.cc             |  2 +
 tests/python/relax/test_codegen_cutlass.py   | 68 ++++++++++++++++++--
 8 files changed, 259 insertions(+), 59 deletions(-)

diff --git a/python/tvm/contrib/cutlass/build.py 
b/python/tvm/contrib/cutlass/build.py
index c6e5adacec..954aef60c2 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -19,6 +19,7 @@
 import logging
 import multiprocessing
 import os
+from typing import Optional
 
 import tvm
 from tvm import relax, relay, runtime
@@ -522,7 +523,19 @@ def tune_cutlass_function(
     )
 
 
-def _extract_relax_function_info(f):
+def _get_call_node(expr: relax.Expr, op_name: str) -> Optional[relax.Call]:
+    node = None
+
+    def fvisit(e):
+        nonlocal node
+        if isinstance(e, relax.Call) and e.op.name == op_name:
+            node = e
+
+    relax.analysis.post_order_visit(expr, fvisit)
+    return node
+
+
+def _extract_relax_function_signature(f):
     signature = {}
 
     for i, arg in enumerate(f.params):
@@ -534,16 +547,26 @@ def _extract_relax_function_info(f):
     signature["ret_shape"] = list(ret_sinfo.shape)
     signature["ret_dtype"] = ret_sinfo.dtype
 
-    op_attrs = {}
+    return signature
 
-    def fvisit(e):
-        nonlocal op_attrs
-        if isinstance(e, relax.Call) and e.op.name in ["relax.nn.conv2d"]:
-            op_attrs = e.attrs
 
-    relax.analysis.post_order_visit(f.body, fvisit)
+def _extract_arg_idx(pattern_name, f):
+    pattern_entry = relax.backend.get_pattern(pattern_name)
+    if pattern_entry is None:
+        raise ValueError(f"Unsupported op_type {pattern_name}")
+    var2val = relax.analysis.get_var2val(f)
+    matched_expr = pattern_entry.pattern.extract_matched_expr(f.body.body, 
var2val)
 
-    return signature, op_attrs
+    func_args = list(f.params)
+
+    arg_idx = {}
+    for arg_name, arg_pattern in pattern_entry.arg_patterns.items():
+        arg_expr = matched_expr[arg_pattern]
+        if arg_expr not in func_args:
+            raise ValueError(f"Cannot find arg {arg_name} in the fused 
function parameters")
+        arg_idx[arg_name] = func_args.index(arg_expr)
+
+    return arg_idx
 
 
 @relax.expr_functor.mutator
@@ -566,7 +589,8 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
 
     def handle_conv2d(self, f, op_type):
         """Tune and annotate a conv2d op."""
-        signature, op_attrs = _extract_relax_function_info(f)
+        signature = _extract_relax_function_signature(f)
+        op_attrs = _get_call_node(f.body, "relax.nn.conv2d").attrs
 
         d_shape = signature["arg0_shape"]
         w_shape = signature["arg1_shape"]
@@ -622,18 +646,29 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
 
     def handle_matmul(self, f, op_type):
         """Tune and annotate a dense op."""
-        signature, _ = _extract_relax_function_info(f)
+        signature = _extract_relax_function_signature(f)
+        arg_idx = _extract_arg_idx(op_type, f)
+
+        lhs_arg = f"arg{arg_idx['lhs']}"
+        rhs_arg = f"arg{arg_idx['rhs']}"
 
-        arg0_shape = signature["arg0_shape"]
-        arg1_shape = signature["arg1_shape"]
+        lhs_shape = signature[f"{lhs_arg}_shape"]
+        rhs_shape = signature[f"{rhs_arg}_shape"]
         out_shape = signature["ret_shape"]
-        arg0_dtype = signature["arg0_dtype"]
-        arg1_dtype = signature["arg1_dtype"]
+        lhs_dtype = signature[f"{lhs_arg}_dtype"]
+        rhs_dtype = signature[f"{rhs_arg}_dtype"]
         out_dtype = signature["ret_dtype"]
 
-        MM = arg0_shape[0]
-        KK = arg0_shape[1]
-        NN = arg1_shape[1]
+        MM = lhs_shape[0]
+        KK = lhs_shape[1]
+        if "transposed" in op_type:
+            NN = rhs_shape[0]
+            ldb = "K"
+            layout_b = LayoutType.ColumnMajor
+        else:
+            NN = rhs_shape[1]
+            ldb = "N"
+            layout_b = LayoutType.RowMajor
 
         use_3xtf32 = self.options.get("use_3xtf32", False)
         find_first_valid = self.options.get("find_first_valid", True)
@@ -645,26 +680,29 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
             NN,
             KK,
             out_dtype,
-            arg0_dtype,
-            arg1_dtype,
+            lhs_dtype,
+            rhs_dtype,
             use_3xtf32,
             batched=False,
             find_first_valid=find_first_valid,
             use_multiprocessing=use_multiprocessing,
-            layout_b=LayoutType.RowMajor,
+            layout_b=layout_b,
         )
 
         return f.with_attrs(
             {
                 "op_type": op_type,
-                "arg0_dtype": arg0_dtype,
-                "arg1_dtype": arg1_dtype,
+                "lhs_arg_idx": arg_idx["lhs"],
+                "rhs_arg_idx": arg_idx["rhs"],
+                "bias_arg_idx": arg_idx.get("bias"),
+                "arg0_dtype": signature["arg0_dtype"],
+                "arg1_dtype": signature["arg1_dtype"],
                 "ret_dtype": out_dtype,
-                "arg0_shape": arg0_shape,
-                "arg1_shape": arg1_shape,
+                "arg0_shape": signature["arg0_shape"],
+                "arg1_shape": signature["arg1_shape"],
                 "ret_shape": out_shape,
                 "lda": "K",
-                "ldb": "N",
+                "ldb": ldb,
                 "ldc": "N",
                 "cutlass_op_name": op_name,
                 "cutlass_op_def": op_def,
diff --git a/python/tvm/contrib/cutlass/gemm_operation.py 
b/python/tvm/contrib/cutlass/gemm_operation.py
index 58f5de6a9c..3e74cbaec8 100644
--- a/python/tvm/contrib/cutlass/gemm_operation.py
+++ b/python/tvm/contrib/cutlass/gemm_operation.py
@@ -259,7 +259,7 @@ class EmitGemmInstance:
         return substitute_template(gemm_template, values)
 
 
-def instantiate_gemm_template(attrs, func_args):
+def instantiate_gemm_template(attrs):
     """Return CUTLASS host code for GEMM based on a template and the provided 
attribute map."""
 
     template = """
@@ -277,8 +277,8 @@ def instantiate_gemm_template(attrs, func_args):
   cutlass::gemm::GemmCoord problem_size(M, N, K);
   ElementComputeEpilogue alpha = ElementComputeEpilogue(1);
   ElementComputeEpilogue beta = ElementComputeEpilogue(${beta});
-  void* ptr_a = (void*)(${arg0}->data);
-  void* ptr_b = (void*)(${arg1}->data);
+  void* ptr_a = (void*)(${lhs_arg}->data);
+  void* ptr_b = (void*)(${rhs_arg}->data);
   ${bias_decl}
   void* ptr_out = (void*)(out0->data);
 
@@ -310,7 +310,7 @@ def instantiate_gemm_template(attrs, func_args):
     if has_bias:
         aux_map.update(
             {
-                "bias_decl": "void* ptr_c_bias = (void*)(${arg2}->data);\n",
+                "bias_decl": "void* ptr_c_bias = 
(void*)(${bias_arg}->data);\n",
                 "ptr_c": "ptr_c_bias",
                 "c_stride": "0",
             }
@@ -342,7 +342,4 @@ def instantiate_gemm_template(attrs, func_args):
 
     template = substitute_template(template, aux_map)
 
-    for i, arg in enumerate(func_args):
-        attrs["arg{}".format(i)] = arg
-
     return substitute_template(template, attrs)
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py 
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index d3ab020839..92bf04e863 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -17,27 +17,28 @@
 # pylint: disable=invalid-name
 """Common functions and classes for CUTLASS GEMM and Conv2d geneator."""
 import logging
+import multiprocessing
 import os
 import re
-import tempfile
 import subprocess
-import multiprocessing
+import tempfile
+
 import tvm._ffi
-from tvm.tir import IntImm
 from tvm.runtime import Object
+from tvm.tir import IntImm
+
 from . import _ffi_api as ffi
+from .conv2d_operation import instantiate_conv2d_template
+from .gemm_operation import instantiate_gemm_template
 from .library import (
-    MathInstruction,
     DataType,
     DataTypeTag,
-    OpcodeClass,
+    EpilogueFunctor,
+    MathInstruction,
     MathOperation,
+    OpcodeClass,
     TileDescription,
-    EpilogueFunctor,
 )
-from .gemm_operation import instantiate_gemm_template
-from .conv2d_operation import instantiate_conv2d_template
-
 
 logger = logging.getLogger("cutlass")
 
@@ -371,6 +372,10 @@ EPILOGUE_MAP = {
     "cutlass.matmul_bias": (EpilogueFunctor.LinearCombinationBias, True),
     "cutlass.matmul_bias_relu": (EpilogueFunctor.LinearCombinationRelu, True),
     "cutlass.matmul_bias_gelu": (EpilogueFunctor.LinearCombinationGelu, False),
+    "cutlass.matmul_transposed": (EpilogueFunctor.LinearCombination, False),
+    "cutlass.matmul_transposed_bias": (EpilogueFunctor.LinearCombinationBias, 
True),
+    "cutlass.matmul_transposed_bias_relu": 
(EpilogueFunctor.LinearCombinationRelu, True),
+    "cutlass.matmul_transposed_bias_gelu": 
(EpilogueFunctor.LinearCombinationGelu, False),
     "cutlass.batch_matmul": (EpilogueFunctor.LinearCombination, False),
     "cutlass.conv2d_bias_hardswish": 
(EpilogueFunctor.LinearCombinationHardSwish, False),
     "cutlass.conv2d_bias_silu": (EpilogueFunctor.LinearCombinationSilu, False),
@@ -454,6 +459,13 @@ class CodegenResult(Object):
         self.__init_handle_by_constructor__(ffi.CodegenResult, code, headers)
 
 
+def _get_optional_int_annotation(annotations, key, default=None):
+    value = annotations.get(key, default)
+    if value is not None:
+        return int(value)
+    return value
+
+
 @tvm._ffi.register_func("contrib.cutlass.instantiate_template")
 def instantiate_template(func_name, annotations, func_args):
     """Return CUTLASS host code based on a template and the provided 
annotations.
@@ -519,32 +531,69 @@ def instantiate_template(func_name, annotations, 
func_args):
     if "dense" in func_name or "matmul" in func_name:
         batched = "batch_matmul" in func_name
         batched_offset = 1 if batched else 0
-        attrs["K"] = str(int(arg0_shape[batched_offset + 1]))
-        attrs["M"] = get_dim(arg0_shape[batched_offset], func_args[0], 0, 
batched_offset)
-
-        if annotations["ldb"] == "N":
-            attrs["N"] = get_dim(arg1_shape[batched_offset + 1], func_args[1], 
1, batched_offset)
+        transposed = "transposed" in func_name
+        lhs_arg_idx = _get_optional_int_annotation(annotations, "lhs_arg_idx", 
0)
+        rhs_arg_idx = _get_optional_int_annotation(annotations, "rhs_arg_idx", 
1)
+        bias_arg_idx = _get_optional_int_annotation(annotations, 
"bias_arg_idx", 2)
+        lhs_arg = func_args[lhs_arg_idx]
+        rhs_arg = func_args[rhs_arg_idx]
+        lhs_shape = annotations[f"arg{lhs_arg_idx}_shape"]
+        rhs_shape = annotations[f"arg{rhs_arg_idx}_shape"]
+
+        attrs["lhs_arg"] = lhs_arg
+        attrs["rhs_arg"] = rhs_arg
+        if len(func_args) > 2:
+            attrs["bias_arg"] = func_args[bias_arg_idx]
+        attrs["ElementInputA"] = 
DataTypeTag[dtype_map[annotations[f"arg{lhs_arg_idx}_dtype"]]]
+        attrs["ElementInputB"] = 
DataTypeTag[dtype_map[annotations[f"arg{rhs_arg_idx}_dtype"]]]
+        attrs["ElementOutput"] = 
DataTypeTag[dtype_map[annotations["ret_dtype"]]]
+
+        attrs["K"] = str(int(lhs_shape[batched_offset + 1]))
+        attrs["M"] = get_dim(lhs_shape[batched_offset], lhs_arg, 0, 
batched_offset)
+
+        if transposed:
+            attrs["N"] = get_dim(rhs_shape[batched_offset], rhs_arg, 0, 
batched_offset)
         else:
-            attrs["N"] = get_dim(arg1_shape[batched_offset], func_args[1], 0, 
batched_offset)
+            attrs["N"] = get_dim(rhs_shape[batched_offset + 1], rhs_arg, 1, 
batched_offset)
 
         if batched:
             headers.append("cutlass/gemm/device/gemm_batched.h")
-            attrs["batch"] = get_dim(arg0_shape[0], func_args[0], 0)
-            attrs["batch_stride_A"] = 
get_batch_stride(annotations["batch_stride_A"], 0, 0, 1, 2)
-            attrs["batch_stride_B"] = 
get_batch_stride(annotations["batch_stride_B"], 1, 1, 1, 2)
+            attrs["batch"] = get_dim(lhs_shape[0], lhs_arg, 0)
+            attrs["batch_stride_A"] = get_batch_stride(
+                annotations["batch_stride_A"],
+                lhs_arg_idx,
+                lhs_arg_idx,
+                1,
+                2,
+            )
+            attrs["batch_stride_B"] = get_batch_stride(
+                annotations["batch_stride_B"],
+                rhs_arg_idx,
+                rhs_arg_idx,
+                1,
+                2,
+            )
 
-            if annotations["ldb"] == "N":
+            if transposed:
                 attrs["batch_stride_C"] = get_batch_stride(
-                    annotations["batch_stride_C"], 0, 1, 1, 2
+                    annotations["batch_stride_C"],
+                    lhs_arg_idx,
+                    rhs_arg_idx,
+                    1,
+                    1,
                 )
             else:
                 attrs["batch_stride_C"] = get_batch_stride(
-                    annotations["batch_stride_C"], 0, 1, 1, 1
+                    annotations["batch_stride_C"],
+                    lhs_arg_idx,
+                    rhs_arg_idx,
+                    1,
+                    2,
                 )
         else:
             headers.append("cutlass/gemm/device/gemm.h")
 
-        code = instantiate_gemm_template(attrs, func_args)
+        code = instantiate_gemm_template(attrs)
         return CodegenResult(code, headers)
 
     elif "conv2d" in func_name:
diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py
index d0a1942ebd..e86f8c6074 100644
--- a/python/tvm/relax/__init__.py
+++ b/python/tvm/relax/__init__.py
@@ -87,6 +87,7 @@ from . import transform
 from . import block_builder
 from . import op
 from . import struct_info
+from . import backend
 
 # VM
 from .vm_build import build, Executable
diff --git a/python/tvm/relax/backend/contrib/cutlass.py 
b/python/tvm/relax/backend/contrib/cutlass.py
index 20cf57a40a..51684abb06 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -66,6 +66,36 @@ register_patterns(
                 activation="relax.nn.gelu",
             ),
         ),
+        (
+            "cutlass.matmul_transposed",
+            make_matmul_pattern(
+                with_bias=False,
+                transposed_rhs=True,
+            ),
+        ),
+        (
+            "cutlass.matmul_transposed_bias",
+            make_matmul_pattern(
+                with_bias=True,
+                transposed_rhs=True,
+            ),
+        ),
+        (
+            "cutlass.matmul_transposed_bias_relu",
+            make_matmul_pattern(
+                with_bias=True,
+                activation="relax.nn.relu",
+                transposed_rhs=True,
+            ),
+        ),
+        (
+            "cutlass.matmul_transposed_bias_gelu",
+            make_matmul_pattern(
+                with_bias=True,
+                activation="relax.nn.gelu",
+                transposed_rhs=True,
+            ),
+        ),
     ]
 )
 
diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py
index 9e1963f7ed..300b0af568 100644
--- a/python/tvm/relax/dpl/pattern.py
+++ b/python/tvm/relax/dpl/pattern.py
@@ -204,6 +204,31 @@ class DFPattern(Node):
         """
         return ffi.match_expr(self, expr, var2val)  # type: ignore
 
+    def extract_matched_expr(
+        self, expr, var2val: Optional[Dict[Var, Expr]] = None
+    ) -> Optional[Dict["DFPattern", Expr]]:
+        """
+        Match a relax.Expr and return a map from matching patterns to matched 
expressions.
+
+        Parameters
+        ----------
+        expr : tvm.relax.Expr
+            The expression to match
+        var2val : Optional[Dict[tvm.relax.Var, tvm.relax.Expr]]
+            A mapping from relax.Var to relax.Expr for autojump.
+
+        Returns
+        -------
+        result: Optional[Dict[DFPattern, Expr]]
+            Map from matching patterns to matched expressions.
+            Return None if the pattern does not match expr.
+
+        Note
+        ----
+        Check the note of `match` for the meaning of var2val.
+        """
+        return ffi.extract_matched_expr(self, expr, var2val)
+
     def used_by(self, other: Union["DFPattern", "PatternSeq"], index=-1) -> 
"PatternSeq":
         """
         The current pattern being used by another pattern (sequence)
diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc
index 92eb452a00..da8c6ce2da 100644
--- a/src/relax/ir/dataflow_matcher.cc
+++ b/src/relax/ir/dataflow_matcher.cc
@@ -515,6 +515,8 @@ Optional<Map<DFPattern, Expr>> ExtractMatchedExpr(DFPattern 
pattern, Expr expr,
   return matching;
 }
 
+TVM_REGISTER_GLOBAL("relax.dpl.extract_matched_expr").set_body_typed(ExtractMatchedExpr);
+
 bool MatchExpr(DFPattern pattern, Expr expr, Optional<Map<Var, Expr>> 
bindings_opt) {
   return static_cast<bool>(ExtractMatchedExpr(pattern, expr, bindings_opt));
 }
diff --git a/tests/python/relax/test_codegen_cutlass.py 
b/tests/python/relax/test_codegen_cutlass.py
index 673155342c..af3d40d9c4 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -255,9 +255,12 @@ def test_conv2d_offload():
     tvm.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5)
 
 
-def get_relax_matmul_module(x, y, with_bias=False, activation=None):
+def get_relax_matmul_module(x, y, transposed_y=False, with_bias=False, 
activation=None):
     m, k = x.shape
-    n = y.shape[-1]
+    if transposed_y:
+        n = y.shape[-2]
+    else:
+        n = y.shape[-1]
     dtype = str(x.dtype)
 
     from tvm.script.ir_builder import IRBuilder
@@ -266,13 +269,15 @@ def get_relax_matmul_module(x, y, with_bias=False, 
activation=None):
     with IRBuilder() as builder:
         with relax_builder.function():
             R.func_name("main")
-            x = R.arg("x", R.Tensor((m, k), dtype))
-            y = R.arg("y", R.Tensor((k, n), dtype))
+            x = R.arg("x", R.Tensor(x.shape, dtype))
+            y = R.arg("y", R.Tensor(y.shape, dtype))
             if with_bias:
                 bias = R.arg("bias", R.Tensor((n,), dtype))
 
             with R.dataflow() as frame:
-                result = R.emit(R.matmul(x, y))
+                if transposed_y:
+                    y = R.emit(R.permute_dims(y))
+                result = R.emit(R.matmul(x, y, out_dtype=dtype))
                 if with_bias:
                     result = R.emit(result + bias)
                 if activation is not None:
@@ -380,5 +385,58 @@ def test_kernel_sharing():
     tvm.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5)
 
 
+def test_matmul_transposed_offload(matmul_x, matmul_y):
+    x, y = matmul_x, matmul_y
+
+    mod = get_relax_matmul_module(x, y.transpose(), transposed_y=True)
+    out = get_result_with_relax_cutlass_offload(mod, x, y.transpose())
+    ref_relay_expr = get_relay_matmul(x.shape, y.shape[::-1])
+    ref = get_relay_ref(ref_relay_expr, x, y.transpose())
+
+    tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-4)
+
+
+def test_matmul_transposed_bias_offload(matmul_x, matmul_y, matmul_bias):
+    x, y, bias = matmul_x, matmul_y, matmul_bias
+
+    mod = get_relax_matmul_module(
+        x, y.transpose(), transposed_y=True, with_bias=True, activation=None
+    )
+    out = get_result_with_relax_cutlass_offload(mod, x, y.transpose(), bias)
+
+    ref_relay_expr = get_relay_matmul_bias(x.shape, y.shape[::-1])
+    ref = get_relay_ref(ref_relay_expr, x, y.transpose(), bias)
+
+    tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-4)
+
+
+def test_matmul_transposed_bias_relu_offload(matmul_x, matmul_y, matmul_bias):
+    x, y, bias = matmul_x, matmul_y, matmul_bias
+
+    mod = get_relax_matmul_module(
+        x, y.transpose(), transposed_y=True, with_bias=True, 
activation=R.nn.relu
+    )
+    out = get_result_with_relax_cutlass_offload(mod, x, y.transpose(), bias)
+
+    ref_relay_expr = get_relay_matmul_bias_relu(x.shape, y.shape[::-1])
+    ref = get_relay_ref(ref_relay_expr, x, y.transpose(), bias)
+
+    tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-4)
+
+
+def test_matmul_transposed_bias_gelu_offload(matmul_x, matmul_y, matmul_bias):
+    x, y, bias = matmul_x, matmul_y, matmul_bias
+
+    mod = get_relax_matmul_module(
+        x, y.transpose(), transposed_y=True, with_bias=True, 
activation=R.nn.gelu
+    )
+    out = get_result_with_relax_cutlass_offload(mod, x, y.transpose(), bias)
+
+    ref_relay_expr = get_relay_matmul_bias_gelu(x.shape, y.shape[::-1])
+    ref = get_relay_ref(ref_relay_expr, x, y.transpose(), bias)
+
+    tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-3)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to