This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 42de91ff45 [Fix] Fix `topi.rms_norm` with float32 upscale (#16091)
42de91ff45 is described below

commit 42de91ff458eac951c3a8ac7020ca246e0442563
Author: Yaxing Cai <[email protected]>
AuthorDate: Thu Nov 9 08:07:00 2023 -0800

    [Fix] Fix `topi.rms_norm` with float32 upscale (#16091)
    
    This PR fixes the `topi.rms_norm` with upscale to float32, for large 
reduction dimension of computation on float16.
---
 include/tvm/topi/nn/rms_norm.h                 | 28 +++++++++++---------------
 python/tvm/topi/nn/rms_norm.py                 |  7 ++-----
 python/tvm/topi/testing/rms_norm_python.py     |  9 +++++----
 src/topi/nn.cc                                 |  2 +-
 tests/python/topi/python/test_topi_rms_norm.py | 14 ++++++-------
 5 files changed, 26 insertions(+), 34 deletions(-)

diff --git a/include/tvm/topi/nn/rms_norm.h b/include/tvm/topi/nn/rms_norm.h
index 44d38bae6d..ba2f7e49ac 100644
--- a/include/tvm/topi/nn/rms_norm.h
+++ b/include/tvm/topi/nn/rms_norm.h
@@ -41,32 +41,31 @@ using namespace tvm::te;
  * \param data N-D tensor with shape [d_0, d_1, ..., d_{N-1}]
  * \param weight K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where K == 
len(axis) and
  *               d_{axis_k} == r_k
- * \param bias Optional, K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where
- *             d_{axis_k} == r_k
  * \param axis The axis to normalize over.
  * \param epsilon The epsilon value to avoid division by zero.
  * \param name The name of the operation.
  * \param tag The tag to mark the operation.
  * \return The normalized tensor, with the same shape as data.
  */
-inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const Tensor& 
bias,
-                       const Array<Integer>& axis, double epsilon, std::string 
name = "T_rms_norm",
+inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const 
Array<Integer>& axis,
+                       double epsilon, std::string name = "T_rms_norm",
                        std::string tag = kInjective) {
   const auto& data_type = data->dtype;
   const auto& weight_type = weight.defined() ? weight->dtype : data_type;
   ICHECK(data_type == weight_type) << "rms_norm: data and weight must have the 
same type";
-  const auto& bias_type = bias.defined() ? bias->dtype : data_type;
-  ICHECK(data_type == bias_type) << "rms_norm: data and bias must have the 
same type";
 
-  auto square = multiply(data, data);
+  const auto& data_fp32 = cast(data, DataType::Float(32));
+  const auto& weight_fp32 = cast(weight, DataType::Float(32));
+
+  auto square = multiply(data_fp32, data_fp32);
   auto square_sum = sum(square, axis, /*keepdims=*/false, /*atleast1d=*/true);
 
-  auto ndim = data->shape.size();
+  auto ndim = data_fp32->shape.size();
   ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor";
   auto real_axis = GetRealAxis(static_cast<int>(ndim), axis);
-  auto reduce_extent = make_const(data->dtype, 1);
+  auto reduce_extent = make_const(data_fp32->dtype, 1);
   for (int i : real_axis) {
-    reduce_extent *= data->shape[i];
+    reduce_extent *= data_fp32->shape[i];
   }
   auto rms_norm_func = [&](const Array<Var>& indices) {
     Array<Var> reduce_indices, non_reduce_indices;
@@ -78,15 +77,12 @@ inline Tensor rms_norm(const Tensor& data, const Tensor& 
weight, const Tensor& b
       }
     }
     auto output =
-        data(indices) * weight(reduce_indices) *
+        data_fp32(indices) * weight_fp32(reduce_indices) *
         tvm::rsqrt(square_sum(non_reduce_indices) / reduce_extent + 
make_const(data_type, epsilon));
-    if (bias.defined()) {
-      output += bias(reduce_indices);
-    }
     return output;
   };
-  auto rms_norm = tvm::te::compute(data->shape, rms_norm_func, name, tag);
-  return rms_norm;
+  auto rms_norm = tvm::te::compute(data_fp32->shape, rms_norm_func, name, tag);
+  return cast(rms_norm, data_type);
 }
 
 }  // namespace nn
diff --git a/python/tvm/topi/nn/rms_norm.py b/python/tvm/topi/nn/rms_norm.py
index f2f5a7e674..9284517468 100644
--- a/python/tvm/topi/nn/rms_norm.py
+++ b/python/tvm/topi/nn/rms_norm.py
@@ -18,7 +18,7 @@
 from .. import cpp
 
 
-def rms_norm(data, weight, bias, axis, epsilon=1e-5):
+def rms_norm(data, weight, axis, epsilon=1e-5):
     """Root mean square normalization operator. The output will have the same 
data type as input.
 
     Parameters
@@ -29,9 +29,6 @@ def rms_norm(data, weight, bias, axis, epsilon=1e-5):
     weight: tvm.te.Tensor
         K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and 
d_{axis_k} == r_k
 
-    bias: tvm.te.Tensor
-        Optional, K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) 
and d_{axis_k} == r_k
-
     axis : list of int
         Axis over the normalization applied
 
@@ -43,4 +40,4 @@ def rms_norm(data, weight, bias, axis, epsilon=1e-5):
     result : tvm.te.Tensor
         N-D with shape (d_0, d_1, ..., d_{N-1})
     """
