This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 925cb2bbaa [Unity][BYOC] Add cutlass finegrained decode matmul (#16144)
925cb2bbaa is described below

commit 925cb2bbaa59efedf5b32bee436961e77d979147
Author: Wuwei Lin <[email protected]>
AuthorDate: Tue Nov 21 10:43:00 2023 -0800

    [Unity][BYOC] Add cutlass finegrained decode matmul (#16144)
    
    * [Unity][BYOC] Add cutlass finegrained decode matmul
    
    * clean up
    
    * lint
    
    * fix
    
    * trigger ci
    
    * Update submodule
    
    * Update submodule
---
 3rdparty/cutlass_fpA_intB_gemm               |   2 +-
 python/tvm/contrib/cutlass/build.py          |   4 +
 python/tvm/contrib/cutlass/gemm_operation.py |  16 ++-
 python/tvm/contrib/cutlass/gen_tensor_op.py  |   1 +
 python/tvm/relax/backend/contrib/cutlass.py  |   4 +-
 tests/python/relax/test_codegen_cutlass.py   | 157 +++++++++++++++++++++++++++
 6 files changed, 177 insertions(+), 7 deletions(-)

diff --git a/3rdparty/cutlass_fpA_intB_gemm b/3rdparty/cutlass_fpA_intB_gemm
index 390e821fba..942c01db41 160000
--- a/3rdparty/cutlass_fpA_intB_gemm
+++ b/3rdparty/cutlass_fpA_intB_gemm
@@ -1 +1 @@
-Subproject commit 390e821fbad2356089aab603d7116c6c820eae65
+Subproject commit 942c01db41c7bba09e14dc0c4ee35f9d73d2568b
diff --git a/python/tvm/contrib/cutlass/build.py 
b/python/tvm/contrib/cutlass/build.py
index 671bca7d02..1c0a30c62d 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -707,6 +707,8 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
         lhs_shape = signature[f"{lhs_arg}_shape"]
         rhs_shape = signature[f"{rhs_arg}_shape"]
         ret_shape = signature["ret_shape"]
+        scale_arg = f"arg{arg_idx['scales']}"
+        scale_shape = signature[f"{scale_arg}_shape"]
         N = ret_shape[-1]
 
         attrs = {
@@ -717,6 +719,8 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
             "bias_arg_idx": arg_idx.get("bias"),
             "activation": "identity",
         }
+        # TODO(wuwei): find a better way to get group size
+        attrs["group_size"] = 64 if len(scale_shape) == 2 and scale_shape[0] 
!= 1 else -1
 
         attrs["batch_rank"] = len(lhs_shape[:-1])
         attrs["M"] = reduce(operator.mul, lhs_shape[:-1], 1)
diff --git a/python/tvm/contrib/cutlass/gemm_operation.py 
b/python/tvm/contrib/cutlass/gemm_operation.py
index e20f1e0d0c..2639a0359a 100644
--- a/python/tvm/contrib/cutlass/gemm_operation.py
+++ b/python/tvm/contrib/cutlass/gemm_operation.py
@@ -414,9 +414,17 @@ def instantiate_gemm_template(attrs):
 
 def emit_fp16A_intB_matmul(attrs):
     """Return CUTLASS host code for fp16 A and int4 or int8 B GEMM."""
+    if attrs["group_size"] > 0:
+        attrs["quant_op"] = 
"cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY"
+    else:
+        attrs["quant_op"] = "cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY"
+        attrs["group_size"] = "k"
+
     attrs["template_common"] = substitute_template(
         """
   using namespace fastertransformer;
+  constexpr auto QuantOp = ${quant_op};
+
   int m = ${M};
   int n = ${B_arg}->shape[1] * ${float_per_int};
   int k = ${B_arg}->shape[0];
@@ -430,24 +438,24 @@ def emit_fp16A_intB_matmul(attrs):
 
     template = """
   ${template_common}
-  gemm_fp16_int_bias_act(static_cast<cutlass::half_t*>(${A_arg}->data),
+  gemm_fp16_int_bias_act<${weight_dtype}, 
QuantOp>(static_cast<cutlass::half_t*>(${A_arg}->data),
                 static_cast<${weight_dtype}*>(${B_arg}->data),
                 static_cast<cutlass::half_t*>(${scales_arg}->data),
                 ${bias},
                 static_cast<cutlass::half_t*>(out0->data),
                 "${activation}",
-                m, n, k, ${bias_stride}, nullptr, 0, stream);
+                m, n, k, ${group_size}, ${bias_stride}, nullptr, 0, stream);
 """
 
     template_residual = """
   ${template_common}
-  
gemm_fp16_int_bias_act_residual(static_cast<cutlass::half_t*>(${A_arg}->data),
+  gemm_fp16_int_bias_act_residual<${weight_dtype}, 
QuantOp>(static_cast<cutlass::half_t*>(${A_arg}->data),
                 static_cast<${weight_dtype}*>(${B_arg}->data),
                 static_cast<cutlass::half_t*>(${scales_arg}->data),
                 ${bias},
                 static_cast<cutlass::half_t*>(${residual_arg}->data),
                 static_cast<cutlass::half_t*>(out0->data), "${activation}", 
"${binary_op}", "${unary_op}",
-                m, n, k, nullptr, 0, stream);
+                m, n, k, ${group_size}, nullptr, 0, stream);
 """
 
     if "residual_arg" in attrs:
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py 
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index d42791d71b..15629ddd3d 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -534,6 +534,7 @@ def instantiate_template(func_name, annotations, func_args):
         attrs["activation"] = annotations.get("activation", "identity")
         attrs["bias_stride"] = annotations["bias_stride"]
         attrs["M"] = annotations["M"]
+        attrs["group_size"] = annotations["group_size"]
 
         if not isinstance(attrs["M"], tvm.tir.IntImm):
             attrs["M"] = get_flattened_batch_dim(
diff --git a/python/tvm/relax/backend/contrib/cutlass.py 
b/python/tvm/relax/backend/contrib/cutlass.py
index 100b7aa2fd..c7780b7c67 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -269,8 +269,8 @@ def _check_decode_matmul(ctx):
     if scales.struct_info.dtype != "float16":
         return False
 
-    # scale shape needs to be (N,)
-    if len(scales.struct_info.shape) != 1 or scales.struct_info.shape[0] != N:
+    # scale shape needs to be (N,) or (1, N) or (K // group_size, N)
+    if len(scales.struct_info.shape) > 2 or scales.struct_info.shape[-1] != N:
         return False
 
     if "bias" in ctx.annotated_expr:
diff --git a/tests/python/relax/test_codegen_cutlass.py 
b/tests/python/relax/test_codegen_cutlass.py
index 9bec214ab9..8ad8fbc531 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -1996,6 +1996,163 @@ def test_fp16A_int8B_gemm_batched():
     tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
 
 
+def test_fp16A_int8B_gemm_batched_finegrained():
+    @I.ir_module
+    class Module:
+        @T.prim_func
+        def decode(
+            A: T.Buffer((T.int64(128), T.int64(128)), "int8"),
+            B: T.Buffer((T.int64(2), T.int64(128)), "float16"),
+            decode_1: T.Buffer((T.int64(128), T.int64(128)), "float16"),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for i, j in T.grid(T.int64(128), T.int64(128)):
+                with T.block("decode"):
+                    v_i, v_j = T.axis.remap("SS", [i, j])
+                    T.reads(A[v_i, v_j], B[v_i // T.int64(64), v_j])
+                    T.writes(decode_1[v_i, v_j])
+                    decode_1[v_i, v_j] = T.Cast("float16", A[v_i, v_j]) * 
B[v_i // T.int64(64), v_j]
+
+        @T.prim_func
+        def encode(
+            A: T.Buffer((T.int64(128), T.int64(128)), "float16"),
+            w_gathered: T.Buffer((T.int64(128), T.int64(128)), "int8"),
+            compute: T.Buffer(
+                (
+                    T.int64(2),
+                    T.int64(128),
+                ),
+                "float16",
+            ),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            max_abs_value = T.alloc_buffer(
+                (
+                    T.int64(2),
+                    T.int64(128),
+                ),
+                "float16",
+            )
+            scale = T.alloc_buffer(
+                (
+                    T.int64(2),
+                    T.int64(128),
+                )
+            )
+            for i, j, k in T.grid(T.int64(2), T.int64(128), T.int64(64)):
+                with T.block("max_abs_value"):
+                    v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
+                    T.reads(A[v_j, v_i * T.int64(64) + v_k])
+                    T.writes(max_abs_value[v_i, v_j])
+                    with T.init():
+                        max_abs_value[v_i, v_j] = T.float16(-65504)
+                    max_abs_value[v_i, v_j] = T.max(
+                        max_abs_value[v_i, v_j], T.fabs(A[v_j, v_i * 
T.int64(64) + v_k])
+                    )
+            for i, j in T.grid(T.int64(2), T.int64(128)):
+                with T.block("scale"):
+                    v_i, v_j = T.axis.remap("SS", [i, j])
+                    T.reads(max_abs_value[v_i, v_j])
+                    T.writes(scale[v_i, v_j])
+                    scale[v_i, v_j] = T.max(
+                        T.Cast("float32", max_abs_value[v_i, v_j]), 
T.float32(0.0001)
+                    ) * T.float32(0.0078125)
+            for j, i in T.grid(T.int64(128), T.int64(128)):
+                with T.block("w_gathered"):
+                    v_j, v_i = T.axis.remap("SS", [j, i])
+                    T.reads(A[v_i, v_j], scale[v_j // T.int64(64), v_i])
+                    T.writes(w_gathered[v_j, v_i])
+                    w_gathered[v_j, v_i] = T.Cast(
+                        "int8",
+                        T.min(
+                            T.max(
+                                T.round(
+                                    T.Cast("float32", A[v_i, v_j]) / scale[v_j 
// T.int64(64), v_i]
+                                ),
+                                T.float32(-128),
+                            ),
+                            T.float32(127),
+                        ),
+                    )
+            for i0, i1 in T.grid(T.int64(2), T.int64(128)):
+                with T.block("compute"):
+                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(scale[v_i0, v_i1])
+                    T.writes(compute[v_i0, v_i1])
+                    compute[v_i0, v_i1] = T.Cast("float16", scale[v_i0, v_i1])
+
+        @R.function
+        def main(
+            x: R.Tensor(("b", 128, 128), dtype="float16"),
+            y: R.Tensor((128, 128), dtype="float16"),
+        ) -> R.Tensor(("b", 128, 128), dtype="float16"):
+            R.func_attr({"num_input": 1})
+            cls = Module
+            b = T.int64()
+            with R.dataflow():
+                lv = R.call_tir(
+                    cls.encode,
+                    (y,),
+                    out_sinfo=[
+                        R.Tensor((128, 128), dtype="int8"),
+                        R.Tensor((2, 128), dtype="float16"),
+                    ],
+                )
+                lv1: R.Tensor((128, 128), dtype="int8") = lv[0]
+                lv2: R.Tensor((128, 128), dtype="int8") = R.call_pure_packed(
+                    "cutlass.ft_preprocess_weight",
+                    lv1,
+                    R.prim_value(80),
+                    R.prim_value(0),
+                    sinfo_args=(R.Tensor((128, 128), dtype="int8"),),
+                )
+                lv3: R.Tensor((2, 128), dtype="float16") = lv[1]
+                lv4: R.Tensor((128, 128), dtype="int8") = 
R.builtin.stop_lift_params(lv2)
+                lv5: R.Tensor((2, 128), dtype="float16") = 
R.builtin.stop_lift_params(lv3)
+                lv6 = R.call_tir(
+                    cls.decode, (lv4, lv5), out_sinfo=R.Tensor((128, 128), 
dtype="float16")
+                )
+                lv1_1: R.Tensor((b, 128, 128), dtype="float16") = R.matmul(
+                    x, lv6, out_dtype="float16"
+                )
+                R.output(lv1_1)
+            return lv1_1
+
+    x_shape = (4, 128, 128)
+    y_shape = (128, 128)
+
+    mod = partition_for_cutlass(Module)
+
+    mod = relax.transform.RunCodegen(
+        {"cutlass": {"sm": 80, "find_first_valid": False}},
+    )(mod)
+
+    x = np.random.randn(*x_shape).astype("float16")
+    y = np.random.normal(0, 0.002, size=y_shape).astype("float16")
+
+    mod = relax.pipeline.get_pipeline()(mod)
+    mod = relax.transform.LiftTransformParams()(mod)
+
+    mod_transform, mod_deploy, transform_func_name = 
split_transform_deploy_mod(mod)
+
+    ex = relax.build(mod_transform, target="llvm")
+    vm = relax.vm.VirtualMachine(ex, tvm.cpu(0))
+
+    (packed_weight, scales,) = vm[
+        transform_func_name
+    ]((tvm.nd.array(y),))
+
+    dev = tvm.device("cuda", 0)
+    ex = relax.build(mod_deploy, target="cuda")
+    vm = relax.vm.VirtualMachine(ex, dev)
+
+    x_nd = tvm.nd.array(x, dev)
+    inp = [x_nd, packed_weight.copyto(dev), scales.copyto(dev)]
+    out = vm["main"](*inp).numpy()
+    ref = np.dot(x, y.transpose())
+    tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+
+
 def test_attention_rewrite_multi_query():
     @I.ir_module
     class Module:

Reply via email to