This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 1bd40faede [Relax] Fix llama4_rope_with_position_map to support
partial rotary factor (#18520)
1bd40faede is described below
commit 1bd40faede42c8d4100c40c35edf02df1e756222
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Fri Feb 13 02:46:24 2026 +0800
[Relax] Fix llama4_rope_with_position_map to support partial rotary factor
(#18520)
## Related Issue
closes #17715
## Why
- Phi-4 uses: partial_rotary_factor = 0.75 (rotary_dim = 96) + longrope
scaling
- Longrope requires: Both long_factors + short_factors packed into one
buffer
- Expected buffer size: (rotary_dim,) = (96,) total
- First half [0:48] = long_factors
- Second half [48:96] = short_factors
- llama4_rope_with_position_map still had old size (rotary_dim // 2,) =
(48,)
---
.../relax/frontend/nn/llm/position_embedding.py | 129 +++++++++++++++------
1 file changed, 93 insertions(+), 36 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/llm/position_embedding.py
b/python/tvm/relax/frontend/nn/llm/position_embedding.py
index ee2a356299..e2a7801add 100644
--- a/python/tvm/relax/frontend/nn/llm/position_embedding.py
+++ b/python/tvm/relax/frontend/nn/llm/position_embedding.py
@@ -117,7 +117,11 @@ def rope_freq_llama4( # pylint:
disable=too-many-arguments,too-many-locals
smoothed_freq_var = tir.Var("smoothed_freq", "float32")
cos_freq = tir.cos(smoothed_freq_var).astype(dtype)
sin_freq = tir.sin(smoothed_freq_var).astype(dtype)
- return cos_freq, sin_freq, {smoothed_freq_var: smoothed_freq,
orig_freq_var: orig_freq}
+ return (
+ cos_freq,
+ sin_freq,
+ {smoothed_freq_var: smoothed_freq, orig_freq_var: orig_freq},
+ )
def rope_freq_llama3( # pylint: disable=too-many-arguments,too-many-locals
@@ -147,7 +151,11 @@ def rope_freq_llama3( # pylint:
disable=too-many-arguments,too-many-locals
smoothed_freq_var = tir.Var("smoothed_freq", "float32")
cos_freq = tir.cos(smoothed_freq_var).astype(dtype)
sin_freq = tir.sin(smoothed_freq_var).astype(dtype)
- return cos_freq, sin_freq, {smoothed_freq_var: smoothed_freq,
orig_freq_var: orig_freq}
+ return (
+ cos_freq,
+ sin_freq,
+ {smoothed_freq_var: smoothed_freq, orig_freq_var: orig_freq},
+ )
def rope_freq_longrope( # pylint: disable=too-many-arguments
@@ -285,7 +293,7 @@ def switch_rope_freq_func(rope_scaling: Dict[str, Any]) ->
Callable:
beta_slow=rope_scaling["beta_slow"],
inv_theta_log_scale=inv_theta_log_scale,
)
- raise ValueError(f'Unsupported RoPE scaling type:
{rope_scaling["rope_type"]}')
+ raise ValueError(f"Unsupported RoPE scaling type:
{rope_scaling['rope_type']}")
# mypy: disable-error-code="attr-defined"
@@ -580,7 +588,10 @@ def llama_rope_with_position_map( # pylint:
disable=too-many-arguments
# long factors is the first half, short factors is the second half
long_factors = T.Buffer((rotary_dim // 2,), "float32",
data=ext_factors.data)
short_factors = T.Buffer(
- (rotary_dim // 2,), "float32", data=ext_factors.data,
elem_offset=(rotary_dim // 2)
+ (rotary_dim // 2,),
+ "float32",
+ data=ext_factors.data,
+ elem_offset=(rotary_dim // 2),
)
if seq_len > original_max_position_embeddings:
@@ -697,6 +708,10 @@ def llama4_rope_with_position_map( # pylint:
disable=too-many-arguments
rotary_dim = head_dim
scale = tir.const(scale, "float32")
is_longrope_scaling = rope_scaling.get("rope_type") == "longrope"
+ if is_longrope_scaling and "original_max_position_embeddings" in
rope_scaling:
+ original_max_position_embeddings =
rope_scaling["original_max_position_embeddings"]
+ else:
+ original_max_position_embeddings = 0
def _rope( # pylint: disable=too-many-arguments
x: T.Buffer,
@@ -780,7 +795,7 @@ def llama4_rope_with_position_map( # pylint:
disable=too-many-arguments
var_q: T.handle,
var_k: T.handle,
var_v: T.handle,
- ext_factors: T.Buffer((rotary_dim // 2,), "float32"), # type: ignore
+ ext_factors: T.Buffer((rotary_dim,), "float32"), # type: ignore
):
T.func_attr(
{
@@ -797,37 +812,79 @@ def llama4_rope_with_position_map( # pylint:
disable=too-many-arguments
position_map = T.match_buffer(
var_position_map, (seq_len,), "int32",
elem_offset=position_map_elem_offset
)
- for iters in T.grid(seq_len, fused_heads, head_dim):
- with T.sblock("llama_fused_rope"):
- s, h, d = T.axis.remap("SSS", iters)
- if h < num_q_heads:
- q[s, h, d] = T.if_then_else(
- d < rotary_dim,
- _rope(
- qkv,
- s,
- h,
- d,
- position_map[s],
- ext_factors if is_longrope_scaling else None,
- ),
- qkv[s, h, d],
- )
- elif h < num_q_heads + num_kv_heads:
- k[s, h - num_q_heads, d] = T.if_then_else(
- d < rotary_dim,
- _rope(
- qkv,
- s,
- h,
- d,
- position_map[s],
- ext_factors if is_longrope_scaling else None,
- ),
- qkv[s, h, d],
- )
- else:
- v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d]
+ # long factors is the first half, short factors is the second half
+ long_factors = T.Buffer((rotary_dim // 2,), "float32",
data=ext_factors.data)
+ short_factors = T.Buffer(
+ (rotary_dim // 2,),
+ "float32",
+ data=ext_factors.data,
+ elem_offset=(rotary_dim // 2),
+ )
+
+ if seq_len > original_max_position_embeddings:
+ for iters in T.grid(seq_len, fused_heads, head_dim):
+ with T.sblock("llama_fused_rope"):
+ s, h, d = T.axis.remap("SSS", iters)
+ if h < num_q_heads:
+ q[s, h, d] = T.if_then_else(
+ d < rotary_dim,
+ _rope(
+ qkv,
+ s,
+ h,
+ d,
+ position_map[s],
+ long_factors if is_longrope_scaling else None,
+ ),
+ qkv[s, h, d],
+ )
+ elif h < num_q_heads + num_kv_heads:
+ k[s, h - num_q_heads, d] = T.if_then_else(
+ d < rotary_dim,
+ _rope(
+ qkv,
+ s,
+ h,
+ d,
+ position_map[s],
+ long_factors if is_longrope_scaling else None,
+ ),
+ qkv[s, h, d],
+ )
+ else:
+ v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h,
d]
+ else:
+ for iters in T.grid(seq_len, fused_heads, head_dim):
+ with T.sblock("llama_fused_rope"):
+ s, h, d = T.axis.remap("SSS", iters)
+ if h < num_q_heads:
+ q[s, h, d] = T.if_then_else(
+ d < rotary_dim,
+ _rope(
+ qkv,
+ s,
+ h,
+ d,
+ position_map[s],
+ short_factors if is_longrope_scaling else None,
+ ),
+ qkv[s, h, d],
+ )
+ elif h < num_q_heads + num_kv_heads:
+ k[s, h - num_q_heads, d] = T.if_then_else(
+ d < rotary_dim,
+ _rope(
+ qkv,
+ s,
+ h,
+ d,
+ position_map[s],
+ short_factors if is_longrope_scaling else None,
+ ),
+ qkv[s, h, d],
+ )
+ else:
+ v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h,
d]
if is_longrope_scaling:
return fused_rope_longrope_scaling