This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 2abff889af [Relax][Pytorch] Add masked_fill op support in
ExportedProgram (#17850)
2abff889af is described below
commit 2abff889af7dac9e5f4a85f55567a73df5f7f0b8
Author: kavin-mcw <[email protected]>
AuthorDate: Fri Apr 18 07:27:10 2025 +0530
[Relax][Pytorch] Add masked_fill op support in ExportedProgram (#17850)
* Add masked_fill support in exportedProgram
* Fix lint issues
---
.../frontend/torch/base_fx_graph_translator.py | 17 +++++++++++++++
.../frontend/torch/exported_program_translator.py | 1 +
python/tvm/relax/frontend/torch/fx_translator.py | 17 ---------------
.../relax/test_frontend_from_exported_program.py | 24 ++++++++++++++++++++++
4 files changed, 42 insertions(+), 17 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 21cbd14d7e..7b380f9876 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1349,6 +1349,23 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
index = self.env[node.args[2]]
return self.block_builder.emit(relax.op.take(x, index, dim))
+ def _inplace_masked_fill(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ mask = self.env[node.args[1]]
+ value = node.args[2]
+ rx_value = relax.const(value)
+ values = self.block_builder.emit(relax.op.full_like(x, rx_value))
+ output = self.block_builder.emit(relax.op.where(mask, values, x))
+ self.env[node.args[0]] = output
+ return output
+
+ def _masked_fill(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ mask = self.env[node.args[1]]
+ rx_value = relax.const(node.args[2])
+ values = self.block_builder.emit(relax.op.full_like(x, rx_value))
+ return self.block_builder.emit(relax.op.where(mask, values, x))
+
def _new_ones(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
self_var = args[0]
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 2c9e255f29..a6f9cafa65 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -448,6 +448,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"full_like.default": self._full_like,
"index_select.default": self._index_select,
"lift_fresh_copy.default": self._to_copy,
+ "masked_fill.Scalar": self._masked_fill,
"new_ones.default": self._new_ones,
"one_hot.default": self._one_hot,
"ones.default": self._ones,
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index a26185ce3c..534d398bea 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -476,23 +476,6 @@ class TorchFXImporter(BaseFXGraphImporter):
self.env[node.args[0]] = filled
return filled
- def _inplace_masked_fill(self, node: fx.Node) -> relax.Var:
- x = self.env[node.args[0]]
- mask = self.env[node.args[1]]
- value = node.args[2]
- rx_value = relax.const(value)
- values = self.block_builder.emit(relax.op.full_like(x, rx_value))
- output = self.block_builder.emit(relax.op.where(mask, values, x))
- self.env[node.args[0]] = output
- return output
-
- def _masked_fill(self, node: fx.Node) -> relax.Var:
- x = self.env[node.args[0]]
- mask = self.env[node.args[1]]
- rx_value = relax.const(node.args[2])
- values = self.block_builder.emit(relax.op.full_like(x, rx_value))
- return self.block_builder.emit(relax.op.where(mask, values, x))
-
def _masked_scatter(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
mask = self.env[node.args[1]]
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index e4694efa56..e78bd339d2 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -3260,6 +3260,30 @@ def test_fill():
verify_model(Fill(), example_args, {}, Expected)
+def test_masked_fill():
+ class Masked_Fill(Module):
+ def forward(self, input: torch.Tensor, mask: torch.Tensor):
+ return torch.masked_fill(input, mask, 0)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ input: R.Tensor((128, 128), dtype="float32"), mask: R.Tensor((128,
128), dtype="bool")
+ ) -> R.Tuple(R.Tensor((128, 128), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((128, 128), dtype="float32") = R.full_like(
+ input, R.const(0, "int32"), dtype="void"
+ )
+ lv1: R.Tensor((128, 128), dtype="float32") = R.where(mask, lv,
input)
+ gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv1,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(128, 128, dtype=torch.float32),
torch.rand(128, 128) < 0.5)
+ verify_model(Masked_Fill(), example_args, {}, Expected)
+
+
def test_new_ones():
class NewOnes(Module):
def forward(self, x):