This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 63b170d80e [Unity] fp16 A x int B GEMM update - support int8, more
bias shape (#15318)
63b170d80e is described below
commit 63b170d80e093890a0b9c9ac1119fed74df77b23
Author: masahi <[email protected]>
AuthorDate: Tue Jul 18 04:41:47 2023 +0900
[Unity] fp16 A x int B GEMM update - support int8, more bias shape (#15318)
* Support int8 FasterTransformer kernel
* test strided bias
* test gelu act
* update decode check
* update FT rev to disable GCC warning
---
3rdparty/cutlass_fpA_intB_gemm | 2 +-
python/tvm/contrib/cutlass/build.py | 45 +++--
python/tvm/contrib/cutlass/gemm_operation.py | 26 +--
python/tvm/contrib/cutlass/gen_tensor_op.py | 19 ++-
python/tvm/contrib/cutlass/layer_norm_operation.py | 4 +-
python/tvm/relax/backend/contrib/cutlass.py | 39 +++--
src/runtime/contrib/cutlass/weight_preprocess.cc | 11 +-
tests/python/relax/test_codegen_cutlass.py | 190 ++++++++++++++++++---
8 files changed, 261 insertions(+), 75 deletions(-)
diff --git a/3rdparty/cutlass_fpA_intB_gemm b/3rdparty/cutlass_fpA_intB_gemm
index 634184f502..390e821fba 160000
--- a/3rdparty/cutlass_fpA_intB_gemm
+++ b/3rdparty/cutlass_fpA_intB_gemm
@@ -1 +1 @@
-Subproject commit 634184f50272227f95805fb8ef1eb8c4da373304
+Subproject commit 390e821fbad2356089aab603d7116c6c820eae65
diff --git a/python/tvm/contrib/cutlass/build.py
b/python/tvm/contrib/cutlass/build.py
index 2d99cca8f9..b6681ae80f 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -700,7 +700,12 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
arg_idx = _extract_arg_idx(op_type, f)
signature = _extract_relax_function_signature(f)
lhs_arg = f"arg{arg_idx['lhs']}"
+ rhs_arg = f"arg{arg_idx['w_encoded']}"
lhs_shape = signature[f"{lhs_arg}_shape"]
+ rhs_shape = signature[f"{rhs_arg}_shape"]
+ ret_shape = signature["ret_shape"]
+ N = ret_shape[-1]
+
attrs = {
"op_type": op_type,
"lhs_arg_idx": arg_idx["lhs"],
@@ -708,8 +713,23 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
"scales_arg_idx": arg_idx["scales"],
"bias_arg_idx": arg_idx.get("bias"),
"batch_offset": len(lhs_shape) - 2,
+ "activation": "identity",
}
+ attrs["bias_stride"] = 0
+
+ if "bias" in arg_idx:
+ bias_shape = signature[f"arg{arg_idx['bias']}_shape"]
+ bias_shape_1d = reduce(operator.mul, bias_shape, 1)
+ if bias_shape_1d != bias_shape[-1]:
+ attrs["bias_stride"] = bias_shape[-1]
+
+ if N == rhs_shape[1]:
+ attrs["weight_nbit"] = 8
+ else:
+ assert N == rhs_shape[1] * 2
+ attrs["weight_nbit"] = 4
+
if "residual" in op_type:
residual_pos = op_type.find("residual_")
postfix = op_type[residual_pos + len("residual_") :]
@@ -739,6 +759,11 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
"residual_arg_idx": arg_idx["residual"],
}
)
+ else:
+ for act in ["relu", "silu", "gelu"]:
+ if act in op_type:
+ attrs["activation"] = act
+ break
return f.with_attrs(attrs)
@@ -912,18 +937,8 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
}
)
- def handle_layer_norm(self, f, _):
- """Annotate a layer norm op."""
- signature = _extract_relax_function_signature(f)
- attrs = {}
- attrs["M"] = reduce(operator.mul, signature["arg0_shape"][:-1], 1)
- attrs["N"] = signature["arg0_shape"][-1]
- dtype = signature["arg0_dtype"]
- attrs["data_type"] = {"float32": "float", "float16":
"cutlass::half_t"}[str(dtype)]
- return f.with_attrs(attrs)
-
- def handle_rms_norm(self, f, _):
- """Annotate a rms norm op."""
+ def handle_norm(self, f, _):
+ """Annotate a layer or rms norm op."""
signature = _extract_relax_function_signature(f)
attrs = {}
attrs["batch_rank"] = len(signature["arg0_shape"][:-1])
@@ -948,10 +963,8 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
return self.handle_matmul(f, op_type)
elif "attention" in op_type:
return self.handle_attention(f, op_type)
- elif "layer_norm" in op_type:
- return self.handle_layer_norm(f, op_type)
- elif "rms_norm" in op_type:
- return self.handle_rms_norm(f, op_type)
+ elif "layer_norm" in op_type or "rms_norm" in op_type:
+ return self.handle_norm(f, op_type)
raise ValueError("Unsupported composite {}".format(op_type))
diff --git a/python/tvm/contrib/cutlass/gemm_operation.py
b/python/tvm/contrib/cutlass/gemm_operation.py
index 3fa6e9be8d..e445316dfd 100644
--- a/python/tvm/contrib/cutlass/gemm_operation.py
+++ b/python/tvm/contrib/cutlass/gemm_operation.py
@@ -412,14 +412,13 @@ def instantiate_gemm_template(attrs):
return substitute_template(template, attrs)
-def emit_fp16A_int4B_matmul(attrs):
- """Return CUTLASS host code for fp16 A and int4 B GEMM."""
-
+def emit_fp16A_intB_matmul(attrs):
+ """Return CUTLASS host code for fp16 A and int4 or int8 B GEMM."""
attrs["template_common"] = substitute_template(
"""
using namespace fastertransformer;
int m = ${A_arg}->shape[${batch_offset}];
- int n = ${B_arg}->shape[1] * 2;
+ int n = ${B_arg}->shape[1] * ${float_per_int};
int k = ${B_arg}->shape[0];
auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
@@ -451,17 +450,20 @@ def emit_fp16A_int4B_matmul(attrs):
m, n, k, nullptr, 0, stream);
"""
- if "residual_arg" in attrs and "bias_arg" in attrs:
- template_residual = substitute_template(
- template_residual, {"bias":
"static_cast<cutlass::half_t*>(${bias_arg}->data)"}
- )
- return substitute_template(template_residual, attrs)
-
if "residual_arg" in attrs:
- template_residual = substitute_template(template_residual, {"bias":
"nullptr"})
+ if "bias_arg" in attrs:
+ bias = "static_cast<cutlass::half_t*>(${bias_arg}->data)"
+ else:
+ bias = "nullptr"
+
+ template_residual = substitute_template(template_residual, {"bias":
bias})
return substitute_template(template_residual, attrs)
if "bias_arg" in attrs:
- return substitute_template(template_bias, attrs)
+ template = substitute_template(
+ template, {"bias":
"static_cast<cutlass::half_t*>(${bias_arg}->data)"}
+ )
+ else:
+ template = substitute_template(template, {"bias": "nullptr"})
return substitute_template(template, attrs)
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index 8c8bcc20c3..bf02d8f7b8 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -31,7 +31,7 @@ from tvm.tir import IntImm
from . import _ffi_api as ffi
from .attention_operation import instantiate_attention_template
from .conv2d_operation import instantiate_conv2d_template
-from .gemm_operation import instantiate_gemm_template, emit_fp16A_int4B_matmul
+from .gemm_operation import instantiate_gemm_template, emit_fp16A_intB_matmul
from .layer_norm_operation import instantiate_layer_norm_template
from .rms_norm_operation import instantiate_rms_norm_template
from .library import (
@@ -526,6 +526,7 @@ def instantiate_template(func_name, annotations, func_args):
attrs["scales_arg"] = func_args[scales_arg_idx]
attrs["batch_offset"] = _get_optional_int_annotation(annotations,
"batch_offset", 0)
attrs["activation"] = annotations.get("activation", "identity")
+ attrs["bias_stride"] = annotations["bias_stride"]
if bias_arg_idx is not None:
attrs["bias_arg"] = func_args[bias_arg_idx]
@@ -535,7 +536,15 @@ def instantiate_template(func_name, annotations,
func_args):
attrs["binary_op"] = annotations["binary_op"]
attrs["unary_op"] = annotations["unary_op"]
- code = emit_fp16A_int4B_matmul(attrs)
+ if annotations["weight_nbit"] == 4:
+ attrs["weight_dtype"] = "cutlass::uint4b_t"
+ attrs["float_per_int"] = 2
+ else:
+ assert annotations["weight_nbit"] == 8
+ attrs["weight_dtype"] = "uint8_t"
+ attrs["float_per_int"] = 1
+
+ code = emit_fp16A_intB_matmul(attrs)
return CodegenResult(code, headers)
elif "dense" in func_name or "matmul" in func_name:
@@ -792,6 +801,12 @@ def instantiate_template(func_name, annotations,
func_args):
headers.append("cutlass/layout/matrix.h")
attrs = {"input": func_args[0], "gamma": func_args[1], "beta":
func_args[2]}
attrs.update(dict(annotations))
+
+ if isinstance(attrs["M"], tvm.tir.Var):
+ attrs["M"] = " * ".join(
+ ["{}->shape[{}]".format(func_args[0], i) for i in
range(int(attrs["batch_rank"]))]
+ )
+
code = instantiate_layer_norm_template(attrs)
return CodegenResult(code, headers)
elif "rms_norm" in func_name:
diff --git a/python/tvm/contrib/cutlass/layer_norm_operation.py
b/python/tvm/contrib/cutlass/layer_norm_operation.py
index ad2e730d27..760ddec3bf 100644
--- a/python/tvm/contrib/cutlass/layer_norm_operation.py
+++ b/python/tvm/contrib/cutlass/layer_norm_operation.py
@@ -28,8 +28,8 @@ def instantiate_layer_norm_template(attrs):
using data_type = ${data_type};
using namespace cutlass::layout;
- auto M = ${M};
- auto N = ${N};
+ int M = ${M};
+ int N = ${N};
cutlass::MatrixCoord size(M, N);
auto layout_2D = RowMajor::packed(size);
auto layout_channels = RowMajor::packed({1, N});
diff --git a/python/tvm/relax/backend/contrib/cutlass.py
b/python/tvm/relax/backend/contrib/cutlass.py
index 5cb5a6f3d7..cdce3225b5 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -102,6 +102,15 @@ def _has_dependency(from_var: Var, to_var: Var,
var_usages: Mapping[Var, Sequenc
return False
+def _is_same_shape(shape1, shape2):
+ analyzer = tvm.arith.Analyzer()
+ return all([analyzer.can_prove_equal(s1, s2) for s1, s2 in zip(shape1,
shape2)])
+
+
+def _is_bias_like(shape, out_channel):
+ return shape[-1] == out_channel and _shape_1d(shape) == out_channel
+
+
def _check_residual(root_call: Call, context: PatternCheckContext) -> bool:
if "residual" in context.annotated_expr:
residual = context.annotated_expr["residual"]
@@ -116,13 +125,11 @@ def _check_residual(root_call: Call, context:
PatternCheckContext) -> bool:
# If residual depends on the result of the root call, this cannot
be handled by cutlass.
return False
- shape1 = [int(s) for s in root_var.struct_info.shape]
- shape2 = [int(s) for s in residual.struct_info.shape]
-
+ shape1 = root_var.struct_info.shape
+ shape2 = residual.struct_info.shape
out_channel = shape1[-1]
- is_bias_like = lambda shape: (shape[-1] == out_channel and
_shape_1d(shape) == out_channel)
- if shape1 != shape2 and not is_bias_like(shape2):
+ if not _is_same_shape(shape1, shape2) and not _is_bias_like(shape2,
out_channel):
return False
return True
@@ -239,7 +246,6 @@ def _check_decode_matmul(ctx):
return False
N = root.struct_info.shape[-1]
- K = call_tir_decode.struct_info.shape[0]
if ctx.annotated_expr["lhs"].struct_info.dtype != "float16":
return False
@@ -254,14 +260,10 @@ def _check_decode_matmul(ctx):
if (
isinstance(packed_weight, Call)
and isinstance(packed_weight.args[0], ExternFunc)
- and packed_weight.args[0].global_symbol !=
"cutlass.ft_preprocess_weight_int4"
+ and packed_weight.args[0].global_symbol !=
"cutlass.ft_preprocess_weight"
):
return False
- # packed weight needs to be of shape (K, N // 2)
- if packed_weight.struct_info.shape[0] != K or
packed_weight.struct_info.shape[1] != N // 2:
- return False
-
scales = ctx.annotated_expr["scales"]
if scales.struct_info.dtype != "float16":
@@ -271,11 +273,13 @@ def _check_decode_matmul(ctx):
if len(scales.struct_info.shape) != 1 or scales.struct_info.shape[0] != N:
return False
- # bias shape needs to be (N,), possibly with additional axes on the front.
if "bias" in ctx.annotated_expr:
+ out_shape = root.struct_info.shape
bias_shape = ctx.annotated_expr["bias"].struct_info.shape
- bias_shape_1d = reduce(operator.mul, bias_shape, 1)
- if bias_shape_1d != bias_shape[-1]:
+
+ # bias shape needs to be (N,), possibly with additional axes on the
front.
+ # It can also have the same shape as the output.
+ if not _is_bias_like(bias_shape, N) and not _is_same_shape(out_shape,
bias_shape):
return False
return True
@@ -309,16 +313,15 @@ def decode_matmul_patterns():
else:
out = matmul
- # TODO(masahi): Support more activations
- if "silu" in name:
- out = is_op("relax.nn.silu")(out)
+ if "gelu" in name:
+ out = is_op("relax.nn.gelu")(out)
return name, out, annotations, _check_decode_matmul
return [
_decode_matmul_pattern("cutlass.decode_matmul"),
_decode_matmul_pattern("cutlass.decode_matmul_bias"),
- _decode_matmul_pattern("cutlass.decode_matmul_silu"),
+ _decode_matmul_pattern("cutlass.decode_matmul_bias_gelu"),
]
diff --git a/src/runtime/contrib/cutlass/weight_preprocess.cc
b/src/runtime/contrib/cutlass/weight_preprocess.cc
index 902926ea1b..ef80627cc7 100644
--- a/src/runtime/contrib/cutlass/weight_preprocess.cc
+++ b/src/runtime/contrib/cutlass/weight_preprocess.cc
@@ -35,8 +35,8 @@ namespace runtime {
// black box.
//
// The preprocessing functions are defined in C++, so we need to copy the
input weight to CPU.
-TVM_REGISTER_GLOBAL("cutlass.ft_preprocess_weight_int4")
- .set_body_typed([](NDArray packed_weight, int sm) {
+TVM_REGISTER_GLOBAL("cutlass.ft_preprocess_weight")
+ .set_body_typed([](NDArray packed_weight, int sm, bool is_int4) {
int rows = packed_weight->shape[0];
int cols = packed_weight->shape[1];
std::vector<int8_t> input_cpu(rows * cols);
@@ -44,8 +44,11 @@ TVM_REGISTER_GLOBAL("cutlass.ft_preprocess_weight_int4")
packed_weight.CopyToBytes(input_cpu.data(), input_cpu.size());
// multiply cols by 2 since the "col" params in preprocess_weights
refers to the column of
// the unpacked weight.
- fastertransformer::preprocess_weights(output_cpu.data(),
input_cpu.data(), rows, cols * 2,
- true /*is_int4*/, sm);
+ if (is_int4) {
+ cols *= 2;
+ }
+ fastertransformer::preprocess_weights(output_cpu.data(),
input_cpu.data(), rows, cols,
+ is_int4, sm);
auto out = NDArray::Empty(packed_weight.Shape(), packed_weight->dtype,
packed_weight->device);
out.CopyFromBytes(output_cpu.data(), output_cpu.size());
return out;
diff --git a/tests/python/relax/test_codegen_cutlass.py
b/tests/python/relax/test_codegen_cutlass.py
index 1528141e4a..02f15ad3d7 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -1258,6 +1258,25 @@ def test_attention_rewrite_fp16():
tvm.ir.assert_structural_equal(mod, Expected)
+def split_transform_deploy_mod(mod):
+ mod_transform = tvm.IRModule()
+ mod_deploy = tvm.IRModule().with_attrs(mod.attrs)
+
+ transform_func_name = None
+
+ for gv, func in mod.functions.items():
+ if "transform_params" in gv.name_hint:
+ transform_func_name = gv.name_hint
+ mod_transform[gv] = func
+ elif isinstance(func, tvm.tir.PrimFunc):
+ mod_transform[gv] = func
+ else:
+ mod_deploy[gv] = func
+
+ assert transform_func_name is not None
+ return mod_transform, mod_deploy, transform_func_name
+
+
def test_fp16A_int4B_gemm():
@I.ir_module
class Module:
@@ -1378,9 +1397,10 @@ def test_fp16A_int4B_gemm():
)
lv1 = lv[0]
lv2 = R.call_pure_packed(
- "cutlass.ft_preprocess_weight_int4",
+ "cutlass.ft_preprocess_weight",
lv1,
80,
+ True,
sinfo_args=(R.Tensor((64, 64), dtype="int8"),),
)
lv3: R.Tensor((128,), dtype="float16") = lv[1]
@@ -1409,9 +1429,10 @@ def test_fp16A_int4B_gemm():
)
lv1 = lv[0]
lv2 = R.call_pure_packed(
- "cutlass.ft_preprocess_weight_int4",
+ "cutlass.ft_preprocess_weight",
lv1,
80,
+ True,
sinfo_args=(R.Tensor((64, 64), dtype="int8"),),
)
lv3: R.Tensor((128,), dtype="float16") = lv[1]
@@ -1424,24 +1445,6 @@ def test_fp16A_int4B_gemm():
R.output(lv3_1)
return lv3_1
- def split_transform_deploy_mod(mod):
- mod_transform = tvm.IRModule()
- mod_deploy = tvm.IRModule().with_attrs(mod.attrs)
-
- transform_func_name = None
-
- for gv, func in mod.functions.items():
- if "transform_params" in gv.name_hint:
- transform_func_name = gv.name_hint
- mod_transform[gv] = func
- elif isinstance(func, tvm.tir.PrimFunc):
- mod_transform[gv] = func
- else:
- mod_deploy[gv] = func
-
- assert transform_func_name is not None
- return mod_transform, mod_deploy, transform_func_name
-
x_shape = (64, 64)
y_shape = (128, 64)
@@ -1496,6 +1499,153 @@ def test_fp16A_int4B_gemm():
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+def test_fp16A_int8B_gemm():
+ @I.ir_module
+ class Module:
+ @T.prim_func
+ def decode(
+ A: T.Buffer((T.int64(64), T.int64(64)), "int8"),
+ B: T.Buffer((T.int64(64),), "float16"),
+ decode_1: T.Buffer((T.int64(64), T.int64(64)), "float16"),
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ for i, j in T.grid(T.int64(64), T.int64(64)):
+ with T.block("decode"):
+ v_i, v_j = T.axis.remap("SS", [i, j])
+ T.reads(A[v_i, v_j], B[v_j])
+ T.writes(decode_1[v_i, v_j])
+ decode_1[v_i, v_j] = T.Cast("float16", A[v_i, v_j]) *
B[v_j]
+
+ @T.prim_func
+ def encode(
+ A: T.Buffer((T.int64(64), T.int64(64)), "float16"),
+ w_gathered: T.Buffer((T.int64(64), T.int64(64)), "int8"),
+ compute: T.Buffer((T.int64(64),), "float16"),
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ max_abs_value = T.alloc_buffer((T.int64(64),), "float16")
+ scale = T.alloc_buffer((T.int64(64),))
+ for i, k in T.grid(T.int64(64), T.int64(64)):
+ with T.block("max_abs_value"):
+ v_i, v_k = T.axis.remap("SR", [i, k])
+ T.reads(A[v_i, v_k])
+ T.writes(max_abs_value[v_i])
+ with T.init():
+ max_abs_value[v_i] = T.float16(-65504)
+ max_abs_value[v_i] = T.max(max_abs_value[v_i],
T.fabs(A[v_i, v_k]))
+ for i in range(T.int64(64)):
+ with T.block("scale"):
+ v_i = T.axis.spatial(T.int64(64), i)
+ T.reads(max_abs_value[v_i])
+ T.writes(scale[v_i])
+ scale[v_i] = T.max(
+ T.Cast("float32", max_abs_value[v_i]),
T.float32(0.0001)
+ ) * T.float32(0.0078125)
+ for j, i in T.grid(T.int64(64), T.int64(64)):
+ with T.block("w_gathered"):
+ v_j, v_i = T.axis.remap("SS", [j, i])
+ T.reads(A[v_i, v_j], scale[v_i])
+ T.writes(w_gathered[v_j, v_i])
+ w_gathered[v_j, v_i] = T.Cast(
+ "int8",
+ T.min(
+ T.max(
+ T.round(T.Cast("float32", A[v_i, v_j]) /
scale[v_i]),
+ T.float32(-128),
+ ),
+ T.float32(127),
+ ),
+ )
+ for i0 in range(T.int64(64)):
+ with T.block("compute"):
+ v_i0 = T.axis.spatial(T.int64(64), i0)
+ T.reads(scale[v_i0])
+ T.writes(compute[v_i0])
+ compute[v_i0] = T.Cast("float16", scale[v_i0])
+
+ @R.function
+ def main(
+ x: R.Tensor((64, 64), dtype="float16"),
+ y: R.Tensor((64, 64), dtype="float16"),
+ bias: R.Tensor((64, 64), dtype="float16"),
+ ) -> R.Tensor((64, 64), dtype="float16"):
+ R.func_attr({"num_input": 1})
+ cls = Module
+ with R.dataflow():
+ lv = R.call_tir(
+ cls.encode,
+ (y,),
+ out_sinfo=[R.Tensor((64, 64), dtype="int8"),
R.Tensor((64,), dtype="float16")],
+ )
+ lv1: R.Tensor((64, 64), dtype="int8") = lv[0]
+ lv2: R.Tensor((64, 64), dtype="int8") = R.call_pure_packed(
+ "cutlass.ft_preprocess_weight",
+ lv1,
+ R.prim_value(80),
+ R.prim_value(0),
+ sinfo_args=(R.Tensor((64, 64), dtype="int8"),),
+ )
+ lv3: R.Tensor((64,), dtype="float16") = lv[1]
+ lv4: R.Tensor((64, 64), dtype="int8") =
R.builtin.stop_lift_params(lv2)
+ lv5: R.Tensor((64,), dtype="float16") =
R.builtin.stop_lift_params(lv3)
+ lv6 = R.call_tir(
+ cls.decode, (lv4, lv5), out_sinfo=R.Tensor((64, 64),
dtype="float16")
+ )
+ lv1_1: R.Tensor((64, 64), dtype="float16") = R.matmul(x, lv6,
out_dtype="float16")
+ lv2_1: R.Tensor((64, 128), dtype="float16") = R.add(lv1_1,
bias)
+ lv2_2: R.Tensor((64, 128), dtype="float16") = R.nn.gelu(lv2_1)
+ R.output(lv2_2)
+ return lv2_2
+
+ x_shape = (64, 64)
+ y_shape = (64, 64)
+
+ mod = partition_for_cutlass(Module)
+ func_names = [name.name_hint for (name, _) in mod.functions.items()]
+ assert "fused_decode_relax_matmul_relax_add_relax_nn_gelu_cutlass" in
func_names
+
+ mod = relax.transform.RunCodegen(
+ {"cutlass": {"sm": 80, "find_first_valid": False}},
+ )(mod)
+
+ x = np.random.randn(*x_shape).astype("float16")
+ y = np.random.normal(0, 0.002, size=y_shape).astype("float16")
+ bias = np.random.randn(x_shape[0], y_shape[0]).astype("float16")
+
+ mod = relax.pipeline.get_pipeline()(mod)
+ mod = relax.transform.LiftTransformParams()(mod)
+
+ mod_transform, mod_deploy, transform_func_name =
split_transform_deploy_mod(mod)
+
+ ex = relax.build(mod_transform, target="llvm")
+ vm = relax.vm.VirtualMachine(ex, tvm.cpu(0))
+
+ packed_weight, scales, bias_trans = vm[transform_func_name](
+ (tvm.nd.array(y), tvm.nd.array(bias))
+ )
+
+ dev = tvm.device("cuda", 0)
+ ex = relax.build(mod_deploy, target="cuda")
+ vm = relax.vm.VirtualMachine(ex, dev)
+
+ x_nd = tvm.nd.array(x, dev)
+ params = (packed_weight.copyto(dev), scales.copyto(dev),
bias_trans.copyto(dev))
+ inp = [x_nd, params]
+ out = vm["main"](*inp).numpy()
+
+ def gelu_fp16(x):
+ erf_inp = x * (0.5**0.5)
+ from scipy.special import erf
+
+ erf_out = erf(erf_inp.astype("float32")).astype("float16")
+ return x * 0.5 * (1.0 + erf_out)
+
+ ref = gelu_fp16(np.dot(x, y.transpose()) + bias)
+ tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+
+
def test_rms_norm():
@I.ir_module
class Module: