This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 432ccfa7a5 [Relax][PyTorch] Add support for and_, lshift, min, or_,
rshift, xor ops (#17668)
432ccfa7a5 is described below
commit 432ccfa7a5b40a6d8e2cac948c6f912d67687c45
Author: Shushi Hong <[email protected]>
AuthorDate: Sun Feb 23 20:04:03 2025 +0800
[Relax][PyTorch] Add support for and_, lshift, min, or_, rshift, xor ops
(#17668)
* Update fx_translator.py
* Update test_frontend_from_fx.py
* Update test_frontend_from_fx.py
---
python/tvm/relax/frontend/torch/fx_translator.py | 6 +
tests/python/relax/test_frontend_from_fx.py | 228 +++++++++++++++++++++++
2 files changed, 234 insertions(+)
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index d49cfa6893..dffe2b60eb 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -660,23 +660,29 @@ class TorchFXImporter(BaseFXGraphImporter):
"triu": self._tril_triu(relax.op.triu),
# binary
"add": self._binary_op(relax.op.add, operator.add),
+ "and_": self._binary_op(relax.op.bitwise_and, operator.and_),
"eq": self._binary_op(relax.op.equal, operator.eq),
"floordiv": self._binary_op(relax.op.floor_divide,
operator.floordiv),
"ge": self._binary_op(relax.op.greater_equal, operator.ge),
"gt": self._binary_op(relax.op.greater, operator.gt),
"iadd": self._binary_op(relax.op.add, operator.add),
"le": self._binary_op(relax.op.less_equal, operator.le),
+ "lshift": self._binary_op(relax.op.left_shift, operator.lshift),
"lt": self._binary_op(relax.op.less, operator.lt),
"matmul": self._binary_op(
partial(relax.op.linear_algebra.matmul, out_dtype="float32"),
operator.matmul
),
"max": self._binary_op(relax.op.maximum, max),
+ "min": self._binary_op(relax.op.minimum, min),
"mod": self._binary_op(relax.op.mod, operator.mod),
"mul": self._binary_op(relax.op.multiply, operator.mul),
"ne": self._binary_op(relax.op.not_equal, operator.ne),
"pow": self._binary_op(relax.op.power, operator.pow),
+ "or_": self._binary_op(relax.op.bitwise_or, operator.or_),
+ "rshift": self._binary_op(relax.op.right_shift, operator.rshift),
"sub": self._binary_op(relax.op.subtract, operator.sub),
"truediv": self._binary_op(relax.op.divide, operator.truediv),
+ "xor": self._binary_op(relax.op.bitwise_xor, operator.xor),
# neural network
"adaptive_avg_pool2d": self._adaptive_avg_pool2d,
"addmm": self._addmm,
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 371343b60a..8b4ea5c8cc 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -1485,6 +1485,8 @@ def test_groupnorm():
def test_binary():
input_info1 = [([1, 3, 10, 10], "float32"), ([1, 3, 10, 10], "float32")]
input_info2 = [([1, 3, 10, 10], "float32")]
+ input_info3 = [([1, 3, 10, 10], "int32"), ([1, 3, 10, 10], "int32")]
+ input_info4 = [([1, 3, 10, 10], "int32")]
# Add
class Add1(Module):
@@ -1962,6 +1964,211 @@ def test_binary():
verify_model(Ne1(), input_info1, {}, expected23)
verify_model(Ne2(), input_info2, {}, expected24)
+ # Lshift
+ class LShift1(Module):
+ def forward(self, lhs, rhs):
+ return lhs << rhs
+
+ @tvm.script.ir_module
+ class expected25:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
+ rhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
+ ) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="int32") =
R.left_shift(lhs_1, rhs_1)
+ gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
+ R.output(gv)
+
+ return gv
+
+ class LShift2(Module):
+ def forward(self, lhs):
+ return lhs << 1
+
+ @tvm.script.ir_module
+ class expected26:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
+ ) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="int32") =
R.left_shift(lhs_1, R.const(1))
+ gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
+ R.output(gv)
+
+ return gv
+
+ verify_model(LShift1(), input_info3, {}, expected25)
+ verify_model(LShift2(), input_info4, {}, expected26)
+
+ # Rshift
+ class RShift1(Module):
+ def forward(self, lhs, rhs):
+ return lhs >> rhs
+
+ @tvm.script.ir_module
+ class expected27:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
+ rhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
+ ) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="int32") =
R.right_shift(lhs_1, rhs_1)
+ gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
+ R.output(gv)
+
+ return gv
+
+ class RShift2(Module):
+ def forward(self, lhs):
+ return lhs >> 1
+
+ @tvm.script.ir_module
+ class expected28:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
+ ) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="int32") =
R.right_shift(lhs_1, R.const(1))
+ gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
+ R.output(gv)
+
+ return gv
+
+ verify_model(RShift1(), input_info3, {}, expected27)
+ verify_model(RShift2(), input_info4, {}, expected28)
+
+ # Bitwise and
+ class BitwiseAnd1(Module):
+ def forward(self, lhs, rhs):
+ return lhs & rhs
+
+ @tvm.script.ir_module
+ class expected29:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
+ rhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
+ ) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="int32") =
R.bitwise_and(lhs_1, rhs_1)
+ gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
+ R.output(gv)
+
+ return gv
+
+ class BitwiseAnd2(Module):
+ def forward(self, lhs):
+ return lhs & 1
+
+ @tvm.script.ir_module
+ class expected30:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
+ ) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="int32") =
R.bitwise_and(lhs_1, R.const(1))
+ gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
+ R.output(gv)
+
+ return gv
+
+ verify_model(BitwiseAnd1(), input_info3, {}, expected29)
+ verify_model(BitwiseAnd2(), input_info4, {}, expected30)
+
+ # Bitwise or
+ class BitwiseOr1(Module):
+ def forward(self, lhs, rhs):
+ return lhs | rhs
+
+ @tvm.script.ir_module
+ class expected31:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
+ rhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
+ ) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="int32") =
R.bitwise_or(lhs_1, rhs_1)
+ gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
+ R.output(gv)
+
+ return gv
+
+ class BitwiseOr2(Module):
+ def forward(self, lhs):
+ return lhs | 1
+
+ @tvm.script.ir_module
+ class expected32:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
+ ) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="int32") =
R.bitwise_or(lhs_1, R.const(1))
+ gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
+ R.output(gv)
+
+ return gv
+
+ verify_model(BitwiseOr1(), input_info3, {}, expected31)
+ verify_model(BitwiseOr2(), input_info4, {}, expected32)
+
+ # Bitwise xor
+ class BitwiseXor1(Module):
+ def forward(self, lhs, rhs):
+ return lhs ^ rhs
+
+ @tvm.script.ir_module
+ class expected33:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
+ rhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
+ ) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="int32") =
R.bitwise_xor(lhs_1, rhs_1)
+ gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
+ R.output(gv)
+
+ return gv
+
+ class BitwiseXor2(Module):
+ def forward(self, lhs):
+ return lhs ^ 1
+
+ @tvm.script.ir_module
+ class expected34:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
+ ) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="int32") =
R.bitwise_xor(lhs_1, R.const(1))
+ gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
+ R.output(gv)
+
+ return gv
+
+ verify_model(BitwiseXor1(), input_info3, {}, expected33)
+ verify_model(BitwiseXor2(), input_info4, {}, expected34)
+
def test_size():
input_info = [([1, 3, 10, 10], "float32")]
@@ -3745,6 +3952,27 @@ def test_max():
verify_model(Max(), [([256, 256], "float32"), ([256, 256], "float32")],
{}, Expected1)
+def test_min():
+ class Min(Module):
+ def forward(self, x, y):
+ return torch.min(x, y)
+
+ @I.ir_module
+ class Expected1:
+ @R.function
+ def main(
+ inp_0: R.Tensor((256, 256), dtype="float32"),
+ inp_1: R.Tensor((256, 256), dtype="float32"),
+ ) -> R.Tensor((256, 256), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((256, 256), dtype="float32") = R.minimum(inp_0,
inp_1)
+ gv: R.Tensor((256, 256), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Min(), [([256, 256], "float32"), ([256, 256], "float32")],
{}, Expected1)
+
+
def test_attention():
@I.ir_module
class Expected1: