This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ee1eb3dcf6 [Bug] Fix core dump in InferLayoutRMSNorm and fix typo
(#18210)
ee1eb3dcf6 is described below
commit ee1eb3dcf61fc6aabb47625eed26cf44ecef862e
Author: chenxinli <[email protected]>
AuthorDate: Fri Aug 15 20:28:34 2025 +0800
[Bug] Fix core dump in InferLayoutRMSNorm and fix typo (#18210)
Fix core dump in InferLayoutRMSNorm and fix typo
---
python/tvm/relax/op/nn/nn.py | 5 +----
src/relax/op/nn/nn.cc | 9 ++++-----
2 files changed, 5 insertions(+), 9 deletions(-)
diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
index 5834cf14d2..a38b31c9bb 100644
--- a/python/tvm/relax/op/nn/nn.py
+++ b/python/tvm/relax/op/nn/nn.py
@@ -1813,7 +1813,7 @@ def rms_norm(
.. math::
- out = \frac{data}{\sqrt{mean(data, axis)+\epsilon}} * weight + bias
+ out = \frac{data}{\sqrt{mean(data, axis)+\epsilon}} * weight
Parameters
----------
@@ -1823,9 +1823,6 @@ def rms_norm(
weight : relax.Expr
The scale factor.
- bias : relax.Expr
- The offset factor.
-
axes : Union[int, List[int]]
The axes that along which the normalization is applied.
diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc
index 344c9bc7a3..3597b16a5b 100644
--- a/src/relax/op/nn/nn.cc
+++ b/src/relax/op/nn/nn.cc
@@ -848,13 +848,12 @@ InferLayoutOutput InferLayoutRMSNorm(const Call& call,
LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);
ObjectPtr<RMSNormAttrs> new_attrs = make_object<RMSNormAttrs>(*attrs);
- std::vector<Integer> new_axis;
+ std::vector<Integer> new_axes;
for (const auto& axis : attrs->axes) {
- new_axis.push_back(FindAxis(layout->layout, axis->value));
+ new_axes.push_back(FindAxis(layout->layout, axis->value));
}
- new_attrs->axes = std::move(new_axis);
- return InferLayoutOutput({layout, initial_layouts[1], initial_layouts[2]},
{layout},
- Attrs(new_attrs));
+ new_attrs->axes = std::move(new_axes);
+ return InferLayoutOutput({layout, initial_layouts[1]}, {layout},
Attrs(new_attrs));
}
TVM_REGISTER_OP("relax.nn.rms_norm")