This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 1bb7833757 [Relax][PyTorch] Add Logaddexp op support for exported
program (#17803)
1bb7833757 is described below
commit 1bb7833757b82624cf9deff1ad6790c9354d1745
Author: AishwaryaElango <[email protected]>
AuthorDate: Thu Apr 17 12:53:16 2025 +0530
[Relax][PyTorch] Add Logaddexp op support for exported program (#17803)
* Add support for logaddexp core operator
* Add test script for logaddexp
* Add fix for lint issues
* Adjust trailing spaces
* Adjust leading whitespace
* Add fix for lint inssues
* Add fix for logaddexp test script
* Fix lint issues
* decomposition at op level
* unity check
---------
Co-authored-by: Pratheesh <[email protected]>
---
include/tvm/tir/op.h | 11 +++++++++
include/tvm/topi/broadcast.h | 16 +++++++++++++
.../frontend/torch/exported_program_translator.py | 1 +
python/tvm/relax/op/__init__.py | 1 +
python/tvm/relax/op/binary.py | 19 ++++++++++++++++
python/tvm/relax/transform/legalize_ops/binary.py | 1 +
python/tvm/script/ir_builder/relax/ir.py | 2 ++
python/tvm/te/__init__.py | 2 +-
python/tvm/tir/__init__.py | 2 +-
python/tvm/tir/op.py | 22 ++++++++++++++++++
python/tvm/topi/broadcast.py | 19 ++++++++++++++++
src/relax/op/tensor/binary.cc | 1 +
src/relax/op/tensor/binary.h | 3 +++
src/tir/op/op.cc | 10 +++++++++
src/topi/broadcast.cc | 1 +
.../relax/test_frontend_from_exported_program.py | 26 ++++++++++++++++++++++
16 files changed, 135 insertions(+), 2 deletions(-)
diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h
index cfbd445295..ce7a425c94 100644
--- a/include/tvm/tir/op.h
+++ b/include/tvm/tir/op.h
@@ -394,6 +394,15 @@ TVM_DLL PrimExpr indexmod(PrimExpr a, PrimExpr b, Span
span = Span());
* index types(int32, int64) when possible.
*/
TVM_DLL PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span = Span());
+/*!
+ * \brief Compute log(exp(a) + exp(b)).
+ *
+ * \param a Left operand.
+ * \param b Right operand.
+ * \param span The location of this operation in the source.
+ * \return The result expression.
+ */
+TVM_DLL PrimExpr logaddexp(PrimExpr a, PrimExpr b, Span span = Span());
/*!
* \brief compute ceil(a / b)
*
@@ -404,6 +413,7 @@ TVM_DLL PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span
= Span());
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
+
TVM_DLL PrimExpr ceildiv(PrimExpr a, PrimExpr b, Span span = Span());
/*!
* \brief compute the remainder of floordiv
@@ -1071,6 +1081,7 @@ TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(indexmod);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(truncdiv);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(truncmod);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(floordiv);
+TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(logaddexp);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(floormod);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(right_shift); // NOLINT(*)
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(left_shift); // NOLINT(*)
diff --git a/include/tvm/topi/broadcast.h b/include/tvm/topi/broadcast.h
index d27b6f1a3c..9be7256b44 100644
--- a/include/tvm/topi/broadcast.h
+++ b/include/tvm/topi/broadcast.h
@@ -257,6 +257,22 @@ TOPI_DEFINE_BCAST_OP(floor_divide, {
}
});
+/*!
+ * \fn log_add_exp
+ * \brief Compute log(exp(A) + exp(B)) with auto-broadcasting.
+ *
+ * This operation is useful for numerically stable log-sum-exp computations,
+ * which frequently appear in probabilistic and statistical models.
+ *
+ * \param A The first input tensor, or Expr.
+ * \param B The second input tensor, or Expr.
+ * \param name The name of the operation.
+ * \param tag The tag to mark the operation.
+ *
+ * \return The computed log-sum-exp result.
+ */
+TOPI_DEFINE_BCAST_OP(log_add_exp, { return logaddexp(a, b); });
+
/*!
* \fn trunc divide
* \brief Compute trunc(A / B) with auto-broadcasting.
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 8f6418891b..c82a5e2b11 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -327,6 +327,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"eq.Scalar": self._binary_op(relax.op.equal, operator.eq),
"eq.Tensor": self._binary_op(relax.op.equal, operator.eq),
"floor_divide.default": self._binary_op(relax.op.floor_divide,
operator.floordiv),
+ "logaddexp.default": self._binary_op(relax.op.log_add_exp,
torch.logaddexp),
"ge.Scalar": self._binary_op(relax.op.greater_equal, operator.ge),
"ge.Tensor": self._binary_op(relax.op.greater_equal, operator.ge),
"gt.Scalar": self._binary_op(relax.op.greater, operator.gt),
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index 97f18a2396..ddfdfc2b05 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -50,6 +50,7 @@ from .binary import (
divide,
equal,
floor_divide,
+ log_add_exp,
floor_mod,
greater,
greater_equal,
diff --git a/python/tvm/relax/op/binary.py b/python/tvm/relax/op/binary.py
index 7a41c8b095..d18aac8635 100644
--- a/python/tvm/relax/op/binary.py
+++ b/python/tvm/relax/op/binary.py
@@ -85,6 +85,25 @@ def floor_divide(x1: Expr, x2: Expr) -> Expr:
return _ffi_api.floor_divide(x1, x2) # type: ignore
+def log_add_exp(x1: Expr, x2: Expr) -> Expr:
+ """
+ Compute the log of the sum of exponentials of the inputs, element-wise.
+
+ Parameters
+ ----------
+ x1 : Expr
+ The first input tensor.
+ x2 : Expr
+ The second input tensor.
+
+ Returns
+ -------
+ Expr
+ The element-wise log-sum-exp of `x1` and `x2`.
+ """
+ return _ffi_api.log_add_exp(x1, x2)
+
+
def multiply(x1: Expr, x2: Expr) -> Expr:
"""Multiplication with numpy-style broadcasting.
diff --git a/python/tvm/relax/transform/legalize_ops/binary.py
b/python/tvm/relax/transform/legalize_ops/binary.py
index 41e317f1e0..1acbddb219 100644
--- a/python/tvm/relax/transform/legalize_ops/binary.py
+++ b/python/tvm/relax/transform/legalize_ops/binary.py
@@ -44,6 +44,7 @@ def _binary(te_func: TEFunc) -> LegalizeFunc:
register_legalize("relax.add", _binary(topi.add))
register_legalize("relax.divide", _binary(topi.divide))
register_legalize("relax.floor_divide", _binary(topi.floor_divide))
+register_legalize("relax.log_add_exp", _binary(topi.log_add_exp))
register_legalize("relax.multiply", _binary(topi.multiply))
register_legalize("relax.power", _binary(topi.power))
register_legalize("relax.subtract", _binary(topi.subtract))
diff --git a/python/tvm/script/ir_builder/relax/ir.py
b/python/tvm/script/ir_builder/relax/ir.py
index ddc534cf60..6fa3cc61cb 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -112,6 +112,7 @@ from tvm.relax.op import (
less_equal,
linear,
log,
+ log_add_exp,
logical_and,
logical_not,
logical_or,
@@ -794,6 +795,7 @@ __all__ = [
"less_equal",
"linear",
"log",
+ "log_add_exp",
"logical_and",
"logical_not",
"logical_or",
diff --git a/python/tvm/te/__init__.py b/python/tvm/te/__init__.py
index b31853bea6..362419bebf 100644
--- a/python/tvm/te/__init__.py
+++ b/python/tvm/te/__init__.py
@@ -24,7 +24,7 @@ from tvm.tir import sinh, cosh, log2, log10
from tvm.tir import asin, asinh, acos, acosh, atan, atanh
from tvm.tir import trunc, abs, round, nearbyint, power, popcount, fmod,
if_then_else
from tvm.tir import isnan, isfinite, isinf
-from tvm.tir import div, indexdiv, indexmod, truncdiv, truncmod, floordiv,
floormod
+from tvm.tir import div, indexdiv, indexmod, truncdiv, truncmod, floordiv,
floormod, logaddexp
from tvm.tir import comm_reducer, min, max, sum
from tvm.tir import add, subtract, multiply
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index 4f56ec3c15..5ceb481270 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -90,7 +90,7 @@ from .op import bitwise_and, bitwise_not, bitwise_or,
bitwise_xor
from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil, hypot
from .op import trunc, abs, round, nextafter, nearbyint, power, pow, popcount,
fmod, if_then_else
from .op import likely, isnan, isnullptr, isfinite, isinf, copysign
-from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv,
floormod, ceildiv
+from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv,
floormod, ceildiv, logaddexp
from .op import comm_reducer, min, max, sum
from .op import q_multiply_shift, q_multiply_shift_per_axis, shift_left,
shift_right
from .op import TVMBackendAllocWorkspace, TVMBackendFreeWorkspace
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 53c92fff86..3770a8be5f 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -3221,6 +3221,28 @@ def floordiv(a, b, span=None):
return _ffi_api._OpFloorDiv(a, b, span) # type: ignore
+def logaddexp(a, b, span=None):
+ """Compute the logaddexp of two expressions.
+
+ Parameters
+ ----------
+ a : PrimExpr
+ The left hand operand
+
+ b : PrimExpr
+ The right hand operand
+
+ span : Optional[Span]
+ The location of this operator in the source.
+
+ Returns
+ -------
+ res : PrimExpr
+ The result expression.
+ """
+ return _ffi_api._OpLogAddExp(a, b, span) # type: ignore
+
+
def floormod(a, b, span=None):
"""Compute the floormod of two expressions.
diff --git a/python/tvm/topi/broadcast.py b/python/tvm/topi/broadcast.py
index 2b350ff817..e2982ecfc2 100644
--- a/python/tvm/topi/broadcast.py
+++ b/python/tvm/topi/broadcast.py
@@ -135,6 +135,25 @@ def floor_divide(lhs, rhs):
return _cpp.floor_divide(lhs, rhs)
+def log_add_exp(lhs, rhs):
+ """Log-sum-exp operation with auto-broadcasting.
+
+ Parameters
+ ----------
+ x1 : tvm.te.Tensor or Expr
+ The first input tensor or expression.
+ x2 : tvm.te.Tensor or Expr
+ The second input tensor or expression.
+
+ Returns
+ -------
+ ret : tvm.te.Tensor or Expr
+ Returns an Expr if both operands are Expr.
+ Otherwise, returns a Tensor.
+ """
+ return _cpp.log_add_exp(lhs, rhs)
+
+
def mod(lhs, rhs):
"""Modulus with auto-broadcasting
diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc
index 4a63993d50..e7fab8f166 100644
--- a/src/relax/op/tensor/binary.cc
+++ b/src/relax/op/tensor/binary.cc
@@ -193,6 +193,7 @@ InferLayoutOutput InferLayoutBinaryEwise(const Call& call,
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(add);
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(divide);
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(floor_divide);
+RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(log_add_exp);
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(multiply);
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(power);
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(subtract);
diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h
index b66eb96f84..6b106f760d 100644
--- a/src/relax/op/tensor/binary.h
+++ b/src/relax/op/tensor/binary.h
@@ -70,6 +70,9 @@ Expr divide(Expr x1, Expr x2);
/*! \brief Floor division with numpy-style broadcasting. */
Expr floor_divide(Expr x1, Expr x2);
+/*! \brief Log Add Exponent with numpy-style broadcasting. */
+Expr log_add_exp(Expr x1, Expr x2);
+
/*! \brief Multiplication with numpy-style broadcasting. */
Expr multiply(Expr x1, Expr x2);
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 46c15cb3df..47aecf4809 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -507,6 +507,15 @@ PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span) {
return tir::FloorDiv(a, b, span);
}
+PrimExpr logaddexp(PrimExpr a, PrimExpr b, Span span) {
+ ICHECK(a.dtype().is_float()) << a;
+ ICHECK(b.dtype().is_float()) << b;
+ BinaryOpMatchTypes(a, b, span);
+ PrimExpr exp_sum = add(exp(a), exp(b));
+ PrimExpr log_exp_sum = log(exp_sum);
+ return log_exp_sum;
+}
+
PrimExpr ceildiv(PrimExpr a, PrimExpr b, Span span) {
ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a;
ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b;
@@ -1134,6 +1143,7 @@ REGISTER_MAKE_BINARY_OP(_OpMod, truncmod);
REGISTER_MAKE_BINARY_OP(_OpIndexDiv, indexdiv);
REGISTER_MAKE_BINARY_OP(_OpIndexMod, indexmod);
REGISTER_MAKE_BINARY_OP(_OpFloorDiv, floordiv);
+REGISTER_MAKE_BINARY_OP(_OpLogAddExp, logaddexp);
REGISTER_MAKE_BINARY_OP(_OpFloorMod, floormod);
REGISTER_MAKE_BINARY_OP(_OpTruncDiv, truncdiv);
REGISTER_MAKE_BINARY_OP(_OpTruncMod, truncmod);
diff --git a/src/topi/broadcast.cc b/src/topi/broadcast.cc
index f6a28c7722..2105172aed 100644
--- a/src/topi/broadcast.cc
+++ b/src/topi/broadcast.cc
@@ -52,6 +52,7 @@ TOPI_REGISTER_BCAST_OP("topi.subtract", topi::subtract);
TOPI_REGISTER_BCAST_OP("topi.multiply", topi::multiply);
TOPI_REGISTER_BCAST_OP("topi.divide", topi::divide);
TOPI_REGISTER_BCAST_OP("topi.floor_divide", topi::floor_divide);
+TOPI_REGISTER_BCAST_OP("topi.log_add_exp", topi::log_add_exp);
TOPI_REGISTER_BCAST_OP("topi.mod", topi::mod);
TOPI_REGISTER_BCAST_OP("topi.floor_mod", topi::floor_mod);
TOPI_REGISTER_BCAST_OP("topi.maximum", topi::maximum);
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 284544be50..26d3d3f7bd 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -585,6 +585,32 @@ def test_leakyrelu():
verify_model(LeakyReLU1(), example_args, {}, expected)
+def test_logaddexp():
+ class LogAddExp(Module):
+ def forward(self, input1, input2):
+ return torch.logaddexp(input1, input2)
+
+ @tvm.script.ir_module
+ class expected:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ input_2: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.log_add_exp(input_1, input_2)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (
+ torch.randn(1, 3, 10, 10, dtype=torch.float32),
+ torch.randn(1, 3, 10, 10, dtype=torch.float32),
+ )
+ verify_model(LogAddExp(), example_args, {}, expected)
+
+
def test_logsoftmax():
class LogSoftmax(Module):
def __init__(self):