This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 1bb7833757 [Relax][PyTorch] Add Logaddexp op support for exported 
program  (#17803)
1bb7833757 is described below

commit 1bb7833757b82624cf9deff1ad6790c9354d1745
Author: AishwaryaElango <[email protected]>
AuthorDate: Thu Apr 17 12:53:16 2025 +0530

    [Relax][PyTorch] Add Logaddexp op support for exported program  (#17803)
    
    * Add support for logaddexp core operator
    
    * Add test script for logaddexp
    
    * Add fix for lint issues
    
    * Adjust trailing spaces
    
    * Adjust leading whitespace
    
    * Add fix for lint inssues
    
    * Add fix for logaddexp test script
    
    * Fix lint issues
    
    * decomposition at op level
    
    * unity check
    
    ---------
    
    Co-authored-by: Pratheesh <[email protected]>
---
 include/tvm/tir/op.h                               | 11 +++++++++
 include/tvm/topi/broadcast.h                       | 16 +++++++++++++
 .../frontend/torch/exported_program_translator.py  |  1 +
 python/tvm/relax/op/__init__.py                    |  1 +
 python/tvm/relax/op/binary.py                      | 19 ++++++++++++++++
 python/tvm/relax/transform/legalize_ops/binary.py  |  1 +
 python/tvm/script/ir_builder/relax/ir.py           |  2 ++
 python/tvm/te/__init__.py                          |  2 +-
 python/tvm/tir/__init__.py                         |  2 +-
 python/tvm/tir/op.py                               | 22 ++++++++++++++++++
 python/tvm/topi/broadcast.py                       | 19 ++++++++++++++++
 src/relax/op/tensor/binary.cc                      |  1 +
 src/relax/op/tensor/binary.h                       |  3 +++
 src/tir/op/op.cc                                   | 10 +++++++++
 src/topi/broadcast.cc                              |  1 +
 .../relax/test_frontend_from_exported_program.py   | 26 ++++++++++++++++++++++
 16 files changed, 135 insertions(+), 2 deletions(-)

diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h
index cfbd445295..ce7a425c94 100644
--- a/include/tvm/tir/op.h
+++ b/include/tvm/tir/op.h
@@ -394,6 +394,15 @@ TVM_DLL PrimExpr indexmod(PrimExpr a, PrimExpr b, Span 
span = Span());
  *       index types(int32, int64) when possible.
  */
 TVM_DLL PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span = Span());
+/*!
+ * \brief Compute log(exp(a) + exp(b)).
+ *
+ * \param a Left operand.
+ * \param b Right operand.
+ * \param span The location of this operation in the source.
+ * \return The result expression.
+ */
+TVM_DLL PrimExpr logaddexp(PrimExpr a, PrimExpr b, Span span = Span());
 /*!
  * \brief compute ceil(a / b)
  *
@@ -404,6 +413,7 @@ TVM_DLL PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span 
= Span());
  * \note this function does eager constant folding for
  *       index types(int32, int64) when possible.
  */
+
 TVM_DLL PrimExpr ceildiv(PrimExpr a, PrimExpr b, Span span = Span());
 /*!
  * \brief compute the remainder of floordiv
@@ -1071,6 +1081,7 @@ TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(indexmod);
 TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(truncdiv);
 TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(truncmod);
 TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(floordiv);
+TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(logaddexp);
 TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(floormod);
 TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(right_shift);  // NOLINT(*)
 TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(left_shift);   // NOLINT(*)
diff --git a/include/tvm/topi/broadcast.h b/include/tvm/topi/broadcast.h
index d27b6f1a3c..9be7256b44 100644
--- a/include/tvm/topi/broadcast.h
+++ b/include/tvm/topi/broadcast.h
@@ -257,6 +257,22 @@ TOPI_DEFINE_BCAST_OP(floor_divide, {
   }
 });
 
+/*!
+ * \fn log_add_exp
+ * \brief Compute log(exp(A) + exp(B)) with auto-broadcasting.
+ *
+ * This operation is useful for numerically stable log-sum-exp computations,
+ * which frequently appear in probabilistic and statistical models.
+ *
+ * \param A The first input tensor, or Expr.
+ * \param B The second input tensor, or Expr.
+ * \param name The name of the operation.
+ * \param tag The tag to mark the operation.
+ *
+ * \return The computed log-sum-exp result.
+ */
+TOPI_DEFINE_BCAST_OP(log_add_exp, { return logaddexp(a, b); });
+
 /*!
  * \fn trunc divide
  * \brief Compute trunc(A / B) with auto-broadcasting.
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 8f6418891b..c82a5e2b11 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -327,6 +327,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "eq.Scalar": self._binary_op(relax.op.equal, operator.eq),
             "eq.Tensor": self._binary_op(relax.op.equal, operator.eq),
             "floor_divide.default": self._binary_op(relax.op.floor_divide, 
operator.floordiv),
+            "logaddexp.default": self._binary_op(relax.op.log_add_exp, 
torch.logaddexp),
             "ge.Scalar": self._binary_op(relax.op.greater_equal, operator.ge),
             "ge.Tensor": self._binary_op(relax.op.greater_equal, operator.ge),
             "gt.Scalar": self._binary_op(relax.op.greater, operator.gt),
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index 97f18a2396..ddfdfc2b05 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -50,6 +50,7 @@ from .binary import (
     divide,
     equal,
     floor_divide,
+    log_add_exp,
     floor_mod,
     greater,
     greater_equal,
diff --git a/python/tvm/relax/op/binary.py b/python/tvm/relax/op/binary.py
index 7a41c8b095..d18aac8635 100644
--- a/python/tvm/relax/op/binary.py
+++ b/python/tvm/relax/op/binary.py
@@ -85,6 +85,25 @@ def floor_divide(x1: Expr, x2: Expr) -> Expr:
     return _ffi_api.floor_divide(x1, x2)  # type: ignore
 
 
+def log_add_exp(x1: Expr, x2: Expr) -> Expr:
+    """
+    Compute the log of the sum of exponentials of the inputs, element-wise.
+
+    Parameters
+    ----------
+    x1 : Expr
+        The first input tensor.
+    x2 : Expr
+        The second input tensor.
+
+    Returns
+    -------
+    Expr
+        The element-wise log-sum-exp of `x1` and `x2`.
+    """
+    return _ffi_api.log_add_exp(x1, x2)
+
+
 def multiply(x1: Expr, x2: Expr) -> Expr:
     """Multiplication with numpy-style broadcasting.
 
diff --git a/python/tvm/relax/transform/legalize_ops/binary.py 
b/python/tvm/relax/transform/legalize_ops/binary.py
index 41e317f1e0..1acbddb219 100644
--- a/python/tvm/relax/transform/legalize_ops/binary.py
+++ b/python/tvm/relax/transform/legalize_ops/binary.py
@@ -44,6 +44,7 @@ def _binary(te_func: TEFunc) -> LegalizeFunc:
 register_legalize("relax.add", _binary(topi.add))
 register_legalize("relax.divide", _binary(topi.divide))
 register_legalize("relax.floor_divide", _binary(topi.floor_divide))
+register_legalize("relax.log_add_exp", _binary(topi.log_add_exp))
 register_legalize("relax.multiply", _binary(topi.multiply))
 register_legalize("relax.power", _binary(topi.power))
 register_legalize("relax.subtract", _binary(topi.subtract))
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index ddc534cf60..6fa3cc61cb 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -112,6 +112,7 @@ from tvm.relax.op import (
     less_equal,
     linear,
     log,
+    log_add_exp,
     logical_and,
     logical_not,
     logical_or,
@@ -794,6 +795,7 @@ __all__ = [
     "less_equal",
     "linear",
     "log",
+    "log_add_exp",
     "logical_and",
     "logical_not",
     "logical_or",
diff --git a/python/tvm/te/__init__.py b/python/tvm/te/__init__.py
index b31853bea6..362419bebf 100644
--- a/python/tvm/te/__init__.py
+++ b/python/tvm/te/__init__.py
@@ -24,7 +24,7 @@ from tvm.tir import sinh, cosh, log2, log10
 from tvm.tir import asin, asinh, acos, acosh, atan, atanh
 from tvm.tir import trunc, abs, round, nearbyint, power, popcount, fmod, 
if_then_else
 from tvm.tir import isnan, isfinite, isinf
-from tvm.tir import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, 
floormod
+from tvm.tir import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, 
floormod, logaddexp
 from tvm.tir import comm_reducer, min, max, sum
 from tvm.tir import add, subtract, multiply
 
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index 4f56ec3c15..5ceb481270 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -90,7 +90,7 @@ from .op import bitwise_and, bitwise_not, bitwise_or, 
bitwise_xor
 from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil, hypot
 from .op import trunc, abs, round, nextafter, nearbyint, power, pow, popcount, 
fmod, if_then_else
 from .op import likely, isnan, isnullptr, isfinite, isinf, copysign
-from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, 
floormod, ceildiv
+from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, 
floormod, ceildiv, logaddexp
 from .op import comm_reducer, min, max, sum
 from .op import q_multiply_shift, q_multiply_shift_per_axis, shift_left, 
shift_right
 from .op import TVMBackendAllocWorkspace, TVMBackendFreeWorkspace
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 53c92fff86..3770a8be5f 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -3221,6 +3221,28 @@ def floordiv(a, b, span=None):
     return _ffi_api._OpFloorDiv(a, b, span)  # type: ignore
 
 
+def logaddexp(a, b, span=None):
+    """Compute the logaddexp of two expressions.
+
+    Parameters
+    ----------
+    a : PrimExpr
+        The left hand operand
+
+    b : PrimExpr
+        The right hand operand
+
+    span : Optional[Span]
+        The location of this operator in the source.
+
+    Returns
+    -------
+    res : PrimExpr
+        The result expression.
+    """
+    return _ffi_api._OpLogAddExp(a, b, span)  # type: ignore
+
+
 def floormod(a, b, span=None):
     """Compute the floormod of two expressions.
 
diff --git a/python/tvm/topi/broadcast.py b/python/tvm/topi/broadcast.py
index 2b350ff817..e2982ecfc2 100644
--- a/python/tvm/topi/broadcast.py
+++ b/python/tvm/topi/broadcast.py
@@ -135,6 +135,25 @@ def floor_divide(lhs, rhs):
     return _cpp.floor_divide(lhs, rhs)
 
 
+def log_add_exp(lhs, rhs):
+    """Log-sum-exp operation with auto-broadcasting.
+
+    Parameters
+    ----------
+    x1 : tvm.te.Tensor or Expr
+        The first input tensor or expression.
+    x2 : tvm.te.Tensor or Expr
+        The second input tensor or expression.
+
+    Returns
+    -------
+    ret : tvm.te.Tensor or Expr
+        Returns an Expr if both operands are Expr.
+        Otherwise, returns a Tensor.
+    """
+    return _cpp.log_add_exp(lhs, rhs)
+
+
 def mod(lhs, rhs):
     """Modulus with auto-broadcasting
 
diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc
index 4a63993d50..e7fab8f166 100644
--- a/src/relax/op/tensor/binary.cc
+++ b/src/relax/op/tensor/binary.cc
@@ -193,6 +193,7 @@ InferLayoutOutput InferLayoutBinaryEwise(const Call& call,
 RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(add);
 RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(divide);
 RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(floor_divide);
+RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(log_add_exp);
 RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(multiply);
 RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(power);
 RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(subtract);
diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h
index b66eb96f84..6b106f760d 100644
--- a/src/relax/op/tensor/binary.h
+++ b/src/relax/op/tensor/binary.h
@@ -70,6 +70,9 @@ Expr divide(Expr x1, Expr x2);
 /*! \brief Floor division with numpy-style broadcasting. */
 Expr floor_divide(Expr x1, Expr x2);
 
+/*! \brief Log Add Exponent with numpy-style broadcasting. */
+Expr log_add_exp(Expr x1, Expr x2);
+
 /*! \brief Multiplication with numpy-style broadcasting. */
 Expr multiply(Expr x1, Expr x2);
 
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 46c15cb3df..47aecf4809 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -507,6 +507,15 @@ PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span) {
   return tir::FloorDiv(a, b, span);
 }
 
