This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new fe5364992d [Python][Relax] Update Rotary positional embedding scaling 
(#17506)
fe5364992d is described below

commit fe5364992d29900e91b26dbfdb5cf76a3d66f09c
Author: Shushi Hong <[email protected]>
AuthorDate: Fri Nov 15 23:03:36 2024 +0800

    [Python][Relax] Update Rotary positional embedding scaling (#17506)
    
    This PR introduces two more styles of RoPE scaling:
    the gptj style and the yarn scale.
---
 .../relax/frontend/nn/llm/position_embedding.py    | 113 +++++++++++++++++++--
 1 file changed, 103 insertions(+), 10 deletions(-)

diff --git a/python/tvm/relax/frontend/nn/llm/position_embedding.py 
b/python/tvm/relax/frontend/nn/llm/position_embedding.py
index 4373395e32..f5a2831382 100644
--- a/python/tvm/relax/frontend/nn/llm/position_embedding.py
+++ b/python/tvm/relax/frontend/nn/llm/position_embedding.py
@@ -66,6 +66,15 @@ def rope_freq_default(s: tir.Var, d: tir.Var, d_range: int, 
theta: float, dtype:
     return cos_freq, sin_freq, {freq_var: freq}
 
 
+def rope_freq_gptj(s: tir.Var, d: tir.Var, d_range: int, theta: float, dtype: 
str):
+    """Compute the inverse frequency of RoPE for gptj RoPE scaling."""
+    freq = s / tir.power(theta, 2 * (d // 2) % d_range / tir.const(d_range, 
"float32"))
+    freq_var = tir.Var("freq", "float32")
+    cos_freq = tir.cos(freq_var).astype(dtype)
+    sin_freq = tir.sin(freq_var).astype(dtype)
+    return cos_freq, sin_freq, {freq_var: freq}
+
+
 def rope_freq_llama3(  # pylint: disable=too-many-arguments,too-many-locals
     s: tir.Var,
     d: tir.Var,
@@ -123,12 +132,74 @@ def rope_freq_longrope(  # pylint: 
disable=too-many-arguments
     return cos_freq, sin_freq, {freq_var: freq}
 
 
+def yarn_find_correction_dim(
+    num_rotations: int,
+    d: tir.Var,
+    theta: float,
+    max_position_embeddings: int,
+):
+    """Inverse dim formula to find dim based on number of rotations"""
+    return (d * math.log(max_position_embeddings / (num_rotations * 2 * 
math.pi))) / (
+        2 * math.log(theta)
+    )
+
+
+def yarn_find_correction_range(
+    low_rot: int,
+    high_rot: int,
+    d: tir.Var,
+    theta: float,
+    max_position_embeddings: int,
+):
+    """Find the correction range based on the number of rotations"""
+    low = tir.floor(yarn_find_correction_dim(low_rot, d, theta, 
max_position_embeddings))
+    high = tir.ceil(yarn_find_correction_dim(high_rot, d, theta, 
max_position_embeddings))
+    return tir.max(low, 0), tir.min(high, d - 1)
+
+
+def rope_freq_yarn(
+    s: tir.Var,
+    d: tir.Var,
+    d_range: int,
+    theta: float,
+    dtype: str,
+    original_max_position_embeddings: int,
+    scaling_factor: float,
+    beta_fast: int,
+    beta_slow: int,
+):  # pylint: disable=too-many-arguments, too-many-locals
+    """Compute the inverse frequency of RoPE for yarn RoPE scaling."""
+    freq_extra = tir.const(1, "float32") / tir.power(
+        theta, d * 2 % d_range / tir.const(d_range, "float32")
+    )
+
+    freq_inter = tir.const(1, "float32") / tir.power(
+        scaling_factor * theta, d * 2 % d_range / tir.const(d_range, "float32")
+    )
+
+    low, high = yarn_find_correction_range(
+        beta_fast, beta_slow, d, theta, original_max_position_embeddings
+    )
+    high = tir.if_then_else(low == high, high + 0.001, high)
+    inv_freq_mask = tir.const(1, "float32") - tir.max(
+        tir.min((d - low) / (high - low), 1.0), 0.0
+    ).astype("float32")
+    inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
+    freq = s * inv_freq
+    freq_var = tir.Var("freq", "float32")
+    cos_freq = tir.cos(freq_var).astype(dtype)
+    sin_freq = tir.sin(freq_var).astype(dtype)
+    return cos_freq, sin_freq, {freq_var: freq}
+
+
 def switch_rope_freq_func(rope_scaling: Dict[str, Any]) -> Callable:
     """Return the RoPE inverse frequency computation function based
     on the given RoPE scaling.
     """
     if "rope_type" not in rope_scaling:
         return rope_freq_default
+    if rope_scaling["rope_type"] == "gptj":
+        return rope_freq_gptj
     if rope_scaling["rope_type"] == "llama3":
         return partial(
             rope_freq_llama3,
@@ -143,6 +214,14 @@ def switch_rope_freq_func(rope_scaling: Dict[str, Any]) -> 
Callable:
             max_position_embeddings=rope_scaling["max_position_embeddings"],
             
original_max_position_embeddings=rope_scaling["original_max_position_embeddings"],
         )
+    if rope_scaling["rope_type"] == "yarn":
+        return partial(
+            rope_freq_yarn,
+            
original_max_position_embeddings=rope_scaling["original_max_position_embeddings"],
+            scaling_factor=rope_scaling["factor"],
+            beta_fast=rope_scaling["beta_fast"],
+            beta_slow=rope_scaling["beta_slow"],
+        )
     raise ValueError(f'Unsupported RoPE scaling type: 
{rope_scaling["rope_type"]}')
 
 
@@ -220,11 +299,18 @@ def llama_rope(  # pylint: disable=too-many-arguments
             (s + offset) * scale, d, rotary_dim, theta, dtype
         )
         cos = cos_freq * x[b, s, h, d]
-        sin = sin_freq * tir.if_then_else(
-            d < rotary_dim // 2,
-            -x[b, s, h, d + rotary_dim // 2],
-            x[b, s, h, d - rotary_dim // 2],
-        )
+        if rope_scaling["rope_type"] == "gptj":
+            sin = sin_freq * tir.if_then_else(
+                d % 2 == 0,
+                -x[b, s, h, d + 1],
+                x[b, s, h, d - 1],
+            )
+        else:
+            sin = sin_freq * tir.if_then_else(
+                d < rotary_dim // 2,
+                -x[b, s, h, d + rotary_dim // 2],
+                x[b, s, h, d - rotary_dim // 2],
+            )
         expr = cos + sin
         for var, value in var_map.items():
             expr = tir.Let(var, value, expr)
@@ -341,11 +427,18 @@ def llama_rope_with_position_map(  # pylint: 
disable=too-many-arguments
             pos * scale, d, rotary_dim, theta, "float32", **kwargs
         )
         cos = cos_freq * x[s, h, d].astype("float32")
-        sin = sin_freq * tir.if_then_else(
-            d < rotary_dim // 2,
-            -x[s, h, d + rotary_dim // 2],
-            x[s, h, d - rotary_dim // 2],
-        ).astype("float32")
+        if "rope_type" in rope_scaling and rope_scaling["rope_type"] == "gptj":
+            sin = sin_freq * tir.if_then_else(
+                d % 2 == 0,
+                -x[s, h, d + 1],
+                x[s, h, d - 1],
+            ).astype("float32")
+        else:
+            sin = sin_freq * tir.if_then_else(
+                d < rotary_dim // 2,
+                -x[s, h, d + rotary_dim // 2],
+                x[s, h, d - rotary_dim // 2],
+            ).astype("float32")
         expr = (cos + sin).astype(dtype)
         for var, value in var_map.items():
             expr = tir.Let(var, value, expr)

Reply via email to