This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 8218b18da3 [Relax] Add mod operator support (#18559)
8218b18da3 is described below

commit 8218b18da331f887934f72ab4f4b4a5f2c0dc082
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Tue Dec 9 20:11:48 2025 +0800

    [Relax] Add mod operator support (#18559)
    
    ## How
    
    - Resolve todo by changing from raising error to calling _op_ffi_api.mod
    - Add both operators to the parametrized test
---
 python/tvm/relax/expr.py             | 3 +--
 tests/python/relax/test_op_binary.py | 4 ++++
 2 files changed, 5 insertions(+), 2 deletions(-)

diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py
index 8dd4eff5c7..e9bc9a7a3e 100644
--- a/python/tvm/relax/expr.py
+++ b/python/tvm/relax/expr.py
@@ -185,8 +185,7 @@ class ExprWithOp(Expr, Scriptable):
         return _binary_rhs_helper(other)
 
     def __mod__(self, other: Expr) -> "ExprWithOp":
-        # TODO(siyuan): Support it after mod operator is supported in relax
-        raise ValueError("relax.mod is not supported yet.")
+        return _binary_op_helper(self, other, _op_ffi_api.mod)  # type: ignore
 
     def __rmod__(self, other: Expr) -> "ExprWithOp":
         return _binary_rhs_helper(other)
diff --git a/tests/python/relax/test_op_binary.py 
b/tests/python/relax/test_op_binary.py
index 20c111495d..3376569bf3 100644
--- a/tests/python/relax/test_op_binary.py
+++ b/tests/python/relax/test_op_binary.py
@@ -33,6 +33,8 @@ def test_op_correctness():
     assert relax.op.multiply(x, y).op == Op.get("relax.multiply")
     assert relax.op.power(x, y).op == Op.get("relax.power")
     assert relax.op.subtract(x, y).op == Op.get("relax.subtract")
+    assert relax.op.mod(x, y).op == Op.get("relax.mod")
+    assert relax.op.floor_mod(x, y).op == Op.get("relax.floor_mod")
 
     assert relax.op.equal(x, y).op == Op.get("relax.equal")
     assert relax.op.greater(x, y).op == Op.get("relax.greater")
@@ -70,6 +72,8 @@ def _check_inference(bb: relax.BlockBuilder, call: 
relax.Call, expected_sinfo: r
     (relax.op.subtract, tir.Sub),
     (relax.op.maximum, tir.Max),
     (relax.op.minimum, tir.Min),
+    (relax.op.mod, tir.Mod),
+    (relax.op.floor_mod, tir.FloorMod),
 )
 
 

Reply via email to