sunggg commented on code in PR #15111:
URL: https://github.com/apache/tvm/pull/15111#discussion_r1232777164


##########
tests/python/relax/test_codegen_cutlass.py:
##########
@@ -1250,5 +1250,243 @@ def main(
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_fp16A_int4B_gemm():
+    @I.ir_module
+    class Module:
+        @T.prim_func
+        def decode(
+            A: T.Buffer((T.int64(64), T.int64(64)), "int8"),
+            B: T.Buffer((T.int64(128),), "float16"),
+            decode_1: T.Buffer((T.int64(64), T.int64(128)), "float16"),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            for i, j in T.grid(T.int64(64), T.int64(128)):
+                with T.block("decode"):
+                    v_i, v_j = T.axis.remap("SS", [i, j])
+                    T.reads(A[v_i, v_j // T.int64(2)], B[v_j])
+                    T.writes(decode_1[v_i, v_j])
+                    decode_1[v_i, v_j] = (
+                        T.Cast(
+                            "float16",
+                            T.shift_right(
+                                T.shift_left(
+                                    T.bitwise_and(
+                                        T.shift_right(
+                                            T.Cast("int32", A[v_i, v_j // 
T.int64(2)]),
+                                            T.Cast("int32", v_j % T.int64(2)) 
* 4,
+                                        ),
+                                        15,
+                                    ),
+                                    28,
+                                ),
+                                28,
+                            ),
+                        )
+                        * B[v_j]
+                    )
+
+        @T.prim_func
+        def encode(
+            A: T.Buffer((T.int64(128), T.int64(64)), "float16"),
+            w_gathered: T.Buffer((T.int64(64), T.int64(64)), "int8"),
+            compute: T.Buffer((T.int64(128),), "float16"),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            max_abs_value = T.alloc_buffer((T.int64(128),), "float16")
+            scale = T.alloc_buffer((T.int64(128),))
+            for i, k in T.grid(T.int64(128), T.int64(64)):
+                with T.block("max_abs_value"):
+                    v_i, v_k = T.axis.remap("SR", [i, k])
+                    T.reads(A[v_i, v_k])
+                    T.writes(max_abs_value[v_i])
+                    with T.init():
+                        max_abs_value[v_i] = T.float16(-65504)
+                    max_abs_value[v_i] = T.max(max_abs_value[v_i], 
T.fabs(A[v_i, v_k]))
+            for i in range(T.int64(128)):
+                with T.block("scale"):
+                    v_i = T.axis.spatial(T.int64(128), i)
+                    T.reads(max_abs_value[v_i])
+                    T.writes(scale[v_i])
+                    scale[v_i] = T.max(
+                        T.Cast("float32", max_abs_value[v_i]), 
T.float32(0.0001)
+                    ) * T.float32(0.125)
+            for j, i, k in T.grid(T.int64(64), T.int64(64), T.int64(2)):
+                with T.block("w_gathered"):
+                    v_j, v_i, v_k = T.axis.remap("SSR", [j, i, k])
+                    T.reads(A[v_i * T.int64(2) + v_k, v_j], scale[v_i * 
T.int64(2) + v_k])
+                    T.writes(w_gathered[v_j, v_i])
+                    with T.init():
+                        w_gathered[v_j, v_i] = T.int8(0)
+                    w_gathered[v_j, v_i] = T.bitwise_or(
+                        w_gathered[v_j, v_i],
+                        T.if_then_else(
+                            v_i * T.int64(2) + v_k < T.int64(128),
+                            T.shift_left(
+                                T.bitwise_and(
+                                    T.Cast(
+                                        "int8",
+                                        T.min(
+                                            T.max(
+                                                T.round(
+                                                    T.Cast(
+                                                        "float32", A[v_i * 
T.int64(2) + v_k, v_j]
+                                                    )
+                                                    / scale[v_i * T.int64(2) + 
v_k]
+                                                ),
+                                                T.float32(-8),
+                                            ),
+                                            T.float32(7),
+                                        ),
+                                    ),
+                                    T.int8(15),
+                                ),
+                                T.Cast("int8", v_k) * T.int8(4),
+                            ),
+                            T.int8(0),
+                        ),
+                    )
+            for i0 in range(T.int64(128)):
+                with T.block("compute"):
+                    v_i0 = T.axis.spatial(T.int64(128), i0)
+                    T.reads(scale[v_i0])
+                    T.writes(compute[v_i0])
+                    compute[v_i0] = T.Cast("float16", scale[v_i0])
+
+        @R.function
+        def main_bias(
+            x: R.Tensor((64, 64), dtype="float16"),
+            y: R.Tensor((128, 64), dtype="float16"),
+            bias: R.Tensor((1, 128), dtype="float16"),
+        ) -> R.Tensor((64, 128), dtype="float16"):
+            R.func_attr({"num_input": 1})
+            cls = Module
+            with R.dataflow():
+                lv = R.call_tir(
+                    cls.encode,
+                    (y,),
+                    out_sinfo=[R.Tensor((64, 64), dtype="int8"), 
R.Tensor((128,), dtype="float16")],
+                )
+                lv1 = lv[0]
+                lv2 = R.call_pure_packed(
+                    "cutlass.ft_preprocess_weight_int4",
+                    lv1,
+                    80,
+                    sinfo_args=(R.Tensor((64, 64), dtype="int8"),),
+                )
+                lv3: R.Tensor((128,), dtype="float16") = lv[1]
+                lv6 = R.call_tir(

Review Comment:
   I'm wondering, in the future, once we settle down representative 
quantization schemes, would it make sense to introduce relax-level decode op? 



##########
src/runtime/contrib/cutlass/weight_preprocess.cc:
##########
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
+
+#include 
"../../../3rdparty/cutlass_fpA_intB_gemm/cutlass_kernels/cutlass_preprocessors.h"
+
+namespace tvm {
+namespace runtime {
+
+TVM_REGISTER_GLOBAL("cutlass.ft_preprocess_weight_int4")
+    .set_body_typed([](NDArray packed_weight, int sm) {
+      int rows = packed_weight->shape[0];
+      int cols = packed_weight->shape[1];
+      std::vector<int8_t> input_cpu(rows * cols);
+      std::vector<int8_t> output_cpu(rows * cols);
+      packed_weight.CopyToBytes(input_cpu.data(), input_cpu.size());
+      // multiply cols by 2 since the "col" params in preprocess_weights 
refers to the column of
+      // the unpacked weight.
+      fastertransformer::preprocess_weights(output_cpu.data(), 
input_cpu.data(), rows, cols * 2,

Review Comment:
   Maybe good to add some documentation about why we need this preprocess, why 
it is happening on CPU, and if it is possible/better to move it on GPU?



##########
tests/python/relax/test_codegen_cutlass.py:
##########
@@ -1250,5 +1250,243 @@ def main(
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_fp16A_int4B_gemm():
+    @I.ir_module
+    class Module:
+        @T.prim_func
+        def decode(
+            A: T.Buffer((T.int64(64), T.int64(64)), "int8"),
+            B: T.Buffer((T.int64(128),), "float16"),
+            decode_1: T.Buffer((T.int64(64), T.int64(128)), "float16"),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            for i, j in T.grid(T.int64(64), T.int64(128)):
+                with T.block("decode"):
+                    v_i, v_j = T.axis.remap("SS", [i, j])
+                    T.reads(A[v_i, v_j // T.int64(2)], B[v_j])
+                    T.writes(decode_1[v_i, v_j])
+                    decode_1[v_i, v_j] = (
+                        T.Cast(
+                            "float16",
+                            T.shift_right(
+                                T.shift_left(
+                                    T.bitwise_and(
+                                        T.shift_right(
+                                            T.Cast("int32", A[v_i, v_j // 
T.int64(2)]),
+                                            T.Cast("int32", v_j % T.int64(2)) 
* 4,
+                                        ),
+                                        15,
+                                    ),
+                                    28,
+                                ),
+                                28,
+                            ),
+                        )
+                        * B[v_j]
+                    )
+
+        @T.prim_func
+        def encode(
+            A: T.Buffer((T.int64(128), T.int64(64)), "float16"),
+            w_gathered: T.Buffer((T.int64(64), T.int64(64)), "int8"),
+            compute: T.Buffer((T.int64(128),), "float16"),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            max_abs_value = T.alloc_buffer((T.int64(128),), "float16")
+            scale = T.alloc_buffer((T.int64(128),))
+            for i, k in T.grid(T.int64(128), T.int64(64)):
+                with T.block("max_abs_value"):
+                    v_i, v_k = T.axis.remap("SR", [i, k])
+                    T.reads(A[v_i, v_k])
+                    T.writes(max_abs_value[v_i])
+                    with T.init():
+                        max_abs_value[v_i] = T.float16(-65504)
+                    max_abs_value[v_i] = T.max(max_abs_value[v_i], 
T.fabs(A[v_i, v_k]))
+            for i in range(T.int64(128)):
+                with T.block("scale"):
+                    v_i = T.axis.spatial(T.int64(128), i)
+                    T.reads(max_abs_value[v_i])
+                    T.writes(scale[v_i])
+                    scale[v_i] = T.max(
+                        T.Cast("float32", max_abs_value[v_i]), 
T.float32(0.0001)
+                    ) * T.float32(0.125)
+            for j, i, k in T.grid(T.int64(64), T.int64(64), T.int64(2)):
+                with T.block("w_gathered"):
+                    v_j, v_i, v_k = T.axis.remap("SSR", [j, i, k])
+                    T.reads(A[v_i * T.int64(2) + v_k, v_j], scale[v_i * 
T.int64(2) + v_k])
+                    T.writes(w_gathered[v_j, v_i])
+                    with T.init():
+                        w_gathered[v_j, v_i] = T.int8(0)
+                    w_gathered[v_j, v_i] = T.bitwise_or(
+                        w_gathered[v_j, v_i],
+                        T.if_then_else(
+                            v_i * T.int64(2) + v_k < T.int64(128),
+                            T.shift_left(
+                                T.bitwise_and(
+                                    T.Cast(
+                                        "int8",
+                                        T.min(
+                                            T.max(
+                                                T.round(
+                                                    T.Cast(
+                                                        "float32", A[v_i * 
T.int64(2) + v_k, v_j]
+                                                    )
+                                                    / scale[v_i * T.int64(2) + 
v_k]
+                                                ),
+                                                T.float32(-8),
+                                            ),
+                                            T.float32(7),
+                                        ),
+                                    ),
+                                    T.int8(15),
+                                ),
+                                T.Cast("int8", v_k) * T.int8(4),
+                            ),
+                            T.int8(0),
+                        ),
+                    )
+            for i0 in range(T.int64(128)):
+                with T.block("compute"):
+                    v_i0 = T.axis.spatial(T.int64(128), i0)
+                    T.reads(scale[v_i0])
+                    T.writes(compute[v_i0])
+                    compute[v_i0] = T.Cast("float16", scale[v_i0])
+
+        @R.function
+        def main_bias(
+            x: R.Tensor((64, 64), dtype="float16"),
+            y: R.Tensor((128, 64), dtype="float16"),
+            bias: R.Tensor((1, 128), dtype="float16"),
+        ) -> R.Tensor((64, 128), dtype="float16"):
+            R.func_attr({"num_input": 1})
+            cls = Module
+            with R.dataflow():
+                lv = R.call_tir(
+                    cls.encode,
+                    (y,),
+                    out_sinfo=[R.Tensor((64, 64), dtype="int8"), 
R.Tensor((128,), dtype="float16")],
+                )
+                lv1 = lv[0]
+                lv2 = R.call_pure_packed(
+                    "cutlass.ft_preprocess_weight_int4",
+                    lv1,
+                    80,
+                    sinfo_args=(R.Tensor((64, 64), dtype="int8"),),
+                )
+                lv3: R.Tensor((128,), dtype="float16") = lv[1]
+                lv6 = R.call_tir(
+                    cls.decode, (lv2, lv3), out_sinfo=R.Tensor((64, 128), 
dtype="float16")
+                )
+                lv1_1: R.Tensor((64, 128), dtype="float16") = R.matmul(x, lv6, 
out_dtype="float16")
+                lv2_1: R.Tensor((64, 128), dtype="float16") = R.add(lv1_1, 
bias)
+                R.output(lv2_1)
+            return lv2_1
+
+        @R.function
+        def main_residual(
+            x: R.Tensor((64, 64), dtype="float16"),
+            residual: R.Tensor((64, 128), dtype="float16"),
+            y: R.Tensor((128, 64), dtype="float16"),
+            bias: R.Tensor((1, 128), dtype="float16"),
+        ) -> R.Tensor((64, 128), dtype="float16"):
+            R.func_attr({"num_input": 2})
+            cls = Module
+            with R.dataflow():
+                lv = R.call_tir(
+                    cls.encode,
+                    (y,),
+                    out_sinfo=[R.Tensor((64, 64), dtype="int8"), 
R.Tensor((128,), dtype="float16")],
+                )
+                lv1 = lv[0]
+                lv2 = R.call_pure_packed(
+                    "cutlass.ft_preprocess_weight_int4",

Review Comment:
   If we always have to apply this preprocess, can we do the preprocess and 
then dump the weight? 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to