This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 925cb2bbaa [Unity][BYOC] Add cutlass finegrained decode matmul (#16144)
925cb2bbaa is described below
commit 925cb2bbaa59efedf5b32bee436961e77d979147
Author: Wuwei Lin <[email protected]>
AuthorDate: Tue Nov 21 10:43:00 2023 -0800
[Unity][BYOC] Add cutlass finegrained decode matmul (#16144)
* [Unity][BYOC] Add cutlass finegrained decode matmul
* clean up
* lint
* fix
* trigger ci
* Update submodule
* Update submodule
---
3rdparty/cutlass_fpA_intB_gemm | 2 +-
python/tvm/contrib/cutlass/build.py | 4 +
python/tvm/contrib/cutlass/gemm_operation.py | 16 ++-
python/tvm/contrib/cutlass/gen_tensor_op.py | 1 +
python/tvm/relax/backend/contrib/cutlass.py | 4 +-
tests/python/relax/test_codegen_cutlass.py | 157 +++++++++++++++++++++++++++
6 files changed, 177 insertions(+), 7 deletions(-)
diff --git a/3rdparty/cutlass_fpA_intB_gemm b/3rdparty/cutlass_fpA_intB_gemm
index 390e821fba..942c01db41 160000
--- a/3rdparty/cutlass_fpA_intB_gemm
+++ b/3rdparty/cutlass_fpA_intB_gemm
@@ -1 +1 @@
-Subproject commit 390e821fbad2356089aab603d7116c6c820eae65
+Subproject commit 942c01db41c7bba09e14dc0c4ee35f9d73d2568b
diff --git a/python/tvm/contrib/cutlass/build.py
b/python/tvm/contrib/cutlass/build.py
index 671bca7d02..1c0a30c62d 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -707,6 +707,8 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
lhs_shape = signature[f"{lhs_arg}_shape"]
rhs_shape = signature[f"{rhs_arg}_shape"]
ret_shape = signature["ret_shape"]
+ scale_arg = f"arg{arg_idx['scales']}"
+ scale_shape = signature[f"{scale_arg}_shape"]
N = ret_shape[-1]
attrs = {
@@ -717,6 +719,8 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
"bias_arg_idx": arg_idx.get("bias"),
"activation": "identity",
}
+ # TODO(wuwei): find a better way to get group size
+ attrs["group_size"] = 64 if len(scale_shape) == 2 and scale_shape[0]
!= 1 else -1
attrs["batch_rank"] = len(lhs_shape[:-1])
attrs["M"] = reduce(operator.mul, lhs_shape[:-1], 1)
diff --git a/python/tvm/contrib/cutlass/gemm_operation.py
b/python/tvm/contrib/cutlass/gemm_operation.py
index e20f1e0d0c..2639a0359a 100644
--- a/python/tvm/contrib/cutlass/gemm_operation.py
+++ b/python/tvm/contrib/cutlass/gemm_operation.py
@@ -414,9 +414,17 @@ def instantiate_gemm_template(attrs):
def emit_fp16A_intB_matmul(attrs):
"""Return CUTLASS host code for fp16 A and int4 or int8 B GEMM."""
+ if attrs["group_size"] > 0:
+ attrs["quant_op"] =
"cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY"
+ else:
+ attrs["quant_op"] = "cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY"
+ attrs["group_size"] = "k"
+
attrs["template_common"] = substitute_template(
"""
using namespace fastertransformer;
+ constexpr auto QuantOp = ${quant_op};
+
int m = ${M};
int n = ${B_arg}->shape[1] * ${float_per_int};
int k = ${B_arg}->shape[0];
@@ -430,24 +438,24 @@ def emit_fp16A_intB_matmul(attrs):
template = """
${template_common}
- gemm_fp16_int_bias_act(static_cast<cutlass::half_t*>(${A_arg}->data),
+ gemm_fp16_int_bias_act<${weight_dtype},
QuantOp>(static_cast<cutlass::half_t*>(${A_arg}->data),
static_cast<${weight_dtype}*>(${B_arg}->data),
static_cast<cutlass::half_t*>(${scales_arg}->data),
${bias},
static_cast<cutlass::half_t*>(out0->data),
"${activation}",
- m, n, k, ${bias_stride}, nullptr, 0, stream);
+ m, n, k, ${group_size}, ${bias_stride}, nullptr, 0, stream);
"""
template_residual = """
${template_common}
-
gemm_fp16_int_bias_act_residual(static_cast<cutlass::half_t*>(${A_arg}->data),
+ gemm_fp16_int_bias_act_residual<${weight_dtype},
QuantOp>(static_cast<cutlass::half_t*>(${A_arg}->data),
static_cast<${weight_dtype}*>(${B_arg}->data),
static_cast<cutlass::half_t*>(${scales_arg}->data),
${bias},
static_cast<cutlass::half_t*>(${residual_arg}->data),
static_cast<cutlass::half_t*>(out0->data), "${activation}",
"${binary_op}", "${unary_op}",
- m, n, k, nullptr, 0, stream);
+ m, n, k, ${group_size}, nullptr, 0, stream);
"""
if "residual_arg" in attrs:
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index d42791d71b..15629ddd3d 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -534,6 +534,7 @@ def instantiate_template(func_name, annotations, func_args):
attrs["activation"] = annotations.get("activation", "identity")
attrs["bias_stride"] = annotations["bias_stride"]
attrs["M"] = annotations["M"]
+ attrs["group_size"] = annotations["group_size"]
if not isinstance(attrs["M"], tvm.tir.IntImm):
attrs["M"] = get_flattened_batch_dim(
diff --git a/python/tvm/relax/backend/contrib/cutlass.py
b/python/tvm/relax/backend/contrib/cutlass.py
index 100b7aa2fd..c7780b7c67 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -269,8 +269,8 @@ def _check_decode_matmul(ctx):
if scales.struct_info.dtype != "float16":
return False
- # scale shape needs to be (N,)
- if len(scales.struct_info.shape) != 1 or scales.struct_info.shape[0] != N:
+ # scale shape needs to be (N,) or (1, N) or (K // group_size, N)
+ if len(scales.struct_info.shape) > 2 or scales.struct_info.shape[-1] != N:
return False
if "bias" in ctx.annotated_expr:
diff --git a/tests/python/relax/test_codegen_cutlass.py
b/tests/python/relax/test_codegen_cutlass.py
index 9bec214ab9..8ad8fbc531 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -1996,6 +1996,163 @@ def test_fp16A_int8B_gemm_batched():
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+def test_fp16A_int8B_gemm_batched_finegrained():
+ @I.ir_module
+ class Module:
+ @T.prim_func
+ def decode(
+ A: T.Buffer((T.int64(128), T.int64(128)), "int8"),
+ B: T.Buffer((T.int64(2), T.int64(128)), "float16"),
+ decode_1: T.Buffer((T.int64(128), T.int64(128)), "float16"),
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ for i, j in T.grid(T.int64(128), T.int64(128)):
+ with T.block("decode"):
+ v_i, v_j = T.axis.remap("SS", [i, j])
+ T.reads(A[v_i, v_j], B[v_i // T.int64(64), v_j])
+ T.writes(decode_1[v_i, v_j])
+ decode_1[v_i, v_j] = T.Cast("float16", A[v_i, v_j]) *
B[v_i // T.int64(64), v_j]
+
+ @T.prim_func
+ def encode(
+ A: T.Buffer((T.int64(128), T.int64(128)), "float16"),
+ w_gathered: T.Buffer((T.int64(128), T.int64(128)), "int8"),
+ compute: T.Buffer(
+ (
+ T.int64(2),
+ T.int64(128),
+ ),
+ "float16",
+ ),
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ max_abs_value = T.alloc_buffer(
+ (
+ T.int64(2),
+ T.int64(128),
+ ),
+ "float16",
+ )
+ scale = T.alloc_buffer(
+ (
+ T.int64(2),
+ T.int64(128),
+ )
+ )
+ for i, j, k in T.grid(T.int64(2), T.int64(128), T.int64(64)):
+ with T.block("max_abs_value"):
+ v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
+ T.reads(A[v_j, v_i * T.int64(64) + v_k])
+ T.writes(max_abs_value[v_i, v_j])
+ with T.init():
+ max_abs_value[v_i, v_j] = T.float16(-65504)
+ max_abs_value[v_i, v_j] = T.max(
+ max_abs_value[v_i, v_j], T.fabs(A[v_j, v_i *
T.int64(64) + v_k])
+ )
+ for i, j in T.grid(T.int64(2), T.int64(128)):
+ with T.block("scale"):
+ v_i, v_j = T.axis.remap("SS", [i, j])
+ T.reads(max_abs_value[v_i, v_j])
+ T.writes(scale[v_i, v_j])
+ scale[v_i, v_j] = T.max(
+ T.Cast("float32", max_abs_value[v_i, v_j]),
T.float32(0.0001)
+ ) * T.float32(0.0078125)
+ for j, i in T.grid(T.int64(128), T.int64(128)):
+ with T.block("w_gathered"):
+ v_j, v_i = T.axis.remap("SS", [j, i])
+ T.reads(A[v_i, v_j], scale[v_j // T.int64(64), v_i])
+ T.writes(w_gathered[v_j, v_i])
+ w_gathered[v_j, v_i] = T.Cast(
+ "int8",
+ T.min(
+ T.max(
+ T.round(
+ T.Cast("float32", A[v_i, v_j]) / scale[v_j
// T.int64(64), v_i]
+ ),
+ T.float32(-128),
+ ),
+ T.float32(127),
+ ),
+ )
+ for i0, i1 in T.grid(T.int64(2), T.int64(128)):
+ with T.block("compute"):
+ v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+ T.reads(scale[v_i0, v_i1])
+ T.writes(compute[v_i0, v_i1])
+ compute[v_i0, v_i1] = T.Cast("float16", scale[v_i0, v_i1])
+
+ @R.function
+ def main(
+ x: R.Tensor(("b", 128, 128), dtype="float16"),
+ y: R.Tensor((128, 128), dtype="float16"),
+ ) -> R.Tensor(("b", 128, 128), dtype="float16"):
+ R.func_attr({"num_input": 1})
+ cls = Module
+ b = T.int64()
+ with R.dataflow():
+ lv = R.call_tir(
+ cls.encode,
+ (y,),
+ out_sinfo=[
+ R.Tensor((128, 128), dtype="int8"),
+ R.Tensor((2, 128), dtype="float16"),
+ ],
+ )
+ lv1: R.Tensor((128, 128), dtype="int8") = lv[0]
+ lv2: R.Tensor((128, 128), dtype="int8") = R.call_pure_packed(
+ "cutlass.ft_preprocess_weight",
+ lv1,
+ R.prim_value(80),
+ R.prim_value(0),
+ sinfo_args=(R.Tensor((128, 128), dtype="int8"),),
+ )
+ lv3: R.Tensor((2, 128), dtype="float16") = lv[1]
+ lv4: R.Tensor((128, 128), dtype="int8") =
R.builtin.stop_lift_params(lv2)
+ lv5: R.Tensor((2, 128), dtype="float16") =
R.builtin.stop_lift_params(lv3)
+ lv6 = R.call_tir(
+ cls.decode, (lv4, lv5), out_sinfo=R.Tensor((128, 128),
dtype="float16")
+ )
+ lv1_1: R.Tensor((b, 128, 128), dtype="float16") = R.matmul(
+ x, lv6, out_dtype="float16"
+ )
+ R.output(lv1_1)
+ return lv1_1
+
+ x_shape = (4, 128, 128)
+ y_shape = (128, 128)
+
+ mod = partition_for_cutlass(Module)
+
+ mod = relax.transform.RunCodegen(
+ {"cutlass": {"sm": 80, "find_first_valid": False}},
+ )(mod)
+
+ x = np.random.randn(*x_shape).astype("float16")
+ y = np.random.normal(0, 0.002, size=y_shape).astype("float16")
+
+ mod = relax.pipeline.get_pipeline()(mod)
+ mod = relax.transform.LiftTransformParams()(mod)
+
+ mod_transform, mod_deploy, transform_func_name =
split_transform_deploy_mod(mod)
+
+ ex = relax.build(mod_transform, target="llvm")
+ vm = relax.vm.VirtualMachine(ex, tvm.cpu(0))
+
+ (packed_weight, scales,) = vm[
+ transform_func_name
+ ]((tvm.nd.array(y),))
+
+ dev = tvm.device("cuda", 0)
+ ex = relax.build(mod_deploy, target="cuda")
+ vm = relax.vm.VirtualMachine(ex, dev)
+
+ x_nd = tvm.nd.array(x, dev)
+ inp = [x_nd, packed_weight.copyto(dev), scales.copyto(dev)]
+ out = vm["main"](*inp).numpy()
+ ref = np.dot(x, y.transpose())
+ tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+
+
def test_attention_rewrite_multi_query():
@I.ir_module
class Module: