gemini-code-assist[bot] commented on code in PR #18335:
URL: https://github.com/apache/tvm/pull/18335#discussion_r2370996074
##########
python/tvm/relax/frontend/nn/llm/position_embedding.py:
##########
@@ -75,6 +75,51 @@ def rope_freq_gptj(s: tir.Var, d: tir.Var, d_range: int,
theta: float, dtype: st
return cos_freq, sin_freq, {freq_var: freq}
+def rope_freq_llama4( # pylint: disable=too-many-arguments,too-many-locals
+ s: tir.Var,
+ d: tir.Var,
+ d_range: int,
+ theta: float,
+ dtype: str,
+ factor: float,
+ low_freq_factor: float,
+ high_freq_factor: float,
+ original_max_position_embeddings: float,
+):
+ """Compute the inverse frequency of RoPE for llama3 RoPE scaling."""
+ orig_freq = tir.const(1, "float32") / tir.power(
+ theta, 2 * (d // 2) / tir.const(d_range, "float32")
+ )
+ orig_freq_var = tir.Var("orig_freq", "float32")
+
+ llama3_inv_scaling_factor = 1.0 / factor
+
+ if high_freq_factor == low_freq_factor:
+ wavelength = tir.const(2 * math.pi, "float32") / orig_freq_var
+ threshold_wavelen = tir.const(original_max_position_embeddings /
low_freq_factor, "float32")
+
+ scaled_freq = tir.if_then_else(
+ wavelength > threshold_wavelen, orig_freq_var / factor,
orig_freq_var
+ )
+ smoothed_freq = s * scaled_freq
+
+ else:
+ # Original smooth interpolation logic
+ inv_diff_freq_factor = 1.0 / (high_freq_factor - low_freq_factor)
+
+ llama3_alpha = original_max_position_embeddings / (2 * math.pi) *
inv_diff_freq_factor
+ llama3_beta = low_freq_factor * inv_diff_freq_factor
+ smooth = tir.max(0.0, tir.min(1.0, llama3_alpha * orig_freq_var -
llama3_beta))
+ smoothed_freq = s * (
+ (1.0 - smooth) * orig_freq_var * llama3_inv_scaling_factor +
smooth * orig_freq_var
Review Comment:

The docstring and some variable names in `rope_freq_llama4` seem to be
copied from `rope_freq_llama3`. For consistency and to avoid confusion, they
should be updated to refer to `llama4`.
- The docstring at line 89 refers to "llama3 RoPE scaling".
- Variables `llama3_inv_scaling_factor`, `llama3_alpha`, and `llama3_beta`
are used in lines 95, 110, 111, 112, and 114.
##########
python/tvm/relax/frontend/nn/llm/position_embedding.py:
##########
@@ -545,3 +598,184 @@ def fused_rope_longrope_scaling( # pylint:
disable=too-many-locals
if is_longrope_scaling:
return fused_rope_longrope_scaling
return fused_rope
+
+
+def llama4_rope_with_position_map( # pylint: disable=too-many-arguments
+ theta: float,
+ scale: float,
+ head_dim: int,
+ num_q_heads: int,
+ num_kv_heads: int,
+ dtype: str,
+ rope_scaling: Dict[str, Any],
+ rotary_dim: Optional[int] = None,
+):
+ """Return the TIR function that computes Llama-style RoPE with q position
map.
+
+ Parameters
+ ----------
+ theta : float
+ The theta value, or "base" in RoPE, which controls the frequency.
+
+ scale : float
+ The RoPE scaling factor.
+
+ head_dim : int
+ The number of features on each head.
+
+ num_q_heads : int
+ The number of query heads.
+
+ num_kv_heads : int
+ The number of key/value heads. It differs from `num_q_heads` in
group-query attention.
+
+ dtype : str
+ The dtype of qkv data.
+
+ rope_scaling : Dict
+ The configuration of RoPE scaling.
+
+ rotary_dim : int
+ The number of dimensions in the embedding that RoPE is applied to. By
default, the
+ rotary_dim is the same as head_dim.
+ """
+ fused_heads = num_q_heads + num_kv_heads * 2
+ if rotary_dim is None:
+ rotary_dim = head_dim
+ scale = tir.const(scale, "float32")
+ is_longrope_scaling = rope_scaling.get("rope_type") == "longrope"
+
+ def _rope( # pylint: disable=too-many-arguments
+ x: T.Buffer,
+ s: tir.Var,
+ h: tir.Var,
+ d: tir.Var,
+ pos: tir.Var,
+ ext_factors: Optional[T.Buffer] = None,
+ ):
+ kwargs = {}
+ if ext_factors:
+ kwargs["ext_factors"] = ext_factors
+ cos_freq, sin_freq, var_map = switch_rope_freq_func(rope_scaling)(
+ pos * scale, d, rotary_dim, theta, "float32", **kwargs
+ )
+ cos = cos_freq * x[s, h, d].astype("float32")
+ if "rope_type" in rope_scaling and rope_scaling["rope_type"] == "gptj":
+ sin = sin_freq * tir.if_then_else(
+ d % 2 == 0,
+ -x[s, h, d + 1],
+ x[s, h, d - 1],
+ ).astype("float32")
+ else:
+ # Data layout is different for llama4 vs llama3
+ sin = sin_freq * tir.if_then_else(
+ d % 2 == 0,
+ -x[s, h, d + 1],
+ x[s, h, d - 1],
+ ).astype("float32")
+ expr = (cos + sin).astype(dtype)
+ for var, value in var_map.items():
+ expr = tir.Let(var, value, expr)
+ return expr
+
+ @T.prim_func(private=True)
+ def fused_rope( # pylint: disable=too-many-locals
+ var_qkv: T.handle,
+ var_position_map: T.handle,
+ var_q: T.handle,
+ var_k: T.handle,
+ var_v: T.handle,
+ apply_rope: T.int64,
+ ):
+ T.func_attr(
+ {
+ "op_pattern": 8, # 2 means injective, 8 means opaque
+ "tir.noalias": True,
+ }
+ )
+ seq_len = T.int32()
+ position_map_elem_offset = T.int32()
+ qkv = T.match_buffer(var_qkv, (seq_len, fused_heads, head_dim), dtype)
+ q = T.match_buffer(var_q, (seq_len, num_q_heads, head_dim), dtype)
+ k = T.match_buffer(var_k, (seq_len, num_kv_heads, head_dim), dtype)
+ v = T.match_buffer(var_v, (seq_len, num_kv_heads, head_dim), dtype)
+ position_map = T.match_buffer(
+ var_position_map, (seq_len,), "int32",
elem_offset=position_map_elem_offset
+ )
+ for iters in T.grid(seq_len, fused_heads, head_dim):
+ with T.block("llama_fused_rope"):
+ s, h, d = T.axis.remap("SSS", iters)
+ if h < num_q_heads:
+ q[s, h, d] = T.if_then_else(
+ apply_rope > 0 and d < rotary_dim,
+ _rope(qkv, s, h, d, position_map[s]),
+ qkv[s, h, d],
+ )
+ elif h < num_q_heads + num_kv_heads:
+ k[s, h - num_q_heads, d] = T.if_then_else(
+ apply_rope > 0 and d < rotary_dim,
+ _rope(qkv, s, h, d, position_map[s]),
+ qkv[s, h, d],
+ )
+ else:
+ v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d]
+
+ @T.prim_func
+ def fused_rope_longrope_scaling( # pylint: disable=too-many-locals
+ var_qkv: T.handle,
+ var_position_map: T.handle,
+ var_q: T.handle,
+ var_k: T.handle,
+ var_v: T.handle,
+ ext_factors: T.Buffer((rotary_dim // 2,), "float32"), # type: ignore
+ ):
+ T.func_attr(
+ {
+ "op_pattern": 8, # 2 means injective, 8 means opaque
+ "tir.noalias": True,
+ }
+ )
+ seq_len = T.int64()
+ position_map_elem_offset = T.int64()
+ qkv = T.match_buffer(var_qkv, (seq_len, fused_heads, head_dim), dtype)
+ q = T.match_buffer(var_q, (seq_len, num_q_heads, head_dim), dtype)
+ k = T.match_buffer(var_k, (seq_len, num_kv_heads, head_dim), dtype)
+ v = T.match_buffer(var_v, (seq_len, num_kv_heads, head_dim), dtype)
+ position_map = T.match_buffer(
+ var_position_map, (seq_len,), "int32",
elem_offset=position_map_elem_offset
+ )
+ for iters in T.grid(seq_len, fused_heads, head_dim):
+ with T.block("llama_fused_rope"):
+ s, h, d = T.axis.remap("SSS", iters)
+ if h < num_q_heads:
+ q[s, h, d] = T.if_then_else(
+ d < rotary_dim,
+ _rope(
+ qkv,
+ s,
+ h,
+ d,
+ position_map[s],
+ ext_factors if is_longrope_scaling else None,
+ ),
+ qkv[s, h, d],
+ )
+ elif h < num_q_heads + num_kv_heads:
+ k[s, h - num_q_heads, d] = T.if_then_else(
+ d < rotary_dim,
+ _rope(
+ qkv,
+ s,
+ h,
+ d,
+ position_map[s],
+ ext_factors if is_longrope_scaling else None,
+ ),
+ qkv[s, h, d],
+ )
+ else:
+ v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d]
+
+ if is_longrope_scaling:
+ return fused_rope_longrope_scaling
+ return fused_rope
Review Comment:

The new function `llama4_rope_with_position_map` is almost a complete
duplicate of `llama_rope_with_position_map`. The only significant difference is
the logic for calculating `sin` inside the nested `_rope` function, which
depends on the data layout.
To improve maintainability and reduce code duplication, consider merging
`llama4_rope_with_position_map` into `llama_rope_with_position_map`. You can
add a condition based on `rope_scaling["rope_type"]` to select the correct
`sin` calculation, similar to how it's already done for `"gptj"`.
For example, you could modify the `_rope` function inside
`llama_rope_with_position_map` like this:
```python
def _rope( # pylint: disable=too-many-arguments
x: T.Buffer,
s: tir.Var,
h: tir.Var,
d: tir.Var,
pos: tir.Var,
ext_factors: Optional[T.Buffer] = None,
):
# ... (cos_freq, sin_freq, var_map calculation)
cos = cos_freq * x[s, h, d].astype("float32")
rope_type = rope_scaling.get("rope_type")
if rope_type in ("gptj", "llama4"):
sin = sin_freq * tir.if_then_else(
d % 2 == 0,
-x[s, h, d + 1],
x[s, h, d - 1],
).astype("float32")
else: # Default for llama/llama3
sin = sin_freq * tir.if_then_else(
d < rotary_dim // 2,
-x[s, h, d + rotary_dim // 2],
x[s, h, d - rotary_dim // 2],
).astype("float32")
expr = (cos + sin).astype(dtype)
# ... (rest of the function)
return expr
```
This refactoring would eliminate about 180 lines of duplicated code and make
the codebase easier to maintain.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]