This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new ab8ac58241 [Unity][CUTLASS] Offload RMS norm (#15288)
ab8ac58241 is described below
commit ab8ac58241884d6ec11fba14d59a80c70c50b88a
Author: masahi <[email protected]>
AuthorDate: Wed Jul 12 02:22:05 2023 +0900
[Unity][CUTLASS] Offload RMS norm (#15288)
I've add a simple RMS kernel to cutlass in
https://github.com/NVIDIA/cutlass/pull/979. For llama workload, it is about 2x
faster than MS-tuned kernels.
---
3rdparty/cutlass | 2 +-
python/tvm/contrib/cutlass/build.py | 13 +++++
python/tvm/contrib/cutlass/gen_tensor_op.py | 14 +++++
python/tvm/contrib/cutlass/rms_norm_operation.py | 43 +++++++++++++++
python/tvm/relax/backend/contrib/cutlass.py | 21 ++++++++
python/tvm/relax/backend/patterns.py | 22 +++++++-
tests/python/relax/test_codegen_cutlass.py | 66 ++++++++++++++++++++++++
7 files changed, 178 insertions(+), 3 deletions(-)
diff --git a/3rdparty/cutlass b/3rdparty/cutlass
index 92ebbf1dc4..f679663224 160000
--- a/3rdparty/cutlass
+++ b/3rdparty/cutlass
@@ -1 +1 @@
-Subproject commit 92ebbf1dc4612bf838ace6f2e6d262919f0abd63
+Subproject commit f679663224ef5a67c33dc94f89619128a53221c1
diff --git a/python/tvm/contrib/cutlass/build.py
b/python/tvm/contrib/cutlass/build.py
index a0aae539cf..2d99cca8f9 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -922,6 +922,17 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
attrs["data_type"] = {"float32": "float", "float16":
"cutlass::half_t"}[str(dtype)]
return f.with_attrs(attrs)
+ def handle_rms_norm(self, f, _):
+ """Annotate a rms norm op."""
+ signature = _extract_relax_function_signature(f)
+ attrs = {}
+ attrs["batch_rank"] = len(signature["arg0_shape"][:-1])
+ attrs["M"] = reduce(operator.mul, signature["arg0_shape"][:-1], 1)
+ attrs["N"] = signature["arg0_shape"][-1]
+ dtype = signature["arg0_dtype"]
+ attrs["data_type"] = {"float32": "float", "float16":
"cutlass::half_t"}[str(dtype)]
+ return f.with_attrs(attrs)
+
def visit_function_(self, f):
if "Composite" not in f.attrs:
body = super().visit_expr(f.body)
@@ -939,6 +950,8 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
return self.handle_attention(f, op_type)
elif "layer_norm" in op_type:
return self.handle_layer_norm(f, op_type)
+ elif "rms_norm" in op_type:
+ return self.handle_rms_norm(f, op_type)
raise ValueError("Unsupported composite {}".format(op_type))
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index 3ca99244fc..2988f9a8a2 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -33,6 +33,7 @@ from .attention_operation import
instantiate_attention_template
from .conv2d_operation import instantiate_conv2d_template
from .gemm_operation import instantiate_gemm_template, emit_fp16A_int4B_matmul
from .layer_norm_operation import instantiate_layer_norm_template
+from .rms_norm_operation import instantiate_rms_norm_template
from .library import (
DataType,
DataTypeSize,
@@ -791,5 +792,18 @@ def instantiate_template(func_name, annotations,
func_args):
attrs.update(dict(annotations))
code = instantiate_layer_norm_template(attrs)
return CodegenResult(code, headers)
+ elif "rms_norm" in func_name:
+ headers.append("cutlass/util/device_rmsnorm.h")
+ headers.append("cutlass/layout/matrix.h")
+ attrs = {"input": func_args[0], "weight": func_args[1]}
+ attrs.update(dict(annotations))
+
+ if isinstance(attrs["M"], tvm.tir.Var):
+ attrs["M"] = " * ".join(
+ ["{}->shape[{}]".format(func_args[0], i) for i in
range(int(attrs["batch_rank"]))]
+ )
+
+ code = instantiate_rms_norm_template(attrs)
+ return CodegenResult(code, headers)
raise ValueError(f"Do not have a template for {func_name}")
diff --git a/python/tvm/contrib/cutlass/rms_norm_operation.py
b/python/tvm/contrib/cutlass/rms_norm_operation.py
new file mode 100644
index 0000000000..e24d6bc39a
--- /dev/null
+++ b/python/tvm/contrib/cutlass/rms_norm_operation.py
@@ -0,0 +1,43 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Generator for CUTLASS rms norm kernels."""
+from .library import substitute_template
+
+
+def instantiate_rms_norm_template(attrs):
+ """
+ Return CUTLASS host code for rms norm based on
+ a template and the provided attribute map.
+ """
+ template = """
+ using data_type = ${data_type};
+ using namespace cutlass::layout;
+
+ int M = ${M};
+ int N = ${N};
+ cutlass::MatrixCoord size(M, N);
+ auto layout_2D = RowMajor::packed(size);
+ auto layout_channels = RowMajor::packed({1, N});
+
+ cutlass::TensorRef<data_type, RowMajor> _input((data_type*)${input}->data,
layout_2D);
+ cutlass::TensorRef<data_type, RowMajor>
_weight((data_type*)${weight}->data, layout_channels);
+ cutlass::TensorRef<data_type, RowMajor> _output((data_type*)out0->data,
layout_2D);
+
+ cutlass::rmsnorm(size, _output, _input, _weight, nullptr);
+ """
+ return substitute_template(template, attrs)
diff --git a/python/tvm/relax/backend/contrib/cutlass.py
b/python/tvm/relax/backend/contrib/cutlass.py
index 5380b1a9b5..5cb5a6f3d7 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -42,6 +42,7 @@ from ..patterns import (
make_attention_rewrite_pattern,
make_fused_bias_activation_pattern,
make_layer_norm_pattern,
+ make_rms_norm_pattern,
make_matmul_pattern,
make_residual_block_pattern,
make_stacked_attention_pattern,
@@ -472,6 +473,25 @@ def layer_norm_pattern():
]
+def _check_rms_norm(ctx: PatternCheckContext) -> bool:
+ rms_norm = ctx.annotated_expr["rms_norm"]
+ if "rms_norm" not in rms_norm.args[0].name_hint:
+ return False
+
+ return True
+
+
+def rms_norm_pattern():
+ """Create a RMS norm pattern for CUTLASS."""
+ return [
+ (
+ "cutlass.rms_norm",
+ *make_rms_norm_pattern(),
+ _check_rms_norm,
+ ),
+ ]
+
+
def attention_rewrite_patterns():
"""
Returns a list of all attention rewriting patterns in cutlass BYOC backend.
@@ -495,6 +515,7 @@ register_patterns(
*residual_block_patterns(),
*attention_patterns(),
*layer_norm_pattern(),
+ *rms_norm_pattern(),
]
)
diff --git a/python/tvm/relax/backend/patterns.py
b/python/tvm/relax/backend/patterns.py
index f7e8dd0406..7bdd64176e 100644
--- a/python/tvm/relax/backend/patterns.py
+++ b/python/tvm/relax/backend/patterns.py
@@ -14,12 +14,20 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
+# pylint: disable=invalid-name
"""Common patterns used in BYOC"""
from typing import Dict, Mapping, Tuple, Union
from tvm.script import relax as R, tir as T
-from tvm.relax.dpl.pattern import DFPattern, is_const, is_op,
is_tuple_get_item, wildcard
+from tvm.relax.dpl.pattern import (
+ DFPattern,
+ is_const,
+ is_op,
+ is_tuple_get_item,
+ wildcard,
+ GlobalVarPattern,
+ TuplePattern,
+)
def _with_bias_activation_pattern(
@@ -263,6 +271,16 @@ def make_layer_norm_pattern():
return is_op("relax.nn.layer_norm")(inp, gamma, beta), {}
+def make_rms_norm_pattern():
+ """Create a layer norm pattern."""
+ inp = wildcard()
+ weight = wildcard()
+ gv = GlobalVarPattern()
+ out = is_op("relax.call_tir")(gv, TuplePattern([inp, weight]))
+ annotations = {"gv": gv, "inp": inp, "rms_norm": out}
+ return out, annotations
+
+
def make_attention_rewrite_pattern(
qkv_layout: str, out_layout: str, with_bias: bool, with_cast: bool
):
diff --git a/tests/python/relax/test_codegen_cutlass.py
b/tests/python/relax/test_codegen_cutlass.py
index 87c12c4838..e1ce46ecb0 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -1488,5 +1488,71 @@ def test_fp16A_int4B_gemm():
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+def test_rms_norm():
+ @I.ir_module
+ class Module:
+ @T.prim_func
+ def rms_norm(
+ A: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"),
+ B: T.Buffer((T.int64(4096),), "float16"),
+ rms_norm: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)),
"float16"),
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ Ared_temp = T.alloc_buffer((T.int64(1), T.int64(1)))
+ for bsz, i, k in T.grid(T.int64(1), T.int64(1), T.int64(4096)):
+ with T.block("Ared_temp"):
+ v_bsz, v_i, v_k = T.axis.remap("SSR", [bsz, i, k])
+ T.reads(A[v_bsz, v_i, v_k])
+ T.writes(Ared_temp[v_bsz, v_i])
+ with T.init():
+ Ared_temp[v_bsz, v_i] = T.float32(0)
+ Ared_temp[v_bsz, v_i] = Ared_temp[v_bsz, v_i] + T.Cast(
+ "float32", A[v_bsz, v_i, v_k]
+ ) * T.Cast("float32", A[v_bsz, v_i, v_k])
+ for bsz, i, k in T.grid(T.int64(1), T.int64(1), T.int64(4096)):
+ with T.block("rms_norm"):
+ v_bsz, v_i, v_k = T.axis.remap("SSS", [bsz, i, k])
+ T.reads(B[v_k], A[v_bsz, v_i, v_k], Ared_temp[v_bsz, v_i])
+ T.writes(rms_norm[v_bsz, v_i, v_k])
+ rms_norm[v_bsz, v_i, v_k] = T.Cast(
+ "float16",
+ T.Cast("float32", B[v_k])
+ * (
+ T.Cast("float32", A[v_bsz, v_i, v_k])
+ / T.sqrt(
+ Ared_temp[v_bsz, v_i] *
T.float32(0.000244140625)
+ + T.float32(9.9999999999999995e-07)
+ )
+ ),
+ )
+
+ @R.function
+ def main(
+ input: R.Tensor((1, 1, 4096), dtype="float16"),
+ weight: R.Tensor((4096,), dtype="float16"),
+ ) -> R.Tensor((1, 1, 4096), dtype="float16"):
+ cls = Module
+ with R.dataflow():
+ lv = R.call_tir(
+ cls.rms_norm, (input, weight), out_sinfo=R.Tensor((1, 1,
4096), dtype="float16")
+ )
+ R.output(lv)
+ return lv
+
+ data_shape = (1, 1, 4096)
+ dtype = "float16"
+ mod = partition_for_cutlass(Module)
+
+ mod = relax.transform.RunCodegen()(mod)
+
+ inp = np.random.randn(*data_shape).astype(dtype)
+ weight = np.random.randn(data_shape[-1]).astype(dtype)
+ out = build_and_run(mod, [inp, weight], "cuda")
+ ref = build_and_run(Module, [inp, weight], "llvm", legalize=True)
+
+ tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+
+
if __name__ == "__main__":
tvm.testing.main()