This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c00f52a70d [Relax][PyTorch] Add Stack Op Support for Exported Program  
(#17819)
c00f52a70d is described below

commit c00f52a70d92f1d0b61a819efa3c906d00b17f7e
Author: Pratheesh-04-MCW <[email protected]>
AuthorDate: Fri Apr 18 07:33:12 2025 +0530

    [Relax][PyTorch] Add Stack Op Support for Exported Program  (#17819)
    
    * add op support for stack
    
    * trailing whitespace issue fixed
    
    * fixed lint issues
    
    * fixed whitespace issue
    
    * fixed lint error
    
    * fixing lint issues
    
    * fixed whitespace issue
    
    * add test script for fx_graph
    
    * fix lint issues
    
    * fixed unity check issues
    
    * unity check
    
    * fixed unity check issues
    
    * lint issues
---
 include/tvm/relax/attrs/manipulate.h               |  13 ++
 .../frontend/torch/base_fx_graph_translator.py     |  16 +-
 python/tvm/relax/op/__init__.py                    |   1 +
 python/tvm/relax/op/manipulate.py                  |  24 +++
 python/tvm/relax/op/op_attrs.py                    |   5 +
 .../tvm/relax/transform/legalize_ops/manipulate.py |  22 +++
 python/tvm/script/ir_builder/relax/ir.py           |   2 +
 python/tvm/topi/transform.py                       |  16 +-
 src/contrib/msc/framework/torch/torch_opcode.cc    |   8 +
 src/relax/op/tensor/manipulate.cc                  | 209 +++++++++++++++++++++
 src/relax/op/tensor/manipulate.h                   |   8 +-
 tests/python/contrib/test_msc/test_graph_build.py  | 134 ++++++-------
 .../relax/test_frontend_from_exported_program.py   |  64 +++++++
 tests/python/relax/test_frontend_from_fx.py        |   7 +-
 14 files changed, 435 insertions(+), 94 deletions(-)

diff --git a/include/tvm/relax/attrs/manipulate.h 
b/include/tvm/relax/attrs/manipulate.h
index e6c16d233a..67f99d9b41 100644
--- a/include/tvm/relax/attrs/manipulate.h
+++ b/include/tvm/relax/attrs/manipulate.h
@@ -119,6 +119,19 @@ struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
   }
 };  // struct SqueezeAttrs
 
