This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 299ef81f0f [BugFix][Relax][Pytorch] Incorrect Handling of In-Place Ops
in FX-Based TVM Frontend (#17875)
299ef81f0f is described below
commit 299ef81f0f7e4bae93826141815dfc97c0ea0a42
Author: kavin-mcw <[email protected]>
AuthorDate: Wed Apr 23 09:15:34 2025 +0530
[BugFix][Relax][Pytorch] Incorrect Handling of In-Place Ops in FX-Based TVM
Frontend (#17875)
---
python/tvm/relax/frontend/torch/fx_translator.py | 50 +++++++++++++++++++++++-
1 file changed, 49 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index 548320bd85..c3bf8f0454 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -132,6 +132,54 @@ class TorchFXImporter(BaseFXGraphImporter):
return convert
+ ########## Binary Ops ##############
+
+ def _binary_op_inplace(self, relax_op: Callable, intrinsic_op: Callable)
-> Callable:
+ from torch import fx
+
+ def convert(node: fx.Node) -> relax.Var:
+ def promote_binary_op_args(lhs, rhs):
+ if isinstance(lhs, relax.Expr) and isinstance(rhs, relax.Expr):
+ return lhs, rhs
+ elif isinstance(lhs, relax.Expr):
+ assert isinstance(lhs.struct_info, relax.TensorStructInfo)
+ return lhs, relax.const(rhs, lhs.struct_info.dtype)
+ elif isinstance(rhs, relax.Expr):
+ assert isinstance(rhs.struct_info, relax.TensorStructInfo)
+ return relax.const(lhs, rhs.struct_info.dtype), rhs
+ else:
+ assert False
+
+ def call_binary_op(op, lhs, rhs):
+ lhs, rhs = promote_binary_op_args(lhs, rhs)
+ return self.block_builder.emit(op(lhs, rhs))
+
+ lhs, rhs = self.retrieve_args(node)
+ if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
+ output = call_binary_op(relax_op, lhs, rhs)
+ self.env[node.args[0]] = output
+ return output
+
+ elif isinstance(lhs, relax.expr.Constant):
+ output = call_binary_op(
+ relax_op, lhs, relax.const(rhs,
dtype=lhs.struct_info.dtype)
+ )
+ self.env[node.args[0]] = output
+ return output
+
+ elif isinstance(rhs, relax.expr.Constant):
+ output = call_binary_op(
+ relax_op, relax.const(lhs, dtype=rhs.struct_info.dtype),
rhs
+ )
+ self.env[node.args[0]] = output
+ return output
+
+ output = intrinsic_op(lhs, rhs)
+ self.env[node.args[0]] = output
+ return output
+
+ return convert
+
########## Neural Network ##########
def _adaptive_avg_pool2d_module(self, node: fx.Node) -> relax.Var:
@@ -679,7 +727,7 @@ class TorchFXImporter(BaseFXGraphImporter):
# binary
"add": self._binary_op(relax.op.add, operator.add),
"and_": self._binary_op(relax.op.bitwise_and, operator.and_),
- "bitwise_or_": self._binary_op(relax.op.bitwise_or, operator.or_),
+ "bitwise_or_": self._binary_op_inplace(relax.op.bitwise_or,
operator.or_),
"bitwise_or": self._binary_op(relax.op.bitwise_or, operator.or_),
"eq": self._binary_op(relax.op.equal, operator.eq),
"floordiv": self._binary_op(relax.op.floor_divide,
operator.floordiv),