gemini-code-assist[bot] commented on code in PR #18661:
URL: https://github.com/apache/tvm/pull/18661#discussion_r2688257669
##########
python/tvm/relax/frontend/nn/llm/position_embedding.py:
##########
@@ -180,38 +180,39 @@ def rope_freq_longrope( # pylint:
disable=too-many-arguments
def yarn_find_correction_dim(
num_rotations: int,
d: tir.Var,
- theta: float,
+ theta: Union[float, tir.PrimExpr],
max_position_embeddings: int,
+ inv_theta_log_scale: Optional[Union[float, tir.PrimExpr]] = None,
):
"""Inverse dim formula to find dim based on number of rotations"""
- return (d * math.log(max_position_embeddings / (num_rotations * 2 *
math.pi))) / (
- 2 * math.log(theta)
- )
+ return d * math.log(max_position_embeddings / (num_rotations * 2 *
math.pi)) * inv_theta_log_scale
def yarn_find_correction_range(
low_rot: int,
high_rot: int,
d: tir.Var,
- theta: float,
+ theta: Union[float, tir.PrimExpr],
max_position_embeddings: int,
+ inv_theta_log_scale: Optional[Union[float, tir.PrimExpr]] = None,
):
"""Find the correction range based on the number of rotations"""
- low = yarn_find_correction_dim(low_rot, d, theta, max_position_embeddings)
- high = yarn_find_correction_dim(high_rot, d, theta,
max_position_embeddings)
+ low = yarn_find_correction_dim(low_rot, d, theta, max_position_embeddings,
inv_theta_log_scale=inv_theta_log_scale)
+ high = yarn_find_correction_dim(high_rot, d, theta,
max_position_embeddings, inv_theta_log_scale=inv_theta_log_scale)
return tir.max(low, 0), tir.min(high, d - 1)
Review Comment:

The `theta` parameter is no longer used in `yarn_find_correction_dim` and is
only passed through `yarn_find_correction_range`. It can be removed from both
function signatures to simplify the code.
```python
def yarn_find_correction_dim(
num_rotations: int,
d: tir.Var,
max_position_embeddings: int,
inv_theta_log_scale: Optional[Union[float, tir.PrimExpr]] = None,
):
"""Inverse dim formula to find dim based on number of rotations"""
return d * math.log(max_position_embeddings / (num_rotations * 2 *
math.pi)) * inv_theta_log_scale
def yarn_find_correction_range(
low_rot: int,
high_rot: int,
d: tir.Var,
max_position_embeddings: int,
inv_theta_log_scale: Optional[Union[float, tir.PrimExpr]] = None,
):
"""Find the correction range based on the number of rotations"""
low = yarn_find_correction_dim(low_rot, d, max_position_embeddings,
inv_theta_log_scale=inv_theta_log_scale)
high = yarn_find_correction_dim(high_rot, d, max_position_embeddings,
inv_theta_log_scale=inv_theta_log_scale)
return tir.max(low, 0), tir.min(high, d - 1)
```
##########
python/tvm/relax/frontend/nn/llm/position_embedding.py:
##########
@@ -221,7 +222,7 @@ def rope_freq_yarn(
freq_inter = tir.const(1, "float32") / (scaling_factor * freq_power)
low, high = yarn_find_correction_range(
- beta_fast, beta_slow, d_range, theta, original_max_position_embeddings
+ beta_fast, beta_slow, d_range, theta,
original_max_position_embeddings, inv_theta_log_scale=inv_theta_log_scale
Review Comment:

Following the removal of the `theta` parameter from
`yarn_find_correction_range`, this call site should be updated to no longer
pass it.
```suggestion
beta_fast, beta_slow, d_range, original_max_position_embeddings,
inv_theta_log_scale=inv_theta_log_scale
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]