This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new c861351bae [Unity][OP] Sync `rms_norm` with main (#15355)
c861351bae is described below
commit c861351baef3d6755920284270dedb0a160252cd
Author: Yaxing Cai <[email protected]>
AuthorDate: Wed Jul 19 00:50:36 2023 -0700
[Unity][OP] Sync `rms_norm` with main (#15355)
This PR is a fix for the conflict of implementations of `rms_norm` between
unity branch and main branch.
---
include/tvm/relax/attrs/nn.h | 11 +
python/tvm/relax/op/nn/nn.py | 44 ++++
python/tvm/relax/transform/legalize_ops/nn.py | 12 +
src/relax/op/nn/nn.cc | 60 +++++
src/relax/op/nn/nn.h | 3 +
.../python/relax/test_transform_legalize_ops_nn.py | 257 +++++++++++++++++++++
6 files changed, 387 insertions(+)
diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h
index 3368b66983..a59cf5e71f 100644
--- a/include/tvm/relax/attrs/nn.h
+++ b/include/tvm/relax/attrs/nn.h
@@ -296,6 +296,17 @@ struct GroupNormAttrs : public
tvm::AttrsNode<GroupNormAttrs> {
}
}; // struct GroupNormAttrs
+/*! \brief Attributes used in rms_norm operator */
+struct RMSNormAttrs : public tvm::AttrsNode<RMSNormAttrs> {
+ Array<Integer> axes;
+ double epsilon;
+
+ TVM_DECLARE_ATTRS(RMSNormAttrs, "relax.attrs.RMSNormAttrs") {
+ TVM_ATTR_FIELD(axes).describe("The axes that along which the normalization
is applied.");
+ TVM_ATTR_FIELD(epsilon).describe("Small float added to variance to avoid
dividing by zero");
+ }
+}; // struct RMSNormAttrs
+
/*! \brief Attributes used in nll_loss operator */
struct NLLLossAttrs : public tvm::AttrsNode<NLLLossAttrs> {
String reduction;
diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
index 9c4044636c..edba3c505f 100644
--- a/python/tvm/relax/op/nn/nn.py
+++ b/python/tvm/relax/op/nn/nn.py
@@ -925,6 +925,50 @@ def group_norm(
)
+def rms_norm(
+ data: Expr,
+ weight: Expr,
+ bias: Expr,
+ axes: Union[int, List[int]],
+ epsilon: float = 1e-5,
+) -> Expr:
+ r"""
+ Root mean square normalization (Biao Zhang and et al., 2019).
+ Applies root mean square normalization to the n-dimensional input array.
+ This operator takes an n-dimensional input array and normalizes
+ the input using the given axis:
+
+ .. math::
+
+ out = \frac{data}{\sqrt{mean(data, axis)+\epsilon}} * weight + bias
+
+ Parameters
+ ----------
+ data : relax.Expr
+ Input to which rms_norm will be applied.
+
+ weight : relax.Expr
+ The scale factor.
+
+ bias : relax.Expr
+ The offset factor.
+
+ axes : Union[int, List[int]]
+ The axes that along which the normalization is applied.
+
+ epsilon : float
+ Small float added to square mean to avoid dividing by zero.
+
+ Returns
+ -------
+ result : relax.Expr
+ The computed result.
+ """
+ if isinstance(axes, int):
+ axes = [axes]
+ return _ffi_api.rms_norm(data, weight, bias, axes, epsilon) # type: ignore
+
+
def dropout(data: Expr, rate: float = 0.5) -> Expr:
"""Applies the dropout operation to the input tensor.
diff --git a/python/tvm/relax/transform/legalize_ops/nn.py
b/python/tvm/relax/transform/legalize_ops/nn.py
index 85986f0240..b66b2d7b5d 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -334,6 +334,18 @@ def _nn_group_norm(bb: BlockBuilder, call: Call) -> Expr:
)
+@register_legalize("relax.nn.rms_norm")
+def _nn_rms_norm(bb: BlockBuilder, call: Call) -> Expr:
+ return bb.call_te(
+ topi.nn.rms_norm,
+ call.args[0],
+ call.args[1],
+ call.args[2],
+ axis=call.attrs.axes,
+ epsilon=call.attrs.epsilon,
+ )
+
+
@register_legalize("relax.nn.dropout")
def _nn_dropout(bb: BlockBuilder, call: Call) -> Expr:
logging.info("Dropout is handled by frontend translator at this moment and
is not legalized.")
diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc
index b0d5b822d2..79ea8650dd 100644
--- a/src/relax/op/nn/nn.cc
+++ b/src/relax/op/nn/nn.cc
@@ -437,6 +437,66 @@ TVM_REGISTER_OP("relax.nn.group_norm")
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow)
.set_attr<Bool>("FPurity", Bool(true));
+/* relax.nn.rms_norm */
+TVM_REGISTER_NODE_TYPE(RMSNormAttrs);
+
+Expr rms_norm(Expr data, Expr weight, Expr bias, Array<Integer> axes, double
epsilon) {
+ ObjectPtr<RMSNormAttrs> attrs = make_object<RMSNormAttrs>();
+ attrs->axes = std::move(axes);
+ attrs->epsilon = epsilon;
+
+ static const Op& op = Op::Get("relax.nn.rms_norm");
+ return Call(op, {std::move(data), std::move(weight), std::move(bias)},
Attrs{attrs}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.nn.rms_norm").set_body_typed(rms_norm);
+
+StructInfo InferStructInfoRMSNorm(const Call& call, const BlockBuilder& ctx) {
+ Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
+
+ const auto* attrs = call->attrs.as<RMSNormAttrs>();
+ bool unknown_shape = NormCheckDtypeAndShape(call, ctx, input_sinfo,
attrs->axes);
+
+ return unknown_shape ? TensorStructInfo(input_sinfo[0]->dtype,
input_sinfo[0]->ndim)
+ : input_sinfo[0];
+}
+
+InferLayoutOutput InferLayoutRMSNorm(const Call& call,
+ const Map<String, Array<String>>&
desired_layouts,
+ const VarLayoutMap& var_layout_map) {
+ ICHECK(NoDesiredLayout(call, desired_layouts));
+ std::vector<NLayout> initial_layouts;
+ for (size_t i = 0; i < 3; ++i) {
+ const auto* tensor_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[i]);
+ ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
+ ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim";
+ initial_layouts.push_back(InitialLayoutDecision(tensor_sinfo->ndim));
+ }
+ const auto* attrs = call->attrs.as<RMSNormAttrs>();
+ ICHECK(attrs) << "Invalid Call";
+
+ LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);
+ ObjectPtr<RMSNormAttrs> new_attrs = make_object<RMSNormAttrs>(*attrs);
+ std::vector<Integer> new_axis;
+ for (const auto& axis : attrs->axes) {
+ new_axis.push_back(FindAxis(layout->layout, axis->value));
+ }
+ new_attrs->axes = std::move(new_axis);
+ return InferLayoutOutput({layout, initial_layouts[1], initial_layouts[2]},
{layout},
+ Attrs(new_attrs));
+}
+
+TVM_REGISTER_OP("relax.nn.rms_norm")
+ .set_attrs_type<RMSNormAttrs>()
+ .set_num_inputs(3)
+ .add_argument("data", "Tensor", "Input to which rms_norm will be applied.")
+ .add_argument("weight", "Tensor", "The scale factor.")
+ .add_argument("bias", "Tensor", "The offset factor.")
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoRMSNorm)
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutRMSNorm)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow)
+ .set_attr<Bool>("FPurity", Bool(true));
+
/* relax.nn.dropout */
TVM_REGISTER_NODE_TYPE(DropoutAttrs);
diff --git a/src/relax/op/nn/nn.h b/src/relax/op/nn/nn.h
index ce6b369b23..624cfe9078 100644
--- a/src/relax/op/nn/nn.h
+++ b/src/relax/op/nn/nn.h
@@ -78,6 +78,9 @@ Expr layer_norm(Expr data, Expr gamma, Expr beta,
Array<Integer> axes, double ep
Expr group_norm(Expr data, Expr gamma, Expr beta, int num_groups, int
channel_axis,
Array<Integer> axes, double epsilon, bool center, bool scale);
+/*! \brief Compute root mean square normalization. */
+Expr rms_norm(Expr data, Expr weight, Expr bias, Array<Integer> axes, double
epsilon);
+
/*!
* \brief Applies the dropout operation to the input tensor.
* \param data The input data to the operator.
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py
b/tests/python/relax/test_transform_legalize_ops_nn.py
index 27c67e728d..66da2bae55 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -2742,6 +2742,263 @@ def test_group_norm_symbolic():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_rms_norm():
+ # fmt: off
+ @tvm.script.ir_module
+ class RMSNorm:
+ @R.function
+ def main(x: R.Tensor((2, 3, 4, 5), "float32"), weight: R.Tensor((4,
5), "float32"), bias: R.Tensor((4, 5), "float32")) -> R.Tensor((2, 3, 4, 5),
"float32"):
+ gv: R.Tensor((2, 3, 4, 5), "float32") = R.nn.rms_norm(x, weight,
bias, axes=[-2, -1])
+ return gv
+
+ @tvm.script.ir_module
+ class Expected:
+ @T.prim_func
+ def rms_norm(
+ A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)),
"float32"),
+ B: T.Buffer((T.int64(4), T.int64(5)), "float32"),
+ C: T.Buffer((T.int64(4), T.int64(5)), "float32"),
+ T_rms_norm: T.Buffer(
+ (T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"
+ ),
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4),
T.int64(5)))
+ T_multiply_red = T.alloc_buffer((T.int64(2), T.int64(3)))
+ for ax0, ax1, ax2, ax3 in T.grid(
+ T.int64(2), T.int64(3), T.int64(4), T.int64(5)
+ ):
+ with T.block("T_multiply"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3])
+ T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = (
+ A[v_ax0, v_ax1, v_ax2, v_ax3] * A[v_ax0, v_ax1, v_ax2,
v_ax3]
+ )
+ for ax0, ax1, k2, k3 in T.grid(T.int64(2), T.int64(3), T.int64(4),
T.int64(5)):
+ with T.block("T_multiply_red"):
+ v_ax0, v_ax1, v_k2, v_k3 = T.axis.remap("SSRR", [ax0, ax1,
k2, k3])
+ T.reads(T_multiply[v_ax0, v_ax1, v_k2, v_k3])
+ T.writes(T_multiply_red[v_ax0, v_ax1])
+ with T.init():
+ T_multiply_red[v_ax0, v_ax1] = T.float32(0)
+ T_multiply_red[v_ax0, v_ax1] = (
+ T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0,
v_ax1, v_k2, v_k3]
+ )
+ for ax0, ax1, ax2, ax3 in T.grid(
+ T.int64(2), T.int64(3), T.int64(4), T.int64(5)
+ ):
+ with T.block("T_rms_norm"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(
+ A[v_ax0, v_ax1, v_ax2, v_ax3],
+ B[v_ax2, v_ax3],
+ T_multiply_red[v_ax0, v_ax1],
+ C[v_ax2, v_ax3],
+ )
+ T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = (
+ A[v_ax0, v_ax1, v_ax2, v_ax3]
+ * B[v_ax2, v_ax3]
+ * T.rsqrt(
+ T_multiply_red[v_ax0, v_ax1] * T.float32(0.05)
+ + T.float32(1e-5)
+ )
+ + C[v_ax2, v_ax3]
+ )
+
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 4, 5), dtype="float32"),
+ weight: R.Tensor((4, 5), dtype="float32"),
+ bias: R.Tensor((4, 5), dtype="float32"),
+ ) -> R.Tensor((2, 3, 4, 5), dtype="float32"):
+ cls = Expected
+ gv = R.call_tir(
+ cls.rms_norm,
+ (x, weight, bias),
+ out_sinfo=R.Tensor((2, 3, 4, 5), dtype="float32"),
+ )
+ return gv
+ # fmt: on
+ mod = LegalizeOps()(RMSNorm)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_rms_norm_fp16():
+ # fmt: off
+ @tvm.script.ir_module
+ class RMSNorm:
+ @R.function
+ def main(x: R.Tensor((2, 3, 4, 5), "float16"), weight: R.Tensor((4,
5), "float16"), bias: R.Tensor((4, 5), "float16")) -> R.Tensor((2, 3, 4, 5),
"float16"):
+ gv: R.Tensor((2, 3, 4, 5), "float16") = R.nn.rms_norm(x, weight,
bias, axes=[-2, -1])
+ return gv
+
+ @tvm.script.ir_module
+ class Expected:
+ @T.prim_func
+ def rms_norm(
+ A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)),
"float16"),
+ B: T.Buffer((T.int64(4), T.int64(5)), "float16"),
+ C: T.Buffer((T.int64(4), T.int64(5)), "float16"),
+ T_rms_norm: T.Buffer(
+ (T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float16"
+ ),
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ T_multiply = T.alloc_buffer(
+ (T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float16"
+ )
+ T_multiply_red = T.alloc_buffer((T.int64(2), T.int64(3)),
"float16")
+ for ax0, ax1, ax2, ax3 in T.grid(
+ T.int64(2), T.int64(3), T.int64(4), T.int64(5)
+ ):
+ with T.block("T_multiply"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3])
+ T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = (
+ A[v_ax0, v_ax1, v_ax2, v_ax3] * A[v_ax0, v_ax1, v_ax2,
v_ax3]
+ )
+ for ax0, ax1, k2, k3 in T.grid(T.int64(2), T.int64(3), T.int64(4),
T.int64(5)):
+ with T.block("T_multiply_red"):
+ v_ax0, v_ax1, v_k2, v_k3 = T.axis.remap("SSRR", [ax0, ax1,
k2, k3])
+ T.reads(T_multiply[v_ax0, v_ax1, v_k2, v_k3])
+ T.writes(T_multiply_red[v_ax0, v_ax1])
+ with T.init():
+ T_multiply_red[v_ax0, v_ax1] = T.float16(0)
+ T_multiply_red[v_ax0, v_ax1] = (
+ T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0,
v_ax1, v_k2, v_k3]
+ )
+ for ax0, ax1, ax2, ax3 in T.grid(
+ T.int64(2), T.int64(3), T.int64(4), T.int64(5)
+ ):
+ with T.block("T_rms_norm"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(
+ A[v_ax0, v_ax1, v_ax2, v_ax3],
+ B[v_ax2, v_ax3],
+ T_multiply_red[v_ax0, v_ax1],
+ C[v_ax2, v_ax3],
+ )
+ T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = (
+ A[v_ax0, v_ax1, v_ax2, v_ax3]
+ * B[v_ax2, v_ax3]
+ * T.rsqrt(
+ T_multiply_red[v_ax0, v_ax1] / (T.float16(4) *
T.float16(5))
+ + T.float16(1e-5)
+ )
+ + C[v_ax2, v_ax3]
+ )
+
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 4, 5), dtype="float16"),
+ weight: R.Tensor((4, 5), dtype="float16"),
+ bias: R.Tensor((4, 5), dtype="float16"),
+ ) -> R.Tensor((2, 3, 4, 5), dtype="float16"):
+ cls = Expected
+ gv = R.call_tir(
+ cls.rms_norm,
+ (x, weight, bias),
+ out_sinfo=R.Tensor((2, 3, 4, 5), dtype="float16"),
+ )
+ return gv
+ # fmt: on
+ mod = LegalizeOps()(RMSNorm)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_rms_norm_symbolic():
+ # fmt: off
+ @tvm.script.ir_module
+ class RMSNorm:
+ @R.function
+ def main(x: R.Tensor(("n", "s", "f"), "float32"), weight:
R.Tensor(("s", "f"), "float32"), bias: R.Tensor(("s", "f"), "float32")) ->
R.Tensor(("n", "s", "f"), "float32"):
+ n = T.int64()
+ s = T.int64()
+ f = T.int64()
+ gv: R.Tensor((n, s, f), "float32") = R.nn.rms_norm(x, weight,
bias, axes=[1, 2])
+ return gv
+
+ @tvm.script.ir_module
+ class Expected:
+ @T.prim_func
+ def rms_norm(
+ var_A: T.handle, var_B: T.handle, var_C: T.handle, var_T_rms_norm:
T.handle
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ n, s, f = T.int64(), T.int64(), T.int64()
+ A = T.match_buffer(var_A, (n, s, f))
+ B = T.match_buffer(var_B, (s, f))
+ C = T.match_buffer(var_C, (s, f))
+ T_rms_norm = T.match_buffer(var_T_rms_norm, (n, s, f))
+ # with T.block("root"):
+ T_multiply = T.alloc_buffer((n, s, f))
+ T_multiply_red = T.alloc_buffer((n,))
+ for ax0, ax1, ax2 in T.grid(n, s, f):
+ with T.block("T_multiply"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(A[v_ax0, v_ax1, v_ax2])
+ T.writes(T_multiply[v_ax0, v_ax1, v_ax2])
+ T_multiply[v_ax0, v_ax1, v_ax2] = (
+ A[v_ax0, v_ax1, v_ax2] * A[v_ax0, v_ax1, v_ax2]
+ )
+ for ax0, k1, k2 in T.grid(n, s, f):
+ with T.block("T_multiply_red"):
+ v_ax0, v_k1, v_k2 = T.axis.remap("SRR", [ax0, k1, k2])
+ T.reads(T_multiply[v_ax0, v_k1, v_k2])
+ T.writes(T_multiply_red[v_ax0])
+ with T.init():
+ T_multiply_red[v_ax0] = T.float32(0)
+ T_multiply_red[v_ax0] = (
+ T_multiply_red[v_ax0] + T_multiply[v_ax0, v_k1, v_k2]
+ )
+ for ax0, ax1, ax2 in T.grid(n, s, f):
+ with T.block("T_rms_norm"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(
+ A[v_ax0, v_ax1, v_ax2],
+ B[v_ax1, v_ax2],
+ T_multiply_red[v_ax0],
+ C[v_ax1, v_ax2],
+ )
+ T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2])
+ T_rms_norm[v_ax0, v_ax1, v_ax2] = (
+ A[v_ax0, v_ax1, v_ax2]
+ * B[v_ax1, v_ax2]
+ * T.rsqrt(
+ T_multiply_red[v_ax0]
+ / (T.Cast("float32", s) * T.Cast("float32", f))
+ + T.float32(1e-5)
+ )
+ + C[v_ax1, v_ax2]
+ )
+
+ @R.function
+ def main(
+ x: R.Tensor(("n", "s", "f"), dtype="float32"),
+ weight: R.Tensor(("s", "f"), dtype="float32"),
+ bias: R.Tensor(("s", "f"), dtype="float32"),
+ ) -> R.Tensor(("n", "s", "f"), dtype="float32"):
+ n = T.int64()
+ s = T.int64()
+ f = T.int64()
+ cls = Expected
+ gv = R.call_tir(
+ cls.rms_norm,
+ (x, weight, bias),
+ out_sinfo=R.Tensor((n, s, f), dtype="float32"),
+ )
+ return gv
+ # fmt: on
+ mod = LegalizeOps()(RMSNorm)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
def test_attention():
# fmt: off
@tvm.script.ir_module