Liberatedwinner opened a new pull request, #18576:
URL: https://github.com/apache/tvm/pull/18576

   ## Summary
   Fixed frequency calculations for RoPE (YaRN) scaling and correct range 
finding.
   
   ## Description
   Greetings:
   
   This PR corrects the mathematical formulation of the 
[YaRN](https://arxiv.org/abs/2309.00071) RoPE scaling.
   I have verified that this change eliminates the discrepancy observed when 
comparing against PyTorch baseline (an implementation of `gpt-oss`).
   
   ### in `yarn_find_correction_range()`
   #### `low`, `high`
   Removed `tir.floor` and `tir.ceil` operations in 
`yarn_find_correction_dim()`. 
   In YaRN paper, there is no floor or ceil function within calculations of 
those values.
   In `gpt-oss`, the implementation uses floating-point values for these 
thresholds to ensure smooth interpolation in the ramp function. 
   Rounding them caused quantization errors in the ramp mask.
   
   ### in `rope_freq_yarn()`
   #### `freq_inter`
   Currently, the implementation calculates the inverse frequency as:
   ```
   freq_inter = tir.const(1, "float32") / tir.power(
       scaling_factor * theta, d * 2 % d_range / tir.const(d_range, "float32")
   )
   ```
   
   This implies `scale` is also affected by the exponent, leading to 
non-uniform scaling across dimensions.
   
   According to the YaRN method (and an implementation of `gpt-oss`), the 
scaling factor should be applied linearly:
   ```
   exponent = d * 2 % d_range / tir.const(d_range, "float32")
   freq_power = tir.power(theta, exponent)
   freq_inter = tir.const(1, "float32") / (scaling_factor * freq_power)
   ```
   
   #### `d_range`
   The `yarn_find_correction_range()` function was incorrectly using the 
current dimension index `d` to calculate thresholds. 
   This caused the ramp boundaries to shift dynamically per dimension. 
   It has been corrected to use the total dimension size (`d_range`) to ensure 
consistent frequency thresholds.
   
   Before: 
   ```
   yarn_find_correction_range(..., d, ...)
   ```
   
   After: 
   ```
   yarn_find_correction_range(..., d_range, ...)
   ```
   
   Thank you very much for reading.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to