This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 8ad9bffb11 [Relax][PyTorch] Add support for lerp, select and clone ops
(#17760)
8ad9bffb11 is described below
commit 8ad9bffb1124c1553c1bc9c3e4ae5d8a84826afa
Author: Shushi Hong <[email protected]>
AuthorDate: Tue Mar 18 05:13:05 2025 +0800
[Relax][PyTorch] Add support for lerp, select and clone ops (#17760)
This PR supports Pytorch `lerp`, `select` and `clone` ops for Relax
---
python/tvm/relax/frontend/torch/fx_translator.py | 20 ++++++++
tests/python/relax/test_frontend_from_fx.py | 65 ++++++++++++++++++++++++
2 files changed, 85 insertions(+)
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index 952fb6f971..9b835b0eee 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -376,6 +376,16 @@ class TorchFXImporter(BaseFXGraphImporter):
return self._max_pool2d_impl(x, kernel_size, stride, padding,
dilation, ceil_mode)
+ ########## Linear Interpolation ##########
+
+ def _lerp(self, node: fx.Node) -> relax.Var:
+ start = self.env[node.args[0]]
+ end = self.env[node.args[1]]
+ weight = self.env[node.args[2]]
+ return self.block_builder.emit(
+ relax.op.add(start, relax.op.multiply(weight,
relax.op.subtract(end, start)))
+ )
+
########## Manipulation ##########
def _chunk(self, node: fx.Node) -> relax.Var:
@@ -414,6 +424,12 @@ class TorchFXImporter(BaseFXGraphImporter):
shape = self.shape_of(x)
return relax.const(reduce(lambda x, y: x * y, [s.value for s in
shape]), "int32")
+ def _select(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ dim = node.args[1]
+ index = relax.const(node.args[2], "int64")
+ return self.block_builder.emit(relax.op.take(x, index, dim))
+
def _size(self, node: fx.Node) -> relax.Expr:
x = self.env[node.args[0]]
shape = self.shape_of(x)
@@ -737,6 +753,8 @@ class TorchFXImporter(BaseFXGraphImporter):
"scaled_dot_product_attention": self._scaled_dot_product_attention,
"stochastic_depth": lambda node: self.env[node.args[0]],
"unbind": self._unbind,
+ # linear interpolation
+ "lerp": self._lerp,
# statistical
"mean": self._mean,
"sum": self._sum,
@@ -759,6 +777,7 @@ class TorchFXImporter(BaseFXGraphImporter):
"repeat": self._repeat,
"reshape": self._reshape,
"scatter": self._scatter,
+ "select": self._select,
"size": self._size,
"split": self._split,
"squeeze": self._squeeze,
@@ -772,6 +791,7 @@ class TorchFXImporter(BaseFXGraphImporter):
"view": self._reshape,
# tensor creation
"arange": self._arange,
+ "clone": lambda node: self.env[node.args[0]],
"empty": self._empty,
"empty_like": self._empty_like,
"fill_": self._inplace_fill,
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index fbea8b7388..f06ce7a753 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -4125,5 +4125,70 @@ def test_numel():
verify_model(Numel(), [([5, 3], "float32")], {}, Expected)
+def test_select():
+ class Select(Module):
+ def forward(self, data):
+ return torch.select(data, 0, 1)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ inp_0: R.Tensor((5, 3), dtype="float32"),
+ ) -> R.Tensor((3,), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((3,), dtype="float32") = R.take(inp_0, R.const(1,
"int64"), axis=0)
+ gv: R.Tensor((3,), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Select(), [([5, 3], "float32")], {}, Expected)
+
+
+def test_clone():
+ class Clone(Module):
+ def forward(self, x):
+ return x.clone()
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ inp_0: R.Tensor((5, 3), dtype="float32"),
+ ) -> R.Tensor((5, 3), dtype="float32"):
+ with R.dataflow():
+ gv: R.Tensor((5, 3), dtype="float32") = inp_0
+ R.output(gv)
+ return gv
+
+ verify_model(Clone(), [([5, 3], "float32")], {}, Expected)
+
+
+def test_lerp():
+ class Lerp(Module):
+ def forward(self, start, end, weight):
+ return torch.lerp(start, end, weight)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ inp_0: R.Tensor((5, 3), dtype="float32"),
+ inp_1: R.Tensor((5, 3), dtype="float32"),
+ inp_2: R.Tensor((5, 3), dtype="float32"),
+ ) -> R.Tensor((5, 3), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((5, 3), dtype="float32") = R.add(
+ inp_0, R.multiply(inp_2, R.subtract(inp_1, inp_0))
+ )
+ gv: R.Tensor((5, 3), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(
+ Lerp(), [([5, 3], "float32"), ([5, 3], "float32"), ([5, 3],
"float32")], {}, Expected
+ )
+
+
if __name__ == "__main__":
tvm.testing.main()