This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 432ccfa7a5 [Relax][PyTorch] Add support for and_, lshift, min, or_, 
rshift, xor ops (#17668)
432ccfa7a5 is described below

commit 432ccfa7a5b40a6d8e2cac948c6f912d67687c45
Author: Shushi Hong <[email protected]>
AuthorDate: Sun Feb 23 20:04:03 2025 +0800

    [Relax][PyTorch] Add support for and_, lshift, min, or_, rshift, xor ops 
(#17668)
    
    * Update fx_translator.py
    
    * Update test_frontend_from_fx.py
    
    * Update test_frontend_from_fx.py
---
 python/tvm/relax/frontend/torch/fx_translator.py |   6 +
 tests/python/relax/test_frontend_from_fx.py      | 228 +++++++++++++++++++++++
 2 files changed, 234 insertions(+)

diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index d49cfa6893..dffe2b60eb 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -660,23 +660,29 @@ class TorchFXImporter(BaseFXGraphImporter):
             "triu": self._tril_triu(relax.op.triu),
             # binary
             "add": self._binary_op(relax.op.add, operator.add),
+            "and_": self._binary_op(relax.op.bitwise_and, operator.and_),
             "eq": self._binary_op(relax.op.equal, operator.eq),
             "floordiv": self._binary_op(relax.op.floor_divide, 
operator.floordiv),
             "ge": self._binary_op(relax.op.greater_equal, operator.ge),
             "gt": self._binary_op(relax.op.greater, operator.gt),
             "iadd": self._binary_op(relax.op.add, operator.add),
             "le": self._binary_op(relax.op.less_equal, operator.le),
+            "lshift": self._binary_op(relax.op.left_shift, operator.lshift),
             "lt": self._binary_op(relax.op.less, operator.lt),
             "matmul": self._binary_op(
                 partial(relax.op.linear_algebra.matmul, out_dtype="float32"), 
operator.matmul
             ),
             "max": self._binary_op(relax.op.maximum, max),
+            "min": self._binary_op(relax.op.minimum, min),
             "mod": self._binary_op(relax.op.mod, operator.mod),
             "mul": self._binary_op(relax.op.multiply, operator.mul),
             "ne": self._binary_op(relax.op.not_equal, operator.ne),
             "pow": self._binary_op(relax.op.power, operator.pow),
+            "or_": self._binary_op(relax.op.bitwise_or, operator.or_),
+            "rshift": self._binary_op(relax.op.right_shift, operator.rshift),
             "sub": self._binary_op(relax.op.subtract, operator.sub),
             "truediv": self._binary_op(relax.op.divide, operator.truediv),
+            "xor": self._binary_op(relax.op.bitwise_xor, operator.xor),
             # neural network
             "adaptive_avg_pool2d": self._adaptive_avg_pool2d,
             "addmm": self._addmm,
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 371343b60a..8b4ea5c8cc 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -1485,6 +1485,8 @@ def test_groupnorm():
 def test_binary():
     input_info1 = [([1, 3, 10, 10], "float32"), ([1, 3, 10, 10], "float32")]
     input_info2 = [([1, 3, 10, 10], "float32")]
+    input_info3 = [([1, 3, 10, 10], "int32"), ([1, 3, 10, 10], "int32")]
+    input_info4 = [([1, 3, 10, 10], "int32")]
 
     # Add
     class Add1(Module):
@@ -1962,6 +1964,211 @@ def test_binary():
     verify_model(Ne1(), input_info1, {}, expected23)
     verify_model(Ne2(), input_info2, {}, expected24)
 
+    # Lshift
+    class LShift1(Module):
+        def forward(self, lhs, rhs):
+            return lhs << rhs
+
+    @tvm.script.ir_module
+    class expected25:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
+            rhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
+        ) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="int32") = 
R.left_shift(lhs_1, rhs_1)
+                gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
+                R.output(gv)
+
+            return gv
+
+    class LShift2(Module):
+        def forward(self, lhs):
+            return lhs << 1
+
+    @tvm.script.ir_module
+    class expected26:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
+        ) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="int32") = 
R.left_shift(lhs_1, R.const(1))
+                gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
+                R.output(gv)
+
+            return gv
+
+    verify_model(LShift1(), input_info3, {}, expected25)
+    verify_model(LShift2(), input_info4, {}, expected26)
+
+    # Rshift
+    class RShift1(Module):
+        def forward(self, lhs, rhs):
+            return lhs >> rhs
+
+    @tvm.script.ir_module
+    class expected27:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
+            rhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
+        ) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="int32") = 
R.right_shift(lhs_1, rhs_1)
+                gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
+                R.output(gv)
+
+            return gv
+
+    class RShift2(Module):
+        def forward(self, lhs):
+            return lhs >> 1
+
+    @tvm.script.ir_module
+    class expected28:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
+        ) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="int32") = 
R.right_shift(lhs_1, R.const(1))
+                gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
+                R.output(gv)
+
+            return gv
+
+    verify_model(RShift1(), input_info3, {}, expected27)
+    verify_model(RShift2(), input_info4, {}, expected28)
+
+    # Bitwise and
+    class BitwiseAnd1(Module):
+        def forward(self, lhs, rhs):
+            return lhs & rhs
+
+    @tvm.script.ir_module
+    class expected29:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
+            rhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
+        ) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="int32") = 
R.bitwise_and(lhs_1, rhs_1)
+                gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
+                R.output(gv)
+
+            return gv
+
+    class BitwiseAnd2(Module):
+        def forward(self, lhs):
+            return lhs & 1
+
+    @tvm.script.ir_module
+    class expected30:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
+        ) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="int32") = 
R.bitwise_and(lhs_1, R.const(1))
+                gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
+                R.output(gv)
+
+            return gv
+
+    verify_model(BitwiseAnd1(), input_info3, {}, expected29)
+    verify_model(BitwiseAnd2(), input_info4, {}, expected30)
+
+    # Bitwise or
+    class BitwiseOr1(Module):
+        def forward(self, lhs, rhs):
+            return lhs | rhs
+
+    @tvm.script.ir_module
+    class expected31:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
+            rhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
+        ) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="int32") = 
R.bitwise_or(lhs_1, rhs_1)
+                gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
+                R.output(gv)
+
+            return gv
+
+    class BitwiseOr2(Module):
+        def forward(self, lhs):
+            return lhs | 1
+
+    @tvm.script.ir_module
+    class expected32:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
+        ) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="int32") = 
R.bitwise_or(lhs_1, R.const(1))
+                gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
+                R.output(gv)
+
+            return gv
+
+    verify_model(BitwiseOr1(), input_info3, {}, expected31)
+    verify_model(BitwiseOr2(), input_info4, {}, expected32)
+
+    # Bitwise xor
+    class BitwiseXor1(Module):
+        def forward(self, lhs, rhs):
+            return lhs ^ rhs
+
+    @tvm.script.ir_module
+    class expected33:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
+            rhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
+        ) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="int32") = 
R.bitwise_xor(lhs_1, rhs_1)
+                gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
+                R.output(gv)
+
+            return gv
+
+    class BitwiseXor2(Module):
+        def forward(self, lhs):
+            return lhs ^ 1
+
+    @tvm.script.ir_module
+    class expected34:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
+        ) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="int32") = 
R.bitwise_xor(lhs_1, R.const(1))
+                gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
+                R.output(gv)
+
+            return gv
+
+    verify_model(BitwiseXor1(), input_info3, {}, expected33)
+    verify_model(BitwiseXor2(), input_info4, {}, expected34)
+
 
 def test_size():
     input_info = [([1, 3, 10, 10], "float32")]
@@ -3745,6 +3952,27 @@ def test_max():
     verify_model(Max(), [([256, 256], "float32"), ([256, 256], "float32")], 
{}, Expected1)
 
 
+def test_min():
+    class Min(Module):
+        def forward(self, x, y):
+            return torch.min(x, y)
+
+    @I.ir_module
+    class Expected1:
+        @R.function
+        def main(
+            inp_0: R.Tensor((256, 256), dtype="float32"),
+            inp_1: R.Tensor((256, 256), dtype="float32"),
+        ) -> R.Tensor((256, 256), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((256, 256), dtype="float32") = R.minimum(inp_0, 
inp_1)
+                gv: R.Tensor((256, 256), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Min(), [([256, 256], "float32"), ([256, 256], "float32")], 
{}, Expected1)
+
+
 def test_attention():
     @I.ir_module
     class Expected1:

Reply via email to