This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new f4704f2288 [Relax][PyTorch] Add torch.outer Op Support for Exported
Program and FX graph (#17930)
f4704f2288 is described below
commit f4704f2288b14c05e99de9f74bcb9b530c2dd7e6
Author: Deivanayaki S <[email protected]>
AuthorDate: Sat May 10 12:23:34 2025 +0530
[Relax][PyTorch] Add torch.outer Op Support for Exported Program and FX
graph (#17930)
* add torch.outer op support into exported program and fx translator
* fix lint issues
* fix cpp lints
* update the format of input info n fx test script
---------
Co-authored-by: deivanayakisankaralingam <deiva@Deivanayaki>
---
.../frontend/torch/exported_program_translator.py | 3 ++
python/tvm/relax/frontend/torch/fx_translator.py | 3 ++
python/tvm/relax/op/__init__.py | 2 +-
python/tvm/relax/op/linear_algebra.py | 27 +++++++++++++++
.../relax/transform/legalize_ops/linear_algebra.py | 19 +++++++++++
python/tvm/script/ir_builder/relax/ir.py | 2 ++
src/relax/op/tensor/linear_algebra.cc | 38 ++++++++++++++++++++++
src/relax/op/tensor/linear_algebra.h | 8 +++++
.../relax/test_frontend_from_exported_program.py | 24 ++++++++++++++
tests/python/relax/test_frontend_from_fx.py | 21 ++++++++++++
10 files changed, 146 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 87508c9fea..d69d5bcfa1 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -404,6 +404,9 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"mul_.Tensor": self._binary_op(relax.op.multiply, operator.mul),
"ne.Tensor": self._binary_op(relax.op.not_equal, operator.ne),
"ne.Scalar": self._binary_op(relax.op.not_equal, operator.ne),
+ "outer.default": lambda node: self.block_builder.emit(
+ relax.op.outer(self.env[node.args[0]], self.env[node.args[1]])
+ ),
"pow.Scalar": self._binary_op(relax.op.power, operator.pow),
"pow.Tensor_Scalar": self._binary_op(relax.op.power, operator.pow),
"pow.Tensor_Tensor": self._binary_op(relax.op.power, operator.pow),
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index 0e8814dd97..b2a1f5eae1 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -832,6 +832,9 @@ class TorchFXImporter(BaseFXGraphImporter):
"mod": self._binary_op(relax.op.floor_mod, operator.mod),
"mul": self._binary_op(relax.op.multiply, operator.mul),
"ne": self._binary_op(relax.op.not_equal, operator.ne),
+ "outer": lambda node: self.block_builder.emit(
+ relax.op.outer(self.env[node.args[0]], self.env[node.args[1]])
+ ),
"pow": self._binary_op(relax.op.power, operator.pow),
"or_": self._binary_op(relax.op.bitwise_or, operator.or_),
"rshift": self._binary_op(relax.op.right_shift, operator.rshift),
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index bfc0a997df..0a2f0980fd 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -84,7 +84,7 @@ from .create import (
)
from .datatype import astype, wrap_param
from .index import dynamic_strided_slice, strided_slice, take
-from .linear_algebra import einsum, linear, matmul
+from .linear_algebra import einsum, linear, matmul, outer
from .manipulate import (
broadcast_to,
collapse_sum_like,
diff --git a/python/tvm/relax/op/linear_algebra.py
b/python/tvm/relax/op/linear_algebra.py
index efb5085c78..9b09119576 100644
--- a/python/tvm/relax/op/linear_algebra.py
+++ b/python/tvm/relax/op/linear_algebra.py
@@ -110,3 +110,30 @@ def einsum(operands, subscripts):
operands = RxTuple(operands)
return _ffi_api.einsum(operands, subscripts) # type: ignore
+
+
+def outer(x1: Expr, x2: Expr) -> Expr:
+ """
+ Computes the outer product of two input expressions.
+
+ Parameters
+ ----------
+ x1 : relax.Expr
+ The first input expression.
+
+ x2 : relax.Expr
+ The second input expression.
+
+ Notes
+ -----
+ This operation computes the outer product between two expressions,
+ resulting in a tensor where each element is the product of elements
+ from `x1` and `x2`. It is commonly used in tensor and matrix operations
+ to expand lower-dimensional inputs into higher-dimensional representations.
+
+ Returns
+ -------
+ result : relax.Expr
+ The resulting expression representing the outer product.
+ """
+ return _ffi_api.outer(x1, x2)
diff --git a/python/tvm/relax/transform/legalize_ops/linear_algebra.py
b/python/tvm/relax/transform/legalize_ops/linear_algebra.py
index 318c9521f3..154afa9dff 100644
--- a/python/tvm/relax/transform/legalize_ops/linear_algebra.py
+++ b/python/tvm/relax/transform/legalize_ops/linear_algebra.py
@@ -115,3 +115,22 @@ def _einsum(bb: BlockBuilder, call: Call) -> Expr:
t.fields if isinstance(t, Tuple) else [bb.emit(TupleGetItem(t, i)) for
i in range(n_field)]
)
return bb.call_te(topi.einsum, call.attrs.subscripts, *fields)
+
+
+@register_legalize("relax.outer")
+def _outer(bb: BlockBuilder, call: Call) -> Expr:
+ def te_outer(a: te.Tensor, b: te.Tensor) -> te.Tensor:
+ a_shape = list(a.shape)
+ b_shape = list(b.shape)
+ assert len(a_shape) == 1 and len(b_shape) == 1, "outer requires 1D
tensors"
+
+ n = a_shape[0]
+ m = b_shape[0]
+
+ def compute_fn(i, j):
+ return a[i] * b[j]
+
+ return te.compute((n, m), compute_fn, name="outer")
+
+ lhs, rhs = call.args
+ return bb.call_te(te_outer, lhs, rhs, primfunc_name_hint="outer")
diff --git a/python/tvm/script/ir_builder/relax/ir.py
b/python/tvm/script/ir_builder/relax/ir.py
index d1e86cc7f4..b696d73031 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -138,6 +138,7 @@ from tvm.relax.op import (
ones,
ones_like,
one_hot,
+ outer,
permute_dims,
power,
print,
@@ -826,6 +827,7 @@ __all__ = [
"one_hot",
"opencl",
"output",
+ "outer",
"permute_dims",
"power",
"prim_value",
diff --git a/src/relax/op/tensor/linear_algebra.cc
b/src/relax/op/tensor/linear_algebra.cc
index 0fdbee1c6a..4ca42bffec 100644
--- a/src/relax/op/tensor/linear_algebra.cc
+++ b/src/relax/op/tensor/linear_algebra.cc
@@ -251,5 +251,43 @@ TVM_REGISTER_OP("relax.einsum")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoEinsum)
.set_attr<Bool>("FPurity", Bool(true));
+/* relax.outer */
+
+Expr outer(Expr x1, Expr x2) {
+ static const Op& op = Op::Get("relax.outer");
+ return Call(op, {std::move(x1), std::move(x2)}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.outer").set_body_typed(outer);
+
+StructInfo InferStructInfoOuter(const Call& call, const BlockBuilder& ctx) {
+ auto input_sinfo = GetInputTensorStructInfo(call, ctx);
+ auto x1_sinfo = input_sinfo[0];
+ auto x2_sinfo = input_sinfo[1];
+
+ // Ensure both inputs are 1D tensors
+ if (x1_sinfo->ndim != 1 || x2_sinfo->ndim != 1) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "torch.outer requires both inputs to be 1D tensors.");
+ }
+
+ // Determine output shape
+ auto x1_shape = x1_sinfo->shape.as<ShapeExprNode>();
+ auto x2_shape = x2_sinfo->shape.as<ShapeExprNode>();
+ if (!x1_shape || !x2_shape) {
+ return TensorStructInfo(x1_sinfo->dtype, 2);
+ }
+ Array<PrimExpr> output_shape = {x1_shape->values[0], x2_shape->values[0]};
+ return TensorStructInfo(ShapeExpr(output_shape), x1_sinfo->dtype);
+}
+
+TVM_REGISTER_OP("relax.outer")
+ .set_num_inputs(2)
+ .add_argument("x1", "Tensor", "The first input tensor.")
+ .add_argument("x2", "Tensor", "The second input tensor.")
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoOuter)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kAlways)
+ .set_attr<Bool>("FPurity", Bool(true));
+
} // namespace relax
} // namespace tvm
diff --git a/src/relax/op/tensor/linear_algebra.h
b/src/relax/op/tensor/linear_algebra.h
index 638e5af8f8..eb003fed1c 100644
--- a/src/relax/op/tensor/linear_algebra.h
+++ b/src/relax/op/tensor/linear_algebra.h
@@ -51,6 +51,14 @@ Expr matmul(Expr x1, Expr x2, Optional<DataType> out_dtype);
*/
Expr einsum(Expr operands, String subscripts);
+/*!
+ * \brief Compute the outer product of two input expressions.
+ * \param x1 The first input expression.
+ * \param x2 The second input expression.
+ * \return The resulting expression representing the outer product.
+ */
+Expr outer(Expr x1, Expr x2);
+
} // namespace relax
} // namespace tvm
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 75c745a213..4cb9e903a1 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -2423,6 +2423,30 @@ def test_einsum():
verify_model(Einsum2(), example_args, {}, Expected2)
+def test_outer():
+ class Outer(torch.nn.Module):
+ def forward(self, x, y):
+ return torch.outer(x, y)
+
+ @tvm.script.ir_module
+ class expected:
+ @R.function
+ def main(
+ a: R.Tensor((3,), dtype="float32"), b: R.Tensor((4,),
dtype="float32")
+ ) -> R.Tuple(R.Tensor((3, 4), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((3, 4), dtype="float32") = R.outer(a, b)
+ gv: R.Tuple(R.Tensor((3, 4), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (
+ torch.randn(3, dtype=torch.float32),
+ torch.randn(4, dtype=torch.float32),
+ )
+ verify_model(Outer(), example_args, {}, expected)
+
+
def test_embedding():
class Embedding(Module):
def __init__(self):
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 48e12dfe49..7928975301 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -874,6 +874,27 @@ def test_einsum():
verify_model(Einsum2(), [([5], "float32"), ([4], "float32")], {},
Expected2)
+def test_outer():
+ class Outer(torch.nn.Module):
+ def forward(self, x, y):
+ return torch.outer(x, y)
+
+ @tvm.script.ir_module
+ class expected:
+ @R.function
+ def main(
+ a: R.Tensor((3,), dtype="float32"), b: R.Tensor((4,),
dtype="float32")
+ ) -> R.Tensor((3, 4), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((3, 4), dtype="float32") = R.outer(a, b)
+ gv: R.Tensor((3, 4), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ input_infos = [([3], "float32"), ([4], "float32")]
+ verify_model(Outer(), input_infos, {}, expected)
+
+
@tvm.testing.requires_gpu
def test_softplus():
import torch