This is an automated email from the ASF dual-hosted git repository.

yaxingcai pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 65d99bd751 [Unity][CUTLASS] Add layer norm support (#14731)
65d99bd751 is described below

commit 65d99bd7512cf2f5a74446789ed71810701de80d
Author: masahi <[email protected]>
AuthorDate: Sat Apr 29 06:12:35 2023 +0900

    [Unity][CUTLASS] Add layer norm support (#14731)
    
    Support offloading layer norm to cutlass
---
 python/tvm/contrib/cutlass/attention_operation.py  |  4 +-
 python/tvm/contrib/cutlass/build.py                | 14 ++++++-
 python/tvm/contrib/cutlass/gen_tensor_op.py        |  8 ++++
 python/tvm/contrib/cutlass/layer_norm_operation.py | 44 +++++++++++++++++++++
 python/tvm/relax/backend/contrib/cutlass.py        | 32 +++++++++++++++
 python/tvm/relax/backend/patterns.py               |  9 +++++
 tests/python/relax/test_codegen_cutlass.py         | 46 ++++++++++++++++++++++
 7 files changed, 154 insertions(+), 3 deletions(-)

diff --git a/python/tvm/contrib/cutlass/attention_operation.py 
b/python/tvm/contrib/cutlass/attention_operation.py
index 0e507d3be1..57c9ef4f91 100644
--- a/python/tvm/contrib/cutlass/attention_operation.py
+++ b/python/tvm/contrib/cutlass/attention_operation.py
@@ -14,9 +14,9 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=invalid-name, unused-wildcard-import, wildcard-import
+# pylint: disable=invalid-name
 """Generator for CUTLASS attention kernels."""
-from .library import *
+from .library import substitute_template
 
 
 def instantiate_attention_template(attrs):
diff --git a/python/tvm/contrib/cutlass/build.py 
b/python/tvm/contrib/cutlass/build.py
index bf55e8d15b..3c9e8a9e0f 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -776,7 +776,7 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
         )
 
     def handle_attention(self, f, op_type):
-        """Tune and annotate a dense op."""
+        """Annotate an attention op."""
         signature = _extract_relax_function_signature(f)
         if _get_call_node(f.body, "relax.nn.attention") is not None:
             op_attrs = _get_call_node(f.body, "relax.nn.attention").attrs
@@ -841,6 +841,16 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
             }
         )
 
+    def handle_layer_norm(self, f, _):
+        """Annotate a layer norm op."""
+        signature = _extract_relax_function_signature(f)
+        attrs = {}
+        attrs["M"] = reduce(operator.mul, signature["arg0_shape"][:-1], 1)
+        attrs["N"] = signature["arg0_shape"][-1]
+        dtype = signature["arg0_dtype"]
+        attrs["data_type"] = {"float32": "float", "float16": 
"cutlass::half_t"}[str(dtype)]
+        return f.with_attrs(attrs)
+
     def visit_function_(self, f):
         if "Composite" not in f.attrs:
             body = super().visit_expr(f.body)
@@ -854,6 +864,8 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
             return self.handle_matmul(f, op_type)
         elif "attention" in op_type:
             return self.handle_attention(f, op_type)
+        elif "layer_norm" in op_type:
+            return self.handle_layer_norm(f, op_type)
 
         raise ValueError("Unsupported composite {}".format(op_type))
 
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py 
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index 392a7bf7e2..5e5ac621ef 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -32,6 +32,7 @@ from . import _ffi_api as ffi
 from .attention_operation import instantiate_attention_template
 from .conv2d_operation import instantiate_conv2d_template
 from .gemm_operation import instantiate_gemm_template
