masahi commented on code in PR #15111:
URL: https://github.com/apache/tvm/pull/15111#discussion_r1232880042


##########
tests/python/relax/test_codegen_cutlass.py:
##########
@@ -1250,5 +1250,243 @@ def main(
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_fp16A_int4B_gemm():
+    @I.ir_module
+    class Module:
+        @T.prim_func
+        def decode(
+            A: T.Buffer((T.int64(64), T.int64(64)), "int8"),
+            B: T.Buffer((T.int64(128),), "float16"),
+            decode_1: T.Buffer((T.int64(64), T.int64(128)), "float16"),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            for i, j in T.grid(T.int64(64), T.int64(128)):
+                with T.block("decode"):
+                    v_i, v_j = T.axis.remap("SS", [i, j])
+                    T.reads(A[v_i, v_j // T.int64(2)], B[v_j])
+                    T.writes(decode_1[v_i, v_j])
+                    decode_1[v_i, v_j] = (
+                        T.Cast(
+                            "float16",
+                            T.shift_right(
+                                T.shift_left(
+                                    T.bitwise_and(
+                                        T.shift_right(
+                                            T.Cast("int32", A[v_i, v_j // 
T.int64(2)]),
+                                            T.Cast("int32", v_j % T.int64(2)) 
* 4,
+                                        ),
+                                        15,
+                                    ),
+                                    28,
+                                ),
+                                28,
+                            ),
+                        )
+                        * B[v_j]
+                    )
+
+        @T.prim_func
+        def encode(
+            A: T.Buffer((T.int64(128), T.int64(64)), "float16"),
+            w_gathered: T.Buffer((T.int64(64), T.int64(64)), "int8"),
+            compute: T.Buffer((T.int64(128),), "float16"),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            max_abs_value = T.alloc_buffer((T.int64(128),), "float16")
+            scale = T.alloc_buffer((T.int64(128),))
+            for i, k in T.grid(T.int64(128), T.int64(64)):
+                with T.block("max_abs_value"):
+                    v_i, v_k = T.axis.remap("SR", [i, k])
+                    T.reads(A[v_i, v_k])
+                    T.writes(max_abs_value[v_i])
+                    with T.init():
+                        max_abs_value[v_i] = T.float16(-65504)
+                    max_abs_value[v_i] = T.max(max_abs_value[v_i], 
T.fabs(A[v_i, v_k]))
+            for i in range(T.int64(128)):
+                with T.block("scale"):
+                    v_i = T.axis.spatial(T.int64(128), i)
+                    T.reads(max_abs_value[v_i])
+                    T.writes(scale[v_i])
+                    scale[v_i] = T.max(
+                        T.Cast("float32", max_abs_value[v_i]), 
T.float32(0.0001)
+                    ) * T.float32(0.125)
+            for j, i, k in T.grid(T.int64(64), T.int64(64), T.int64(2)):
+                with T.block("w_gathered"):
+                    v_j, v_i, v_k = T.axis.remap("SSR", [j, i, k])
+                    T.reads(A[v_i * T.int64(2) + v_k, v_j], scale[v_i * 
T.int64(2) + v_k])
+                    T.writes(w_gathered[v_j, v_i])
+                    with T.init():
+                        w_gathered[v_j, v_i] = T.int8(0)
+                    w_gathered[v_j, v_i] = T.bitwise_or(
+                        w_gathered[v_j, v_i],
+                        T.if_then_else(
+                            v_i * T.int64(2) + v_k < T.int64(128),
+                            T.shift_left(
+                                T.bitwise_and(
+                                    T.Cast(
+                                        "int8",
+                                        T.min(
+                                            T.max(
+                                                T.round(
+                                                    T.Cast(
+                                                        "float32", A[v_i * 
T.int64(2) + v_k, v_j]
+                                                    )
+                                                    / scale[v_i * T.int64(2) + 
v_k]
+                                                ),
+                                                T.float32(-8),
+                                            ),
+                                            T.float32(7),
+                                        ),
+                                    ),
+                                    T.int8(15),
+                                ),
+                                T.Cast("int8", v_k) * T.int8(4),
+                            ),
+                            T.int8(0),
+                        ),
+                    )
+            for i0 in range(T.int64(128)):
+                with T.block("compute"):
+                    v_i0 = T.axis.spatial(T.int64(128), i0)
+                    T.reads(scale[v_i0])
+                    T.writes(compute[v_i0])
+                    compute[v_i0] = T.Cast("float16", scale[v_i0])
+
+        @R.function
+        def main_bias(
+            x: R.Tensor((64, 64), dtype="float16"),
+            y: R.Tensor((128, 64), dtype="float16"),
+            bias: R.Tensor((1, 128), dtype="float16"),
+        ) -> R.Tensor((64, 128), dtype="float16"):
+            R.func_attr({"num_input": 1})
+            cls = Module
+            with R.dataflow():
+                lv = R.call_tir(
+                    cls.encode,
+                    (y,),
+                    out_sinfo=[R.Tensor((64, 64), dtype="int8"), 
R.Tensor((128,), dtype="float16")],
+                )
+                lv1 = lv[0]
+                lv2 = R.call_pure_packed(
+                    "cutlass.ft_preprocess_weight_int4",
+                    lv1,
+                    80,
+                    sinfo_args=(R.Tensor((64, 64), dtype="int8"),),
+                )
+                lv3: R.Tensor((128,), dtype="float16") = lv[1]
+                lv6 = R.call_tir(
+                    cls.decode, (lv2, lv3), out_sinfo=R.Tensor((64, 128), 
dtype="float16")
+                )
+                lv1_1: R.Tensor((64, 128), dtype="float16") = R.matmul(x, lv6, 
out_dtype="float16")
+                lv2_1: R.Tensor((64, 128), dtype="float16") = R.add(lv1_1, 
bias)
+                R.output(lv2_1)
+            return lv2_1
+
+        @R.function
+        def main_residual(
+            x: R.Tensor((64, 64), dtype="float16"),
+            residual: R.Tensor((64, 128), dtype="float16"),
+            y: R.Tensor((128, 64), dtype="float16"),
+            bias: R.Tensor((1, 128), dtype="float16"),
+        ) -> R.Tensor((64, 128), dtype="float16"):
+            R.func_attr({"num_input": 2})
+            cls = Module
+            with R.dataflow():
+                lv = R.call_tir(
+                    cls.encode,
+                    (y,),
+                    out_sinfo=[R.Tensor((64, 64), dtype="int8"), 
R.Tensor((128,), dtype="float16")],
+                )
+                lv1 = lv[0]
+                lv2 = R.call_pure_packed(
+                    "cutlass.ft_preprocess_weight_int4",

Review Comment:
   Yes that's possible, encode + preprocess and dump. In practice, if we do 
`LiftTransformParams` pass, this preprocess is already done as part of running 
`transform_params` function (together with encode). So we are already doing 
what you described.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to