This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e3f5ac1c6b [Relax] Correct YaRN RoPE frequency scaling formula to
align with the original paper (#18576)
e3f5ac1c6b is described below
commit e3f5ac1c6bccebc4bf1c35c9a1d81cf4c0a1740d
Author: Yeongjae Jang <[email protected]>
AuthorDate: Thu Jan 1 05:49:16 2026 +0900
[Relax] Correct YaRN RoPE frequency scaling formula to align with the
original paper (#18576)
## Summary
Fixed frequency calculations for RoPE (YaRN) scaling and correct range
finding.
## Description
Greetings:
This PR corrects the mathematical formulation of the
[YaRN](https://arxiv.org/abs/2309.00071) RoPE scaling.
I have verified that this change eliminates the discrepancy observed
when comparing against PyTorch baseline (an implementation of
`gpt-oss`).
### in `yarn_find_correction_range()`
#### `low`, `high`
Removed `tir.floor` and `tir.ceil` operations in
`yarn_find_correction_dim()`.
In YaRN paper, there is no floor or ceil function within calculations of
those values.
In `gpt-oss`, the implementation uses floating-point values for these
thresholds to ensure smooth interpolation in the ramp function.
Rounding them caused quantization errors in the ramp mask.
### in `rope_freq_yarn()`
#### `freq_inter`
Currently, the implementation calculates the inverse frequency as:
```
freq_inter = tir.const(1, "float32") / tir.power(
scaling_factor * theta, d * 2 % d_range / tir.const(d_range, "float32")
)
```
This implies `scale` is also affected by the exponent, leading to
non-uniform scaling across dimensions.
According to the YaRN method (and an implementation of `gpt-oss`), the
scaling factor should be applied linearly:
```
exponent = d * 2 % d_range / tir.const(d_range, "float32")
freq_power = tir.power(theta, exponent)
freq_inter = tir.const(1, "float32") / (scaling_factor * freq_power)
```
#### `d_range`
The `yarn_find_correction_range()` function was incorrectly using the
current dimension index `d` to calculate thresholds.
This caused the ramp boundaries to shift dynamically per dimension.
It has been corrected to use the total dimension size (`d_range`) to
ensure consistent frequency thresholds.
Before:
```
yarn_find_correction_range(..., d, ...)
```
After:
```
yarn_find_correction_range(..., d_range, ...)
```
Thank you very much for reading.
---------
Co-authored-by: gemini-code-assist[bot]
<176961590+gemini-code-assist[bot]@users.noreply.github.com>
---
python/tvm/relax/frontend/nn/llm/position_embedding.py | 16 +++++++---------
1 file changed, 7 insertions(+), 9 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/llm/position_embedding.py
b/python/tvm/relax/frontend/nn/llm/position_embedding.py
index 35eeb4f5f3..60808a6b35 100644
--- a/python/tvm/relax/frontend/nn/llm/position_embedding.py
+++ b/python/tvm/relax/frontend/nn/llm/position_embedding.py
@@ -197,8 +197,8 @@ def yarn_find_correction_range(
max_position_embeddings: int,
):
"""Find the correction range based on the number of rotations"""
- low = tir.floor(yarn_find_correction_dim(low_rot, d, theta,
max_position_embeddings))
- high = tir.ceil(yarn_find_correction_dim(high_rot, d, theta,
max_position_embeddings))
+ low = yarn_find_correction_dim(low_rot, d, theta, max_position_embeddings)
+ high = yarn_find_correction_dim(high_rot, d, theta,
max_position_embeddings)
return tir.max(low, 0), tir.min(high, d - 1)
@@ -214,16 +214,14 @@ def rope_freq_yarn(
beta_slow: int,
): # pylint: disable=too-many-arguments, too-many-locals
"""Compute the inverse frequency of RoPE for yarn RoPE scaling."""
- freq_extra = tir.const(1, "float32") / tir.power(
- theta, d * 2 % d_range / tir.const(d_range, "float32")
- )
- freq_inter = tir.const(1, "float32") / tir.power(
- scaling_factor * theta, d * 2 % d_range / tir.const(d_range, "float32")
- )
+ exponent = d * 2 % d_range / tir.const(d_range, "float32")
+ freq_power = tir.power(theta, exponent)
+ freq_extra = tir.const(1, "float32") / freq_power
+ freq_inter = tir.const(1, "float32") / (scaling_factor * freq_power)
low, high = yarn_find_correction_range(
- beta_fast, beta_slow, d, theta, original_max_position_embeddings
+ beta_fast, beta_slow, d_range, theta, original_max_position_embeddings
)
high = tir.if_then_else(low == high, high + 0.001, high)
inv_freq_mask = tir.const(1, "float32") - tir.max(