This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new f4704f2288 [Relax][PyTorch] Add torch.outer Op Support for Exported 
Program and FX graph  (#17930)
f4704f2288 is described below

commit f4704f2288b14c05e99de9f74bcb9b530c2dd7e6
Author: Deivanayaki S <[email protected]>
AuthorDate: Sat May 10 12:23:34 2025 +0530

    [Relax][PyTorch] Add torch.outer Op Support for Exported Program and FX 
graph  (#17930)
    
    * add torch.outer op support into exported program and fx translator
    
    * fix lint issues
    
    * fix cpp lints
    
    * update the format of input info n fx test script
    
    ---------
    
    Co-authored-by: deivanayakisankaralingam <deiva@Deivanayaki>
---
 .../frontend/torch/exported_program_translator.py  |  3 ++
 python/tvm/relax/frontend/torch/fx_translator.py   |  3 ++
 python/tvm/relax/op/__init__.py                    |  2 +-
 python/tvm/relax/op/linear_algebra.py              | 27 +++++++++++++++
 .../relax/transform/legalize_ops/linear_algebra.py | 19 +++++++++++
 python/tvm/script/ir_builder/relax/ir.py           |  2 ++
 src/relax/op/tensor/linear_algebra.cc              | 38 ++++++++++++++++++++++
 src/relax/op/tensor/linear_algebra.h               |  8 +++++
 .../relax/test_frontend_from_exported_program.py   | 24 ++++++++++++++
 tests/python/relax/test_frontend_from_fx.py        | 21 ++++++++++++
 10 files changed, 146 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 87508c9fea..d69d5bcfa1 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -404,6 +404,9 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "mul_.Tensor": self._binary_op(relax.op.multiply, operator.mul),
             "ne.Tensor": self._binary_op(relax.op.not_equal, operator.ne),
             "ne.Scalar": self._binary_op(relax.op.not_equal, operator.ne),
+            "outer.default": lambda node: self.block_builder.emit(
+                relax.op.outer(self.env[node.args[0]], self.env[node.args[1]])
+            ),
             "pow.Scalar": self._binary_op(relax.op.power, operator.pow),
             "pow.Tensor_Scalar": self._binary_op(relax.op.power, operator.pow),
             "pow.Tensor_Tensor": self._binary_op(relax.op.power, operator.pow),
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 0e8814dd97..b2a1f5eae1 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -832,6 +832,9 @@ class TorchFXImporter(BaseFXGraphImporter):
             "mod": self._binary_op(relax.op.floor_mod, operator.mod),
             "mul": self._binary_op(relax.op.multiply, operator.mul),
             "ne": self._binary_op(relax.op.not_equal, operator.ne),
+            "outer": lambda node: self.block_builder.emit(
+                relax.op.outer(self.env[node.args[0]], self.env[node.args[1]])
+            ),
             "pow": self._binary_op(relax.op.power, operator.pow),
             "or_": self._binary_op(relax.op.bitwise_or, operator.or_),
             "rshift": self._binary_op(relax.op.right_shift, operator.rshift),
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index bfc0a997df..0a2f0980fd 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -84,7 +84,7 @@ from .create import (
 )
 from .datatype import astype, wrap_param
 from .index import dynamic_strided_slice, strided_slice, take
-from .linear_algebra import einsum, linear, matmul
+from .linear_algebra import einsum, linear, matmul, outer
 from .manipulate import (
     broadcast_to,
     collapse_sum_like,
diff --git a/python/tvm/relax/op/linear_algebra.py 
b/python/tvm/relax/op/linear_algebra.py
index efb5085c78..9b09119576 100644
--- a/python/tvm/relax/op/linear_algebra.py
+++ b/python/tvm/relax/op/linear_algebra.py
@@ -110,3 +110,30 @@ def einsum(operands, subscripts):
         operands = RxTuple(operands)
 
     return _ffi_api.einsum(operands, subscripts)  # type: ignore
+
+
+def outer(x1: Expr, x2: Expr) -> Expr:
+    """
+    Computes the outer product of two input expressions.
+
+    Parameters
+    ----------
+    x1 : relax.Expr
+        The first input expression.
+
+    x2 : relax.Expr
+        The second input expression.
+
+    Notes
+    -----
+    This operation computes the outer product between two expressions,
+    resulting in a tensor where each element is the product of elements
+    from `x1` and `x2`. It is commonly used in tensor and matrix operations
+    to expand lower-dimensional inputs into higher-dimensional representations.
+
+    Returns
+    -------
+    result : relax.Expr
+        The resulting expression representing the outer product.
+    """
+    return _ffi_api.outer(x1, x2)
diff --git a/python/tvm/relax/transform/legalize_ops/linear_algebra.py 
b/python/tvm/relax/transform/legalize_ops/linear_algebra.py
index 318c9521f3..154afa9dff 100644
--- a/python/tvm/relax/transform/legalize_ops/linear_algebra.py
+++ b/python/tvm/relax/transform/legalize_ops/linear_algebra.py
@@ -115,3 +115,22 @@ def _einsum(bb: BlockBuilder, call: Call) -> Expr:
         t.fields if isinstance(t, Tuple) else [bb.emit(TupleGetItem(t, i)) for 
i in range(n_field)]
     )
     return bb.call_te(topi.einsum, call.attrs.subscripts, *fields)
+
+
+@register_legalize("relax.outer")
+def _outer(bb: BlockBuilder, call: Call) -> Expr:
+    def te_outer(a: te.Tensor, b: te.Tensor) -> te.Tensor:
+        a_shape = list(a.shape)
+        b_shape = list(b.shape)
+        assert len(a_shape) == 1 and len(b_shape) == 1, "outer requires 1D 
tensors"
+
+        n = a_shape[0]
+        m = b_shape[0]
+
+        def compute_fn(i, j):
+            return a[i] * b[j]
+
+        return te.compute((n, m), compute_fn, name="outer")
+
+    lhs, rhs = call.args
+    return bb.call_te(te_outer, lhs, rhs, primfunc_name_hint="outer")
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index d1e86cc7f4..b696d73031 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -138,6 +138,7 @@ from tvm.relax.op import (
     ones,
     ones_like,
     one_hot,
+    outer,
     permute_dims,
     power,
     print,
@@ -826,6 +827,7 @@ __all__ = [
     "one_hot",
     "opencl",
     "output",
+    "outer",
     "permute_dims",
     "power",
     "prim_value",
diff --git a/src/relax/op/tensor/linear_algebra.cc 
b/src/relax/op/tensor/linear_algebra.cc
index 0fdbee1c6a..4ca42bffec 100644
--- a/src/relax/op/tensor/linear_algebra.cc
+++ b/src/relax/op/tensor/linear_algebra.cc
@@ -251,5 +251,43 @@ TVM_REGISTER_OP("relax.einsum")
     .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoEinsum)
     .set_attr<Bool>("FPurity", Bool(true));
 
+/* relax.outer */
+
+Expr outer(Expr x1, Expr x2) {
+  static const Op& op = Op::Get("relax.outer");
+  return Call(op, {std::move(x1), std::move(x2)}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.outer").set_body_typed(outer);
+
+StructInfo InferStructInfoOuter(const Call& call, const BlockBuilder& ctx) {
+  auto input_sinfo = GetInputTensorStructInfo(call, ctx);
+  auto x1_sinfo = input_sinfo[0];
+  auto x2_sinfo = input_sinfo[1];
+
+  // Ensure both inputs are 1D tensors
+  if (x1_sinfo->ndim != 1 || x2_sinfo->ndim != 1) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "torch.outer requires both inputs to be 1D tensors.");
+  }
+
+  // Determine output shape
+  auto x1_shape = x1_sinfo->shape.as<ShapeExprNode>();
+  auto x2_shape = x2_sinfo->shape.as<ShapeExprNode>();
+  if (!x1_shape || !x2_shape) {
+    return TensorStructInfo(x1_sinfo->dtype, 2);
+  }
+  Array<PrimExpr> output_shape = {x1_shape->values[0], x2_shape->values[0]};
+  return TensorStructInfo(ShapeExpr(output_shape), x1_sinfo->dtype);
+}
+
+TVM_REGISTER_OP("relax.outer")
+    .set_num_inputs(2)
+    .add_argument("x1", "Tensor", "The first input tensor.")
+    .add_argument("x2", "Tensor", "The second input tensor.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoOuter)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kAlways)
+    .set_attr<Bool>("FPurity", Bool(true));
+
 }  // namespace relax
 }  // namespace tvm
diff --git a/src/relax/op/tensor/linear_algebra.h 
b/src/relax/op/tensor/linear_algebra.h
index 638e5af8f8..eb003fed1c 100644
--- a/src/relax/op/tensor/linear_algebra.h
+++ b/src/relax/op/tensor/linear_algebra.h
@@ -51,6 +51,14 @@ Expr matmul(Expr x1, Expr x2, Optional<DataType> out_dtype);
  */
 Expr einsum(Expr operands, String subscripts);
 
+/*!
+ * \brief Compute the outer product of two input expressions.
+ * \param x1 The first input expression.
+ * \param x2 The second input expression.
+ * \return The resulting expression representing the outer product.
+ */
+Expr outer(Expr x1, Expr x2);
+
 }  // namespace relax
 }  // namespace tvm
 
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 75c745a213..4cb9e903a1 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -2423,6 +2423,30 @@ def test_einsum():
     verify_model(Einsum2(), example_args, {}, Expected2)
 
 