+from .layer_norm_operation import instantiate_layer_norm_template
 from .library import (
     DataType,
     DataTypeSize,
@@ -764,5 +765,12 @@ def instantiate_template(func_name, annotations, 
func_args):
             attrs["kSupportsBias"] = attrs["scale"] < 0
         code = instantiate_attention_template(attrs)
         return CodegenResult(code, headers)
+    elif "layer_norm" in func_name:
+        headers.append("cutlass/util/device_layernorm.h")
+        headers.append("cutlass/layout/matrix.h")
+        attrs = {"input": func_args[0], "gamma": func_args[1], "beta": 
func_args[2]}
+        attrs.update(dict(annotations))
+        code = instantiate_layer_norm_template(attrs)
+        return CodegenResult(code, headers)
 
     raise ValueError("Do not have a template for {}".format(func_name))
diff --git a/python/tvm/contrib/cutlass/layer_norm_operation.py 
b/python/tvm/contrib/cutlass/layer_norm_operation.py
new file mode 100644
index 0000000000..589f559e93
--- /dev/null
+++ b/python/tvm/contrib/cutlass/layer_norm_operation.py
@@ -0,0 +1,44 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Generator for CUTLASS layer norm kernels."""
+from .library import substitute_template
+
+
+def instantiate_layer_norm_template(attrs):
+    """
+    Return CUTLASS host code for layer norm based on
+    a template and the provided attribute map.
+    """
+    template = """
+    using data_type = ${data_type};
+    using namespace cutlass::layout;
+
+    auto M = ${M};
+    auto N = ${N};
+    cutlass::MatrixCoord size(M, N);
+    auto layout_2D = RowMajor::packed(size);
+    auto layout_channels = RowMajor::packed({1, N});
+
+    cutlass::TensorRef<data_type, RowMajor> _input((data_type*)${input}->data, 
layout_2D);
+    cutlass::TensorRef<data_type, RowMajor> _gamma((data_type*)${gamma}->data, 
layout_channels);
+    cutlass::TensorRef<data_type, RowMajor> _beta((data_type*)${beta}->data, 
layout_channels);
+    cutlass::TensorRef<data_type, RowMajor> _output((data_type*)out0->data, 
layout_2D);
+
+    cutlass::layernorm(size, _output, _input, _gamma, _beta, NULL);
+    """
+    return substitute_template(template, attrs)
diff --git a/python/tvm/relax/backend/contrib/cutlass.py 
b/python/tvm/relax/backend/contrib/cutlass.py
index 7c22ec3dda..86eab773cc 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -30,6 +30,7 @@ from ..patterns import (
     make_matmul_pattern,
     make_residual_block_pattern,
     make_stacked_attention_pattern,
+    make_layer_norm_pattern,
 )
 
 
@@ -315,12 +316,43 @@ def attention_patterns():
     ]
 
 
+def _check_layer_norm(context: PatternCheckContext) -> bool:
+    attrs = context.matched_expr.attrs
+
+    if not attrs.center or not attrs.scale:
+        return False
+
+    if len(attrs.axes) != 1:
+        # Contiguous inner-most axes can be supported, but reject it for now 
for simplicity.
+        return False
+
+    axis = int(attrs.axes[0])
+    rank = len(context.matched_expr.struct_info.shape)
+
+    if axis < 0:
+        axis += rank
+
+    return axis == rank - 1
+
+
+def layer_norm_pattern():
+    """Create a layer norm pattern for CUTLASS."""
+    return [
+        (
+            "cutlass.layer_norm",
+            *make_layer_norm_pattern(),
+            _check_layer_norm,
+        ),
+    ]
+
+
 register_patterns(
     [
         *conv2d_patterns(),
         *matmul_patterns(),
         *residual_block_patterns(),
         *attention_patterns(),
+        *layer_norm_pattern(),
     ]
 )
 
diff --git a/python/tvm/relax/backend/patterns.py 
b/python/tvm/relax/backend/patterns.py
index 7119c6c4b0..7242cb3f0d 100644
--- a/python/tvm/relax/backend/patterns.py
+++ b/python/tvm/relax/backend/patterns.py
@@ -252,3 +252,12 @@ def make_stacked_attention_pattern(start_op: str, 
with_bias: bool = False):
     else:
         out = is_op("relax.nn.attention")(query, key, value)
     return out, annotations
+
+
+def make_layer_norm_pattern():
+    """Create a layer norm pattern."""
+    inp = wildcard()
+    gamma = wildcard()
+    beta = wildcard()
+
+    return is_op("relax.nn.layer_norm")(inp, gamma, beta), {}
diff --git a/tests/python/relax/test_codegen_cutlass.py 
b/tests/python/relax/test_codegen_cutlass.py
index d5d3142cab..85f96b5e96 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -792,5 +792,51 @@ def test_invalid_residual():
     assert "fused_relax_nn_conv2d_relax_add_cutlass" in func_names
 
 
[email protected](
+    "data_shape, dtype, axes",
+    [
+        ((2, 128, 64), "float16", [-1]),
+        ((128, 30), "float32", [-1]),
+        ((2, 128, 64), "float32", [1]),
+        ((2, 128, 64), "float32", [1, 2]),
+    ],
+)
+def test_layer_norm(data_shape, dtype, axes):
+    def get_mod(data_shape, dtype, axes):
+        reduced_shape = [data_shape[axis] for axis in axes]
+        with IRBuilder() as builder:
+            with relax_builder.function():
+                R.func_name("main")
+                inp = R.arg("input", R.Tensor(data_shape, dtype))
+                gamma = R.arg("gamma", R.Tensor(reduced_shape, dtype))
+                beta = R.arg("beta", R.Tensor(reduced_shape, dtype))
+
+                with R.dataflow() as frame:
+                    output = R.emit(R.nn.layer_norm(inp, gamma, beta, axes))
+                    R.output(output)
+
+                R.func_ret_value(frame.output_vars[0])
+
+        func = builder.get()
+        return tvm.IRModule({"main": func})
+
+    Module = get_mod(data_shape, dtype, axes)
+    mod = partition_for_cutlass(Module)
+
+    if len(axes) != 1 or (axes[0] != -1 and axes[0] != len(data_shape) - 1):
+        tvm.ir.assert_structural_equal(mod, Module)
+        return
+
+    mod = relax.transform.RunCodegen()(mod)
+
+    inp = np.random.randn(*data_shape).astype(dtype)
+    gamma = np.random.randn(data_shape[-1]).astype(dtype)
+    beta = np.random.randn(data_shape[-1]).astype(dtype)
+    out = build_and_run(mod, [inp, gamma, beta], "cuda")
+    ref = build_and_run(Module, [inp, gamma, beta], "llvm", legalize=True)
+
+    tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to