This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2abff889af [Relax][Pytorch] Add masked_fill op support in 
ExportedProgram (#17850)
2abff889af is described below

commit 2abff889af7dac9e5f4a85f55567a73df5f7f0b8
Author: kavin-mcw <[email protected]>
AuthorDate: Fri Apr 18 07:27:10 2025 +0530

    [Relax][Pytorch] Add masked_fill op support in ExportedProgram (#17850)
    
    * Add masked_fill support in exportedProgram
    
    * Fix lint issues
---
 .../frontend/torch/base_fx_graph_translator.py     | 17 +++++++++++++++
 .../frontend/torch/exported_program_translator.py  |  1 +
 python/tvm/relax/frontend/torch/fx_translator.py   | 17 ---------------
 .../relax/test_frontend_from_exported_program.py   | 24 ++++++++++++++++++++++
 4 files changed, 42 insertions(+), 17 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 21cbd14d7e..7b380f9876 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1349,6 +1349,23 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         index = self.env[node.args[2]]
         return self.block_builder.emit(relax.op.take(x, index, dim))
 
+    def _inplace_masked_fill(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        mask = self.env[node.args[1]]
+        value = node.args[2]
+        rx_value = relax.const(value)
+        values = self.block_builder.emit(relax.op.full_like(x, rx_value))
+        output = self.block_builder.emit(relax.op.where(mask, values, x))
+        self.env[node.args[0]] = output
+        return output
+
+    def _masked_fill(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        mask = self.env[node.args[1]]
+        rx_value = relax.const(node.args[2])
+        values = self.block_builder.emit(relax.op.full_like(x, rx_value))
+        return self.block_builder.emit(relax.op.where(mask, values, x))
+
     def _new_ones(self, node: fx.Node) -> relax.Var:
         args = self.retrieve_args(node)
         self_var = args[0]
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 2c9e255f29..a6f9cafa65 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -448,6 +448,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "full_like.default": self._full_like,
             "index_select.default": self._index_select,
             "lift_fresh_copy.default": self._to_copy,
+            "masked_fill.Scalar": self._masked_fill,
             "new_ones.default": self._new_ones,
             "one_hot.default": self._one_hot,
             "ones.default": self._ones,
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index a26185ce3c..534d398bea 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -476,23 +476,6 @@ class TorchFXImporter(BaseFXGraphImporter):
         self.env[node.args[0]] = filled
         return filled
 
-    def _inplace_masked_fill(self, node: fx.Node) -> relax.Var:
-        x = self.env[node.args[0]]
-        mask = self.env[node.args[1]]
-        value = node.args[2]
-        rx_value = relax.const(value)
-        values = self.block_builder.emit(relax.op.full_like(x, rx_value))
-        output = self.block_builder.emit(relax.op.where(mask, values, x))
-        self.env[node.args[0]] = output
-        return output
-
-    def _masked_fill(self, node: fx.Node) -> relax.Var:
-        x = self.env[node.args[0]]
-        mask = self.env[node.args[1]]
-        rx_value = relax.const(node.args[2])
-        values = self.block_builder.emit(relax.op.full_like(x, rx_value))
-        return self.block_builder.emit(relax.op.where(mask, values, x))
-
     def _masked_scatter(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         mask = self.env[node.args[1]]
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index e4694efa56..e78bd339d2 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -3260,6 +3260,30 @@ def test_fill():
     verify_model(Fill(), example_args, {}, Expected)
 
 
+def test_masked_fill():
+    class Masked_Fill(Module):
+        def forward(self, input: torch.Tensor, mask: torch.Tensor):
+            return torch.masked_fill(input, mask, 0)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            input: R.Tensor((128, 128), dtype="float32"), mask: R.Tensor((128, 
128), dtype="bool")
+        ) -> R.Tuple(R.Tensor((128, 128), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((128, 128), dtype="float32") = R.full_like(
+                    input, R.const(0, "int32"), dtype="void"
+                )
+                lv1: R.Tensor((128, 128), dtype="float32") = R.where(mask, lv, 
input)
+                gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv1,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(128, 128, dtype=torch.float32), 
torch.rand(128, 128) < 0.5)
+    verify_model(Masked_Fill(), example_args, {}, Expected)
+
+
 def test_new_ones():
     class NewOnes(Module):
         def forward(self, x):

Reply via email to