This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 63b170d80e [Unity] fp16 A x int B GEMM update - support int8, more 
bias shape (#15318)
63b170d80e is described below

commit 63b170d80e093890a0b9c9ac1119fed74df77b23
Author: masahi <[email protected]>
AuthorDate: Tue Jul 18 04:41:47 2023 +0900

    [Unity] fp16 A x int B GEMM update - support int8, more bias shape (#15318)
    
    * Support int8 FasterTransformer kernel
    
    * test strided bias
    
    * test gelu act
    
    * update decode check
    
    * update FT rev to disable GCC warning
---
 3rdparty/cutlass_fpA_intB_gemm                     |   2 +-
 python/tvm/contrib/cutlass/build.py                |  45 +++--
 python/tvm/contrib/cutlass/gemm_operation.py       |  26 +--
 python/tvm/contrib/cutlass/gen_tensor_op.py        |  19 ++-
 python/tvm/contrib/cutlass/layer_norm_operation.py |   4 +-
 python/tvm/relax/backend/contrib/cutlass.py        |  39 +++--
 src/runtime/contrib/cutlass/weight_preprocess.cc   |  11 +-
 tests/python/relax/test_codegen_cutlass.py         | 190 ++++++++++++++++++---
 8 files changed, 261 insertions(+), 75 deletions(-)

diff --git a/3rdparty/cutlass_fpA_intB_gemm b/3rdparty/cutlass_fpA_intB_gemm
index 634184f502..390e821fba 160000
--- a/3rdparty/cutlass_fpA_intB_gemm
+++ b/3rdparty/cutlass_fpA_intB_gemm
@@ -1 +1 @@
-Subproject commit 634184f50272227f95805fb8ef1eb8c4da373304
+Subproject commit 390e821fbad2356089aab603d7116c6c820eae65
diff --git a/python/tvm/contrib/cutlass/build.py 
b/python/tvm/contrib/cutlass/build.py
index 2d99cca8f9..b6681ae80f 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -700,7 +700,12 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
         arg_idx = _extract_arg_idx(op_type, f)
         signature = _extract_relax_function_signature(f)
         lhs_arg = f"arg{arg_idx['lhs']}"
+        rhs_arg = f"arg{arg_idx['w_encoded']}"
         lhs_shape = signature[f"{lhs_arg}_shape"]
+        rhs_shape = signature[f"{rhs_arg}_shape"]
+        ret_shape = signature["ret_shape"]
+        N = ret_shape[-1]
+
         attrs = {
             "op_type": op_type,
             "lhs_arg_idx": arg_idx["lhs"],
@@ -708,8 +713,23 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
             "scales_arg_idx": arg_idx["scales"],
             "bias_arg_idx": arg_idx.get("bias"),
             "batch_offset": len(lhs_shape) - 2,
+            "activation": "identity",
         }
 
+        attrs["bias_stride"] = 0
+
+        if "bias" in arg_idx:
+            bias_shape = signature[f"arg{arg_idx['bias']}_shape"]
+            bias_shape_1d = reduce(operator.mul, bias_shape, 1)
+            if bias_shape_1d != bias_shape[-1]:
+                attrs["bias_stride"] = bias_shape[-1]
+
+        if N == rhs_shape[1]:
+            attrs["weight_nbit"] = 8
+        else:
+            assert N == rhs_shape[1] * 2
+            attrs["weight_nbit"] = 4
+
         if "residual" in op_type:
             residual_pos = op_type.find("residual_")
             postfix = op_type[residual_pos + len("residual_") :]
@@ -739,6 +759,11 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
                     "residual_arg_idx": arg_idx["residual"],
                 }
             )
+        else:
+            for act in ["relu", "silu", "gelu"]:
+                if act in op_type:
+                    attrs["activation"] = act
+                    break
 
         return f.with_attrs(attrs)
 
@@ -912,18 +937,8 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
             }
         )
 
-    def handle_layer_norm(self, f, _):
-        """Annotate a layer norm op."""
-        signature = _extract_relax_function_signature(f)
-        attrs = {}
-        attrs["M"] = reduce(operator.mul, signature["arg0_shape"][:-1], 1)
-        attrs["N"] = signature["arg0_shape"][-1]
-        dtype = signature["arg0_dtype"]
-        attrs["data_type"] = {"float32": "float", "float16": 
"cutlass::half_t"}[str(dtype)]
-        return f.with_attrs(attrs)
-
-    def handle_rms_norm(self, f, _):
-        """Annotate a rms norm op."""
+    def handle_norm(self, f, _):
+        """Annotate a layer or rms norm op."""
         signature = _extract_relax_function_signature(f)
         attrs = {}
         attrs["batch_rank"] = len(signature["arg0_shape"][:-1])