-    return cpp.nn.rms_norm(data, weight, bias, axis, epsilon)
+    return cpp.nn.rms_norm(data, weight, axis, epsilon)
diff --git a/python/tvm/topi/testing/rms_norm_python.py 
b/python/tvm/topi/testing/rms_norm_python.py
index 7fad5d57ce..651f6f8843 100644
--- a/python/tvm/topi/testing/rms_norm_python.py
+++ b/python/tvm/topi/testing/rms_norm_python.py
@@ -19,7 +19,7 @@
 import numpy as np
 
 
-def rms_norm_python(data, weight, bias, axis, epsilon=1e-5):
+def rms_norm_python(data, weight, axis, epsilon=1e-5):
     """Root mean square normalization operator in Python.
 
     Parameters
@@ -44,8 +44,9 @@ def rms_norm_python(data, weight, bias, axis, epsilon=1e-5):
     result : np.ndarray
         N-D with shape (d_0, d_1, ..., d_{N-1})
     """
+    dtype = data.dtype
+    data = data.astype("float32")
+    weight = weight.astype("float32")
     square_mean = np.mean(np.square(data), axis, keepdims=True)
     result = data * weight / np.sqrt(square_mean + epsilon)
-    if bias is not None:
-        result += bias
-    return result
+    return result.astype(dtype)
diff --git a/src/topi/nn.cc b/src/topi/nn.cc
index ba88f01c68..9ce329b206 100644
--- a/src/topi/nn.cc
+++ b/src/topi/nn.cc
@@ -179,7 +179,7 @@ 
TVM_REGISTER_GLOBAL("topi.nn.instance_norm").set_body([](TVMArgs args, TVMRetVal
 
 /* Ops from nn/rms_norm.h */
 TVM_REGISTER_GLOBAL("topi.nn.rms_norm").set_body([](TVMArgs args, TVMRetValue* 
rv) {
-  *rv = nn::rms_norm(args[0], args[1], args[2], args[3], 
static_cast<double>(args[4]));
+  *rv = nn::rms_norm(args[0], args[1], args[2], static_cast<double>(args[3]));
 });
 
 }  // namespace topi
diff --git a/tests/python/topi/python/test_topi_rms_norm.py 
b/tests/python/topi/python/test_topi_rms_norm.py
index 35a1485afa..c8c1b8795f 100644
--- a/tests/python/topi/python/test_topi_rms_norm.py
+++ b/tests/python/topi/python/test_topi_rms_norm.py
@@ -34,7 +34,8 @@ _rms_norm_schedule = {
 # only test on llvm because schedule is missing
 @tvm.testing.parametrize_targets("llvm")
 @pytest.mark.parametrize(
-    "shape,axis", [([4, 16], (1,)), ([4, 16, 16], (1, 2)), ([("a", 4), ("b", 
16)], (1,))]
+    "shape,axis",
+    [([4, 16], (1,)), ([4, 16, 16], (1, 2)), ([("a", 4), ("b", 16)], (1,)), 
([2, 8192], (1,))],
 )
 @pytest.mark.parametrize("dtype", ["float32", "float16"])
 def test_rms_norm(target, dev, shape, axis, dtype, episilon=1e-5, rtol=5e-3, 
atol=1e-4):
@@ -42,25 +43,22 @@ def test_rms_norm(target, dev, shape, axis, dtype, 
episilon=1e-5, rtol=5e-3, ato
     scale_shape_te = [shape_te[dim] for dim in axis]
     data = te.placeholder(shape_te, dtype=dtype, name="data")
     weight = te.placeholder(scale_shape_te, dtype=dtype, name="weight")
-    bias = te.placeholder(scale_shape_te, dtype=dtype, name="weight")
-    B = topi.nn.rms_norm(data, weight, bias, axis, episilon)
+    B = topi.nn.rms_norm(data, weight, axis, episilon)
 
     shape_np = [v[1] if isinstance(v, tuple) else v for v in shape]
     scale_shape_np = [shape_np[dim] for dim in axis]
     data_np = np.random.uniform(size=shape_np).astype(dtype)
     weight_np = np.random.uniform(size=scale_shape_np).astype(dtype)
-    bias_np = np.random.uniform(size=scale_shape_np).astype(dtype)
-    b_np = tvm.topi.testing.rms_norm_python(data_np, weight_np, bias_np, axis, 
episilon)
+    b_np = tvm.topi.testing.rms_norm_python(data_np, weight_np, axis, episilon)
 
     with tvm.target.Target(target):
         s_func = tvm.topi.testing.dispatch(target, _rms_norm_schedule)
         s = s_func([B])
     data_tvm = tvm.nd.array(data_np, dev)
     weight_tvm = tvm.nd.array(weight_np, dev)
-    bias_tvm = tvm.nd.array(bias_np, dev)
     b_tvm = tvm.nd.array(np.zeros(shape_np, dtype=dtype), dev)
-    f = tvm.build(s, [data, weight, bias, B], target)
-    f(data_tvm, weight_tvm, bias_tvm, b_tvm)
+    f = tvm.build(s, [data, weight, B], target)
+    f(data_tvm, weight_tvm, b_tvm)
     tvm.testing.assert_allclose(b_tvm.numpy(), b_np, rtol=rtol, atol=atol)
 
 

Reply via email to