This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new c861351bae [Unity][OP] Sync `rms_norm` with main (#15355)
c861351bae is described below

commit c861351baef3d6755920284270dedb0a160252cd
Author: Yaxing Cai <[email protected]>
AuthorDate: Wed Jul 19 00:50:36 2023 -0700

    [Unity][OP] Sync `rms_norm` with main (#15355)
    
    This PR is a fix for the conflict of implementations of `rms_norm` between 
unity branch and main branch.
---
 include/tvm/relax/attrs/nn.h                       |  11 +
 python/tvm/relax/op/nn/nn.py                       |  44 ++++
 python/tvm/relax/transform/legalize_ops/nn.py      |  12 +
 src/relax/op/nn/nn.cc                              |  60 +++++
 src/relax/op/nn/nn.h                               |   3 +
 .../python/relax/test_transform_legalize_ops_nn.py | 257 +++++++++++++++++++++
 6 files changed, 387 insertions(+)

diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h
index 3368b66983..a59cf5e71f 100644
--- a/include/tvm/relax/attrs/nn.h
+++ b/include/tvm/relax/attrs/nn.h
@@ -296,6 +296,17 @@ struct GroupNormAttrs : public 
tvm::AttrsNode<GroupNormAttrs> {
   }
 };  // struct GroupNormAttrs
 
+/*! \brief Attributes used in rms_norm operator */
+struct RMSNormAttrs : public tvm::AttrsNode<RMSNormAttrs> {
+  Array<Integer> axes;
+  double epsilon;
+
+  TVM_DECLARE_ATTRS(RMSNormAttrs, "relax.attrs.RMSNormAttrs") {
+    TVM_ATTR_FIELD(axes).describe("The axes that along which the normalization 
is applied.");
+    TVM_ATTR_FIELD(epsilon).describe("Small float added to variance to avoid 
dividing by zero");
+  }
+};  // struct RMSNormAttrs
+
 /*! \brief Attributes used in nll_loss operator */
 struct NLLLossAttrs : public tvm::AttrsNode<NLLLossAttrs> {
   String reduction;
diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
index 9c4044636c..edba3c505f 100644
--- a/python/tvm/relax/op/nn/nn.py
+++ b/python/tvm/relax/op/nn/nn.py
@@ -925,6 +925,50 @@ def group_norm(
     )
 
 
+def rms_norm(
+    data: Expr,
+    weight: Expr,
+    bias: Expr,
+    axes: Union[int, List[int]],
+    epsilon: float = 1e-5,
+) -> Expr:
+    r"""
+    Root mean square normalization (Biao Zhang and et al., 2019).
+    Applies root mean square normalization to the n-dimensional input array.
+    This operator takes an n-dimensional input array and normalizes
+    the input using the given axis:
+
+    .. math::
+
+        out = \frac{data}{\sqrt{mean(data, axis)+\epsilon}} * weight + bias
+
+    Parameters
+    ----------
+    data : relax.Expr
+        Input to which rms_norm will be applied.
+
+    weight : relax.Expr
+        The scale factor.
+
+    bias : relax.Expr
+        The offset factor.
+
+    axes : Union[int, List[int]]
+        The axes that along which the normalization is applied.
+
+    epsilon : float
+        Small float added to square mean to avoid dividing by zero.
+
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+    """
+    if isinstance(axes, int):
+        axes = [axes]
+    return _ffi_api.rms_norm(data, weight, bias, axes, epsilon)  # type: ignore
+
+
 def dropout(data: Expr, rate: float = 0.5) -> Expr:
     """Applies the dropout operation to the input tensor.
 
diff --git a/python/tvm/relax/transform/legalize_ops/nn.py 
b/python/tvm/relax/transform/legalize_ops/nn.py
index 85986f0240..b66b2d7b5d 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -334,6 +334,18 @@ def _nn_group_norm(bb: BlockBuilder, call: Call) -> Expr:
     )
 
 
+@register_legalize("relax.nn.rms_norm")
+def _nn_rms_norm(bb: BlockBuilder, call: Call) -> Expr:
+    return bb.call_te(
+        topi.nn.rms_norm,
+        call.args[0],
+        call.args[1],
+        call.args[2],
+        axis=call.attrs.axes,
+        epsilon=call.attrs.epsilon,
+    )
+
+
 @register_legalize("relax.nn.dropout")
 def _nn_dropout(bb: BlockBuilder, call: Call) -> Expr:
     logging.info("Dropout is handled by frontend translator at this moment and 
is not legalized.")
diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc
index b0d5b822d2..79ea8650dd 100644
--- a/src/relax/op/nn/nn.cc
+++ b/src/relax/op/nn/nn.cc
@@ -437,6 +437,66 @@ TVM_REGISTER_OP("relax.nn.group_norm")
     .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow)
     .set_attr<Bool>("FPurity", Bool(true));
 
+/* relax.nn.rms_norm */
+TVM_REGISTER_NODE_TYPE(RMSNormAttrs);
+
+Expr rms_norm(Expr data, Expr weight, Expr bias, Array<Integer> axes, double 
epsilon) {
+  ObjectPtr<RMSNormAttrs> attrs = make_object<RMSNormAttrs>();
+  attrs->axes = std::move(axes);
+  attrs->epsilon = epsilon;
+
+  static const Op& op = Op::Get("relax.nn.rms_norm");
+  return Call(op, {std::move(data), std::move(weight), std::move(bias)}, 
Attrs{attrs}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.nn.rms_norm").set_body_typed(rms_norm);
+
+StructInfo InferStructInfoRMSNorm(const Call& call, const BlockBuilder& ctx) {
+  Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
+
+  const auto* attrs = call->attrs.as<RMSNormAttrs>();
+  bool unknown_shape = NormCheckDtypeAndShape(call, ctx, input_sinfo, 
attrs->axes);
+
+  return unknown_shape ? TensorStructInfo(input_sinfo[0]->dtype, 
input_sinfo[0]->ndim)
+                       : input_sinfo[0];
+}
+
+InferLayoutOutput InferLayoutRMSNorm(const Call& call,
+                                     const Map<String, Array<String>>& 
desired_layouts,
+                                     const VarLayoutMap& var_layout_map) {
+  ICHECK(NoDesiredLayout(call, desired_layouts));
+  std::vector<NLayout> initial_layouts;
+  for (size_t i = 0; i < 3; ++i) {
+    const auto* tensor_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[i]);
+    ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
+    ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim";
+    initial_layouts.push_back(InitialLayoutDecision(tensor_sinfo->ndim));
+  }
+  const auto* attrs = call->attrs.as<RMSNormAttrs>();
+  ICHECK(attrs) << "Invalid Call";
+
+  LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);
+  ObjectPtr<RMSNormAttrs> new_attrs = make_object<RMSNormAttrs>(*attrs);
+  std::vector<Integer> new_axis;
+  for (const auto& axis : attrs->axes) {
+    new_axis.push_back(FindAxis(layout->layout, axis->value));
+  }
+  new_attrs->axes = std::move(new_axis);
+  return InferLayoutOutput({layout, initial_layouts[1], initial_layouts[2]}, 
{layout},
+                           Attrs(new_attrs));
+}
+
+TVM_REGISTER_OP("relax.nn.rms_norm")
+    .set_attrs_type<RMSNormAttrs>()
+    .set_num_inputs(3)
+    .add_argument("data", "Tensor", "Input to which rms_norm will be applied.")
+    .add_argument("weight", "Tensor", "The scale factor.")
+    .add_argument("bias", "Tensor", "The offset factor.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoRMSNorm)
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutRMSNorm)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow)
+    .set_attr<Bool>("FPurity", Bool(true));
+
 /* relax.nn.dropout */
 TVM_REGISTER_NODE_TYPE(DropoutAttrs);
 
diff --git a/src/relax/op/nn/nn.h b/src/relax/op/nn/nn.h
index ce6b369b23..624cfe9078 100644
--- a/src/relax/op/nn/nn.h
+++ b/src/relax/op/nn/nn.h
@@ -78,6 +78,9 @@ Expr layer_norm(Expr data, Expr gamma, Expr beta, 
Array<Integer> axes, double ep
 Expr group_norm(Expr data, Expr gamma, Expr beta, int num_groups, int 
channel_axis,
                 Array<Integer> axes, double epsilon, bool center, bool scale);
 
+/*! \brief Compute root mean square normalization. */
+Expr rms_norm(Expr data, Expr weight, Expr bias, Array<Integer> axes, double 
epsilon);
+
 /*!
  * \brief Applies the dropout operation to the input tensor.
  * \param data The input data to the operator.
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py 
b/tests/python/relax/test_transform_legalize_ops_nn.py
index 27c67e728d..66da2bae55 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -2742,6 +2742,263 @@ def test_group_norm_symbolic():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_rms_norm():
+    # fmt: off
+    @tvm.script.ir_module
+    class RMSNorm:
+        @R.function
+        def main(x: R.Tensor((2, 3, 4, 5), "float32"), weight: R.Tensor((4, 
5), "float32"), bias: R.Tensor((4, 5), "float32")) -> R.Tensor((2, 3, 4, 5), 
"float32"):
+            gv: R.Tensor((2, 3, 4, 5), "float32") = R.nn.rms_norm(x, weight, 
bias, axes=[-2, -1])
+            return gv
+
+    @tvm.script.ir_module
+    class Expected:
+        @T.prim_func
+        def rms_norm(
+            A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), 
"float32"),
+            B: T.Buffer((T.int64(4), T.int64(5)), "float32"),
+            C: T.Buffer((T.int64(4), T.int64(5)), "float32"),
+            T_rms_norm: T.Buffer(
+                (T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"
+            ),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), 
T.int64(5)))
+            T_multiply_red = T.alloc_buffer((T.int64(2), T.int64(3)))
+            for ax0, ax1, ax2, ax3 in T.grid(
+                T.int64(2), T.int64(3), T.int64(4), T.int64(5)
+            ):
+                with T.block("T_multiply"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = (
+                        A[v_ax0, v_ax1, v_ax2, v_ax3] * A[v_ax0, v_ax1, v_ax2, 
v_ax3]
+                    )
+            for ax0, ax1, k2, k3 in T.grid(T.int64(2), T.int64(3), T.int64(4), 
T.int64(5)):
+                with T.block("T_multiply_red"):
+                    v_ax0, v_ax1, v_k2, v_k3 = T.axis.remap("SSRR", [ax0, ax1, 
k2, k3])
+                    T.reads(T_multiply[v_ax0, v_ax1, v_k2, v_k3])
+                    T.writes(T_multiply_red[v_ax0, v_ax1])
+                    with T.init():
+                        T_multiply_red[v_ax0, v_ax1] = T.float32(0)
+                    T_multiply_red[v_ax0, v_ax1] = (
+                        T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0, 
v_ax1, v_k2, v_k3]
+                    )
+            for ax0, ax1, ax2, ax3 in T.grid(
+                T.int64(2), T.int64(3), T.int64(4), T.int64(5)
+            ):
+                with T.block("T_rms_norm"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(
+                        A[v_ax0, v_ax1, v_ax2, v_ax3],
+                        B[v_ax2, v_ax3],
+                        T_multiply_red[v_ax0, v_ax1],
+                        C[v_ax2, v_ax3],
+                    )
+                    T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = (
+                        A[v_ax0, v_ax1, v_ax2, v_ax3]
+                        * B[v_ax2, v_ax3]
+                        * T.rsqrt(
+                            T_multiply_red[v_ax0, v_ax1] * T.float32(0.05)
+                            + T.float32(1e-5)
+                        )
+                        + C[v_ax2, v_ax3]
+                    )
+
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 4, 5), dtype="float32"),
+            weight: R.Tensor((4, 5), dtype="float32"),
+            bias: R.Tensor((4, 5), dtype="float32"),
+        ) -> R.Tensor((2, 3, 4, 5), dtype="float32"):
+            cls = Expected
+            gv = R.call_tir(
+                cls.rms_norm,
+                (x, weight, bias),
+                out_sinfo=R.Tensor((2, 3, 4, 5), dtype="float32"),
+            )
+            return gv
+    # fmt: on
+    mod = LegalizeOps()(RMSNorm)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_rms_norm_fp16():
+    # fmt: off
+    @tvm.script.ir_module
+    class RMSNorm:
+        @R.function
+        def main(x: R.Tensor((2, 3, 4, 5), "float16"), weight: R.Tensor((4, 
5), "float16"), bias: R.Tensor((4, 5), "float16")) -> R.Tensor((2, 3, 4, 5), 
"float16"):
+            gv: R.Tensor((2, 3, 4, 5), "float16") = R.nn.rms_norm(x, weight, 
bias, axes=[-2, -1])
+            return gv
+
+    @tvm.script.ir_module
+    class Expected:
+        @T.prim_func
+        def rms_norm(
+            A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), 
"float16"),
+            B: T.Buffer((T.int64(4), T.int64(5)), "float16"),
+            C: T.Buffer((T.int64(4), T.int64(5)), "float16"),
+            T_rms_norm: T.Buffer(
+                (T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float16"
+            ),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            T_multiply = T.alloc_buffer(
+                (T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float16"
+            )
+            T_multiply_red = T.alloc_buffer((T.int64(2), T.int64(3)), 
"float16")
+            for ax0, ax1, ax2, ax3 in T.grid(
+                T.int64(2), T.int64(3), T.int64(4), T.int64(5)
+            ):
+                with T.block("T_multiply"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = (
+                        A[v_ax0, v_ax1, v_ax2, v_ax3] * A[v_ax0, v_ax1, v_ax2, 
v_ax3]
+                    )
+            for ax0, ax1, k2, k3 in T.grid(T.int64(2), T.int64(3), T.int64(4), 
T.int64(5)):
+                with T.block("T_multiply_red"):
+                    v_ax0, v_ax1, v_k2, v_k3 = T.axis.remap("SSRR", [ax0, ax1, 
k2, k3])
+                    T.reads(T_multiply[v_ax0, v_ax1, v_k2, v_k3])
+                    T.writes(T_multiply_red[v_ax0, v_ax1])
+                    with T.init():
+                        T_multiply_red[v_ax0, v_ax1] = T.float16(0)
+                    T_multiply_red[v_ax0, v_ax1] = (
+                        T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0, 
v_ax1, v_k2, v_k3]
+                    )
+            for ax0, ax1, ax2, ax3 in T.grid(
+                T.int64(2), T.int64(3), T.int64(4), T.int64(5)
+            ):
+                with T.block("T_rms_norm"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(
+                        A[v_ax0, v_ax1, v_ax2, v_ax3],
+                        B[v_ax2, v_ax3],
+                        T_multiply_red[v_ax0, v_ax1],
+                        C[v_ax2, v_ax3],
+                    )
+                    T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = (
+                        A[v_ax0, v_ax1, v_ax2, v_ax3]
+                        * B[v_ax2, v_ax3]
+                        * T.rsqrt(
+                            T_multiply_red[v_ax0, v_ax1] / (T.float16(4) * 
T.float16(5))
+                            + T.float16(1e-5)
+                        )
+                        + C[v_ax2, v_ax3]
+                    )
+
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 4, 5), dtype="float16"),
+            weight: R.Tensor((4, 5), dtype="float16"),
+            bias: R.Tensor((4, 5), dtype="float16"),
+        ) -> R.Tensor((2, 3, 4, 5), dtype="float16"):
+            cls = Expected
+            gv = R.call_tir(
+                cls.rms_norm,
+                (x, weight, bias),
+                out_sinfo=R.Tensor((2, 3, 4, 5), dtype="float16"),
+            )
+            return gv
+    # fmt: on
+    mod = LegalizeOps()(RMSNorm)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_rms_norm_symbolic():
+    # fmt: off
+    @tvm.script.ir_module
+    class RMSNorm:
+        @R.function
+        def main(x: R.Tensor(("n", "s", "f"), "float32"), weight: 
R.Tensor(("s", "f"), "float32"), bias: R.Tensor(("s", "f"), "float32")) -> 
R.Tensor(("n", "s", "f"), "float32"):
+            n = T.int64()
+            s = T.int64()
+            f = T.int64()
+            gv: R.Tensor((n, s, f), "float32") = R.nn.rms_norm(x, weight, 
bias, axes=[1, 2])
+            return gv
+
+    @tvm.script.ir_module
+    class Expected:
+        @T.prim_func
+        def rms_norm(
+            var_A: T.handle, var_B: T.handle, var_C: T.handle, var_T_rms_norm: 
T.handle
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            n, s, f = T.int64(), T.int64(), T.int64()
+            A = T.match_buffer(var_A, (n, s, f))
+            B = T.match_buffer(var_B, (s, f))
+            C = T.match_buffer(var_C, (s, f))
+            T_rms_norm = T.match_buffer(var_T_rms_norm, (n, s, f))
+            # with T.block("root"):
+            T_multiply = T.alloc_buffer((n, s, f))
+            T_multiply_red = T.alloc_buffer((n,))
+            for ax0, ax1, ax2 in T.grid(n, s, f):
+                with T.block("T_multiply"):
+                    v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                    T.reads(A[v_ax0, v_ax1, v_ax2])
+                    T.writes(T_multiply[v_ax0, v_ax1, v_ax2])
+                    T_multiply[v_ax0, v_ax1, v_ax2] = (
+                        A[v_ax0, v_ax1, v_ax2] * A[v_ax0, v_ax1, v_ax2]
+                    )
+            for ax0, k1, k2 in T.grid(n, s, f):
+                with T.block("T_multiply_red"):
+                    v_ax0, v_k1, v_k2 = T.axis.remap("SRR", [ax0, k1, k2])
+                    T.reads(T_multiply[v_ax0, v_k1, v_k2])
+                    T.writes(T_multiply_red[v_ax0])
+                    with T.init():
+                        T_multiply_red[v_ax0] = T.float32(0)
+                    T_multiply_red[v_ax0] = (
+                        T_multiply_red[v_ax0] + T_multiply[v_ax0, v_k1, v_k2]
+                    )
+            for ax0, ax1, ax2 in T.grid(n, s, f):
+                with T.block("T_rms_norm"):
+                    v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                    T.reads(
+                        A[v_ax0, v_ax1, v_ax2],
+                        B[v_ax1, v_ax2],
+                        T_multiply_red[v_ax0],
+                        C[v_ax1, v_ax2],
+                    )
+                    T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2])
+                    T_rms_norm[v_ax0, v_ax1, v_ax2] = (
+                        A[v_ax0, v_ax1, v_ax2]
+                        * B[v_ax1, v_ax2]
+                        * T.rsqrt(
+                            T_multiply_red[v_ax0]
+                            / (T.Cast("float32", s) * T.Cast("float32", f))
+                            + T.float32(1e-5)
+                        )
+                        + C[v_ax1, v_ax2]
+                    )
+
+        @R.function
+        def main(
+            x: R.Tensor(("n", "s", "f"), dtype="float32"),
+            weight: R.Tensor(("s", "f"), dtype="float32"),
+            bias: R.Tensor(("s", "f"), dtype="float32"),
+        ) -> R.Tensor(("n", "s", "f"), dtype="float32"):
+            n = T.int64()
+            s = T.int64()
+            f = T.int64()
+            cls = Expected
+            gv = R.call_tir(
+                cls.rms_norm,
+                (x, weight, bias),
+                out_sinfo=R.Tensor((n, s, f), dtype="float32"),
+            )
+            return gv
+    # fmt: on
+    mod = LegalizeOps()(RMSNorm)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 def test_attention():
     # fmt: off
     @tvm.script.ir_module

Reply via email to