@@ -948,10 +963,8 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
             return self.handle_matmul(f, op_type)
         elif "attention" in op_type:
             return self.handle_attention(f, op_type)
-        elif "layer_norm" in op_type:
-            return self.handle_layer_norm(f, op_type)
-        elif "rms_norm" in op_type:
-            return self.handle_rms_norm(f, op_type)
+        elif "layer_norm" in op_type or "rms_norm" in op_type:
+            return self.handle_norm(f, op_type)
 
         raise ValueError("Unsupported composite {}".format(op_type))
 
diff --git a/python/tvm/contrib/cutlass/gemm_operation.py 
b/python/tvm/contrib/cutlass/gemm_operation.py
index 3fa6e9be8d..e445316dfd 100644
--- a/python/tvm/contrib/cutlass/gemm_operation.py
+++ b/python/tvm/contrib/cutlass/gemm_operation.py
@@ -412,14 +412,13 @@ def instantiate_gemm_template(attrs):
     return substitute_template(template, attrs)
 
 
-def emit_fp16A_int4B_matmul(attrs):
-    """Return CUTLASS host code for fp16 A and int4 B GEMM."""
-
+def emit_fp16A_intB_matmul(attrs):
+    """Return CUTLASS host code for fp16 A and int4 or int8 B GEMM."""
     attrs["template_common"] = substitute_template(
         """
   using namespace fastertransformer;
   int m = ${A_arg}->shape[${batch_offset}];
-  int n = ${B_arg}->shape[1] * 2;
+  int n = ${B_arg}->shape[1] * ${float_per_int};
   int k = ${B_arg}->shape[0];
 
   auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
@@ -451,17 +450,20 @@ def emit_fp16A_int4B_matmul(attrs):
                 m, n, k, nullptr, 0, stream);
 """
 
-    if "residual_arg" in attrs and "bias_arg" in attrs:
-        template_residual = substitute_template(
-            template_residual, {"bias": 
"static_cast<cutlass::half_t*>(${bias_arg}->data)"}
-        )
-        return substitute_template(template_residual, attrs)
-
     if "residual_arg" in attrs:
-        template_residual = substitute_template(template_residual, {"bias": 
"nullptr"})
+        if "bias_arg" in attrs:
+            bias = "static_cast<cutlass::half_t*>(${bias_arg}->data)"
+        else:
+            bias = "nullptr"
+
+        template_residual = substitute_template(template_residual, {"bias": 
bias})
         return substitute_template(template_residual, attrs)
 
     if "bias_arg" in attrs:
-        return substitute_template(template_bias, attrs)
+        template = substitute_template(
+            template, {"bias": 
"static_cast<cutlass::half_t*>(${bias_arg}->data)"}
+        )
+    else:
+        template = substitute_template(template, {"bias": "nullptr"})
 
     return substitute_template(template, attrs)
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py 
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index 8c8bcc20c3..bf02d8f7b8 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -31,7 +31,7 @@ from tvm.tir import IntImm
 from . import _ffi_api as ffi
 from .attention_operation import instantiate_attention_template
 from .conv2d_operation import instantiate_conv2d_template
-from .gemm_operation import instantiate_gemm_template, emit_fp16A_int4B_matmul
+from .gemm_operation import instantiate_gemm_template, emit_fp16A_intB_matmul
 from .layer_norm_operation import instantiate_layer_norm_template
 from .rms_norm_operation import instantiate_rms_norm_template
 from .library import (
@@ -526,6 +526,7 @@ def instantiate_template(func_name, annotations, func_args):
         attrs["scales_arg"] = func_args[scales_arg_idx]
         attrs["batch_offset"] = _get_optional_int_annotation(annotations, 
"batch_offset", 0)
         attrs["activation"] = annotations.get("activation", "identity")
+        attrs["bias_stride"] = annotations["bias_stride"]
 
         if bias_arg_idx is not None:
             attrs["bias_arg"] = func_args[bias_arg_idx]
@@ -535,7 +536,15 @@ def instantiate_template(func_name, annotations, 
func_args):
             attrs["binary_op"] = annotations["binary_op"]
             attrs["unary_op"] = annotations["unary_op"]
 
-        code = emit_fp16A_int4B_matmul(attrs)
+        if annotations["weight_nbit"] == 4:
+            attrs["weight_dtype"] = "cutlass::uint4b_t"
+            attrs["float_per_int"] = 2
+        else:
+            assert annotations["weight_nbit"] == 8
+            attrs["weight_dtype"] = "uint8_t"
+            attrs["float_per_int"] = 1
+
+        code = emit_fp16A_intB_matmul(attrs)
         return CodegenResult(code, headers)
 
     elif "dense" in func_name or "matmul" in func_name:
@@ -792,6 +801,12 @@ def instantiate_template(func_name, annotations, 
func_args):
         headers.append("cutlass/layout/matrix.h")
         attrs = {"input": func_args[0], "gamma": func_args[1], "beta": 
func_args[2]}
         attrs.update(dict(annotations))
+
+        if isinstance(attrs["M"], tvm.tir.Var):
+            attrs["M"] = " * ".join(
+                ["{}->shape[{}]".format(func_args[0], i) for i in 
range(int(attrs["batch_rank"]))]
+            )
+
         code = instantiate_layer_norm_template(attrs)
         return CodegenResult(code, headers)
     elif "rms_norm" in func_name:
diff --git a/python/tvm/contrib/cutlass/layer_norm_operation.py 
b/python/tvm/contrib/cutlass/layer_norm_operation.py
index ad2e730d27..760ddec3bf 100644
--- a/python/tvm/contrib/cutlass/layer_norm_operation.py
+++ b/python/tvm/contrib/cutlass/layer_norm_operation.py
@@ -28,8 +28,8 @@ def instantiate_layer_norm_template(attrs):
     using data_type = ${data_type};
     using namespace cutlass::layout;
 
-    auto M = ${M};
-    auto N = ${N};
+    int M = ${M};
+    int N = ${N};
     cutlass::MatrixCoord size(M, N);
     auto layout_2D = RowMajor::packed(size);
     auto layout_channels = RowMajor::packed({1, N});
diff --git a/python/tvm/relax/backend/contrib/cutlass.py 
b/python/tvm/relax/backend/contrib/cutlass.py
index 5cb5a6f3d7..cdce3225b5 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -102,6 +102,15 @@ def _has_dependency(from_var: Var, to_var: Var, 
var_usages: Mapping[Var, Sequenc
     return False
 
 
+def _is_same_shape(shape1, shape2):
+    analyzer = tvm.arith.Analyzer()
+    return all([analyzer.can_prove_equal(s1, s2) for s1, s2 in zip(shape1, 
shape2)])
+
+
+def _is_bias_like(shape, out_channel):
+    return shape[-1] == out_channel and _shape_1d(shape) == out_channel
+
+
 def _check_residual(root_call: Call, context: PatternCheckContext) -> bool:
     if "residual" in context.annotated_expr:
         residual = context.annotated_expr["residual"]
@@ -116,13 +125,11 @@ def _check_residual(root_call: Call, context: 
PatternCheckContext) -> bool:
             # If residual depends on the result of the root call, this cannot 
be handled by cutlass.
             return False
 
-        shape1 = [int(s) for s in root_var.struct_info.shape]
-        shape2 = [int(s) for s in residual.struct_info.shape]
-
+        shape1 = root_var.struct_info.shape
+        shape2 = residual.struct_info.shape
         out_channel = shape1[-1]
-        is_bias_like = lambda shape: (shape[-1] == out_channel and 
_shape_1d(shape) == out_channel)
 
-        if shape1 != shape2 and not is_bias_like(shape2):
+        if not _is_same_shape(shape1, shape2) and not _is_bias_like(shape2, 
out_channel):
             return False
 
     return True
@@ -239,7 +246,6 @@ def _check_decode_matmul(ctx):
         return False
 
     N = root.struct_info.shape[-1]
-    K = call_tir_decode.struct_info.shape[0]
 
     if ctx.annotated_expr["lhs"].struct_info.dtype != "float16":
         return False
@@ -254,14 +260,10 @@ def _check_decode_matmul(ctx):
     if (
         isinstance(packed_weight, Call)
         and isinstance(packed_weight.args[0], ExternFunc)
-        and packed_weight.args[0].global_symbol != 
"cutlass.ft_preprocess_weight_int4"
+        and packed_weight.args[0].global_symbol != 
"cutlass.ft_preprocess_weight"
     ):
         return False
 
-    # packed weight needs to be of shape (K, N // 2)
-    if packed_weight.struct_info.shape[0] != K or 
packed_weight.struct_info.shape[1] != N // 2:
-        return False
-
     scales = ctx.annotated_expr["scales"]
 
     if scales.struct_info.dtype != "float16":
@@ -271,11 +273,13 @@ def _check_decode_matmul(ctx):
     if len(scales.struct_info.shape) != 1 or scales.struct_info.shape[0] != N:
         return False
 
-    # bias shape needs to be (N,), possibly with additional axes on the front.
     if "bias" in ctx.annotated_expr:
+        out_shape = root.struct_info.shape
         bias_shape = ctx.annotated_expr["bias"].struct_info.shape
-        bias_shape_1d = reduce(operator.mul, bias_shape, 1)
-        if bias_shape_1d != bias_shape[-1]:
+
+        # bias shape needs to be (N,), possibly with additional axes on the 
front.
+        # It can also have the same shape as the output.
+        if not _is_bias_like(bias_shape, N) and not _is_same_shape(out_shape, 
bias_shape):
             return False
 
     return True
@@ -309,16 +313,15 @@ def decode_matmul_patterns():
         else:
             out = matmul
 
-        # TODO(masahi): Support more activations
-        if "silu" in name:
-            out = is_op("relax.nn.silu")(out)
+        if "gelu" in name:
+            out = is_op("relax.nn.gelu")(out)
 
         return name, out, annotations, _check_decode_matmul
 
     return [
         _decode_matmul_pattern("cutlass.decode_matmul"),
         _decode_matmul_pattern("cutlass.decode_matmul_bias"),
-        _decode_matmul_pattern("cutlass.decode_matmul_silu"),
+        _decode_matmul_pattern("cutlass.decode_matmul_bias_gelu"),
     ]
 
 
diff --git a/src/runtime/contrib/cutlass/weight_preprocess.cc 
b/src/runtime/contrib/cutlass/weight_preprocess.cc
index 902926ea1b..ef80627cc7 100644
--- a/src/runtime/contrib/cutlass/weight_preprocess.cc
+++ b/src/runtime/contrib/cutlass/weight_preprocess.cc
@@ -35,8 +35,8 @@ namespace runtime {
 // black box.
 //
 // The preprocessing functions are defined in C++, so we need to copy the 
input weight to CPU.
-TVM_REGISTER_GLOBAL("cutlass.ft_preprocess_weight_int4")
-    .set_body_typed([](NDArray packed_weight, int sm) {
+TVM_REGISTER_GLOBAL("cutlass.ft_preprocess_weight")
+    .set_body_typed([](NDArray packed_weight, int sm, bool is_int4) {
       int rows = packed_weight->shape[0];
       int cols = packed_weight->shape[1];
       std::vector<int8_t> input_cpu(rows * cols);
@@ -44,8 +44,11 @@ TVM_REGISTER_GLOBAL("cutlass.ft_preprocess_weight_int4")
       packed_weight.CopyToBytes(input_cpu.data(), input_cpu.size());
       // multiply cols by 2 since the "col" params in preprocess_weights 
refers to the column of
       // the unpacked weight.
-      fastertransformer::preprocess_weights(output_cpu.data(), 
input_cpu.data(), rows, cols * 2,
-                                            true /*is_int4*/, sm);
+      if (is_int4) {
+        cols *= 2;
+      }
+      fastertransformer::preprocess_weights(output_cpu.data(), 
input_cpu.data(), rows, cols,
+                                            is_int4, sm);
       auto out = NDArray::Empty(packed_weight.Shape(), packed_weight->dtype, 
packed_weight->device);
       out.CopyFromBytes(output_cpu.data(), output_cpu.size());
       return out;
diff --git a/tests/python/relax/test_codegen_cutlass.py 
b/tests/python/relax/test_codegen_cutlass.py
index 1528141e4a..02f15ad3d7 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -1258,6 +1258,25 @@ def test_attention_rewrite_fp16():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def split_transform_deploy_mod(mod):
+    mod_transform = tvm.IRModule()
+    mod_deploy = tvm.IRModule().with_attrs(mod.attrs)
+
+    transform_func_name = None
+
+    for gv, func in mod.functions.items():
+        if "transform_params" in gv.name_hint:
+            transform_func_name = gv.name_hint
+            mod_transform[gv] = func
+        elif isinstance(func, tvm.tir.PrimFunc):
+            mod_transform[gv] = func
+        else:
+            mod_deploy[gv] = func
+
+    assert transform_func_name is not None
+    return mod_transform, mod_deploy, transform_func_name
+
+
 def test_fp16A_int4B_gemm():
     @I.ir_module
     class Module:
@@ -1378,9 +1397,10 @@ def test_fp16A_int4B_gemm():
                 )
                 lv1 = lv[0]
                 lv2 = R.call_pure_packed(
-                    "cutlass.ft_preprocess_weight_int4",
+                    "cutlass.ft_preprocess_weight",
                     lv1,
                     80,
+                    True,
                     sinfo_args=(R.Tensor((64, 64), dtype="int8"),),
                 )
                 lv3: R.Tensor((128,), dtype="float16") = lv[1]
@@ -1409,9 +1429,10 @@ def test_fp16A_int4B_gemm():
                 )
                 lv1 = lv[0]
                 lv2 = R.call_pure_packed(
-                    "cutlass.ft_preprocess_weight_int4",
+                    "cutlass.ft_preprocess_weight",
                     lv1,
                     80,
+                    True,
                     sinfo_args=(R.Tensor((64, 64), dtype="int8"),),
                 )
                 lv3: R.Tensor((128,), dtype="float16") = lv[1]
@@ -1424,24 +1445,6 @@ def test_fp16A_int4B_gemm():
                 R.output(lv3_1)
             return lv3_1
 
-    def split_transform_deploy_mod(mod):
-        mod_transform = tvm.IRModule()
-        mod_deploy = tvm.IRModule().with_attrs(mod.attrs)
-
-        transform_func_name = None
-
-        for gv, func in mod.functions.items():
-            if "transform_params" in gv.name_hint:
-                transform_func_name = gv.name_hint
-                mod_transform[gv] = func
-            elif isinstance(func, tvm.tir.PrimFunc):
-                mod_transform[gv] = func
-            else:
-                mod_deploy[gv] = func
-
-        assert transform_func_name is not None
-        return mod_transform, mod_deploy, transform_func_name
-
     x_shape = (64, 64)
     y_shape = (128, 64)
 
@@ -1496,6 +1499,153 @@ def test_fp16A_int4B_gemm():
         tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
 
 
+def test_fp16A_int8B_gemm():
+    @I.ir_module
+    class Module:
+        @T.prim_func
+        def decode(
+            A: T.Buffer((T.int64(64), T.int64(64)), "int8"),
+            B: T.Buffer((T.int64(64),), "float16"),
+            decode_1: T.Buffer((T.int64(64), T.int64(64)), "float16"),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            for i, j in T.grid(T.int64(64), T.int64(64)):
+                with T.block("decode"):
+                    v_i, v_j = T.axis.remap("SS", [i, j])
+                    T.reads(A[v_i, v_j], B[v_j])
+                    T.writes(decode_1[v_i, v_j])
+                    decode_1[v_i, v_j] = T.Cast("float16", A[v_i, v_j]) * 
B[v_j]
+
+        @T.prim_func
+        def encode(
+            A: T.Buffer((T.int64(64), T.int64(64)), "float16"),
+            w_gathered: T.Buffer((T.int64(64), T.int64(64)), "int8"),
+            compute: T.Buffer((T.int64(64),), "float16"),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            max_abs_value = T.alloc_buffer((T.int64(64),), "float16")
+            scale = T.alloc_buffer((T.int64(64),))
+            for i, k in T.grid(T.int64(64), T.int64(64)):
+                with T.block("max_abs_value"):
+                    v_i, v_k = T.axis.remap("SR", [i, k])
+                    T.reads(A[v_i, v_k])
+                    T.writes(max_abs_value[v_i])
+                    with T.init():
+                        max_abs_value[v_i] = T.float16(-65504)
+                    max_abs_value[v_i] = T.max(max_abs_value[v_i], 
T.fabs(A[v_i, v_k]))
+            for i in range(T.int64(64)):
+                with T.block("scale"):
+                    v_i = T.axis.spatial(T.int64(64), i)
+                    T.reads(max_abs_value[v_i])
+                    T.writes(scale[v_i])
+                    scale[v_i] = T.max(
+                        T.Cast("float32", max_abs_value[v_i]), 
T.float32(0.0001)
+                    ) * T.float32(0.0078125)
+            for j, i in T.grid(T.int64(64), T.int64(64)):
+                with T.block("w_gathered"):
+                    v_j, v_i = T.axis.remap("SS", [j, i])
+                    T.reads(A[v_i, v_j], scale[v_i])
+                    T.writes(w_gathered[v_j, v_i])
+                    w_gathered[v_j, v_i] = T.Cast(
+                        "int8",
+                        T.min(
+                            T.max(
+                                T.round(T.Cast("float32", A[v_i, v_j]) / 
scale[v_i]),
+                                T.float32(-128),
+                            ),
+                            T.float32(127),
+                        ),
+                    )
+            for i0 in range(T.int64(64)):
+                with T.block("compute"):
+                    v_i0 = T.axis.spatial(T.int64(64), i0)
+                    T.reads(scale[v_i0])
+                    T.writes(compute[v_i0])
+                    compute[v_i0] = T.Cast("float16", scale[v_i0])
+
+        @R.function
+        def main(
+            x: R.Tensor((64, 64), dtype="float16"),
+            y: R.Tensor((64, 64), dtype="float16"),
+            bias: R.Tensor((64, 64), dtype="float16"),
+        ) -> R.Tensor((64, 64), dtype="float16"):
+            R.func_attr({"num_input": 1})
+            cls = Module
+            with R.dataflow():
+                lv = R.call_tir(
+                    cls.encode,
+                    (y,),
+                    out_sinfo=[R.Tensor((64, 64), dtype="int8"), 
R.Tensor((64,), dtype="float16")],
+                )
+                lv1: R.Tensor((64, 64), dtype="int8") = lv[0]
+                lv2: R.Tensor((64, 64), dtype="int8") = R.call_pure_packed(
+                    "cutlass.ft_preprocess_weight",
+                    lv1,
+                    R.prim_value(80),
+                    R.prim_value(0),
+                    sinfo_args=(R.Tensor((64, 64), dtype="int8"),),
+                )
+                lv3: R.Tensor((64,), dtype="float16") = lv[1]
+                lv4: R.Tensor((64, 64), dtype="int8") = 
R.builtin.stop_lift_params(lv2)
+                lv5: R.Tensor((64,), dtype="float16") = 
R.builtin.stop_lift_params(lv3)
+                lv6 = R.call_tir(
+                    cls.decode, (lv4, lv5), out_sinfo=R.Tensor((64, 64), 
dtype="float16")
+                )
+                lv1_1: R.Tensor((64, 64), dtype="float16") = R.matmul(x, lv6, 
out_dtype="float16")
+                lv2_1: R.Tensor((64, 128), dtype="float16") = R.add(lv1_1, 
bias)
+                lv2_2: R.Tensor((64, 128), dtype="float16") = R.nn.gelu(lv2_1)
+                R.output(lv2_2)
+            return lv2_2
+
+    x_shape = (64, 64)
+    y_shape = (64, 64)
+
+    mod = partition_for_cutlass(Module)
+    func_names = [name.name_hint for (name, _) in mod.functions.items()]
+    assert "fused_decode_relax_matmul_relax_add_relax_nn_gelu_cutlass" in 
func_names
+
+    mod = relax.transform.RunCodegen(
+        {"cutlass": {"sm": 80, "find_first_valid": False}},
+    )(mod)
+
+    x = np.random.randn(*x_shape).astype("float16")
+    y = np.random.normal(0, 0.002, size=y_shape).astype("float16")
+    bias = np.random.randn(x_shape[0], y_shape[0]).astype("float16")
+
+    mod = relax.pipeline.get_pipeline()(mod)
+    mod = relax.transform.LiftTransformParams()(mod)
+
+    mod_transform, mod_deploy, transform_func_name = 
split_transform_deploy_mod(mod)
+
+    ex = relax.build(mod_transform, target="llvm")
+    vm = relax.vm.VirtualMachine(ex, tvm.cpu(0))
+
+    packed_weight, scales, bias_trans = vm[transform_func_name](
+        (tvm.nd.array(y), tvm.nd.array(bias))
+    )
+
+    dev = tvm.device("cuda", 0)
+    ex = relax.build(mod_deploy, target="cuda")
+    vm = relax.vm.VirtualMachine(ex, dev)
+
+    x_nd = tvm.nd.array(x, dev)
+    params = (packed_weight.copyto(dev), scales.copyto(dev), 
bias_trans.copyto(dev))
+    inp = [x_nd, params]
+    out = vm["main"](*inp).numpy()
+
+    def gelu_fp16(x):
+        erf_inp = x * (0.5**0.5)
+        from scipy.special import erf
+
+        erf_out = erf(erf_inp.astype("float32")).astype("float16")
+        return x * 0.5 * (1.0 + erf_out)
+
+    ref = gelu_fp16(np.dot(x, y.transpose()) + bias)
+    tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+
+
 def test_rms_norm():
     @I.ir_module
     class Module:

Reply via email to