This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 57912395a8 [Relax][PyTorch] Decompose integer pow into repeated 
multiplication (#19660)
57912395a8 is described below

commit 57912395a8f99c4b12e28190b3d99c12f2638e63
Author: Javier De Jesus <[email protected]>
AuthorDate: Wed Jun 3 23:15:47 2026 +0200

    [Relax][PyTorch] Decompose integer pow into repeated multiplication (#19660)
    
    `torch.pow` on an integer tensor returns an integer result, but the
    PyTorch frontend lowered it to `relax.op.power`, which fails
    `LegalizeOps` with `power only applies to float` (TOPI `power` /
    `tvm::pow` requires a floating-point input).
    
    This decomposes an integer base raised to a constant non-negative
    integer exponent into repeated multiplication, so the result stays
    integral and matches PyTorch. Float bases and non-constant or tensor
    exponents keep using `relax.op.power` unchanged. The ONNX frontend
    already uses the same decomposition (`x**3 = x*x*x`).
    
    Added structural tests covering both the FX and ExportedProgram import
    paths.
    
    Fixes #19550
---
 .../frontend/torch/base_fx_graph_translator.py     | 22 ++++++++++++++++++++++
 .../frontend/torch/exported_program_translator.py  |  2 +-
 python/tvm/relax/frontend/torch/fx_translator.py   |  2 +-
 .../relax/test_frontend_from_exported_program.py   | 22 ++++++++++++++++++++++
 tests/python/relax/test_frontend_from_fx.py        | 22 ++++++++++++++++++++++
 5 files changed, 68 insertions(+), 2 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index a2ebed0480..581475ebd8 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -22,6 +22,7 @@
 
 import abc
 import math
+import operator
 from collections.abc import Callable
 from functools import reduce
 
@@ -523,6 +524,27 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
 
         return convert
 
+    def _pow(self, node: fx.Node) -> relax.Var:
+        lhs, rhs = self.retrieve_args(node)
+        # torch integer pow returns an integer tensor, but relax.op.power 
legalizes to
+        # TOPI power which requires floating-point inputs. Decompose an 
integer base with
+        # a constant non-negative integer exponent into repeated 
multiplication instead.
+        if (
+            isinstance(lhs, relax.Expr)
+            and isinstance(lhs.struct_info, relax.TensorStructInfo)
+            and "int" in lhs.struct_info.dtype
+            and isinstance(rhs, int)
+            and not isinstance(rhs, bool)
+            and rhs >= 0
+        ):
+            if rhs == 0:
+                return self.block_builder.emit(relax.op.ones_like(lhs))
+            result = lhs
+            for _ in range(rhs - 1):
+                result = self.block_builder.emit(relax.op.multiply(result, 
lhs))
+            return result
+        return self._binary_op(relax.op.power, operator.pow)(node)
+
     def _div(self, node: fx.Node) -> relax.Var:
         args = self.retrieve_args(node)
         inp_1 = args[0]
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 26f5a5918c..976c9d45b6 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -1645,7 +1645,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
                 relax.op.outer(self.env[node.args[0]], self.env[node.args[1]])
             ),
             "pow.Scalar": self._binary_op(relax.op.power, operator.pow),
-            "pow.Tensor_Scalar": self._binary_op(relax.op.power, operator.pow),
+            "pow.Tensor_Scalar": self._pow,
             "pow.Tensor_Tensor": self._binary_op(relax.op.power, operator.pow),
             "sub.Tensor": self._binary_op(relax.op.subtract, operator.sub),
             "sub.Scalar": self._binary_op(relax.op.subtract, operator.sub),
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 9d27f62b42..867407193a 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -929,7 +929,7 @@ class TorchFXImporter(BaseFXGraphImporter):
             "outer": lambda node: self.block_builder.emit(
                 relax.op.outer(self.env[node.args[0]], self.env[node.args[1]])
             ),
-            "pow": self._binary_op(relax.op.power, operator.pow),
+            "pow": self._pow,
             "or_": self._binary_op(relax.op.bitwise_or, operator.or_),
             "rshift": self._binary_op(relax.op.right_shift, operator.rshift),
             "rsub": self._rsub,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index d1bdad7578..86471d8924 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -1085,6 +1085,28 @@ def test_logical_not():
     verify_model(LogicalNot(), example_args, {}, expected)
 
 
+def test_pow_integer():
+    class Pow(Module):
+        def forward(self, input):
+            return input.pow(4)
+
+    @tvm.script.ir_module
+    class expected:
+        @R.function
+        def main(input: R.Tensor((4,), dtype="int64")) -> 
R.Tuple(R.Tensor((4,), dtype="int64")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((4,), dtype="int64") = R.multiply(input, input)
+                lv1: R.Tensor((4,), dtype="int64") = R.multiply(lv, input)
+                lv2: R.Tensor((4,), dtype="int64") = R.multiply(lv1, input)
+                gv: R.Tuple(R.Tensor((4,), dtype="int64")) = (lv2,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.tensor([-1, 1, 2, 3], dtype=torch.int64),)
+    verify_model(Pow(), example_args, {}, expected)
+
+
 def test_logsoftmax():
     class LogSoftmax(Module):
         def __init__(self):
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 1bf71fb6eb..abfb18cf41 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -3527,6 +3527,28 @@ def test_extended_unary_ops():
     verify_model(Trunc(), input_info, {}, expected_trunc)
 
 
+def test_pow_integer():
+    input_info = [([4], "int64")]
+
+    class Pow(Module):
+        def forward(self, input):
+            return input.pow(4)
+
+    @tvm.script.ir_module
+    class expected:
+        @R.function
+        def main(inp_0: R.Tensor((4,), dtype="int64")) -> R.Tensor((4,), 
dtype="int64"):
+            with R.dataflow():
+                lv: R.Tensor((4,), dtype="int64") = R.multiply(inp_0, inp_0)
+                lv1: R.Tensor((4,), dtype="int64") = R.multiply(lv, inp_0)
+                lv2: R.Tensor((4,), dtype="int64") = R.multiply(lv1, inp_0)
+                gv: R.Tensor((4,), dtype="int64") = lv2
+                R.output(gv)
+            return gv
+
+    verify_model(Pow(), input_info, {}, expected)
+
+
 def test_interpolate():
     input_info = [([1, 3, 10, 10], "float32")]
 

Reply via email to