This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4b555a964f Adjusted Longrope embedding function to match Huggingface
Implementation (#18422)
4b555a964f is described below
commit 4b555a964f39b519eac13d6350cab30f88466fb3
Author: Sidharth N. Babu <[email protected]>
AuthorDate: Wed Nov 12 13:31:53 2025 -0500
Adjusted Longrope embedding function to match Huggingface Implementation
(#18422)
This updated implementation of longrope allows for the consideration
of `long_factors` and `short_factors`, which are scaling dictionaries
provided via HF configs for MSFT's Phi3+ models. In the HF canonical
implementation of longrope, once the sequence length exceeds a certain
pre-configured dimension, you must use a different set of `ext_factors`
than you were previously. This patch enables this by packing both sets
of scaling factors into one argument, and selecting which to use
dynamically within the returned `prim_func`.
The HF implementation of this can be found here:
https://github.com/huggingface/transformers/blob/7b325cd573e40bbb12951b8446176c96e8b1afaa/src/transformers/modeling_rope_utils.py#L521
The link above points directly to the switching logic between long
and short factors, which has been replicated in this PR.
---
.../relax/frontend/nn/llm/position_embedding.py | 107 +++++++++++++++------
1 file changed, 75 insertions(+), 32 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/llm/position_embedding.py
b/python/tvm/relax/frontend/nn/llm/position_embedding.py
index 6fda4b0bca..35eeb4f5f3 100644
--- a/python/tvm/relax/frontend/nn/llm/position_embedding.py
+++ b/python/tvm/relax/frontend/nn/llm/position_embedding.py
@@ -464,6 +464,10 @@ def llama_rope_with_position_map( # pylint:
disable=too-many-arguments
rotary_dim = head_dim
scale = tir.const(scale, "float32")
is_longrope_scaling = rope_scaling.get("rope_type") == "longrope"
+ if is_longrope_scaling and "original_max_position_embeddings" in
rope_scaling:
+ original_max_position_embeddings =
rope_scaling["original_max_position_embeddings"]
+ else:
+ original_max_position_embeddings = 0
def _rope( # pylint: disable=too-many-arguments
x: T.Buffer,
@@ -546,7 +550,7 @@ def llama_rope_with_position_map( # pylint:
disable=too-many-arguments
var_q: T.handle,
var_k: T.handle,
var_v: T.handle,
- ext_factors: T.Buffer((rotary_dim // 2,), "float32"), # type: ignore
+ ext_factors: T.Buffer((rotary_dim,), "float32"), # type: ignore
):
T.func_attr(
{
@@ -563,37 +567,76 @@ def llama_rope_with_position_map( # pylint:
disable=too-many-arguments
position_map = T.match_buffer(
var_position_map, (seq_len,), "int32",
elem_offset=position_map_elem_offset
)
- for iters in T.grid(seq_len, fused_heads, head_dim):
- with T.block("llama_fused_rope"):
- s, h, d = T.axis.remap("SSS", iters)
- if h < num_q_heads:
- q[s, h, d] = T.if_then_else(
- d < rotary_dim,
- _rope(
- qkv,
- s,
- h,
- d,
- position_map[s],
- ext_factors if is_longrope_scaling else None,
- ),
- qkv[s, h, d],
- )
- elif h < num_q_heads + num_kv_heads:
- k[s, h - num_q_heads, d] = T.if_then_else(
- d < rotary_dim,
- _rope(
- qkv,
- s,
- h,
- d,
- position_map[s],
- ext_factors if is_longrope_scaling else None,
- ),
- qkv[s, h, d],
- )
- else:
- v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d]
+ # long factors is the first half, short factors is the second half
+ long_factors = T.Buffer((rotary_dim // 2,), "float32",
data=ext_factors.data)
+ short_factors = T.Buffer(
+ (rotary_dim // 2,), "float32", data=ext_factors.data,
elem_offset=(rotary_dim // 2)
+ )
+
+ if seq_len > original_max_position_embeddings:
+ for iters in T.grid(seq_len, fused_heads, head_dim):
+ with T.block("llama_fused_rope"):
+ s, h, d = T.axis.remap("SSS", iters)
+ if h < num_q_heads:
+ q[s, h, d] = T.if_then_else(
+ d < rotary_dim,
+ _rope(
+ qkv,
+ s,
+ h,
+ d,
+ position_map[s],
+ long_factors if is_longrope_scaling else None,
+ ),
+ qkv[s, h, d],
+ )
+ elif h < num_q_heads + num_kv_heads:
+ k[s, h - num_q_heads, d] = T.if_then_else(
+ d < rotary_dim,
+ _rope(
+ qkv,
+ s,
+ h,
+ d,
+ position_map[s],
+ long_factors if is_longrope_scaling else None,
+ ),
+ qkv[s, h, d],
+ )
+ else:
+ v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h,
d]
+ else:
+ for iters in T.grid(seq_len, fused_heads, head_dim):
+ with T.block("llama_fused_rope"):
+ s, h, d = T.axis.remap("SSS", iters)
+ if h < num_q_heads:
+ q[s, h, d] = T.if_then_else(
+ d < rotary_dim,
+ _rope(
+ qkv,
+ s,
+ h,
+ d,
+ position_map[s],
+ short_factors if is_longrope_scaling else None,
+ ),
+ qkv[s, h, d],
+ )
+ elif h < num_q_heads + num_kv_heads:
+ k[s, h - num_q_heads, d] = T.if_then_else(
+ d < rotary_dim,
+ _rope(
+ qkv,
+ s,
+ h,
+ d,
+ position_map[s],
+ short_factors if is_longrope_scaling else None,
+ ),
+ qkv[s, h, d],
+ )
+ else:
+ v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h,
d]
if is_longrope_scaling:
return fused_rope_longrope_scaling