This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c00f52a70d [Relax][PyTorch] Add Stack Op Support for Exported Program
(#17819)
c00f52a70d is described below
commit c00f52a70d92f1d0b61a819efa3c906d00b17f7e
Author: Pratheesh-04-MCW <[email protected]>
AuthorDate: Fri Apr 18 07:33:12 2025 +0530
[Relax][PyTorch] Add Stack Op Support for Exported Program (#17819)
* add op support for stack
* trailing whitespace issue fixed
* fixed lint issues
* fixed whitespace issue
* fixed lint error
* fixing lint issues
* fixed whitespace issue
* add test script for fx_graph
* fix lint issues
* fixed unity check issues
* unity check
* fixed unity check issues
* lint issues
---
include/tvm/relax/attrs/manipulate.h | 13 ++
.../frontend/torch/base_fx_graph_translator.py | 16 +-
python/tvm/relax/op/__init__.py | 1 +
python/tvm/relax/op/manipulate.py | 24 +++
python/tvm/relax/op/op_attrs.py | 5 +
.../tvm/relax/transform/legalize_ops/manipulate.py | 22 +++
python/tvm/script/ir_builder/relax/ir.py | 2 +
python/tvm/topi/transform.py | 16 +-
src/contrib/msc/framework/torch/torch_opcode.cc | 8 +
src/relax/op/tensor/manipulate.cc | 209 +++++++++++++++++++++
src/relax/op/tensor/manipulate.h | 8 +-
tests/python/contrib/test_msc/test_graph_build.py | 134 ++++++-------
.../relax/test_frontend_from_exported_program.py | 64 +++++++
tests/python/relax/test_frontend_from_fx.py | 7 +-
14 files changed, 435 insertions(+), 94 deletions(-)
diff --git a/include/tvm/relax/attrs/manipulate.h
b/include/tvm/relax/attrs/manipulate.h
index e6c16d233a..67f99d9b41 100644
--- a/include/tvm/relax/attrs/manipulate.h
+++ b/include/tvm/relax/attrs/manipulate.h
@@ -119,6 +119,19 @@ struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
}
}; // struct SqueezeAttrs
+/*! \brief Attributes used in stack operators */
+struct StackAttrs : public tvm::AttrsNode<StackAttrs> {
+ Optional<Integer> axis;
+
+ TVM_DECLARE_ATTRS(StackAttrs, "relax.attrs.StackAttrs") {
+ TVM_ATTR_FIELD(axis).describe(
+ "The axis along which to stack the input tensors. "
+ "The axis will be inserted at this position in the output, "
+ "so it must be in range [-ndim-1, ndim] where ndim is the "
+ "number of dimensions of the input tensors.");
+ }
+}; // struct StackAttrs
+
/*! \brief Attributes used in repeat operators */
struct RepeatAttrs : public tvm::AttrsNode<RepeatAttrs> {
int repeats;
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index a9bee11fc8..7a971c00cd 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1194,21 +1194,9 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
def _stack(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
+ tensor_list = args[0]
axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
- in_args = args[0]
- assert all(
- a.struct_info.shape[axis] == in_args[0].struct_info.shape[axis]
for a in in_args[1:]
- ), "Expect all dim at {} to be the same, get {}".format(
- axis, [a.struct_info.shape for a in args]
- )
- cat = self.block_builder.emit(relax.op.concat(in_args, axis=axis))
- s_shape = []
- for idx, s in enumerate(cat.struct_info.shape):
- if idx == axis:
- s_shape.extend([len(in_args),
in_args[0].struct_info.shape[axis]])
- else:
- s_shape.append(s)
- return self.block_builder.emit(relax.op.reshape(cat, s_shape))
+ return self.block_builder.emit(relax.op.stack(tensor_list, axis=axis))
def _take(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index ddfdfc2b05..3145a7c292 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -104,6 +104,7 @@ from .manipulate import (
scatter_nd,
split,
squeeze,
+ stack,
tile,
)
from .mask import masked_fill
diff --git a/python/tvm/relax/op/manipulate.py
b/python/tvm/relax/op/manipulate.py
index 0f6e537ab3..725e58bd01 100644
--- a/python/tvm/relax/op/manipulate.py
+++ b/python/tvm/relax/op/manipulate.py
@@ -279,6 +279,30 @@ def squeeze(x: Expr, axis: Optional[Union[int, List[int]]]
= None) -> Expr:
return _ffi_api.squeeze(x, axis) # type: ignore
+def stack(tensors: Union[Expr, List[Expr]], axis: int = 0) -> Expr:
+ """Stack the input tensors along a new axis.
+
+ Parameters
+ ----------
+ tensors : Union[relax.Expr, List[relax.Expr]]
+ An Expr in Tuple type, containing the tensors to be stacked,
+ or a list of Tensors. All input tensors must have the same shape.
+
+ axis : int
+ The axis in the resulting tensor along which the input tensors will be
stacked.
+ Negative values wrap around. Default is 0.
+
+ Returns
+ -------
+ result: relax.Expr
+ The stacked tensor with an additional dimension compared to the input
tensors.
+
+ """
+ if isinstance(tensors, (list, tuple)):
+ tensors = RxTuple(tensors)
+ return _ffi_api.stack(tensors, axis) # type: ignore
+
+
def collapse_sum_like(data: Expr, collapse_target: Expr) -> Expr:
"""Return a summation of data to the shape of collapse_target.
diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py
index 4658950f51..fda4258a09 100644
--- a/python/tvm/relax/op/op_attrs.py
+++ b/python/tvm/relax/op/op_attrs.py
@@ -139,6 +139,11 @@ class SqueezeAttrs(Attrs):
"""Attributes for squeeze operator"""
+@tvm._ffi.register_object("relax.attrs.StackAttrs")
+class StackAttrs(Attrs):
+ """Attributes for concat operator"""
+
+
@tvm._ffi.register_object("relax.attrs.LayoutTransformAttrs")
class LayoutTransformAttrs(Attrs):
"""Attributes used in layout_transform operator"""
diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py
b/python/tvm/relax/transform/legalize_ops/manipulate.py
index 662d4e946b..a481d7af95 100644
--- a/python/tvm/relax/transform/legalize_ops/manipulate.py
+++ b/python/tvm/relax/transform/legalize_ops/manipulate.py
@@ -118,6 +118,28 @@ def _squeeze(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(topi.squeeze, call.args[0], call.attrs.axis)
+@register_legalize("relax.stack")
+def _stack(bb: BlockBuilder, call: Call) -> Expr:
+ t = call.args[0]
+ n_field = len(t.struct_info.fields)
+
+ # Follow bindings to find the actual tuple
+ while isinstance(t, Var):
+ binding = bb.lookup_binding(t)
+ if not isinstance(binding, (Tuple, Var)):
+ break
+ t = binding
+
+ assert isinstance(t, (Tuple, Var))
+
+ # Extract fields from either Tuple or bound Var
+ fields = (
+ t.fields if isinstance(t, Tuple) else [bb.emit(TupleGetItem(t, i)) for
i in range(n_field)]
+ )
+
+ return bb.call_te(topi.stack, fields, 0 if call.attrs.axis is None else
call.attrs.axis.value)
+
+
@register_legalize("relax.repeat")
def _repeat(bb: BlockBuilder, call: Call) -> Expr:
def te_repeat(data: te.Tensor, repeats: IntImm, axis: Optional[IntImm]):
diff --git a/python/tvm/script/ir_builder/relax/ir.py
b/python/tvm/script/ir_builder/relax/ir.py
index 6fa3cc61cb..79b1884aac 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -158,6 +158,7 @@ from tvm.relax.op import (
sqrt,
square,
squeeze,
+ stack,
std,
strided_slice,
subtract,
@@ -851,6 +852,7 @@ __all__ = [
"square",
"squeeze",
"sqrt",
+ "stack",
"stop_lift_params",
"str",
"strided_slice",
diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py
index b8605aa58a..37743e97a3 100644
--- a/python/tvm/topi/transform.py
+++ b/python/tvm/topi/transform.py
@@ -403,23 +403,25 @@ def concatenate(a_tuple, axis=0):
return cpp.concatenate(a_tuple, axis)
-def stack(a, axis):
- """Repeats the whole array multiple times.
+def stack(tensors, axis=0):
+ """Join a sequence of tensors along a new axis.
Parameters
----------
- a : tvm.te.Tensor
- The tensor to be stacked.
+ tensors : tuple or list of tvm.te.Tensor
+ The tensors to be stacked. All tensors must have the same shape.
axis : int, optional
- The axis in the result array along which the input arrays are stacked.
-
+ The axis in the resulting tensor along which the input tensors will be
stacked.
+ Negative values wrap around. Default is 0.
Returns
-------
ret : tvm.te.Tensor
+ The stacked tensor with an additional dimension compared to the input
tensors.
+
"""
- return cpp.stack(a, axis)
+ return cpp.stack(tensors, axis)
def split(ary, indices_or_sections, axis=0):
diff --git a/src/contrib/msc/framework/torch/torch_opcode.cc
b/src/contrib/msc/framework/torch/torch_opcode.cc
index f5784efe3d..9e3652f041 100644
--- a/src/contrib/msc/framework/torch/torch_opcode.cc
+++ b/src/contrib/msc/framework/torch/torch_opcode.cc
@@ -209,6 +209,13 @@ class TorchConcatCodeGen : public TorchOpCode {
void CodeGenForward() final {
stack_.op_call().op_inputs_arg().op_arg<int>("axis", "dim"); }
};
+class TorchStackCodeGen : public TorchOpCode {
+ TORCH_OP_CODEGEN_METHODS(TorchStackCodeGen);
+
+ protected:
+ void CodeGenForward() final {
stack_.op_call().op_inputs_arg().op_arg<int>("axis", "dim"); }
+};
+
class TorchConstantCodeGen : public TorchOpCode {
TORCH_OP_CODEGEN_METHODS(TorchConstantCodeGen);
@@ -789,6 +796,7 @@ const std::shared_ptr<std::unordered_map<String,
std::shared_ptr<TorchOpCode>>>
std::make_shared<TorchScatterElementsCodeGen>("",
"torch.scatter"));
map->emplace("scatter_nd", std::make_shared<TorchScatterNDCodeGen>("", ""));
map->emplace("split", std::make_shared<TorchSplitCodeGen>("",
"torch.split"));
+ map->emplace("stack", std::make_shared<TorchStackCodeGen>("",
"torch.stack"));
map->emplace("strided_slice", std::make_shared<TorchStridedSliceCodeGen>("",
""));
map->emplace("take", std::make_shared<TorchTakeCodeGen>("", ""));
diff --git a/src/relax/op/tensor/manipulate.cc
b/src/relax/op/tensor/manipulate.cc
index cb738db363..4abfe01387 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -1193,6 +1193,215 @@ void CheckCollapseShape(const Call& call, const
BlockBuilder& ctx,
}
}
+/* relax.stack */
+TVM_REGISTER_NODE_TYPE(StackAttrs);
+
+Expr stack(Expr tensors, Optional<Integer> axis) {
+ ObjectPtr<StackAttrs> attrs = make_object<StackAttrs>();
+ attrs->axis = std::move(axis);
+
+ static const Op& op = Op::Get("relax.stack");
+ return Call(op, {std::move(tensors)}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.stack").set_body_typed(stack);
+
+Optional<Array<PrimExpr>> CheckStackOutputShape(const Call& call, const
BlockBuilder& ctx,
+ const
std::vector<Array<PrimExpr>>& shape_values,
+ int axis) {
+ bool shape_unknown = false;
+ arith::Analyzer* analyzer = ctx->GetAnalyzer();
+
+ // Stack requires all input tensors to have identical shapes
+ for (int d = 0; d < static_cast<int>(shape_values[0].size()); ++d) {
+ for (int i = 1; i < static_cast<int>(shape_values.size()); ++i) {
+ if (analyzer->CanProve(shape_values[i][d] != shape_values[0][d])) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "Stack expects all input tensors to have identical
shapes. "
+ << "Dimension " << d << " differs between tensors: "
<< shape_values[0][d]
+ << " vs " << shape_values[i][d]);
+ } else if (!analyzer->CanProveEqual(shape_values[i][d],
shape_values[0][d])) {
+ shape_unknown = true;
+ }
+ }
+ }
+
+ if (shape_unknown) {
+ return NullOpt;
+ }
+
+ // Insert new dimension at axis position
+ Array<PrimExpr> output_shape;
+ for (int i = 0; i < axis; ++i) {
+ output_shape.push_back(shape_values[0][i]);
+ }
+ output_shape.push_back(IntImm(DataType::Int(64), shape_values.size())); //
Stack dimension
+ for (int i = axis; i < static_cast<int>(shape_values[0].size()); ++i) {
+ output_shape.push_back(shape_values[0][i]);
+ }
+ return output_shape;
+}
+
+StructInfo InferStructInfoStack(const Call& call, const BlockBuilder& ctx) {
+ if (call->args.size() != 1) {
+ ctx->ReportFatal(Diagnostic::Error(call) << "Stack op should have 1
argument");
+ }
+
+ Array<TensorStructInfo> tensor_sinfo = GetTensorStructInfoFromTuple(call,
ctx, call->args[0]);
+ if (tensor_sinfo.empty()) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "Stack op expects at least one tensor in the input
Tuple. "
+ << "However, the given input Tuple is empty.");
+ }
+
+ const auto* attrs = call->attrs.as<StackAttrs>();
+ ICHECK(attrs != nullptr) << "Stack must have StackAttrs";
+
+ // Default axis is 0 if not specified
+ int output_ndim = tensor_sinfo[0]->ndim + 1; // Stack adds one dimension
+ DataType output_dtype = DataType::Void();
+ Optional<VDevice> vdev = NullOpt;
+ bool shape_unknown = false;
+ bool is_void_dtype = false;
+ bool vdevice_unknown = false;
+ std::vector<Array<PrimExpr>> shape_values;
+ shape_values.reserve(tensor_sinfo.size());
+
+ for (TensorStructInfo sinfo : tensor_sinfo) {
+ // Check dtype consistency
+ if (sinfo->dtype.is_void()) {
+ is_void_dtype = true;
+ } else if (output_dtype.is_void()) {
+ output_dtype = sinfo->dtype;
+ } else if (sinfo->dtype != output_dtype) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "Stack expects all input tensors to have the same
dtype. "
+ << "Found " << output_dtype << " and " << sinfo->dtype);
+ }
+
+ // Check ndim consistency
+ if (sinfo->ndim != kUnknownNDim && sinfo->ndim != tensor_sinfo[0]->ndim) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "Stack expects all input tensors to have same ndim. "
+ << "Found " << tensor_sinfo[0]->ndim << " and " <<
sinfo->ndim);
+ }
+
+ // Check virtual device consistency
+ if (!vdevice_unknown) {
+ if (sinfo->vdevice.defined()) {
+ if (!vdev.defined()) {
+ vdev = sinfo->vdevice.value();
+ } else if (sinfo->vdevice.value() != vdev) {
+ vdevice_unknown = true;
+ }
+ }
+ }
+
+ // Collect shape information
+ const auto* shape_expr = sinfo->shape.as<ShapeExprNode>();
+ if (shape_expr != nullptr) {
+ shape_values.push_back(shape_expr->values);
+ continue;
+ }
+ shape_unknown = true;
+
+ if (!sinfo->shape.defined()) continue;
+ ShapeStructInfo shape_sinfo =
Downcast<ShapeStructInfo>(sinfo->shape.value()->struct_info_);
+ if (shape_sinfo->values.defined()) {
+ shape_values.push_back(shape_sinfo->values.value());
+ }
+ }
+
+ if (is_void_dtype) output_dtype = DataType::Void();
+ if (vdevice_unknown) vdev = NullOpt;
+
+ // Normalize axis (default to 0 if not specified)
+ int axis =
+ attrs->axis.defined() ? NormalizeAxis(call, ctx, output_ndim,
attrs->axis.value()->value) : 0;
+
+ // Single tensor case
+ if (tensor_sinfo.size() == 1) {
+ if (shape_values.empty()) {
+ if (!vdevice_unknown) {
+ return TensorStructInfo(output_dtype, output_ndim, vdev);
+ }
+ return TensorStructInfo(output_dtype, output_ndim);
+ }
+ Array<PrimExpr> output_shape;
+ for (int i = 0; i < axis; ++i) {
+ output_shape.push_back(shape_values[0][i]);
+ }
+ output_shape.push_back(1); // Stack size 1
+ for (int i = axis; i < static_cast<int>(shape_values[0].size()); ++i) {
+ output_shape.push_back(shape_values[0][i]);
+ }
+ if (!vdevice_unknown) {
+ return TensorStructInfo(ShapeExpr(output_shape), output_dtype, vdev);
+ }
+ return TensorStructInfo(ShapeExpr(output_shape), output_dtype);
+ }
+
+ // Multiple tensors case
+ if (shape_values.empty()) {
+ if (!vdevice_unknown) {
+ return TensorStructInfo(output_dtype, output_ndim, vdev);
+ }
+ return TensorStructInfo(output_dtype, output_ndim);
+ }
+
+ Optional<Array<PrimExpr>> output_shape = CheckStackOutputShape(call, ctx,
shape_values, axis);
+ if (shape_unknown || !output_shape.defined()) {
+ if (!vdevice_unknown) {
+ return TensorStructInfo(output_dtype, output_ndim, vdev);
+ }
+ return TensorStructInfo(output_dtype, output_ndim);
+ } else {
+ if (!vdevice_unknown) {
+ return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype,
vdev);
+ }
+ return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype);
+ }
+}
+
+InferLayoutOutput InferLayoutStack(const Call& call,
+ const Map<String, Array<String>>&
desired_layouts,
+ const VarLayoutMap& var_layout_map) {
+ ICHECK(NoDesiredLayout(call, desired_layouts));
+
+ const auto* attrs = call->attrs.as<StackAttrs>();
+ ICHECK(attrs != nullptr) << "Invalid Call";
+ NLayout nlayout = GetNLayout(var_layout_map, call->args[0]);
+ ICHECK(nlayout.IsNested());
+ ICHECK(nlayout.NestedArray()[0].IsLeaf());
+
+ int n_tensor = nlayout.NestedArray().size();
+ LayoutDecision layout = nlayout.NestedArray()[0].LeafValue();
+ Array<NLayout> input_layouts, output_layouts;
+ for (int i = 0; i < n_tensor; ++i) {
+ input_layouts.push_back(layout);
+ }
+
+ // For stack, we need to adjust the output layout by inserting a new axis
+ std::string layout_str = layout->layout.name();
+ int axis = attrs->axis.defined() ? attrs->axis.value()->value : 0;
+ layout_str.insert(static_cast<size_t>(axis), "S"); // Add stack dimension
+ Layout output_layout = Layout(layout_str);
+ output_layouts.push_back(LayoutDecision(output_layout));
+
+ ObjectPtr<StackAttrs> new_attrs = make_object<StackAttrs>(*attrs);
+ new_attrs->axis = Integer(FindAxis(layout->layout, axis));
+ return InferLayoutOutput({NLayout(input_layouts)}, output_layouts,
Attrs(new_attrs));
+}
+
+TVM_REGISTER_OP("relax.stack")
+ .set_attrs_type<StackAttrs>()
+ .set_num_inputs(1)
+ .add_argument("tensors", "Tuple of Tensors", "The input list of tensors to
stack")
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoStack)
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutStack)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow)
+ .set_attr<Bool>("FPurity", Bool(true));
+
/* relax.collapse_sum_like */
Expr collapse_sum_like(Expr data, Expr collapse_target) {
static const Op& op = Op::Get("relax.collapse_sum_like");
diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h
index 1a0c7ddbc7..7e5de217bc 100644
--- a/src/relax/op/tensor/manipulate.h
+++ b/src/relax/op/tensor/manipulate.h
@@ -117,7 +117,13 @@ Expr split(Expr x, Variant<IntImm, Array<IntImm>>
indices_or_sections, int axis)
* \return The squeezed result.
*/
Expr squeeze(Expr x, Optional<Array<Integer>> axis);
-
+/*!
+ * \brief Stack tensors along the specified axis.
+ * \param tensors The input tensors to be stacked.
+ * \param axis The axis along which the tensors will be stacked.
+ * \return The stacked result.
+ */
+Expr stack(Expr tensors, Optional<Integer> axis);
/*!
* \brief Return a summation of data to the shape of collapse_target.
* For details, please see the operator `relax.collapse_sum_to`.
diff --git a/tests/python/contrib/test_msc/test_graph_build.py
b/tests/python/contrib/test_msc/test_graph_build.py
index 5396b5e106..328fbf456e 100644
--- a/tests/python/contrib/test_msc/test_graph_build.py
+++ b/tests/python/contrib/test_msc/test_graph_build.py
@@ -37,7 +37,7 @@ def verify_model(torch_model, input_info, expected):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_conv1d(dynamic):
+def test_conv1d(dynamic: bool):
"""test graph builder for conv1d"""
class Conv1D1(Module):
@@ -77,7 +77,7 @@ def test_conv1d(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_conv2d(dynamic):
+def test_conv2d(dynamic: bool):
"""test graph builder for conv2d"""
class Conv2D1(Module):
@@ -130,7 +130,7 @@ def test_conv2d(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_linear(dynamic):
+def test_linear(dynamic: bool):
"""test graph builder for linear"""
class Dense1(Module):
@@ -201,7 +201,7 @@ def test_linear(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_bmm(dynamic):
+def test_bmm(dynamic: bool):
"""test graph builder for bmm"""
class BMM(Module):
@@ -227,7 +227,7 @@ def test_bmm(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_baddbmm(dynamic):
+def test_baddbmm(dynamic: bool):
"""test graph builder for baddbmm"""
class BAddBMM1(Module):
@@ -273,7 +273,7 @@ def test_baddbmm(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_relu(dynamic):
+def test_relu(dynamic: bool):
"""test graph builder for relu"""
class ReLU(Module):
@@ -303,7 +303,7 @@ def test_relu(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_relu6(dynamic):
+def test_relu6(dynamic: bool):
"""test graph builder for relu6"""
class ReLU6(Module):
@@ -328,7 +328,7 @@ def test_relu6(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_maxpool2d(dynamic):
+def test_maxpool2d(dynamic: bool):
"""test graph builder for maxpool2d"""
class MaxPool2d(Module):
@@ -395,7 +395,7 @@ def test_maxpool2d(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_avgpool2d(dynamic):
+def test_avgpool2d(dynamic: bool):
"""test graph builder for avgpool2d"""
class AvgPool2d(Module):
@@ -443,7 +443,7 @@ def test_avgpool2d(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_adaptive_avgpool2d(dynamic):
+def test_adaptive_avgpool2d(dynamic: bool):
"""test graph builder for adaptive_avgpool2d"""
class AdaptiveAvgPool2d0(Module):
@@ -477,7 +477,7 @@ def test_adaptive_avgpool2d(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_flatten(dynamic):
+def test_flatten(dynamic: bool):
"""test graph builder for flatten"""
class Flatten(Module):
@@ -507,7 +507,7 @@ def test_flatten(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_batchnorm2d(dynamic):
+def test_batchnorm2d(dynamic: bool):
"""test graph builder for batchnorm2d"""
class BatchNorm2d(Module):
@@ -541,7 +541,7 @@ def test_batchnorm2d(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_embedding(dynamic):
+def test_embedding(dynamic: bool):
"""test graph builder for embedding"""
class Embedding(Module):
@@ -579,7 +579,7 @@ def test_embedding(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_dropout(dynamic):
+def test_dropout(dynamic: bool):
"""test graph builder for dropout"""
class Dropout1(Module):
@@ -609,7 +609,7 @@ def test_dropout(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_layernorm(dynamic):
+def test_layernorm(dynamic: bool):
"""test graph builder for layernorm"""
class LayerNorm(Module):
@@ -638,7 +638,7 @@ def test_layernorm(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_functional_layernorm(dynamic):
+def test_functional_layernorm(dynamic: bool):
"""test graph builder for functional_layernorm"""
class LayerNorm(Module):
@@ -670,7 +670,7 @@ def test_functional_layernorm(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_cross_entropy(dynamic):
+def test_cross_entropy(dynamic: bool):
"""test graph builder for cross_entropy"""
class CrossEntropy1(Module):
@@ -735,7 +735,7 @@ def test_cross_entropy(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_functional_cross_entropy(dynamic):
+def test_functional_cross_entropy(dynamic: bool):
"""test graph builder for functional_cross_entropy"""
class CrossEntropy(Module):
@@ -759,7 +759,7 @@ def test_functional_cross_entropy(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_silu(dynamic):
+def test_silu(dynamic: bool):
"""test graph builder for silu"""
class SiLU(Module):
@@ -793,7 +793,7 @@ def test_silu(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_groupnorm(dynamic):
+def test_groupnorm(dynamic: bool):
"""test graph builder for groupnorm"""
class GroupNorm(Module):
@@ -822,7 +822,7 @@ def test_groupnorm(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_softmax(dynamic):
+def test_softmax(dynamic: bool):
"""test graph builder for softmax"""
class Softmax(Module):
@@ -851,7 +851,7 @@ def test_softmax(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_binary(dynamic):
+def test_binary(dynamic: bool):
"""test graph builder for binary"""
bz = "bz" if dynamic else 1
@@ -1111,7 +1111,7 @@ def test_binary(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_size(dynamic):
+def test_size(dynamic: bool):
"""test graph builder for size"""
class Size(Module):
@@ -1132,7 +1132,7 @@ def test_size(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_squeeze(dynamic):
+def test_squeeze(dynamic: bool):
"""test graph builder for squeeze"""
class Squeeze1(Module):
@@ -1173,7 +1173,7 @@ def test_squeeze(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_unsqueeze(dynamic):
+def test_unsqueeze(dynamic: bool):
"""test graph builder for unsqueeze"""
class Unsqueeze1(Module):
@@ -1223,7 +1223,7 @@ def test_unsqueeze(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_getattr(dynamic):
+def test_getattr(dynamic: bool):
"""test graph builder for getattr"""
class GetAttr1(Module):
@@ -1244,7 +1244,7 @@ def test_getattr(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_getitem(dynamic):
+def test_getitem(dynamic: bool):
"""test graph builder for getitem"""
class Slice1(Module):
@@ -1286,7 +1286,7 @@ def test_getitem(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_unary(dynamic):
+def test_unary(dynamic: bool):
"""test graph builder for unary"""
bz = "bz" if dynamic else 1
@@ -1408,7 +1408,7 @@ def test_unary(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_gelu(dynamic):
+def test_gelu(dynamic: bool):
"""test graph builder for gelu"""
class Gelu(Module):
@@ -1433,7 +1433,7 @@ def test_gelu(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_tanh(dynamic):
+def test_tanh(dynamic: bool):
"""test graph builder for tanh"""
class Tanh(Module):
@@ -1458,7 +1458,7 @@ def test_tanh(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_clamp(dynamic):
+def test_clamp(dynamic: bool):
"""test graph builder for clamp"""
class Clamp(Module):
@@ -1479,7 +1479,7 @@ def test_clamp(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_interpolate(dynamic):
+def test_interpolate(dynamic: bool):
"""test graph builder for interpolate"""
class Interpolate(Module):
@@ -1504,7 +1504,7 @@ def test_interpolate(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_addmm(dynamic):
+def test_addmm(dynamic: bool):
"""test graph builder for addmm"""
class Addmm(Module):
@@ -1531,7 +1531,7 @@ def test_addmm(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_split(dynamic):
+def test_split(dynamic: bool):
"""test graph builder for split"""
class Split1(Module):
@@ -1574,7 +1574,7 @@ def test_split(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_unbind(dynamic):
+def test_unbind(dynamic: bool):
"""test graph builder for unbind"""
class Unbind(Module):
@@ -1601,7 +1601,7 @@ def test_unbind(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_cumsum(dynamic):
+def test_cumsum(dynamic: bool):
"""test graph builder for cumsum"""
class Cumsum(Module):
@@ -1622,7 +1622,7 @@ def test_cumsum(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_chunk(dynamic):
+def test_chunk(dynamic: bool):
"""test graph builder for chunk"""
class Chunk(Module):
@@ -1649,7 +1649,7 @@ def test_chunk(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_inplace_fill(dynamic):
+def test_inplace_fill(dynamic: bool):
"""test graph builder for inplace_fill"""
class InplaceFill(Module):
@@ -1734,7 +1734,7 @@ def test_tensor():
@pytest.mark.parametrize("dynamic", [True, False])
-def test_tril(dynamic):
+def test_tril(dynamic: bool):
"""test graph builder for tril"""
class Tril(Module):
@@ -1762,7 +1762,7 @@ def test_tril(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_triu(dynamic):
+def test_triu(dynamic: bool):
"""test graph builder for triu"""
class Triu(Module):
@@ -1807,7 +1807,7 @@ def test_new_ones():
@pytest.mark.parametrize("dynamic", [True, False])
-def test_expand(dynamic):
+def test_expand(dynamic: bool):
"""test graph builder for expand"""
class Expand1(Module):
@@ -1835,7 +1835,7 @@ def test_expand(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_reduce(dynamic):
+def test_reduce(dynamic: bool):
"""test graph builder for reduce"""
# sum
@@ -1857,7 +1857,7 @@ def test_reduce(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_datatype(dynamic):
+def test_datatype(dynamic: bool):
"""test graph builder for datatype"""
bz = "bz" if dynamic else 1
@@ -1948,7 +1948,7 @@ def test_datatype(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_permute(dynamic):
+def test_permute(dynamic: bool):
"""test graph builder for permute"""
class Permute(Module):
@@ -1979,7 +1979,7 @@ def test_permute(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_reshape(dynamic):
+def test_reshape(dynamic: bool):
"""test graph builder for reshape"""
class Reshape(Module):
@@ -2007,7 +2007,7 @@ def test_reshape(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_transpose(dynamic):
+def test_transpose(dynamic: bool):
"""test graph builder for transpose"""
class Transpose(Module):
@@ -2038,7 +2038,7 @@ def test_transpose(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_view(dynamic):
+def test_view(dynamic: bool):
"""test graph builder for view"""
class View(Module):
@@ -2066,7 +2066,7 @@ def test_view(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_keep_params(dynamic):
+def test_keep_params(dynamic: bool):
"""test graph builder for keep_params"""
class Conv2D1(Module):
@@ -2099,7 +2099,7 @@ def test_keep_params(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_unwrap_unit_return_tuple(dynamic):
+def test_unwrap_unit_return_tuple(dynamic: bool):
"""test graph builder for unwrap_unit_return_tuple"""
class Identity(Module):
@@ -2119,7 +2119,7 @@ def test_unwrap_unit_return_tuple(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_no_bind_return_tuple(dynamic):
+def test_no_bind_return_tuple(dynamic: bool):
"""test graph builder for no_bind_return_tuple"""
class Identity(Module):
@@ -2147,7 +2147,7 @@ def test_no_bind_return_tuple(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_argmax(dynamic):
+def test_argmax(dynamic: bool):
"""test graph builder for argmax"""
class Argmax1(Module):
@@ -2178,7 +2178,7 @@ def test_argmax(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_argmin(dynamic):
+def test_argmin(dynamic: bool):
"""test graph builder for argmin"""
class Argmin1(Module):
@@ -2209,7 +2209,7 @@ def test_argmin(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_to(dynamic):
+def test_to(dynamic: bool):
"""test graph builder for to"""
class To1(Module):
@@ -2240,7 +2240,7 @@ def test_to(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_mean(dynamic):
+def test_mean(dynamic: bool):
"""test graph builder for mean"""
class Mean(Module):
@@ -2271,7 +2271,7 @@ def test_mean(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_rsqrt(dynamic):
+def test_rsqrt(dynamic: bool):
"""test graph builder for rsqrt"""
class Rsqrt(Module):
@@ -2291,7 +2291,7 @@ def test_rsqrt(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_neg(dynamic):
+def test_neg(dynamic: bool):
"""test graph builder for neg"""
class Neg(Module):
@@ -2311,7 +2311,7 @@ def test_neg(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_max(dynamic):
+def test_max(dynamic: bool):
"""test graph builder for max"""
class Max(Module):
@@ -2334,7 +2334,7 @@ def test_max(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_cat(dynamic):
+def test_cat(dynamic: bool):
"""test graph builder for cat"""
class Cat1(Module):
@@ -2385,8 +2385,8 @@ def test_cat(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_stack(dynamic):
- """test graph builder for stack"""
+def test_stack(dynamic: bool):
+ """Test graph builder for stack."""
bz = "bz" if dynamic else 1
@@ -2408,23 +2408,23 @@ def test_stack(dynamic):
],
"outputs": [
{
- "name": "reshape",
+ "name": "stack",
"shape": [3, bz, 3, 10, 10],
"dtype": "float32",
- "layout": "" if dynamic else "EABCD",
+ "layout": "SABCD",
}
],
- "nodes": {"total": 5, "input": 3, "concat": 1, "reshape": 1},
+ "nodes": {"total": 4, "input": 3, "stack": 1},
}
if dynamic:
- expected["prims"] = {"total": 3, "shape": 1, "Int": 1, "Mul": 1}
+ expected["prims"] = {"total": 1, "shape": 1}
verify_model(Stack(), input_info, expected)
@pytest.mark.parametrize("dynamic", [True, False])
-def test_scatter(dynamic):
+def test_scatter(dynamic: bool):
"""test graph builder for scatter"""
bz = "bz" if dynamic else 20
@@ -2473,7 +2473,7 @@ def test_scatter(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_masked_scatter(dynamic):
+def test_masked_scatter(dynamic: bool):
"""test graph builder for masked_scatter"""
dim = "dim" if dynamic else 5
@@ -2558,7 +2558,7 @@ def test_masked_scatter(dynamic):
@pytest.mark.parametrize("dynamic", [True, False])
-def test_attention(dynamic):
+def test_attention(dynamic: bool):
"""test graph builder for attention"""
# pylint: disable=import-outside-toplevel
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 7c47832ea9..8db9684999 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -3060,6 +3060,70 @@ def test_squeeze():
verify_model(Squeeze2(), example_args, {}, Expected2)
+def test_stack():
+ class Stack0(Module):
+ def forward(self, x, y):
+ return torch.stack((x, y)) # default dim=0
+
+ class Stack1(Module):
+ def forward(self, x, y):
+ return torch.stack((x, y), dim=1)
+
+ class Stack2(Module):
+ def forward(self, x, y):
+ return torch.stack((x, y), 1) # positional dim
+
+ class Stack3(Module):
+ def forward(self, x, y):
+ return torch.stack((x, y), dim=-1) # negative dim
+
+ @I.ir_module
+ class Expected0:
+ @R.function
+ def main(
+ inp_0: R.Tensor((2, 3), dtype="float32"),
+ inp_1: R.Tensor((2, 3), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((2, 2, 3), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((2, 2, 3), dtype="float32") = R.stack((inp_0,
inp_1), axis=0)
+ gv: R.Tuple(R.Tensor((2, 2, 3), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected1:
+ @R.function
+ def main(
+ inp_0: R.Tensor((2, 3), dtype="float32"),
+ inp_1: R.Tensor((2, 3), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((2, 2, 3), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((2, 2, 3), dtype="float32") = R.stack((inp_0,
inp_1), axis=1)
+ gv: R.Tuple(R.Tensor((2, 2, 3), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected3:
+ @R.function
+ def main(
+ inp_0: R.Tensor((2, 3), dtype="float32"),
+ inp_1: R.Tensor((2, 3), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((2, 3, 2), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((2, 3, 2), dtype="float32") = R.stack((inp_0,
inp_1), axis=-1)
+ gv: R.Tuple(R.Tensor((2, 3, 2), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(2, 3, dtype=torch.float32), torch.randn(2, 3,
dtype=torch.float32))
+
+ verify_model(Stack0(), example_args, {}, Expected0)
+ verify_model(Stack1(), example_args, {}, Expected1)
+ verify_model(Stack2(), example_args, {}, Expected1)
+ verify_model(Stack3(), example_args, {}, Expected3)
+
+
def test_tile():
class Tile1(Module):
def forward(self, x):
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index a2169afd0f..4a2ca336e1 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -4004,13 +4004,10 @@ def test_stack():
inp_2: R.Tensor((1, 3, 10, 10), dtype="float32"),
) -> R.Tensor((3, 1, 3, 10, 10), dtype="float32"):
with R.dataflow():
- lv: R.Tensor((3, 3, 10, 10), dtype="float32") = R.concat(
+ lv: R.Tensor((3, 1, 3, 10, 10), dtype="float32") = R.stack(
(inp_0, inp_1, inp_2), axis=0
)
- lv1: R.Tensor((3, 1, 3, 10, 10), dtype="float32") = R.reshape(
- lv, R.shape([3, 1, 3, 10, 10])
- )
- gv: R.Tensor((3, 1, 3, 10, 10), dtype="float32") = lv1
+ gv: R.Tensor((3, 1, 3, 10, 10), dtype="float32") = lv
R.output(gv)
return gv