+def test_outer():
+    class Outer(torch.nn.Module):
+        def forward(self, x, y):
+            return torch.outer(x, y)
+
+    @tvm.script.ir_module
+    class expected:
+        @R.function
+        def main(
+            a: R.Tensor((3,), dtype="float32"), b: R.Tensor((4,), 
dtype="float32")
+        ) -> R.Tuple(R.Tensor((3, 4), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((3, 4), dtype="float32") = R.outer(a, b)
+                gv: R.Tuple(R.Tensor((3, 4), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (
+        torch.randn(3, dtype=torch.float32),
+        torch.randn(4, dtype=torch.float32),
+    )
+    verify_model(Outer(), example_args, {}, expected)
+
+
 def test_embedding():
     class Embedding(Module):
         def __init__(self):
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 48e12dfe49..7928975301 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -874,6 +874,27 @@ def test_einsum():
     verify_model(Einsum2(), [([5], "float32"), ([4], "float32")], {}, 
Expected2)
 
 
+def test_outer():
+    class Outer(torch.nn.Module):
+        def forward(self, x, y):
+            return torch.outer(x, y)
+
+    @tvm.script.ir_module
+    class expected:
+        @R.function
+        def main(
+            a: R.Tensor((3,), dtype="float32"), b: R.Tensor((4,), 
dtype="float32")
+        ) -> R.Tensor((3, 4), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((3, 4), dtype="float32") = R.outer(a, b)
+                gv: R.Tensor((3, 4), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    input_infos = [([3], "float32"), ([4], "float32")]
+    verify_model(Outer(), input_infos, {}, expected)
+
+
 @tvm.testing.requires_gpu
 def test_softplus():
     import torch

Reply via email to