This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new ab8ac58241 [Unity][CUTLASS] Offload RMS norm (#15288)
ab8ac58241 is described below

commit ab8ac58241884d6ec11fba14d59a80c70c50b88a
Author: masahi <[email protected]>
AuthorDate: Wed Jul 12 02:22:05 2023 +0900

    [Unity][CUTLASS] Offload RMS norm (#15288)
    
    I've add a simple RMS kernel to cutlass in 
https://github.com/NVIDIA/cutlass/pull/979. For llama workload, it is about 2x 
faster than MS-tuned kernels.
---
 3rdparty/cutlass                                 |  2 +-
 python/tvm/contrib/cutlass/build.py              | 13 +++++
 python/tvm/contrib/cutlass/gen_tensor_op.py      | 14 +++++
 python/tvm/contrib/cutlass/rms_norm_operation.py | 43 +++++++++++++++
 python/tvm/relax/backend/contrib/cutlass.py      | 21 ++++++++
 python/tvm/relax/backend/patterns.py             | 22 +++++++-
 tests/python/relax/test_codegen_cutlass.py       | 66 ++++++++++++++++++++++++
 7 files changed, 178 insertions(+), 3 deletions(-)

diff --git a/3rdparty/cutlass b/3rdparty/cutlass
index 92ebbf1dc4..f679663224 160000
--- a/3rdparty/cutlass
+++ b/3rdparty/cutlass
@@ -1 +1 @@
-Subproject commit 92ebbf1dc4612bf838ace6f2e6d262919f0abd63
+Subproject commit f679663224ef5a67c33dc94f89619128a53221c1
diff --git a/python/tvm/contrib/cutlass/build.py 
b/python/tvm/contrib/cutlass/build.py
index a0aae539cf..2d99cca8f9 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -922,6 +922,17 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
         attrs["data_type"] = {"float32": "float", "float16": 
"cutlass::half_t"}[str(dtype)]
         return f.with_attrs(attrs)
 
+    def handle_rms_norm(self, f, _):
+        """Annotate a rms norm op."""
+        signature = _extract_relax_function_signature(f)
+        attrs = {}
+        attrs["batch_rank"] = len(signature["arg0_shape"][:-1])
+        attrs["M"] = reduce(operator.mul, signature["arg0_shape"][:-1], 1)
+        attrs["N"] = signature["arg0_shape"][-1]
+        dtype = signature["arg0_dtype"]
+        attrs["data_type"] = {"float32": "float", "float16": 
"cutlass::half_t"}[str(dtype)]
+        return f.with_attrs(attrs)
+
     def visit_function_(self, f):
         if "Composite" not in f.attrs:
             body = super().visit_expr(f.body)
@@ -939,6 +950,8 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
             return self.handle_attention(f, op_type)
         elif "layer_norm" in op_type:
             return self.handle_layer_norm(f, op_type)
+        elif "rms_norm" in op_type:
+            return self.handle_rms_norm(f, op_type)
 
         raise ValueError("Unsupported composite {}".format(op_type))
 
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py 
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index 3ca99244fc..2988f9a8a2 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -33,6 +33,7 @@ from .attention_operation import 
instantiate_attention_template
 from .conv2d_operation import instantiate_conv2d_template
 from .gemm_operation import instantiate_gemm_template, emit_fp16A_int4B_matmul
 from .layer_norm_operation import instantiate_layer_norm_template
+from .rms_norm_operation import instantiate_rms_norm_template
 from .library import (
     DataType,
     DataTypeSize,
@@ -791,5 +792,18 @@ def instantiate_template(func_name, annotations, 
func_args):
         attrs.update(dict(annotations))
         code = instantiate_layer_norm_template(attrs)
         return CodegenResult(code, headers)
+    elif "rms_norm" in func_name:
+        headers.append("cutlass/util/device_rmsnorm.h")
+        headers.append("cutlass/layout/matrix.h")
+        attrs = {"input": func_args[0], "weight": func_args[1]}
+        attrs.update(dict(annotations))
+
+        if isinstance(attrs["M"], tvm.tir.Var):
+            attrs["M"] = " * ".join(
+                ["{}->shape[{}]".format(func_args[0], i) for i in 
range(int(attrs["batch_rank"]))]
+            )
+
+        code = instantiate_rms_norm_template(attrs)
+        return CodegenResult(code, headers)
 
     raise ValueError(f"Do not have a template for {func_name}")
diff --git a/python/tvm/contrib/cutlass/rms_norm_operation.py 
b/python/tvm/contrib/cutlass/rms_norm_operation.py
new file mode 100644
index 0000000000..e24d6bc39a
--- /dev/null
+++ b/python/tvm/contrib/cutlass/rms_norm_operation.py
@@ -0,0 +1,43 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Generator for CUTLASS rms norm kernels."""
+from .library import substitute_template
+
+
+def instantiate_rms_norm_template(attrs):
+    """
+    Return CUTLASS host code for rms norm based on
+    a template and the provided attribute map.
+    """
+    template = """
+    using data_type = ${data_type};
+    using namespace cutlass::layout;
+
+    int M = ${M};
+    int N = ${N};
+    cutlass::MatrixCoord size(M, N);
+    auto layout_2D = RowMajor::packed(size);
+    auto layout_channels = RowMajor::packed({1, N});
+
+    cutlass::TensorRef<data_type, RowMajor> _input((data_type*)${input}->data, 
layout_2D);
+    cutlass::TensorRef<data_type, RowMajor> 
_weight((data_type*)${weight}->data, layout_channels);
+    cutlass::TensorRef<data_type, RowMajor> _output((data_type*)out0->data, 
layout_2D);
+
+    cutlass::rmsnorm(size, _output, _input, _weight, nullptr);
+    """
+    return substitute_template(template, attrs)
diff --git a/python/tvm/relax/backend/contrib/cutlass.py 
b/python/tvm/relax/backend/contrib/cutlass.py
index 5380b1a9b5..5cb5a6f3d7 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -42,6 +42,7 @@ from ..patterns import (
     make_attention_rewrite_pattern,
     make_fused_bias_activation_pattern,
     make_layer_norm_pattern,
+    make_rms_norm_pattern,
     make_matmul_pattern,
     make_residual_block_pattern,
     make_stacked_attention_pattern,
@@ -472,6 +473,25 @@ def layer_norm_pattern():
     ]
 
 
+def _check_rms_norm(ctx: PatternCheckContext) -> bool:
+    rms_norm = ctx.annotated_expr["rms_norm"]
+    if "rms_norm" not in rms_norm.args[0].name_hint:
+        return False
+
+    return True
+
+
+def rms_norm_pattern():
+    """Create a RMS norm pattern for CUTLASS."""
+    return [
+        (
+            "cutlass.rms_norm",
+            *make_rms_norm_pattern(),
+            _check_rms_norm,
+        ),
+    ]
+
+
 def attention_rewrite_patterns():
     """
     Returns a list of all attention rewriting patterns in cutlass BYOC backend.
@@ -495,6 +515,7 @@ register_patterns(
         *residual_block_patterns(),
         *attention_patterns(),
         *layer_norm_pattern(),
+        *rms_norm_pattern(),
     ]
 )
 
diff --git a/python/tvm/relax/backend/patterns.py 
b/python/tvm/relax/backend/patterns.py
index f7e8dd0406..7bdd64176e 100644
--- a/python/tvm/relax/backend/patterns.py
+++ b/python/tvm/relax/backend/patterns.py
@@ -14,12 +14,20 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
+# pylint: disable=invalid-name
 """Common patterns used in BYOC"""
 
 from typing import Dict, Mapping, Tuple, Union
 from tvm.script import relax as R, tir as T
-from tvm.relax.dpl.pattern import DFPattern, is_const, is_op, 
is_tuple_get_item, wildcard
+from tvm.relax.dpl.pattern import (
+    DFPattern,
+    is_const,
+    is_op,
+    is_tuple_get_item,
+    wildcard,
+    GlobalVarPattern,
+    TuplePattern,
+)
 
 
 def _with_bias_activation_pattern(
@@ -263,6 +271,16 @@ def make_layer_norm_pattern():
     return is_op("relax.nn.layer_norm")(inp, gamma, beta), {}
 
 
+def make_rms_norm_pattern():
+    """Create a layer norm pattern."""
+    inp = wildcard()
+    weight = wildcard()
+    gv = GlobalVarPattern()
+    out = is_op("relax.call_tir")(gv, TuplePattern([inp, weight]))
+    annotations = {"gv": gv, "inp": inp, "rms_norm": out}
+    return out, annotations
+
+
 def make_attention_rewrite_pattern(
     qkv_layout: str, out_layout: str, with_bias: bool, with_cast: bool
 ):
diff --git a/tests/python/relax/test_codegen_cutlass.py 
b/tests/python/relax/test_codegen_cutlass.py
index 87c12c4838..e1ce46ecb0 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -1488,5 +1488,71 @@ def test_fp16A_int4B_gemm():
         tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
 
 
+def test_rms_norm():
+    @I.ir_module
+    class Module:
+        @T.prim_func
+        def rms_norm(
+            A: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"),
+            B: T.Buffer((T.int64(4096),), "float16"),
+            rms_norm: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), 
"float16"),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            Ared_temp = T.alloc_buffer((T.int64(1), T.int64(1)))
+            for bsz, i, k in T.grid(T.int64(1), T.int64(1), T.int64(4096)):
+                with T.block("Ared_temp"):
+                    v_bsz, v_i, v_k = T.axis.remap("SSR", [bsz, i, k])
+                    T.reads(A[v_bsz, v_i, v_k])
+                    T.writes(Ared_temp[v_bsz, v_i])
+                    with T.init():
+                        Ared_temp[v_bsz, v_i] = T.float32(0)
+                    Ared_temp[v_bsz, v_i] = Ared_temp[v_bsz, v_i] + T.Cast(
+                        "float32", A[v_bsz, v_i, v_k]
+                    ) * T.Cast("float32", A[v_bsz, v_i, v_k])
+            for bsz, i, k in T.grid(T.int64(1), T.int64(1), T.int64(4096)):
+                with T.block("rms_norm"):
+                    v_bsz, v_i, v_k = T.axis.remap("SSS", [bsz, i, k])
+                    T.reads(B[v_k], A[v_bsz, v_i, v_k], Ared_temp[v_bsz, v_i])
+                    T.writes(rms_norm[v_bsz, v_i, v_k])
+                    rms_norm[v_bsz, v_i, v_k] = T.Cast(
+                        "float16",
+                        T.Cast("float32", B[v_k])
+                        * (
+                            T.Cast("float32", A[v_bsz, v_i, v_k])
+                            / T.sqrt(
+                                Ared_temp[v_bsz, v_i] * 
T.float32(0.000244140625)
+                                + T.float32(9.9999999999999995e-07)
+                            )
+                        ),
+                    )
+
+        @R.function
+        def main(
+            input: R.Tensor((1, 1, 4096), dtype="float16"),
+            weight: R.Tensor((4096,), dtype="float16"),
+        ) -> R.Tensor((1, 1, 4096), dtype="float16"):
+            cls = Module
+            with R.dataflow():
+                lv = R.call_tir(
+                    cls.rms_norm, (input, weight), out_sinfo=R.Tensor((1, 1, 
4096), dtype="float16")
+                )
+                R.output(lv)
+            return lv
+
+    data_shape = (1, 1, 4096)
+    dtype = "float16"
+    mod = partition_for_cutlass(Module)
+
+    mod = relax.transform.RunCodegen()(mod)
+
+    inp = np.random.randn(*data_shape).astype(dtype)
+    weight = np.random.randn(data_shape[-1]).astype(dtype)
+    out = build_and_run(mod, [inp, weight], "cuda")
+    ref = build_and_run(Module, [inp, weight], "llvm", legalize=True)
+
+    tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to