This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new fe5364992d [Python][Relax] Update Rotary positional embedding scaling
(#17506)
fe5364992d is described below
commit fe5364992d29900e91b26dbfdb5cf76a3d66f09c
Author: Shushi Hong <[email protected]>
AuthorDate: Fri Nov 15 23:03:36 2024 +0800
[Python][Relax] Update Rotary positional embedding scaling (#17506)
This PR introduces two more styles of RoPE scaling:
the gptj style and the yarn scale.
---
.../relax/frontend/nn/llm/position_embedding.py | 113 +++++++++++++++++++--
1 file changed, 103 insertions(+), 10 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/llm/position_embedding.py
b/python/tvm/relax/frontend/nn/llm/position_embedding.py
index 4373395e32..f5a2831382 100644
--- a/python/tvm/relax/frontend/nn/llm/position_embedding.py
+++ b/python/tvm/relax/frontend/nn/llm/position_embedding.py
@@ -66,6 +66,15 @@ def rope_freq_default(s: tir.Var, d: tir.Var, d_range: int,
theta: float, dtype:
return cos_freq, sin_freq, {freq_var: freq}
+def rope_freq_gptj(s: tir.Var, d: tir.Var, d_range: int, theta: float, dtype:
str):
+ """Compute the inverse frequency of RoPE for gptj RoPE scaling."""
+ freq = s / tir.power(theta, 2 * (d // 2) % d_range / tir.const(d_range,
"float32"))
+ freq_var = tir.Var("freq", "float32")
+ cos_freq = tir.cos(freq_var).astype(dtype)
+ sin_freq = tir.sin(freq_var).astype(dtype)
+ return cos_freq, sin_freq, {freq_var: freq}
+
+
def rope_freq_llama3( # pylint: disable=too-many-arguments,too-many-locals
s: tir.Var,
d: tir.Var,
@@ -123,12 +132,74 @@ def rope_freq_longrope( # pylint:
disable=too-many-arguments
return cos_freq, sin_freq, {freq_var: freq}
+def yarn_find_correction_dim(
+ num_rotations: int,
+ d: tir.Var,
+ theta: float,
+ max_position_embeddings: int,
+):
+ """Inverse dim formula to find dim based on number of rotations"""
+ return (d * math.log(max_position_embeddings / (num_rotations * 2 *
math.pi))) / (
+ 2 * math.log(theta)
+ )
+
+
+def yarn_find_correction_range(
+ low_rot: int,
+ high_rot: int,
+ d: tir.Var,
+ theta: float,
+ max_position_embeddings: int,
+):
+ """Find the correction range based on the number of rotations"""
+ low = tir.floor(yarn_find_correction_dim(low_rot, d, theta,
max_position_embeddings))
+ high = tir.ceil(yarn_find_correction_dim(high_rot, d, theta,
max_position_embeddings))
+ return tir.max(low, 0), tir.min(high, d - 1)
+
+
+def rope_freq_yarn(
+ s: tir.Var,
+ d: tir.Var,
+ d_range: int,
+ theta: float,
+ dtype: str,
+ original_max_position_embeddings: int,
+ scaling_factor: float,
+ beta_fast: int,
+ beta_slow: int,
+): # pylint: disable=too-many-arguments, too-many-locals
+ """Compute the inverse frequency of RoPE for yarn RoPE scaling."""
+ freq_extra = tir.const(1, "float32") / tir.power(
+ theta, d * 2 % d_range / tir.const(d_range, "float32")
+ )
+
+ freq_inter = tir.const(1, "float32") / tir.power(
+ scaling_factor * theta, d * 2 % d_range / tir.const(d_range, "float32")
+ )
+
+ low, high = yarn_find_correction_range(
+ beta_fast, beta_slow, d, theta, original_max_position_embeddings
+ )
+ high = tir.if_then_else(low == high, high + 0.001, high)
+ inv_freq_mask = tir.const(1, "float32") - tir.max(
+ tir.min((d - low) / (high - low), 1.0), 0.0
+ ).astype("float32")
+ inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
+ freq = s * inv_freq
+ freq_var = tir.Var("freq", "float32")
+ cos_freq = tir.cos(freq_var).astype(dtype)
+ sin_freq = tir.sin(freq_var).astype(dtype)
+ return cos_freq, sin_freq, {freq_var: freq}
+
+
def switch_rope_freq_func(rope_scaling: Dict[str, Any]) -> Callable:
"""Return the RoPE inverse frequency computation function based
on the given RoPE scaling.
"""
if "rope_type" not in rope_scaling:
return rope_freq_default
+ if rope_scaling["rope_type"] == "gptj":
+ return rope_freq_gptj
if rope_scaling["rope_type"] == "llama3":
return partial(
rope_freq_llama3,
@@ -143,6 +214,14 @@ def switch_rope_freq_func(rope_scaling: Dict[str, Any]) ->
Callable:
max_position_embeddings=rope_scaling["max_position_embeddings"],
original_max_position_embeddings=rope_scaling["original_max_position_embeddings"],
)
+ if rope_scaling["rope_type"] == "yarn":
+ return partial(
+ rope_freq_yarn,
+
original_max_position_embeddings=rope_scaling["original_max_position_embeddings"],
+ scaling_factor=rope_scaling["factor"],
+ beta_fast=rope_scaling["beta_fast"],
+ beta_slow=rope_scaling["beta_slow"],
+ )
raise ValueError(f'Unsupported RoPE scaling type:
{rope_scaling["rope_type"]}')
@@ -220,11 +299,18 @@ def llama_rope( # pylint: disable=too-many-arguments
(s + offset) * scale, d, rotary_dim, theta, dtype
)
cos = cos_freq * x[b, s, h, d]
- sin = sin_freq * tir.if_then_else(
- d < rotary_dim // 2,
- -x[b, s, h, d + rotary_dim // 2],
- x[b, s, h, d - rotary_dim // 2],
- )
+ if rope_scaling["rope_type"] == "gptj":
+ sin = sin_freq * tir.if_then_else(
+ d % 2 == 0,
+ -x[b, s, h, d + 1],
+ x[b, s, h, d - 1],
+ )
+ else:
+ sin = sin_freq * tir.if_then_else(
+ d < rotary_dim // 2,
+ -x[b, s, h, d + rotary_dim // 2],
+ x[b, s, h, d - rotary_dim // 2],
+ )
expr = cos + sin
for var, value in var_map.items():
expr = tir.Let(var, value, expr)
@@ -341,11 +427,18 @@ def llama_rope_with_position_map( # pylint:
disable=too-many-arguments
pos * scale, d, rotary_dim, theta, "float32", **kwargs
)
cos = cos_freq * x[s, h, d].astype("float32")
- sin = sin_freq * tir.if_then_else(
- d < rotary_dim // 2,
- -x[s, h, d + rotary_dim // 2],
- x[s, h, d - rotary_dim // 2],
- ).astype("float32")
+ if "rope_type" in rope_scaling and rope_scaling["rope_type"] == "gptj":
+ sin = sin_freq * tir.if_then_else(
+ d % 2 == 0,
+ -x[s, h, d + 1],
+ x[s, h, d - 1],
+ ).astype("float32")
+ else:
+ sin = sin_freq * tir.if_then_else(
+ d < rotary_dim // 2,
+ -x[s, h, d + rotary_dim // 2],
+ x[s, h, d - rotary_dim // 2],
+ ).astype("float32")
expr = (cos + sin).astype(dtype)
for var, value in var_map.items():
expr = tir.Let(var, value, expr)