This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4747a92827 [Python][Relax] Fix YaRN correction dim calculation (#18661)
4747a92827 is described below
commit 4747a9282728614bf751f9c96b2c91a41a3abb1d
Author: Akaash Parthasarathy <[email protected]>
AuthorDate: Fri Jan 30 16:07:40 2026 -0500
[Python][Relax] Fix YaRN correction dim calculation (#18661)
Precompute ```inv_theta_log_scale```
---
python/tvm/relax/frontend/nn/llm/kv_cache.py | 19 ++++++++++++++
.../relax/frontend/nn/llm/position_embedding.py | 30 +++++++++++++++-------
2 files changed, 40 insertions(+), 9 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py
b/python/tvm/relax/frontend/nn/llm/kv_cache.py
index 6b6029630d..d2df939c3d 100644
--- a/python/tvm/relax/frontend/nn/llm/kv_cache.py
+++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py
@@ -297,6 +297,23 @@ class PagedKVCache(Object): # pylint:
disable=too-few-public-methods
# pylint: enable=protected-access
+def _prepare_yarn_rope_scaling(
+ rope_scaling: Optional[Dict[str, Any]],
+ rope_theta: Optional[float],
+) -> Optional[Dict[str, Any]]:
+ """Ensure Yarn-specific scaling configs include the theta metadata."""
+ if rope_scaling is None:
+ return None
+ if rope_scaling.get("rope_type") != "yarn":
+ return rope_scaling
+
+ rope_scaling_updated = dict(rope_scaling)
+ if "inv_theta_log_scale" not in rope_scaling_updated and rope_theta is not
None:
+ theta_value = float(rope_theta)
+ rope_scaling_updated["inv_theta_log_scale"] = 1.0 / (2 *
math.log(theta_value))
+ return rope_scaling_updated
+
+
class FlashInferPagedKVCache(PagedKVCache): # pylint:
disable=too-few-public-methods
"""Paged KV cache using FlashInfer (CUDA) kernels."""
@@ -372,6 +389,7 @@ class FlashInferPagedKVCache(PagedKVCache): # pylint:
disable=too-few-public-me
Whether to enable disaggregation in the KV cache.
"""
assert rope_mode != RopeMode.INLINE, "FlashInfer RoPE does not support
inline mode."
+ rope_scaling = _prepare_yarn_rope_scaling(rope_scaling, rope_theta)
attn_kind_single = attn_kind[0] if isinstance(attn_kind, List) else
attn_kind
if attn_kind_single == "mha_sliding":
@@ -561,6 +579,7 @@ class TIRPagedKVCache(PagedKVCache): # pylint:
disable=too-few-public-methods
target : Target
The target to build the model to.
"""
+ rope_scaling = _prepare_yarn_rope_scaling(rope_scaling, rope_theta)
attn_kind_single = attn_kind[0] if isinstance(attn_kind, List) else
attn_kind
if attn_kind_single == "mha_sliding":
attn_kind_single = "mha"
diff --git a/python/tvm/relax/frontend/nn/llm/position_embedding.py
b/python/tvm/relax/frontend/nn/llm/position_embedding.py
index b90b4bfecd..ee2a356299 100644
--- a/python/tvm/relax/frontend/nn/llm/position_embedding.py
+++ b/python/tvm/relax/frontend/nn/llm/position_embedding.py
@@ -19,7 +19,7 @@
import math
from functools import partial
-from typing import Any, Callable, Dict, Optional, Tuple
+from typing import Any, Callable, Dict, Optional, Tuple, Union
from tvm import tir
from tvm.relax.frontend.nn import Tensor, op
@@ -180,12 +180,12 @@ def rope_freq_longrope( # pylint:
disable=too-many-arguments
def yarn_find_correction_dim(
num_rotations: int,
d: tir.Var,
- theta: float,
max_position_embeddings: int,
+ inv_theta_log_scale: Optional[Union[float, tir.PrimExpr]] = None,
):
"""Inverse dim formula to find dim based on number of rotations"""
- return (d * math.log(max_position_embeddings / (num_rotations * 2 *
math.pi))) / (
- 2 * math.log(theta)
+ return (
+ d * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))
* inv_theta_log_scale
)
@@ -193,12 +193,16 @@ def yarn_find_correction_range(
low_rot: int,
high_rot: int,
d: tir.Var,
- theta: float,
max_position_embeddings: int,
+ inv_theta_log_scale: Optional[Union[float, tir.PrimExpr]] = None,
):
"""Find the correction range based on the number of rotations"""
- low = yarn_find_correction_dim(low_rot, d, theta, max_position_embeddings)
- high = yarn_find_correction_dim(high_rot, d, theta,
max_position_embeddings)
+ low = yarn_find_correction_dim(
+ low_rot, d, max_position_embeddings,
inv_theta_log_scale=inv_theta_log_scale
+ )
+ high = yarn_find_correction_dim(
+ high_rot, d, max_position_embeddings,
inv_theta_log_scale=inv_theta_log_scale
+ )
return tir.max(low, 0), tir.min(high, d - 1)
@@ -206,12 +210,13 @@ def rope_freq_yarn(
s: tir.Var,
d: tir.Var,
d_range: int,
- theta: float,
+ theta: Union[float, tir.PrimExpr],
dtype: str,
original_max_position_embeddings: int,
scaling_factor: float,
beta_fast: int,
beta_slow: int,
+ inv_theta_log_scale: Optional[Union[float, tir.PrimExpr]] = None,
): # pylint: disable=too-many-arguments, too-many-locals
"""Compute the inverse frequency of RoPE for yarn RoPE scaling."""
@@ -221,7 +226,11 @@ def rope_freq_yarn(
freq_inter = tir.const(1, "float32") / (scaling_factor * freq_power)
low, high = yarn_find_correction_range(
- beta_fast, beta_slow, d_range, theta, original_max_position_embeddings
+ beta_fast,
+ beta_slow,
+ d_range,
+ original_max_position_embeddings,
+ inv_theta_log_scale=inv_theta_log_scale,
)
high = tir.if_then_else(low == high, high + 0.001, high)
inv_freq_mask = tir.const(1, "float32") - tir.max(
@@ -266,12 +275,15 @@ def switch_rope_freq_func(rope_scaling: Dict[str, Any])
-> Callable:
original_max_position_embeddings=rope_scaling["original_max_position_embeddings"],
)
if rope_scaling["rope_type"] == "yarn":
+ inv_theta_log_scale = rope_scaling.get("inv_theta_log_scale")
+ assert inv_theta_log_scale is not None, "inv_theta_log_scale must be
precomputed for YaRN"
return partial(
rope_freq_yarn,
original_max_position_embeddings=rope_scaling["original_max_position_embeddings"],
scaling_factor=rope_scaling["factor"],
beta_fast=rope_scaling["beta_fast"],
beta_slow=rope_scaling["beta_slow"],
+ inv_theta_log_scale=inv_theta_log_scale,
)
raise ValueError(f'Unsupported RoPE scaling type:
{rope_scaling["rope_type"]}')