+PrimExpr logaddexp(PrimExpr a, PrimExpr b, Span span) {
+  ICHECK(a.dtype().is_float()) << a;
+  ICHECK(b.dtype().is_float()) << b;
+  BinaryOpMatchTypes(a, b, span);
+  PrimExpr exp_sum = add(exp(a), exp(b));
+  PrimExpr log_exp_sum = log(exp_sum);
+  return log_exp_sum;
+}
+
 PrimExpr ceildiv(PrimExpr a, PrimExpr b, Span span) {
   ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a;
   ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b;
@@ -1134,6 +1143,7 @@ REGISTER_MAKE_BINARY_OP(_OpMod, truncmod);
 REGISTER_MAKE_BINARY_OP(_OpIndexDiv, indexdiv);
 REGISTER_MAKE_BINARY_OP(_OpIndexMod, indexmod);
 REGISTER_MAKE_BINARY_OP(_OpFloorDiv, floordiv);
+REGISTER_MAKE_BINARY_OP(_OpLogAddExp, logaddexp);
 REGISTER_MAKE_BINARY_OP(_OpFloorMod, floormod);
 REGISTER_MAKE_BINARY_OP(_OpTruncDiv, truncdiv);
 REGISTER_MAKE_BINARY_OP(_OpTruncMod, truncmod);
diff --git a/src/topi/broadcast.cc b/src/topi/broadcast.cc
index f6a28c7722..2105172aed 100644
--- a/src/topi/broadcast.cc
+++ b/src/topi/broadcast.cc
@@ -52,6 +52,7 @@ TOPI_REGISTER_BCAST_OP("topi.subtract", topi::subtract);
 TOPI_REGISTER_BCAST_OP("topi.multiply", topi::multiply);
 TOPI_REGISTER_BCAST_OP("topi.divide", topi::divide);
 TOPI_REGISTER_BCAST_OP("topi.floor_divide", topi::floor_divide);
+TOPI_REGISTER_BCAST_OP("topi.log_add_exp", topi::log_add_exp);
 TOPI_REGISTER_BCAST_OP("topi.mod", topi::mod);
 TOPI_REGISTER_BCAST_OP("topi.floor_mod", topi::floor_mod);
 TOPI_REGISTER_BCAST_OP("topi.maximum", topi::maximum);
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 284544be50..26d3d3f7bd 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -585,6 +585,32 @@ def test_leakyrelu():
     verify_model(LeakyReLU1(), example_args, {}, expected)
 
 
+def test_logaddexp():
+    class LogAddExp(Module):
+        def forward(self, input1, input2):
+            return torch.logaddexp(input1, input2)
+
+    @tvm.script.ir_module
+    class expected:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            input_2: R.Tensor((1, 3, 10, 10), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.log_add_exp(input_1, input_2)
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (
+        torch.randn(1, 3, 10, 10, dtype=torch.float32),
+        torch.randn(1, 3, 10, 10, dtype=torch.float32),
+    )
+    verify_model(LogAddExp(), example_args, {}, expected)
+
+
 def test_logsoftmax():
     class LogSoftmax(Module):
         def __init__(self):

Reply via email to