This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 276b4cedbd [Unity][Fix] Fix `topi.rms_norm` with float32 upscale
(#16099)
276b4cedbd is described below
commit 276b4cedbd71308ed7d43f65375bf42a080c8a01
Author: Yaxing Cai <[email protected]>
AuthorDate: Thu Nov 9 11:23:06 2023 -0800
[Unity][Fix] Fix `topi.rms_norm` with float32 upscale (#16099)
This PR is a mirror PR for #16091
---
include/tvm/topi/nn/rms_norm.h | 17 ++++++++++-------
python/tvm/topi/testing/rms_norm_python.py | 9 +++++----
tests/python/topi/python/test_topi_rms_norm.py | 14 ++++++--------
3 files changed, 21 insertions(+), 19 deletions(-)
diff --git a/include/tvm/topi/nn/rms_norm.h b/include/tvm/topi/nn/rms_norm.h
index 55dac39b71..ba2f7e49ac 100644
--- a/include/tvm/topi/nn/rms_norm.h
+++ b/include/tvm/topi/nn/rms_norm.h
@@ -54,15 +54,18 @@ inline Tensor rms_norm(const Tensor& data, const Tensor&
weight, const Array<Int
const auto& weight_type = weight.defined() ? weight->dtype : data_type;
ICHECK(data_type == weight_type) << "rms_norm: data and weight must have the
same type";
- auto square = multiply(data, data);
+ const auto& data_fp32 = cast(data, DataType::Float(32));
+ const auto& weight_fp32 = cast(weight, DataType::Float(32));
+
+ auto square = multiply(data_fp32, data_fp32);
auto square_sum = sum(square, axis, /*keepdims=*/false, /*atleast1d=*/true);
- auto ndim = data->shape.size();
+ auto ndim = data_fp32->shape.size();
ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor";
auto real_axis = GetRealAxis(static_cast<int>(ndim), axis);
- auto reduce_extent = make_const(data->dtype, 1);
+ auto reduce_extent = make_const(data_fp32->dtype, 1);
for (int i : real_axis) {
- reduce_extent *= data->shape[i];
+ reduce_extent *= data_fp32->shape[i];
}
auto rms_norm_func = [&](const Array<Var>& indices) {
Array<Var> reduce_indices, non_reduce_indices;
@@ -74,12 +77,12 @@ inline Tensor rms_norm(const Tensor& data, const Tensor&
weight, const Array<Int
}
}
auto output =
- data(indices) * weight(reduce_indices) *
+ data_fp32(indices) * weight_fp32(reduce_indices) *
tvm::rsqrt(square_sum(non_reduce_indices) / reduce_extent +
make_const(data_type, epsilon));
return output;
};
- auto rms_norm = tvm::te::compute(data->shape, rms_norm_func, name, tag);
- return rms_norm;
+ auto rms_norm = tvm::te::compute(data_fp32->shape, rms_norm_func, name, tag);
+ return cast(rms_norm, data_type);
}
} // namespace nn
diff --git a/python/tvm/topi/testing/rms_norm_python.py
b/python/tvm/topi/testing/rms_norm_python.py
index 7fad5d57ce..651f6f8843 100644
--- a/python/tvm/topi/testing/rms_norm_python.py
+++ b/python/tvm/topi/testing/rms_norm_python.py
@@ -19,7 +19,7 @@
import numpy as np
-def rms_norm_python(data, weight, bias, axis, epsilon=1e-5):
+def rms_norm_python(data, weight, axis, epsilon=1e-5):
"""Root mean square normalization operator in Python.
Parameters
@@ -44,8 +44,9 @@ def rms_norm_python(data, weight, bias, axis, epsilon=1e-5):
result : np.ndarray
N-D with shape (d_0, d_1, ..., d_{N-1})
"""
+ dtype = data.dtype
+ data = data.astype("float32")
+ weight = weight.astype("float32")
square_mean = np.mean(np.square(data), axis, keepdims=True)
result = data * weight / np.sqrt(square_mean + epsilon)
- if bias is not None:
- result += bias
- return result
+ return result.astype(dtype)
diff --git a/tests/python/topi/python/test_topi_rms_norm.py
b/tests/python/topi/python/test_topi_rms_norm.py
index 35a1485afa..c8c1b8795f 100644
--- a/tests/python/topi/python/test_topi_rms_norm.py
+++ b/tests/python/topi/python/test_topi_rms_norm.py
@@ -34,7 +34,8 @@ _rms_norm_schedule = {
# only test on llvm because schedule is missing
@tvm.testing.parametrize_targets("llvm")
@pytest.mark.parametrize(
- "shape,axis", [([4, 16], (1,)), ([4, 16, 16], (1, 2)), ([("a", 4), ("b",
16)], (1,))]
+ "shape,axis",
+ [([4, 16], (1,)), ([4, 16, 16], (1, 2)), ([("a", 4), ("b", 16)], (1,)),
([2, 8192], (1,))],
)
@pytest.mark.parametrize("dtype", ["float32", "float16"])
def test_rms_norm(target, dev, shape, axis, dtype, episilon=1e-5, rtol=5e-3,
atol=1e-4):
@@ -42,25 +43,22 @@ def test_rms_norm(target, dev, shape, axis, dtype,
episilon=1e-5, rtol=5e-3, ato
scale_shape_te = [shape_te[dim] for dim in axis]
data = te.placeholder(shape_te, dtype=dtype, name="data")
weight = te.placeholder(scale_shape_te, dtype=dtype, name="weight")
- bias = te.placeholder(scale_shape_te, dtype=dtype, name="weight")
- B = topi.nn.rms_norm(data, weight, bias, axis, episilon)
+ B = topi.nn.rms_norm(data, weight, axis, episilon)
shape_np = [v[1] if isinstance(v, tuple) else v for v in shape]
scale_shape_np = [shape_np[dim] for dim in axis]
data_np = np.random.uniform(size=shape_np).astype(dtype)
weight_np = np.random.uniform(size=scale_shape_np).astype(dtype)
- bias_np = np.random.uniform(size=scale_shape_np).astype(dtype)
- b_np = tvm.topi.testing.rms_norm_python(data_np, weight_np, bias_np, axis,
episilon)
+ b_np = tvm.topi.testing.rms_norm_python(data_np, weight_np, axis, episilon)
with tvm.target.Target(target):
s_func = tvm.topi.testing.dispatch(target, _rms_norm_schedule)
s = s_func([B])
data_tvm = tvm.nd.array(data_np, dev)
weight_tvm = tvm.nd.array(weight_np, dev)
- bias_tvm = tvm.nd.array(bias_np, dev)
b_tvm = tvm.nd.array(np.zeros(shape_np, dtype=dtype), dev)
- f = tvm.build(s, [data, weight, bias, B], target)
- f(data_tvm, weight_tvm, bias_tvm, b_tvm)
+ f = tvm.build(s, [data, weight, B], target)
+ f(data_tvm, weight_tvm, b_tvm)
tvm.testing.assert_allclose(b_tvm.numpy(), b_np, rtol=rtol, atol=atol)