+/*! \brief Attributes used in stack operators */
+struct StackAttrs : public tvm::AttrsNode<StackAttrs> {
+  Optional<Integer> axis;
+
+  TVM_DECLARE_ATTRS(StackAttrs, "relax.attrs.StackAttrs") {
+    TVM_ATTR_FIELD(axis).describe(
+        "The axis along which to stack the input tensors. "
+        "The axis will be inserted at this position in the output, "
+        "so it must be in range [-ndim-1, ndim] where ndim is the "
+        "number of dimensions of the input tensors.");
+  }
+};  // struct StackAttrs
+
 /*! \brief Attributes used in repeat operators */
 struct RepeatAttrs : public tvm::AttrsNode<RepeatAttrs> {
   int repeats;
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index a9bee11fc8..7a971c00cd 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1194,21 +1194,9 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
 
     def _stack(self, node: fx.Node) -> relax.Var:
         args = self.retrieve_args(node)
+        tensor_list = args[0]
         axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
-        in_args = args[0]
-        assert all(
-            a.struct_info.shape[axis] == in_args[0].struct_info.shape[axis] 
for a in in_args[1:]
-        ), "Expect all dim at {} to be the same, get {}".format(
-            axis, [a.struct_info.shape for a in args]
-        )
-        cat = self.block_builder.emit(relax.op.concat(in_args, axis=axis))
-        s_shape = []
-        for idx, s in enumerate(cat.struct_info.shape):
-            if idx == axis:
-                s_shape.extend([len(in_args), 
in_args[0].struct_info.shape[axis]])
-            else:
-                s_shape.append(s)
-        return self.block_builder.emit(relax.op.reshape(cat, s_shape))
+        return self.block_builder.emit(relax.op.stack(tensor_list, axis=axis))
 
     def _take(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index ddfdfc2b05..3145a7c292 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -104,6 +104,7 @@ from .manipulate import (
     scatter_nd,
     split,
     squeeze,
+    stack,
     tile,
 )
 from .mask import masked_fill
diff --git a/python/tvm/relax/op/manipulate.py 
b/python/tvm/relax/op/manipulate.py
index 0f6e537ab3..725e58bd01 100644
--- a/python/tvm/relax/op/manipulate.py
+++ b/python/tvm/relax/op/manipulate.py
@@ -279,6 +279,30 @@ def squeeze(x: Expr, axis: Optional[Union[int, List[int]]] 
= None) -> Expr:
     return _ffi_api.squeeze(x, axis)  # type: ignore
 
 
+def stack(tensors: Union[Expr, List[Expr]], axis: int = 0) -> Expr:
+    """Stack the input tensors along a new axis.
+
+    Parameters
+    ----------
+    tensors : Union[relax.Expr, List[relax.Expr]]
+        An Expr in Tuple type, containing the tensors to be stacked,
+        or a list of Tensors. All input tensors must have the same shape.
+
+    axis : int
+        The axis in the resulting tensor along which the input tensors will be 
stacked.
+        Negative values wrap around. Default is 0.
+
+    Returns
+    -------
+    result: relax.Expr
+        The stacked tensor with an additional dimension compared to the input 
tensors.
+
+    """
+    if isinstance(tensors, (list, tuple)):
+        tensors = RxTuple(tensors)
+    return _ffi_api.stack(tensors, axis)  # type: ignore
+
+
 def collapse_sum_like(data: Expr, collapse_target: Expr) -> Expr:
     """Return a summation of data to the shape of collapse_target.
 
diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py
index 4658950f51..fda4258a09 100644
--- a/python/tvm/relax/op/op_attrs.py
+++ b/python/tvm/relax/op/op_attrs.py
@@ -139,6 +139,11 @@ class SqueezeAttrs(Attrs):
     """Attributes for squeeze operator"""
 
 
+@tvm._ffi.register_object("relax.attrs.StackAttrs")
+class StackAttrs(Attrs):
+    """Attributes for concat operator"""
+
+
 @tvm._ffi.register_object("relax.attrs.LayoutTransformAttrs")
 class LayoutTransformAttrs(Attrs):
     """Attributes used in layout_transform operator"""
diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py 
b/python/tvm/relax/transform/legalize_ops/manipulate.py
index 662d4e946b..a481d7af95 100644
--- a/python/tvm/relax/transform/legalize_ops/manipulate.py
+++ b/python/tvm/relax/transform/legalize_ops/manipulate.py
@@ -118,6 +118,28 @@ def _squeeze(bb: BlockBuilder, call: Call) -> Expr:
     return bb.call_te(topi.squeeze, call.args[0], call.attrs.axis)
 
 
+@register_legalize("relax.stack")
+def _stack(bb: BlockBuilder, call: Call) -> Expr:
+    t = call.args[0]
+    n_field = len(t.struct_info.fields)
+
+    # Follow bindings to find the actual tuple
+    while isinstance(t, Var):
+        binding = bb.lookup_binding(t)
+        if not isinstance(binding, (Tuple, Var)):
+            break
+        t = binding
+
+    assert isinstance(t, (Tuple, Var))
+
+    # Extract fields from either Tuple or bound Var
+    fields = (
+        t.fields if isinstance(t, Tuple) else [bb.emit(TupleGetItem(t, i)) for 
i in range(n_field)]
+    )
+
+    return bb.call_te(topi.stack, fields, 0 if call.attrs.axis is None else 
call.attrs.axis.value)
+
+
 @register_legalize("relax.repeat")
 def _repeat(bb: BlockBuilder, call: Call) -> Expr:
     def te_repeat(data: te.Tensor, repeats: IntImm, axis: Optional[IntImm]):
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index 6fa3cc61cb..79b1884aac 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -158,6 +158,7 @@ from tvm.relax.op import (
     sqrt,
     square,
     squeeze,
+    stack,
     std,
     strided_slice,
     subtract,
@@ -851,6 +852,7 @@ __all__ = [
     "square",
     "squeeze",
     "sqrt",
+    "stack",
     "stop_lift_params",
     "str",
     "strided_slice",
diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py
index b8605aa58a..37743e97a3 100644
--- a/python/tvm/topi/transform.py
+++ b/python/tvm/topi/transform.py
@@ -403,23 +403,25 @@ def concatenate(a_tuple, axis=0):
     return cpp.concatenate(a_tuple, axis)
 
 
-def stack(a, axis):
-    """Repeats the whole array multiple times.
+def stack(tensors, axis=0):
+    """Join a sequence of tensors along a new axis.
 
     Parameters
     ----------
-    a : tvm.te.Tensor
-        The tensor to be stacked.
+    tensors : tuple or list of tvm.te.Tensor
+        The tensors to be stacked. All tensors must have the same shape.
 
     axis : int, optional
-        The axis in the result array along which the input arrays are stacked.
-
+        The axis in the resulting tensor along which the input tensors will be 
stacked.
+        Negative values wrap around. Default is 0.
 
     Returns
     -------
     ret : tvm.te.Tensor
+        The stacked tensor with an additional dimension compared to the input 
tensors.
+
     """
-    return cpp.stack(a, axis)
+    return cpp.stack(tensors, axis)
 
 
 def split(ary, indices_or_sections, axis=0):
diff --git a/src/contrib/msc/framework/torch/torch_opcode.cc 
b/src/contrib/msc/framework/torch/torch_opcode.cc
index f5784efe3d..9e3652f041 100644
--- a/src/contrib/msc/framework/torch/torch_opcode.cc
+++ b/src/contrib/msc/framework/torch/torch_opcode.cc
@@ -209,6 +209,13 @@ class TorchConcatCodeGen : public TorchOpCode {
   void CodeGenForward() final { 
stack_.op_call().op_inputs_arg().op_arg<int>("axis", "dim"); }
 };
 
+class TorchStackCodeGen : public TorchOpCode {
+  TORCH_OP_CODEGEN_METHODS(TorchStackCodeGen);
+
+ protected:
+  void CodeGenForward() final { 
stack_.op_call().op_inputs_arg().op_arg<int>("axis", "dim"); }
+};
+
 class TorchConstantCodeGen : public TorchOpCode {
   TORCH_OP_CODEGEN_METHODS(TorchConstantCodeGen);
 
@@ -789,6 +796,7 @@ const std::shared_ptr<std::unordered_map<String, 
std::shared_ptr<TorchOpCode>>>
                std::make_shared<TorchScatterElementsCodeGen>("", 
"torch.scatter"));
   map->emplace("scatter_nd", std::make_shared<TorchScatterNDCodeGen>("", ""));
   map->emplace("split", std::make_shared<TorchSplitCodeGen>("", 
"torch.split"));
+  map->emplace("stack", std::make_shared<TorchStackCodeGen>("", 
"torch.stack"));
   map->emplace("strided_slice", std::make_shared<TorchStridedSliceCodeGen>("", 
""));
   map->emplace("take", std::make_shared<TorchTakeCodeGen>("", ""));
 
diff --git a/src/relax/op/tensor/manipulate.cc 
b/src/relax/op/tensor/manipulate.cc
index cb738db363..4abfe01387 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -1193,6 +1193,215 @@ void CheckCollapseShape(const Call& call, const 
BlockBuilder& ctx,
   }
 }
 
+/* relax.stack */
+TVM_REGISTER_NODE_TYPE(StackAttrs);
+
+Expr stack(Expr tensors, Optional<Integer> axis) {
+  ObjectPtr<StackAttrs> attrs = make_object<StackAttrs>();
+  attrs->axis = std::move(axis);
+
+  static const Op& op = Op::Get("relax.stack");
+  return Call(op, {std::move(tensors)}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.stack").set_body_typed(stack);
+
+Optional<Array<PrimExpr>> CheckStackOutputShape(const Call& call, const 
BlockBuilder& ctx,
+                                                const 
std::vector<Array<PrimExpr>>& shape_values,
+                                                int axis) {
+  bool shape_unknown = false;
+  arith::Analyzer* analyzer = ctx->GetAnalyzer();
+
+  // Stack requires all input tensors to have identical shapes
+  for (int d = 0; d < static_cast<int>(shape_values[0].size()); ++d) {
+    for (int i = 1; i < static_cast<int>(shape_values.size()); ++i) {
+      if (analyzer->CanProve(shape_values[i][d] != shape_values[0][d])) {
+        ctx->ReportFatal(Diagnostic::Error(call)
+                         << "Stack expects all input tensors to have identical 
shapes. "
+                         << "Dimension " << d << " differs between tensors: " 
<< shape_values[0][d]
+                         << " vs " << shape_values[i][d]);
+      } else if (!analyzer->CanProveEqual(shape_values[i][d], 
shape_values[0][d])) {
+        shape_unknown = true;
+      }
+    }
+  }
+
+  if (shape_unknown) {
+    return NullOpt;
+  }
+
+  // Insert new dimension at axis position
+  Array<PrimExpr> output_shape;
+  for (int i = 0; i < axis; ++i) {
+    output_shape.push_back(shape_values[0][i]);
+  }
+  output_shape.push_back(IntImm(DataType::Int(64), shape_values.size()));  // 
Stack dimension
+  for (int i = axis; i < static_cast<int>(shape_values[0].size()); ++i) {
+    output_shape.push_back(shape_values[0][i]);
+  }
+  return output_shape;
+}
+
+StructInfo InferStructInfoStack(const Call& call, const BlockBuilder& ctx) {
+  if (call->args.size() != 1) {
+    ctx->ReportFatal(Diagnostic::Error(call) << "Stack op should have 1 
argument");
+  }
+
+  Array<TensorStructInfo> tensor_sinfo = GetTensorStructInfoFromTuple(call, 
ctx, call->args[0]);
+  if (tensor_sinfo.empty()) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "Stack op expects at least one tensor in the input 
Tuple. "
+                     << "However, the given input Tuple is empty.");
+  }
+
+  const auto* attrs = call->attrs.as<StackAttrs>();
+  ICHECK(attrs != nullptr) << "Stack must have StackAttrs";
+
+  // Default axis is 0 if not specified
+  int output_ndim = tensor_sinfo[0]->ndim + 1;  // Stack adds one dimension
+  DataType output_dtype = DataType::Void();
+  Optional<VDevice> vdev = NullOpt;
+  bool shape_unknown = false;
+  bool is_void_dtype = false;
+  bool vdevice_unknown = false;
+  std::vector<Array<PrimExpr>> shape_values;
+  shape_values.reserve(tensor_sinfo.size());
+
+  for (TensorStructInfo sinfo : tensor_sinfo) {
+    // Check dtype consistency
+    if (sinfo->dtype.is_void()) {
+      is_void_dtype = true;
+    } else if (output_dtype.is_void()) {
+      output_dtype = sinfo->dtype;
+    } else if (sinfo->dtype != output_dtype) {
+      ctx->ReportFatal(Diagnostic::Error(call)
+                       << "Stack expects all input tensors to have the same 
dtype. "
+                       << "Found " << output_dtype << " and " << sinfo->dtype);
+    }
+
+    // Check ndim consistency
+    if (sinfo->ndim != kUnknownNDim && sinfo->ndim != tensor_sinfo[0]->ndim) {
+      ctx->ReportFatal(Diagnostic::Error(call)
+                       << "Stack expects all input tensors to have same ndim. "
+                       << "Found " << tensor_sinfo[0]->ndim << " and " << 
sinfo->ndim);
+    }
+
+    // Check virtual device consistency
+    if (!vdevice_unknown) {
+      if (sinfo->vdevice.defined()) {
+        if (!vdev.defined()) {
+          vdev = sinfo->vdevice.value();
+        } else if (sinfo->vdevice.value() != vdev) {
+          vdevice_unknown = true;
+        }
+      }
+    }
+
+    // Collect shape information
+    const auto* shape_expr = sinfo->shape.as<ShapeExprNode>();
+    if (shape_expr != nullptr) {
+      shape_values.push_back(shape_expr->values);
+      continue;
+    }
+    shape_unknown = true;
+
+    if (!sinfo->shape.defined()) continue;
+    ShapeStructInfo shape_sinfo = 
Downcast<ShapeStructInfo>(sinfo->shape.value()->struct_info_);
+    if (shape_sinfo->values.defined()) {
+      shape_values.push_back(shape_sinfo->values.value());
+    }
+  }
+
+  if (is_void_dtype) output_dtype = DataType::Void();
+  if (vdevice_unknown) vdev = NullOpt;
+
+  // Normalize axis (default to 0 if not specified)
+  int axis =
+      attrs->axis.defined() ? NormalizeAxis(call, ctx, output_ndim, 
attrs->axis.value()->value) : 0;
+
+  // Single tensor case
+  if (tensor_sinfo.size() == 1) {
+    if (shape_values.empty()) {
+      if (!vdevice_unknown) {
+        return TensorStructInfo(output_dtype, output_ndim, vdev);
+      }
+      return TensorStructInfo(output_dtype, output_ndim);
+    }
+    Array<PrimExpr> output_shape;
+    for (int i = 0; i < axis; ++i) {
+      output_shape.push_back(shape_values[0][i]);
+    }
+    output_shape.push_back(1);  // Stack size 1
+    for (int i = axis; i < static_cast<int>(shape_values[0].size()); ++i) {
+      output_shape.push_back(shape_values[0][i]);
+    }
+    if (!vdevice_unknown) {
+      return TensorStructInfo(ShapeExpr(output_shape), output_dtype, vdev);
+    }
+    return TensorStructInfo(ShapeExpr(output_shape), output_dtype);
+  }
+
+  // Multiple tensors case
+  if (shape_values.empty()) {
+    if (!vdevice_unknown) {
+      return TensorStructInfo(output_dtype, output_ndim, vdev);
+    }
+    return TensorStructInfo(output_dtype, output_ndim);
+  }
+
+  Optional<Array<PrimExpr>> output_shape = CheckStackOutputShape(call, ctx, 
shape_values, axis);
+  if (shape_unknown || !output_shape.defined()) {
+    if (!vdevice_unknown) {
+      return TensorStructInfo(output_dtype, output_ndim, vdev);
+    }
+    return TensorStructInfo(output_dtype, output_ndim);
+  } else {
+    if (!vdevice_unknown) {
+      return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype, 
vdev);
+    }
+    return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype);
+  }
+}
+
+InferLayoutOutput InferLayoutStack(const Call& call,
+                                   const Map<String, Array<String>>& 
desired_layouts,
+                                   const VarLayoutMap& var_layout_map) {
+  ICHECK(NoDesiredLayout(call, desired_layouts));
+
+  const auto* attrs = call->attrs.as<StackAttrs>();
+  ICHECK(attrs != nullptr) << "Invalid Call";
+  NLayout nlayout = GetNLayout(var_layout_map, call->args[0]);
+  ICHECK(nlayout.IsNested());
+  ICHECK(nlayout.NestedArray()[0].IsLeaf());
+
+  int n_tensor = nlayout.NestedArray().size();
+  LayoutDecision layout = nlayout.NestedArray()[0].LeafValue();
+  Array<NLayout> input_layouts, output_layouts;
+  for (int i = 0; i < n_tensor; ++i) {
+    input_layouts.push_back(layout);
+  }
+
+  // For stack, we need to adjust the output layout by inserting a new axis
+  std::string layout_str = layout->layout.name();
+  int axis = attrs->axis.defined() ? attrs->axis.value()->value : 0;
+  layout_str.insert(static_cast<size_t>(axis), "S");  // Add stack dimension
+  Layout output_layout = Layout(layout_str);
+  output_layouts.push_back(LayoutDecision(output_layout));
+
+  ObjectPtr<StackAttrs> new_attrs = make_object<StackAttrs>(*attrs);
+  new_attrs->axis = Integer(FindAxis(layout->layout, axis));
+  return InferLayoutOutput({NLayout(input_layouts)}, output_layouts, 
Attrs(new_attrs));
+}
+
+TVM_REGISTER_OP("relax.stack")
+    .set_attrs_type<StackAttrs>()
+    .set_num_inputs(1)
+    .add_argument("tensors", "Tuple of Tensors", "The input list of tensors to 
stack")
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoStack)
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutStack)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow)
+    .set_attr<Bool>("FPurity", Bool(true));
+
 /* relax.collapse_sum_like */
 Expr collapse_sum_like(Expr data, Expr collapse_target) {
   static const Op& op = Op::Get("relax.collapse_sum_like");
diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h
index 1a0c7ddbc7..7e5de217bc 100644
--- a/src/relax/op/tensor/manipulate.h
+++ b/src/relax/op/tensor/manipulate.h
@@ -117,7 +117,13 @@ Expr split(Expr x, Variant<IntImm, Array<IntImm>> 
indices_or_sections, int axis)
  * \return The squeezed result.
  */
 Expr squeeze(Expr x, Optional<Array<Integer>> axis);
-
+/*!
+ * \brief Stack tensors along the specified axis.
+ * \param tensors The input tensors to be stacked.
+ * \param axis The axis along which the tensors will be stacked.
+ * \return The stacked result.
+ */
+Expr stack(Expr tensors, Optional<Integer> axis);
 /*!
  * \brief Return a summation of data to the shape of collapse_target.
  * For details, please see the operator `relax.collapse_sum_to`.
diff --git a/tests/python/contrib/test_msc/test_graph_build.py 
b/tests/python/contrib/test_msc/test_graph_build.py
index 5396b5e106..328fbf456e 100644
--- a/tests/python/contrib/test_msc/test_graph_build.py
+++ b/tests/python/contrib/test_msc/test_graph_build.py
@@ -37,7 +37,7 @@ def verify_model(torch_model, input_info, expected):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_conv1d(dynamic):
+def test_conv1d(dynamic: bool):
     """test graph builder for conv1d"""
 
     class Conv1D1(Module):
@@ -77,7 +77,7 @@ def test_conv1d(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_conv2d(dynamic):
+def test_conv2d(dynamic: bool):
     """test graph builder for conv2d"""
 
     class Conv2D1(Module):
@@ -130,7 +130,7 @@ def test_conv2d(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_linear(dynamic):
+def test_linear(dynamic: bool):
     """test graph builder for linear"""
 
     class Dense1(Module):
@@ -201,7 +201,7 @@ def test_linear(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_bmm(dynamic):
+def test_bmm(dynamic: bool):
     """test graph builder for bmm"""
 
     class BMM(Module):
@@ -227,7 +227,7 @@ def test_bmm(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_baddbmm(dynamic):
+def test_baddbmm(dynamic: bool):
     """test graph builder for baddbmm"""
 
     class BAddBMM1(Module):
@@ -273,7 +273,7 @@ def test_baddbmm(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_relu(dynamic):
+def test_relu(dynamic: bool):
     """test graph builder for relu"""
 
     class ReLU(Module):
@@ -303,7 +303,7 @@ def test_relu(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_relu6(dynamic):
+def test_relu6(dynamic: bool):
     """test graph builder for relu6"""
 
     class ReLU6(Module):
@@ -328,7 +328,7 @@ def test_relu6(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_maxpool2d(dynamic):
+def test_maxpool2d(dynamic: bool):
     """test graph builder for maxpool2d"""
 
     class MaxPool2d(Module):
@@ -395,7 +395,7 @@ def test_maxpool2d(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_avgpool2d(dynamic):
+def test_avgpool2d(dynamic: bool):
     """test graph builder for avgpool2d"""
 
     class AvgPool2d(Module):
@@ -443,7 +443,7 @@ def test_avgpool2d(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_adaptive_avgpool2d(dynamic):
+def test_adaptive_avgpool2d(dynamic: bool):
     """test graph builder for adaptive_avgpool2d"""
 
     class AdaptiveAvgPool2d0(Module):
@@ -477,7 +477,7 @@ def test_adaptive_avgpool2d(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_flatten(dynamic):
+def test_flatten(dynamic: bool):
     """test graph builder for flatten"""
 
     class Flatten(Module):
@@ -507,7 +507,7 @@ def test_flatten(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_batchnorm2d(dynamic):
+def test_batchnorm2d(dynamic: bool):
     """test graph builder for batchnorm2d"""
 
     class BatchNorm2d(Module):
@@ -541,7 +541,7 @@ def test_batchnorm2d(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_embedding(dynamic):
+def test_embedding(dynamic: bool):
     """test graph builder for embedding"""
 
     class Embedding(Module):
@@ -579,7 +579,7 @@ def test_embedding(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_dropout(dynamic):
+def test_dropout(dynamic: bool):
     """test graph builder for dropout"""
 
     class Dropout1(Module):
@@ -609,7 +609,7 @@ def test_dropout(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_layernorm(dynamic):
+def test_layernorm(dynamic: bool):
     """test graph builder for layernorm"""
 
     class LayerNorm(Module):
@@ -638,7 +638,7 @@ def test_layernorm(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_functional_layernorm(dynamic):
+def test_functional_layernorm(dynamic: bool):
     """test graph builder for functional_layernorm"""
 
     class LayerNorm(Module):
@@ -670,7 +670,7 @@ def test_functional_layernorm(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_cross_entropy(dynamic):
+def test_cross_entropy(dynamic: bool):
     """test graph builder for cross_entropy"""
 
     class CrossEntropy1(Module):
@@ -735,7 +735,7 @@ def test_cross_entropy(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_functional_cross_entropy(dynamic):
+def test_functional_cross_entropy(dynamic: bool):
     """test graph builder for functional_cross_entropy"""
 
     class CrossEntropy(Module):
@@ -759,7 +759,7 @@ def test_functional_cross_entropy(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_silu(dynamic):
+def test_silu(dynamic: bool):
     """test graph builder for silu"""
 
     class SiLU(Module):
@@ -793,7 +793,7 @@ def test_silu(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_groupnorm(dynamic):
+def test_groupnorm(dynamic: bool):
     """test graph builder for groupnorm"""
 
     class GroupNorm(Module):
@@ -822,7 +822,7 @@ def test_groupnorm(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_softmax(dynamic):
+def test_softmax(dynamic: bool):
     """test graph builder for softmax"""
 
     class Softmax(Module):
@@ -851,7 +851,7 @@ def test_softmax(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_binary(dynamic):
+def test_binary(dynamic: bool):
     """test graph builder for binary"""
 
     bz = "bz" if dynamic else 1
@@ -1111,7 +1111,7 @@ def test_binary(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_size(dynamic):
+def test_size(dynamic: bool):
     """test graph builder for size"""
 
     class Size(Module):
@@ -1132,7 +1132,7 @@ def test_size(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_squeeze(dynamic):
+def test_squeeze(dynamic: bool):
     """test graph builder for squeeze"""
 
     class Squeeze1(Module):
@@ -1173,7 +1173,7 @@ def test_squeeze(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_unsqueeze(dynamic):
+def test_unsqueeze(dynamic: bool):
     """test graph builder for unsqueeze"""
 
     class Unsqueeze1(Module):
@@ -1223,7 +1223,7 @@ def test_unsqueeze(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_getattr(dynamic):
+def test_getattr(dynamic: bool):
     """test graph builder for getattr"""
 
     class GetAttr1(Module):
@@ -1244,7 +1244,7 @@ def test_getattr(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_getitem(dynamic):
+def test_getitem(dynamic: bool):
     """test graph builder for getitem"""
 
     class Slice1(Module):
@@ -1286,7 +1286,7 @@ def test_getitem(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_unary(dynamic):
+def test_unary(dynamic: bool):
     """test graph builder for unary"""
 
     bz = "bz" if dynamic else 1
@@ -1408,7 +1408,7 @@ def test_unary(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_gelu(dynamic):
+def test_gelu(dynamic: bool):
     """test graph builder for gelu"""
 
     class Gelu(Module):
@@ -1433,7 +1433,7 @@ def test_gelu(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_tanh(dynamic):
+def test_tanh(dynamic: bool):
     """test graph builder for tanh"""
 
     class Tanh(Module):
@@ -1458,7 +1458,7 @@ def test_tanh(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_clamp(dynamic):
+def test_clamp(dynamic: bool):
     """test graph builder for clamp"""
 
     class Clamp(Module):
@@ -1479,7 +1479,7 @@ def test_clamp(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_interpolate(dynamic):
+def test_interpolate(dynamic: bool):
     """test graph builder for interpolate"""
 
     class Interpolate(Module):
@@ -1504,7 +1504,7 @@ def test_interpolate(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_addmm(dynamic):
+def test_addmm(dynamic: bool):
     """test graph builder for addmm"""
 
     class Addmm(Module):
@@ -1531,7 +1531,7 @@ def test_addmm(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_split(dynamic):
+def test_split(dynamic: bool):
     """test graph builder for split"""
 
     class Split1(Module):
@@ -1574,7 +1574,7 @@ def test_split(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_unbind(dynamic):
+def test_unbind(dynamic: bool):
     """test graph builder for unbind"""
 
     class Unbind(Module):
@@ -1601,7 +1601,7 @@ def test_unbind(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_cumsum(dynamic):
+def test_cumsum(dynamic: bool):
     """test graph builder for cumsum"""
 
     class Cumsum(Module):
@@ -1622,7 +1622,7 @@ def test_cumsum(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_chunk(dynamic):
+def test_chunk(dynamic: bool):
     """test graph builder for chunk"""
 
     class Chunk(Module):
@@ -1649,7 +1649,7 @@ def test_chunk(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_inplace_fill(dynamic):
+def test_inplace_fill(dynamic: bool):
     """test graph builder for inplace_fill"""
 
     class InplaceFill(Module):
@@ -1734,7 +1734,7 @@ def test_tensor():
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_tril(dynamic):
+def test_tril(dynamic: bool):
     """test graph builder for tril"""
 
     class Tril(Module):
@@ -1762,7 +1762,7 @@ def test_tril(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_triu(dynamic):
+def test_triu(dynamic: bool):
     """test graph builder for triu"""
 
     class Triu(Module):
@@ -1807,7 +1807,7 @@ def test_new_ones():
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_expand(dynamic):
+def test_expand(dynamic: bool):
     """test graph builder for expand"""
 
     class Expand1(Module):
@@ -1835,7 +1835,7 @@ def test_expand(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_reduce(dynamic):
+def test_reduce(dynamic: bool):
     """test graph builder for reduce"""
 
     # sum
@@ -1857,7 +1857,7 @@ def test_reduce(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_datatype(dynamic):
+def test_datatype(dynamic: bool):
     """test graph builder for datatype"""
 
     bz = "bz" if dynamic else 1
@@ -1948,7 +1948,7 @@ def test_datatype(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_permute(dynamic):
+def test_permute(dynamic: bool):
     """test graph builder for permute"""
 
     class Permute(Module):
@@ -1979,7 +1979,7 @@ def test_permute(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_reshape(dynamic):
+def test_reshape(dynamic: bool):
     """test graph builder for reshape"""
 
     class Reshape(Module):
@@ -2007,7 +2007,7 @@ def test_reshape(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_transpose(dynamic):
+def test_transpose(dynamic: bool):
     """test graph builder for transpose"""
 
     class Transpose(Module):
@@ -2038,7 +2038,7 @@ def test_transpose(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_view(dynamic):
+def test_view(dynamic: bool):
     """test graph builder for view"""
 
     class View(Module):
@@ -2066,7 +2066,7 @@ def test_view(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_keep_params(dynamic):
+def test_keep_params(dynamic: bool):
     """test graph builder for keep_params"""
 
     class Conv2D1(Module):
@@ -2099,7 +2099,7 @@ def test_keep_params(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_unwrap_unit_return_tuple(dynamic):
+def test_unwrap_unit_return_tuple(dynamic: bool):
     """test graph builder for unwrap_unit_return_tuple"""
 
     class Identity(Module):
@@ -2119,7 +2119,7 @@ def test_unwrap_unit_return_tuple(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_no_bind_return_tuple(dynamic):
+def test_no_bind_return_tuple(dynamic: bool):
     """test graph builder for no_bind_return_tuple"""
 
     class Identity(Module):
@@ -2147,7 +2147,7 @@ def test_no_bind_return_tuple(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_argmax(dynamic):
+def test_argmax(dynamic: bool):
     """test graph builder for argmax"""
 
     class Argmax1(Module):
@@ -2178,7 +2178,7 @@ def test_argmax(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_argmin(dynamic):
+def test_argmin(dynamic: bool):
     """test graph builder for argmin"""
 
     class Argmin1(Module):
@@ -2209,7 +2209,7 @@ def test_argmin(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_to(dynamic):
+def test_to(dynamic: bool):
     """test graph builder for to"""
 
     class To1(Module):
@@ -2240,7 +2240,7 @@ def test_to(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_mean(dynamic):
+def test_mean(dynamic: bool):
     """test graph builder for mean"""
 
     class Mean(Module):
@@ -2271,7 +2271,7 @@ def test_mean(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_rsqrt(dynamic):
+def test_rsqrt(dynamic: bool):
     """test graph builder for rsqrt"""
 
     class Rsqrt(Module):
@@ -2291,7 +2291,7 @@ def test_rsqrt(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_neg(dynamic):
+def test_neg(dynamic: bool):
     """test graph builder for neg"""
 
     class Neg(Module):
@@ -2311,7 +2311,7 @@ def test_neg(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_max(dynamic):
+def test_max(dynamic: bool):
     """test graph builder for max"""
 
     class Max(Module):
@@ -2334,7 +2334,7 @@ def test_max(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_cat(dynamic):
+def test_cat(dynamic: bool):
     """test graph builder for cat"""
 
     class Cat1(Module):
@@ -2385,8 +2385,8 @@ def test_cat(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_stack(dynamic):
-    """test graph builder for stack"""
+def test_stack(dynamic: bool):
+    """Test graph builder for stack."""
 
     bz = "bz" if dynamic else 1
 
@@ -2408,23 +2408,23 @@ def test_stack(dynamic):
         ],
         "outputs": [
             {
-                "name": "reshape",
+                "name": "stack",
                 "shape": [3, bz, 3, 10, 10],
                 "dtype": "float32",
-                "layout": "" if dynamic else "EABCD",
+                "layout": "SABCD",
             }
         ],
-        "nodes": {"total": 5, "input": 3, "concat": 1, "reshape": 1},
+        "nodes": {"total": 4, "input": 3, "stack": 1},
     }
 
     if dynamic:
-        expected["prims"] = {"total": 3, "shape": 1, "Int": 1, "Mul": 1}
+        expected["prims"] = {"total": 1, "shape": 1}
 
     verify_model(Stack(), input_info, expected)
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_scatter(dynamic):
+def test_scatter(dynamic: bool):
     """test graph builder for scatter"""
 
     bz = "bz" if dynamic else 20
@@ -2473,7 +2473,7 @@ def test_scatter(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_masked_scatter(dynamic):
+def test_masked_scatter(dynamic: bool):
     """test graph builder for masked_scatter"""
 
     dim = "dim" if dynamic else 5
@@ -2558,7 +2558,7 @@ def test_masked_scatter(dynamic):
 
 
 @pytest.mark.parametrize("dynamic", [True, False])
-def test_attention(dynamic):
+def test_attention(dynamic: bool):
     """test graph builder for attention"""
 
     # pylint: disable=import-outside-toplevel
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 7c47832ea9..8db9684999 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -3060,6 +3060,70 @@ def test_squeeze():
     verify_model(Squeeze2(), example_args, {}, Expected2)
 
 
+def test_stack():
+    class Stack0(Module):
+        def forward(self, x, y):
+            return torch.stack((x, y))  # default dim=0
+
+    class Stack1(Module):
+        def forward(self, x, y):
+            return torch.stack((x, y), dim=1)
+
+    class Stack2(Module):
+        def forward(self, x, y):
+            return torch.stack((x, y), 1)  # positional dim
+
+    class Stack3(Module):
+        def forward(self, x, y):
+            return torch.stack((x, y), dim=-1)  # negative dim
+
+    @I.ir_module
+    class Expected0:
+        @R.function
+        def main(
+            inp_0: R.Tensor((2, 3), dtype="float32"),
+            inp_1: R.Tensor((2, 3), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((2, 2, 3), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((2, 2, 3), dtype="float32") = R.stack((inp_0, 
inp_1), axis=0)
+                gv: R.Tuple(R.Tensor((2, 2, 3), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected1:
+        @R.function
+        def main(
+            inp_0: R.Tensor((2, 3), dtype="float32"),
+            inp_1: R.Tensor((2, 3), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((2, 2, 3), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((2, 2, 3), dtype="float32") = R.stack((inp_0, 
inp_1), axis=1)
+                gv: R.Tuple(R.Tensor((2, 2, 3), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected3:
+        @R.function
+        def main(
+            inp_0: R.Tensor((2, 3), dtype="float32"),
+            inp_1: R.Tensor((2, 3), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((2, 3, 2), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((2, 3, 2), dtype="float32") = R.stack((inp_0, 
inp_1), axis=-1)
+                gv: R.Tuple(R.Tensor((2, 3, 2), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(2, 3, dtype=torch.float32), torch.randn(2, 3, 
dtype=torch.float32))
+
+    verify_model(Stack0(), example_args, {}, Expected0)
+    verify_model(Stack1(), example_args, {}, Expected1)
+    verify_model(Stack2(), example_args, {}, Expected1)
+    verify_model(Stack3(), example_args, {}, Expected3)
+
+
 def test_tile():
     class Tile1(Module):
         def forward(self, x):
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index a2169afd0f..4a2ca336e1 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -4004,13 +4004,10 @@ def test_stack():
             inp_2: R.Tensor((1, 3, 10, 10), dtype="float32"),
         ) -> R.Tensor((3, 1, 3, 10, 10), dtype="float32"):
             with R.dataflow():
-                lv: R.Tensor((3, 3, 10, 10), dtype="float32") = R.concat(
+                lv: R.Tensor((3, 1, 3, 10, 10), dtype="float32") = R.stack(
                     (inp_0, inp_1, inp_2), axis=0
                 )
-                lv1: R.Tensor((3, 1, 3, 10, 10), dtype="float32") = R.reshape(
-                    lv, R.shape([3, 1, 3, 10, 10])
-                )
-                gv: R.Tensor((3, 1, 3, 10, 10), dtype="float32") = lv1
+                gv: R.Tensor((3, 1, 3, 10, 10), dtype="float32") = lv
                 R.output(gv)
             return gv
 


Reply via email to