This is an automated email from the ASF dual-hosted git repository.
yaxingcai pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 65d99bd751 [Unity][CUTLASS] Add layer norm support (#14731)
65d99bd751 is described below
commit 65d99bd7512cf2f5a74446789ed71810701de80d
Author: masahi <[email protected]>
AuthorDate: Sat Apr 29 06:12:35 2023 +0900
[Unity][CUTLASS] Add layer norm support (#14731)
Support offloading layer norm to cutlass
---
python/tvm/contrib/cutlass/attention_operation.py | 4 +-
python/tvm/contrib/cutlass/build.py | 14 ++++++-
python/tvm/contrib/cutlass/gen_tensor_op.py | 8 ++++
python/tvm/contrib/cutlass/layer_norm_operation.py | 44 +++++++++++++++++++++
python/tvm/relax/backend/contrib/cutlass.py | 32 +++++++++++++++
python/tvm/relax/backend/patterns.py | 9 +++++
tests/python/relax/test_codegen_cutlass.py | 46 ++++++++++++++++++++++
7 files changed, 154 insertions(+), 3 deletions(-)
diff --git a/python/tvm/contrib/cutlass/attention_operation.py
b/python/tvm/contrib/cutlass/attention_operation.py
index 0e507d3be1..57c9ef4f91 100644
--- a/python/tvm/contrib/cutlass/attention_operation.py
+++ b/python/tvm/contrib/cutlass/attention_operation.py
@@ -14,9 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=invalid-name, unused-wildcard-import, wildcard-import
+# pylint: disable=invalid-name
"""Generator for CUTLASS attention kernels."""
-from .library import *
+from .library import substitute_template
def instantiate_attention_template(attrs):
diff --git a/python/tvm/contrib/cutlass/build.py
b/python/tvm/contrib/cutlass/build.py
index bf55e8d15b..3c9e8a9e0f 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -776,7 +776,7 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
)
def handle_attention(self, f, op_type):
- """Tune and annotate a dense op."""
+ """Annotate an attention op."""
signature = _extract_relax_function_signature(f)
if _get_call_node(f.body, "relax.nn.attention") is not None:
op_attrs = _get_call_node(f.body, "relax.nn.attention").attrs
@@ -841,6 +841,16 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
}
)
+ def handle_layer_norm(self, f, _):
+ """Annotate a layer norm op."""
+ signature = _extract_relax_function_signature(f)
+ attrs = {}
+ attrs["M"] = reduce(operator.mul, signature["arg0_shape"][:-1], 1)
+ attrs["N"] = signature["arg0_shape"][-1]
+ dtype = signature["arg0_dtype"]
+ attrs["data_type"] = {"float32": "float", "float16":
"cutlass::half_t"}[str(dtype)]
+ return f.with_attrs(attrs)
+
def visit_function_(self, f):
if "Composite" not in f.attrs:
body = super().visit_expr(f.body)
@@ -854,6 +864,8 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
return self.handle_matmul(f, op_type)
elif "attention" in op_type:
return self.handle_attention(f, op_type)
+ elif "layer_norm" in op_type:
+ return self.handle_layer_norm(f, op_type)
raise ValueError("Unsupported composite {}".format(op_type))
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index 392a7bf7e2..5e5ac621ef 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -32,6 +32,7 @@ from . import _ffi_api as ffi
from .attention_operation import instantiate_attention_template
from .conv2d_operation import instantiate_conv2d_template
from .gemm_operation import instantiate_gemm_template
+from .layer_norm_operation import instantiate_layer_norm_template
from .library import (
DataType,
DataTypeSize,
@@ -764,5 +765,12 @@ def instantiate_template(func_name, annotations,
func_args):
attrs["kSupportsBias"] = attrs["scale"] < 0
code = instantiate_attention_template(attrs)
return CodegenResult(code, headers)
+ elif "layer_norm" in func_name:
+ headers.append("cutlass/util/device_layernorm.h")
+ headers.append("cutlass/layout/matrix.h")
+ attrs = {"input": func_args[0], "gamma": func_args[1], "beta":
func_args[2]}
+ attrs.update(dict(annotations))
+ code = instantiate_layer_norm_template(attrs)
+ return CodegenResult(code, headers)
raise ValueError("Do not have a template for {}".format(func_name))
diff --git a/python/tvm/contrib/cutlass/layer_norm_operation.py
b/python/tvm/contrib/cutlass/layer_norm_operation.py
new file mode 100644
index 0000000000..589f559e93
--- /dev/null
+++ b/python/tvm/contrib/cutlass/layer_norm_operation.py
@@ -0,0 +1,44 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Generator for CUTLASS layer norm kernels."""
+from .library import substitute_template
+
+
+def instantiate_layer_norm_template(attrs):
+ """
+ Return CUTLASS host code for layer norm based on
+ a template and the provided attribute map.
+ """
+ template = """
+ using data_type = ${data_type};
+ using namespace cutlass::layout;
+
+ auto M = ${M};
+ auto N = ${N};
+ cutlass::MatrixCoord size(M, N);
+ auto layout_2D = RowMajor::packed(size);
+ auto layout_channels = RowMajor::packed({1, N});
+
+ cutlass::TensorRef<data_type, RowMajor> _input((data_type*)${input}->data,
layout_2D);
+ cutlass::TensorRef<data_type, RowMajor> _gamma((data_type*)${gamma}->data,
layout_channels);
+ cutlass::TensorRef<data_type, RowMajor> _beta((data_type*)${beta}->data,
layout_channels);
+ cutlass::TensorRef<data_type, RowMajor> _output((data_type*)out0->data,
layout_2D);
+
+ cutlass::layernorm(size, _output, _input, _gamma, _beta, NULL);
+ """
+ return substitute_template(template, attrs)
diff --git a/python/tvm/relax/backend/contrib/cutlass.py
b/python/tvm/relax/backend/contrib/cutlass.py
index 7c22ec3dda..86eab773cc 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -30,6 +30,7 @@ from ..patterns import (
make_matmul_pattern,
make_residual_block_pattern,
make_stacked_attention_pattern,
+ make_layer_norm_pattern,
)
@@ -315,12 +316,43 @@ def attention_patterns():
]
+def _check_layer_norm(context: PatternCheckContext) -> bool:
+ attrs = context.matched_expr.attrs
+
+ if not attrs.center or not attrs.scale:
+ return False
+
+ if len(attrs.axes) != 1:
+ # Contiguous inner-most axes can be supported, but reject it for now
for simplicity.
+ return False
+
+ axis = int(attrs.axes[0])
+ rank = len(context.matched_expr.struct_info.shape)
+
+ if axis < 0:
+ axis += rank
+
+ return axis == rank - 1
+
+
+def layer_norm_pattern():
+ """Create a layer norm pattern for CUTLASS."""
+ return [
+ (
+ "cutlass.layer_norm",
+ *make_layer_norm_pattern(),
+ _check_layer_norm,
+ ),
+ ]
+
+
register_patterns(
[
*conv2d_patterns(),
*matmul_patterns(),
*residual_block_patterns(),
*attention_patterns(),
+ *layer_norm_pattern(),
]
)
diff --git a/python/tvm/relax/backend/patterns.py
b/python/tvm/relax/backend/patterns.py
index 7119c6c4b0..7242cb3f0d 100644
--- a/python/tvm/relax/backend/patterns.py
+++ b/python/tvm/relax/backend/patterns.py
@@ -252,3 +252,12 @@ def make_stacked_attention_pattern(start_op: str,
with_bias: bool = False):
else:
out = is_op("relax.nn.attention")(query, key, value)
return out, annotations
+
+
+def make_layer_norm_pattern():
+ """Create a layer norm pattern."""
+ inp = wildcard()
+ gamma = wildcard()
+ beta = wildcard()
+
+ return is_op("relax.nn.layer_norm")(inp, gamma, beta), {}
diff --git a/tests/python/relax/test_codegen_cutlass.py
b/tests/python/relax/test_codegen_cutlass.py
index d5d3142cab..85f96b5e96 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -792,5 +792,51 @@ def test_invalid_residual():
assert "fused_relax_nn_conv2d_relax_add_cutlass" in func_names
[email protected](
+ "data_shape, dtype, axes",
+ [
+ ((2, 128, 64), "float16", [-1]),
+ ((128, 30), "float32", [-1]),
+ ((2, 128, 64), "float32", [1]),
+ ((2, 128, 64), "float32", [1, 2]),
+ ],
+)
+def test_layer_norm(data_shape, dtype, axes):
+ def get_mod(data_shape, dtype, axes):
+ reduced_shape = [data_shape[axis] for axis in axes]
+ with IRBuilder() as builder:
+ with relax_builder.function():
+ R.func_name("main")
+ inp = R.arg("input", R.Tensor(data_shape, dtype))
+ gamma = R.arg("gamma", R.Tensor(reduced_shape, dtype))
+ beta = R.arg("beta", R.Tensor(reduced_shape, dtype))
+
+ with R.dataflow() as frame:
+ output = R.emit(R.nn.layer_norm(inp, gamma, beta, axes))
+ R.output(output)
+
+ R.func_ret_value(frame.output_vars[0])
+
+ func = builder.get()
+ return tvm.IRModule({"main": func})
+
+ Module = get_mod(data_shape, dtype, axes)
+ mod = partition_for_cutlass(Module)
+
+ if len(axes) != 1 or (axes[0] != -1 and axes[0] != len(data_shape) - 1):
+ tvm.ir.assert_structural_equal(mod, Module)
+ return
+
+ mod = relax.transform.RunCodegen()(mod)
+
+ inp = np.random.randn(*data_shape).astype(dtype)
+ gamma = np.random.randn(data_shape[-1]).astype(dtype)
+ beta = np.random.randn(data_shape[-1]).astype(dtype)
+ out = build_and_run(mod, [inp, gamma, beta], "cuda")
+ ref = build_and_run(Module, [inp, gamma, beta], "llvm", legalize=True)
+
+ tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+
+
if __name__ == "__main__":
tvm.testing.main()