This is an automated email from the ASF dual-hosted git repository.
yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3f16ec25ac [Relax][PyTorch] Add support for where, cumprod and
reciprocal ops (#17788)
3f16ec25ac is described below
commit 3f16ec25aca5f479df60258e790949bf43e783bc
Author: Shushi Hong <[email protected]>
AuthorDate: Tue Apr 1 02:08:24 2025 +0800
[Relax][PyTorch] Add support for where, cumprod and reciprocal ops (#17788)
* Update fx_translator.py
* Update base_fx_graph_translator.py
* Update test_frontend_from_fx.py
* Update test_frontend_from_fx.py
---
.../frontend/torch/base_fx_graph_translator.py | 17 ++++++
python/tvm/relax/frontend/torch/fx_translator.py | 7 +++
tests/python/relax/test_frontend_from_fx.py | 65 ++++++++++++++++++++++
3 files changed, 89 insertions(+)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 71554a8a5b..fe0ae412a2 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -949,6 +949,12 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
return convert
+ def _where(self, node: fx.Node) -> relax.Var:
+ condition = self.env[node.args[0]]
+ x = self.env[node.args[1]]
+ y = self.env[node.args[2]]
+ return self.block_builder.emit(relax.op.where(condition, x, y))
+
########## Manipulation ##########
def _cat(self, node: fx.Node) -> relax.Var:
@@ -967,6 +973,17 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
relax.op.split(x=x, indices_or_sections=n_sections, axis=dim)
)
+ def _cumprod(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+
+ dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim",
None)
+ if "dtype" in node.kwargs:
+ dtype = self._convert_data_type(str(node.kwargs["dtype"]),
self.env)
+ else:
+ dtype = None
+
+ return self.block_builder.emit(relax.op.cumprod(x, dim, dtype))
+
def _cumsum(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index 022a7bffea..c4008a9396 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -62,6 +62,10 @@ class TorchFXImporter(BaseFXGraphImporter):
########## Unary Ops ##########
+ def _reciprocal(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ return self.block_builder.emit(relax.op.divide(relax.const(1.0,
x.struct_info.dtype), x))
+
def _leakyrelu_module(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
@@ -708,6 +712,7 @@ class TorchFXImporter(BaseFXGraphImporter):
"logical_not": self._unary_op(relax.op.logical_not),
"log_softmax": self._log_softmax,
"neg": self._unary_op(relax.op.negative),
+ "reciprocal": self._reciprocal,
"relu": self._unary_op(relax.op.nn.relu),
"round": self._round,
"rsqrt": self._unary_op(relax.op.rsqrt),
@@ -784,11 +789,13 @@ class TorchFXImporter(BaseFXGraphImporter):
# search
"argmax": self._argmax_argmin(relax.op.argmax),
"argmin": self._argmax_argmin(relax.op.argmin),
+ "where": self._where,
# tensor manipulation
"cat": self._cat,
"chunk": self._chunk,
"concat": self._cat,
"contiguous": lambda node: self.env[node.args[0]],
+ "cumprod": self._cumprod,
"cumsum": self._cumsum,
"expand": self._expand,
"expand_as.default": self._expand_as,
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 726ff6f8e8..b8d7f0b14e 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -2339,6 +2339,27 @@ def test_extended_unary_ops():
verify_model(LogSoftmax(), input_info, {}, expected_log_softmax)
verify_model(LogSoftmax2(), input_info, {}, expected_log_softmax)
+ # reciprocal
+ class Reciprocal(Module):
+ def forward(self, input):
+ return torch.reciprocal(input)
+
+ @tvm.script.ir_module
+ class expected_reciprocal:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(
+ R.const(1.0, "float32"), input_1
+ )
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Reciprocal(), input_info, {}, expected_reciprocal)
+
# relu
class ReLU0(Module):
def __init__(self):
@@ -4315,5 +4336,49 @@ def test_prod():
verify_model(Prod(), [([5, 3], "float32")], {}, Expected)
+def test_cumprod():
+ class Cumprod(Module):
+ def forward(self, x):
+ return torch.cumprod(x, 0)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ inp_0: R.Tensor((5, 3), dtype="float32"),
+ ) -> R.Tensor((5, 3), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((5, 3), dtype="float32") = R.cumprod(inp_0,
axis=0, exclusive=False)
+ gv: R.Tensor((5, 3), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Cumprod(), [([5, 3], "float32")], {}, Expected)
+
+
+def test_where():
+ class Where(Module):
+ def forward(self, condition, x, y):
+ return torch.where(condition, x, y)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ inp_0: R.Tensor((5, 3), dtype="bool"),
+ inp_1: R.Tensor((5, 3), dtype="float32"),
+ inp_2: R.Tensor((5, 3), dtype="float32"),
+ ) -> R.Tensor((5, 3), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((5, 3), dtype="float32") = R.where(inp_0, inp_1,
inp_2)
+ gv: R.Tensor((5, 3), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(
+ Where(), [([5, 3], "bool"), ([5, 3], "float32"), ([5, 3], "float32")],
{}, Expected
+ )
+
+
if __name__ == "__main__":
tvm.testing.main()