This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 299ef81f0f [BugFix][Relax][Pytorch] Incorrect Handling of In-Place Ops 
in FX-Based TVM Frontend (#17875)
299ef81f0f is described below

commit 299ef81f0f7e4bae93826141815dfc97c0ea0a42
Author: kavin-mcw <[email protected]>
AuthorDate: Wed Apr 23 09:15:34 2025 +0530

    [BugFix][Relax][Pytorch] Incorrect Handling of In-Place Ops in FX-Based TVM 
Frontend (#17875)
---
 python/tvm/relax/frontend/torch/fx_translator.py | 50 +++++++++++++++++++++++-
 1 file changed, 49 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 548320bd85..c3bf8f0454 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -132,6 +132,54 @@ class TorchFXImporter(BaseFXGraphImporter):
 
         return convert
 
+    ########## Binary Ops ##############
+
+    def _binary_op_inplace(self, relax_op: Callable, intrinsic_op: Callable) 
-> Callable:
+        from torch import fx
+
+        def convert(node: fx.Node) -> relax.Var:
+            def promote_binary_op_args(lhs, rhs):
+                if isinstance(lhs, relax.Expr) and isinstance(rhs, relax.Expr):
+                    return lhs, rhs
+                elif isinstance(lhs, relax.Expr):
+                    assert isinstance(lhs.struct_info, relax.TensorStructInfo)
+                    return lhs, relax.const(rhs, lhs.struct_info.dtype)
+                elif isinstance(rhs, relax.Expr):
+                    assert isinstance(rhs.struct_info, relax.TensorStructInfo)
+                    return relax.const(lhs, rhs.struct_info.dtype), rhs
+                else:
+                    assert False
+
+            def call_binary_op(op, lhs, rhs):
+                lhs, rhs = promote_binary_op_args(lhs, rhs)
+                return self.block_builder.emit(op(lhs, rhs))
+
+            lhs, rhs = self.retrieve_args(node)
+            if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
+                output = call_binary_op(relax_op, lhs, rhs)
+                self.env[node.args[0]] = output
+                return output
+
+            elif isinstance(lhs, relax.expr.Constant):
+                output = call_binary_op(
+                    relax_op, lhs, relax.const(rhs, 
dtype=lhs.struct_info.dtype)
+                )
+                self.env[node.args[0]] = output
+                return output
+
+            elif isinstance(rhs, relax.expr.Constant):
+                output = call_binary_op(
+                    relax_op, relax.const(lhs, dtype=rhs.struct_info.dtype), 
rhs
+                )
+                self.env[node.args[0]] = output
+                return output
+
+            output = intrinsic_op(lhs, rhs)
+            self.env[node.args[0]] = output
+            return output
+
+        return convert
+
     ########## Neural Network ##########
 
     def _adaptive_avg_pool2d_module(self, node: fx.Node) -> relax.Var:
@@ -679,7 +727,7 @@ class TorchFXImporter(BaseFXGraphImporter):
             # binary
             "add": self._binary_op(relax.op.add, operator.add),
             "and_": self._binary_op(relax.op.bitwise_and, operator.and_),
-            "bitwise_or_": self._binary_op(relax.op.bitwise_or, operator.or_),
+            "bitwise_or_": self._binary_op_inplace(relax.op.bitwise_or, 
operator.or_),
             "bitwise_or": self._binary_op(relax.op.bitwise_or, operator.or_),
             "eq": self._binary_op(relax.op.equal, operator.eq),
             "floordiv": self._binary_op(relax.op.floor_divide, 
operator.floordiv),

Reply via email to