This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 262c6d2e04 [Relax] Migrate NN conv/pooling/grad attrs from
Array<IntImm> to Array<int64_t> (#18733)
262c6d2e04 is described below
commit 262c6d2e041e6a5d5775049b32b166b526b49fa9
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Wed Feb 11 22:38:33 2026 +0800
[Relax] Migrate NN conv/pooling/grad attrs from Array<IntImm> to
Array<int64_t> (#18733)
## Why
has_attr pattern matching failed for integer attributes because
StructuralEqual requires exact dtype match for IntImm,
causing mismatches between Python-specified pattern values and actual
operator attributes.
## How
- Migrate NN conv/pooling/grad attrs from Array<IntImm> to
Array<int64_t>
---------
Signed-off-by: Guan-Ming Chiu <[email protected]>
---
include/tvm/relax/attrs/nn.h | 64 ++++----
python/tvm/relax/dpl/pattern.py | 4 +-
python/tvm/relax/frontend/onnx/onnx_frontend.py | 2 +-
python/tvm/relax/op/_op_gradient.py | 2 +-
src/contrib/msc/core/utils.cc | 4 +-
.../msc/framework/tensorrt/transform_tensorrt.cc | 13 +-
.../backend/contrib/codegen_json/codegen_json.h | 6 +-
src/relax/backend/contrib/nnapi/codegen.cc | 22 +--
src/relax/op/nn/convolution.cc | 109 ++++++-------
src/relax/op/nn/convolution.h | 38 ++---
src/relax/op/nn/pooling.cc | 174 +++++++++++----------
src/relax/op/nn/pooling.h | 10 +-
src/relax/op/op_common.h | 6 +-
src/relax/op/tensor/grad.cc | 24 +--
src/relax/op/tensor/grad.h | 12 +-
tests/python/relax/test_dataflow_pattern.py | 4 +-
tests/python/relax/test_op_nn_convolution.py | 50 +++---
tests/python/relax/test_op_nn_pooling.py | 88 +++++------
.../relax/test_transform_legalize_ops_grad.py | 6 +-
19 files changed, 319 insertions(+), 319 deletions(-)
diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h
index 13a54a16b3..2a2ac5fe07 100644
--- a/include/tvm/relax/attrs/nn.h
+++ b/include/tvm/relax/attrs/nn.h
@@ -31,9 +31,9 @@ namespace relax {
/*! \brief Attributes used in Conv1d operator */
struct Conv1DAttrs : public AttrsNodeReflAdapter<Conv1DAttrs> {
- ffi::Array<IntImm> strides;
- ffi::Array<IntImm> padding;
- ffi::Array<IntImm> dilation;
+ ffi::Array<int64_t> strides;
+ ffi::Array<int64_t> padding;
+ ffi::Array<int64_t> dilation;
int groups;
ffi::String data_layout;
ffi::String kernel_layout;
@@ -75,9 +75,9 @@ struct Conv1DAttrs : public AttrsNodeReflAdapter<Conv1DAttrs>
{
/*! \brief Attributes used in Conv2d operator */
struct Conv2DAttrs : public AttrsNodeReflAdapter<Conv2DAttrs> {
- ffi::Array<IntImm> strides;
- ffi::Array<IntImm> padding;
- ffi::Array<IntImm> dilation;
+ ffi::Array<int64_t> strides;
+ ffi::Array<int64_t> padding;
+ ffi::Array<int64_t> dilation;
int groups;
ffi::String data_layout;
ffi::String kernel_layout;
@@ -121,9 +121,9 @@ struct Conv2DAttrs : public
AttrsNodeReflAdapter<Conv2DAttrs> {
/*! \brief Attributes used in Conv3d operator */
struct Conv3DAttrs : public AttrsNodeReflAdapter<Conv3DAttrs> {
- ffi::Array<IntImm> strides;
- ffi::Array<IntImm> padding;
- ffi::Array<IntImm> dilation;
+ ffi::Array<int64_t> strides;
+ ffi::Array<int64_t> padding;
+ ffi::Array<int64_t> dilation;
int groups;
ffi::String data_layout;
ffi::String kernel_layout;
@@ -169,10 +169,10 @@ struct Conv3DAttrs : public
AttrsNodeReflAdapter<Conv3DAttrs> {
/*! \brief Attributes used in Conv1DTranspose operator */
struct Conv1DTransposeAttrs : public
AttrsNodeReflAdapter<Conv1DTransposeAttrs> {
- ffi::Array<IntImm> strides;
- ffi::Array<IntImm> padding;
- ffi::Array<IntImm> output_padding;
- ffi::Array<IntImm> dilation;
+ ffi::Array<int64_t> strides;
+ ffi::Array<int64_t> padding;
+ ffi::Array<int64_t> output_padding;
+ ffi::Array<int64_t> dilation;
int groups;
ffi::String data_layout;
ffi::String kernel_layout;
@@ -218,10 +218,10 @@ struct Conv1DTransposeAttrs : public
AttrsNodeReflAdapter<Conv1DTransposeAttrs>
/*! \brief Attributes used in Conv2d operator */
struct Conv2DTransposeAttrs : public
AttrsNodeReflAdapter<Conv2DTransposeAttrs> {
- ffi::Array<IntImm> strides;
- ffi::Array<IntImm> padding;
- ffi::Array<IntImm> output_padding;
- ffi::Array<IntImm> dilation;
+ ffi::Array<int64_t> strides;
+ ffi::Array<int64_t> padding;
+ ffi::Array<int64_t> output_padding;
+ ffi::Array<int64_t> dilation;
int groups;
ffi::String data_layout;
ffi::String kernel_layout;
@@ -269,10 +269,10 @@ struct Conv2DTransposeAttrs : public
AttrsNodeReflAdapter<Conv2DTransposeAttrs>
/*! \brief Attributes used in max_pool1d and avg_pool1d operator */
struct Pool1DAttrs : public AttrsNodeReflAdapter<Pool1DAttrs> {
- ffi::Array<IntImm> pool_size;
- ffi::Array<IntImm> strides;
- ffi::Array<IntImm> padding;
- ffi::Array<IntImm> dilation;
+ ffi::Array<int64_t> pool_size;
+ ffi::Array<int64_t> strides;
+ ffi::Array<int64_t> padding;
+ ffi::Array<int64_t> dilation;
bool ceil_mode;
bool count_include_pad;
ffi::String layout;
@@ -310,10 +310,10 @@ struct Pool1DAttrs : public
AttrsNodeReflAdapter<Pool1DAttrs> {
/*! \brief Attributes used in max_pool2d and avg_pool2d operator */
struct Pool2DAttrs : public AttrsNodeReflAdapter<Pool2DAttrs> {
- ffi::Array<IntImm> pool_size;
- ffi::Array<IntImm> strides;
- ffi::Array<IntImm> padding;
- ffi::Array<IntImm> dilation;
+ ffi::Array<int64_t> pool_size;
+ ffi::Array<int64_t> strides;
+ ffi::Array<int64_t> padding;
+ ffi::Array<int64_t> dilation;
bool ceil_mode;
bool count_include_pad;
ffi::String layout;
@@ -353,10 +353,10 @@ struct Pool2DAttrs : public
AttrsNodeReflAdapter<Pool2DAttrs> {
/*! \brief Attributes used in max_pool3d and avg_pool3d operator */
struct Pool3DAttrs : public AttrsNodeReflAdapter<Pool3DAttrs> {
- ffi::Array<IntImm> pool_size;
- ffi::Array<IntImm> strides;
- ffi::Array<IntImm> padding;
- ffi::Array<IntImm> dilation;
+ ffi::Array<int64_t> pool_size;
+ ffi::Array<int64_t> strides;
+ ffi::Array<int64_t> padding;
+ ffi::Array<int64_t> dilation;
bool ceil_mode;
bool count_include_pad;
ffi::String layout;
@@ -396,7 +396,7 @@ struct Pool3DAttrs : public
AttrsNodeReflAdapter<Pool3DAttrs> {
/*! \brief Attributes for 1d adaptive pool operator */
struct AdaptivePool1DAttrs : public AttrsNodeReflAdapter<AdaptivePool1DAttrs> {
- ffi::Optional<ffi::Array<IntImm>> output_size;
+ ffi::Optional<ffi::Array<int64_t>> output_size;
ffi::String layout;
ffi::String out_layout;
@@ -421,7 +421,7 @@ struct AdaptivePool1DAttrs : public
AttrsNodeReflAdapter<AdaptivePool1DAttrs> {
/*! \brief Attributes for 2d adaptive pool operator */
struct AdaptivePool2DAttrs : public AttrsNodeReflAdapter<AdaptivePool2DAttrs> {
- ffi::Optional<ffi::Array<IntImm>> output_size;
+ ffi::Optional<ffi::Array<int64_t>> output_size;
ffi::String layout;
ffi::String out_layout;
@@ -446,7 +446,7 @@ struct AdaptivePool2DAttrs : public
AttrsNodeReflAdapter<AdaptivePool2DAttrs> {
/*! \brief Attributes for 3d adaptive pool operator */
struct AdaptivePool3DAttrs : public AttrsNodeReflAdapter<AdaptivePool3DAttrs> {
- ffi::Optional<ffi::Array<IntImm>> output_size;
+ ffi::Optional<ffi::Array<int64_t>> output_size;
ffi::String layout;
ffi::String out_layout;
diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py
index ef7516f31f..1a08b66f2a 100644
--- a/python/tvm/relax/dpl/pattern.py
+++ b/python/tvm/relax/dpl/pattern.py
@@ -656,8 +656,8 @@ class PrimArrPattern(DFPattern):
@register_df_node
class AttrPattern(DFPattern):
- """Get match an expression with a certain attributes.
- Currently only supports Op Attributes, not call Attributes.
+ """Match an expression with certain attributes.
+ Supports Op attributes, Call attributes, and Function attributes.
Parameters
----------
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 61ab45d308..5f7c2ab752 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -2567,7 +2567,7 @@ class Pool(OnnxOpConverter):
pads = []
if cls.name == "avg_pool":
for axis in range(len(input_shape) - 2):
- axis_shape = input_shape[2 + axis]
+ axis_shape = int(input_shape[2 + axis])
stride = strides[axis]
kernel = kernel_shape[axis]
pad = cls.get_pad_pair(axis_shape, kernel, stride,
auto_pad)
diff --git a/python/tvm/relax/op/_op_gradient.py
b/python/tvm/relax/op/_op_gradient.py
index fd80f1e313..bf267473f8 100644
--- a/python/tvm/relax/op/_op_gradient.py
+++ b/python/tvm/relax/op/_op_gradient.py
@@ -1202,7 +1202,7 @@ kernel_layout, out_layout, out_dtype)`
out_h = (grad_h - 1) * stride_h - pad_top - pad_bottom + filter_h
out_w = (grad_w - 1) * stride_w - pad_left - pad_right + filter_w
- output_padding = (in_h - out_h, in_w - out_w)
+ output_padding = (int(in_h - out_h), int(in_w - out_w))
data_grad = conv2d_transpose( # type: ignore
output_grad,
diff --git a/src/contrib/msc/core/utils.cc b/src/contrib/msc/core/utils.cc
index bc70c809af..b917bc47a2 100644
--- a/src/contrib/msc/core/utils.cc
+++ b/src/contrib/msc/core/utils.cc
@@ -275,11 +275,13 @@ const ffi::String StringUtils::ToString(const ffi::Any&
obj) {
obj_string = *opt_str;
} else if (const auto* n = obj.as<IntImmNode>()) {
obj_string = std::to_string(n->value);
+ } else if (obj.type_index() == kTVMFFIInt) {
+ obj_string = std::to_string(obj.cast<int64_t>());
} else if (const auto* n = obj.as<FloatImmNode>()) {
obj_string = std::to_string(n->value);
} else if (const auto* n = obj.as<ffi::ArrayObj>()) {
for (size_t i = 0; i < n->size(); i++) {
- obj_string = obj_string + ToString((*n)[i].cast<ObjectRef>());
+ obj_string = obj_string + ToString((*n)[i]);
if (n->size() == 1 || i < n->size() - 1) {
obj_string = obj_string + ",";
}
diff --git a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc
b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc
index e3579ec7ef..8b20d59574 100644
--- a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc
+++ b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc
@@ -429,10 +429,9 @@ Expr RewriteConv1d(BlockBuilder builder, const Var& var,
const Call& src_call,
// change to conv2d
static const Op& conv2d_op = Op::Get("relax.nn.conv2d");
auto conv_attrs = ffi::make_object<Conv2DAttrs>();
- conv_attrs->strides = ffi::Array<IntImm>{src_attrs->strides[0],
Integer(1)};
- conv_attrs->padding =
- ffi::Array<IntImm>{Integer(0), src_attrs->padding[0], Integer(0),
src_attrs->padding[1]};
- conv_attrs->dilation = ffi::Array<IntImm>{src_attrs->dilation[0],
Integer(1)};
+ conv_attrs->strides = ffi::Array<int64_t>{src_attrs->strides[0], 1};
+ conv_attrs->padding = ffi::Array<int64_t>{0, src_attrs->padding[0], 0,
src_attrs->padding[1]};
+ conv_attrs->dilation = ffi::Array<int64_t>{src_attrs->dilation[0], 1};
conv_attrs->groups = src_attrs->groups;
conv_attrs->data_layout = "NCHW";
conv_attrs->kernel_layout = "OIHW";
@@ -706,9 +705,9 @@ Expr RewriteMatmul(BlockBuilder builder, const Var& var,
const Call& src_call,
// to conv2d
static const Op& conv2d_op = Op::Get("relax.nn.conv2d");
auto conv_attrs = ffi::make_object<Conv2DAttrs>();
- conv_attrs->strides = ffi::Array<IntImm>{Integer(1), Integer(1)};
- conv_attrs->padding = ffi::Array<IntImm>{Integer(0), Integer(0),
Integer(0), Integer(0)};
- conv_attrs->dilation = ffi::Array<IntImm>{Integer(1), Integer(1)};
+ conv_attrs->strides = ffi::Array<int64_t>{1, 1};
+ conv_attrs->padding = ffi::Array<int64_t>{0, 0, 0, 0};
+ conv_attrs->dilation = ffi::Array<int64_t>{1, 1};
conv_attrs->groups = 1;
conv_attrs->data_layout = "NCHW";
conv_attrs->kernel_layout = "OIHW";
diff --git a/src/relax/backend/contrib/codegen_json/codegen_json.h
b/src/relax/backend/contrib/codegen_json/codegen_json.h
index 5056962542..6076de75bc 100644
--- a/src/relax/backend/contrib/codegen_json/codegen_json.h
+++ b/src/relax/backend/contrib/codegen_json/codegen_json.h
@@ -115,8 +115,12 @@ class OpAttrExtractor {
if (const auto* an = (*value).as<ffi::ArrayObj>()) {
std::vector<std::string> attr;
for (size_t i = 0; i < an->size(); ++i) {
- if (const auto* im = (*an)[i].as<IntImmNode>()) {
+ if (auto opt_int = (*an)[i].try_cast<int64_t>()) {
+ attr.push_back(std::to_string(opt_int.value()));
+ } else if (const auto* im = (*an)[i].as<IntImmNode>()) {
attr.push_back(std::to_string(im->value));
+ } else if (auto opt_float = (*an)[i].try_cast<double>()) {
+ attr.push_back(Fp2String(opt_float.value()));
} else if (const auto* fm = (*an)[i].as<FloatImmNode>()) {
attr.push_back(Fp2String(fm->value));
} else if (auto opt_str = (*an)[i].as<ffi::String>()) {
diff --git a/src/relax/backend/contrib/nnapi/codegen.cc
b/src/relax/backend/contrib/nnapi/codegen.cc
index 92933ba070..0ea05b9863 100644
--- a/src/relax/backend/contrib/nnapi/codegen.cc
+++ b/src/relax/backend/contrib/nnapi/codegen.cc
@@ -107,10 +107,7 @@ class CollectFromCompositeFunctionBody : public
ExprVisitor {
std::vector<std::string> strides;
if (!conv2d_attr->strides.empty()) {
for (auto stride : conv2d_attr->strides) {
- const auto* stride_val = stride.as<IntImmNode>();
- ICHECK(stride_val) << "convertion failed";
-
- strides.push_back(std::to_string(stride_val->value));
+ strides.push_back(std::to_string(stride));
}
} else {
strides = {"1", "1"};
@@ -118,9 +115,7 @@ class CollectFromCompositeFunctionBody : public ExprVisitor
{
std::vector<std::string> padding;
for (auto pad : conv2d_attr->padding) {
- const auto* padding_val = pad.as<IntImmNode>();
-
- padding.push_back(std::to_string(padding_val->value));
+ padding.push_back(std::to_string(pad));
}
std::vector<std::string> groups;
@@ -147,10 +142,7 @@ class CollectFromCompositeFunctionBody : public
ExprVisitor {
std::vector<std::string> strides;
if (!max_pool_2d_attr->strides.empty()) {
for (auto stride : max_pool_2d_attr->strides) {
- const auto* stride_val = stride.as<IntImmNode>();
- ICHECK(stride_val) << "convertion failed";
-
- strides.push_back(std::to_string(stride_val->value));
+ strides.push_back(std::to_string(stride));
}
} else {
strides.push_back("1");
@@ -159,16 +151,12 @@ class CollectFromCompositeFunctionBody : public
ExprVisitor {
std::vector<std::string> padding;
for (auto pad : max_pool_2d_attr->padding) {
- const auto* padding_val = pad.as<IntImmNode>();
-
- padding.push_back(std::to_string(padding_val->value));
+ padding.push_back(std::to_string(pad));
}
std::vector<std::string> pool_size;
for (auto size : max_pool_2d_attr->pool_size) {
- const auto* pooling_val = size.as<IntImmNode>();
-
- pool_size.push_back(std::to_string(pooling_val->value));
+ pool_size.push_back(std::to_string(size));
}
std::vector<dmlc::any> strides_attr;
diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc
index 3fba58ede2..5368db79d2 100644
--- a/src/relax/op/nn/convolution.cc
+++ b/src/relax/op/nn/convolution.cc
@@ -41,8 +41,8 @@ TVM_FFI_STATIC_INIT_BLOCK() {
/* relax.nn.conv1d */
-Expr conv1d(Expr data, Expr weight, ffi::Array<IntImm> strides,
ffi::Array<IntImm> padding,
- ffi::Array<IntImm> dilation, int groups, ffi::String data_layout,
+Expr conv1d(Expr data, Expr weight, ffi::Array<int64_t> strides,
ffi::Array<int64_t> padding,
+ ffi::Array<int64_t> dilation, int groups, ffi::String data_layout,
ffi::String kernel_layout, ffi::Optional<ffi::String> out_layout,
ffi::Optional<DataType> out_dtype) {
padding = GetCompletePadding1D(std::move(padding));
@@ -125,15 +125,15 @@ StructInfo InferStructInfoConv1d(const Call& call, const
BlockBuilder& ctx) {
PrimExpr input_w = data_NCW_shape[2];
PrimExpr kernel_w = weight_OIW_shape[2];
- PrimExpr padding_w = attrs->padding[0] + attrs->padding[1];
+ PrimExpr padding_w = Integer(attrs->padding[0]) + Integer(attrs->padding[1]);
std::vector<PrimExpr> out_NCW_shape;
out_NCW_shape.resize(3);
out_NCW_shape[0] = data_NCW_shape[0];
out_NCW_shape[1] = weight_OIW_shape[0];
- PrimExpr numerator_w = input_w + padding_w - attrs->dilation[0] * (kernel_w
- 1) - 1;
- out_NCW_shape[2] = analyzer->Simplify(floordiv(numerator_w,
attrs->strides[0]) + 1);
+ PrimExpr numerator_w = input_w + padding_w - Integer(attrs->dilation[0]) *
(kernel_w - 1) - 1;
+ out_NCW_shape[2] = analyzer->Simplify(floordiv(numerator_w,
Integer(attrs->strides[0])) + 1);
ffi::Array<PrimExpr> out_shape = out2NCW.BackwardShape(out_NCW_shape);
return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice);
@@ -202,8 +202,8 @@ TVM_REGISTER_OP("relax.nn.conv1d")
/* relax.nn.conv2d */
-Expr conv2d(Expr data, Expr weight, ffi::Array<IntImm> strides,
ffi::Array<IntImm> padding,
- ffi::Array<IntImm> dilation, int groups, ffi::String data_layout,
+Expr conv2d(Expr data, Expr weight, ffi::Array<int64_t> strides,
ffi::Array<int64_t> padding,
+ ffi::Array<int64_t> dilation, int groups, ffi::String data_layout,
ffi::String kernel_layout, ffi::Optional<ffi::String> out_layout,
ffi::Optional<DataType> out_dtype) {
padding = GetCompletePadding2D(std::move(padding));
@@ -294,18 +294,18 @@ StructInfo InferStructInfoConv2d(const Call& call, const
BlockBuilder& ctx) {
PrimExpr input_w = data_NCHW_shape[3];
PrimExpr kernel_h = weight_OIHW_shape[2];
PrimExpr kernel_w = weight_OIHW_shape[3];
- PrimExpr padding_h = attrs->padding[0] + attrs->padding[2];
- PrimExpr padding_w = attrs->padding[1] + attrs->padding[3];
+ PrimExpr padding_h = Integer(attrs->padding[0]) + Integer(attrs->padding[2]);
+ PrimExpr padding_w = Integer(attrs->padding[1]) + Integer(attrs->padding[3]);
std::vector<PrimExpr> out_NCHW_shape;
out_NCHW_shape.resize(4);
out_NCHW_shape[0] = data_NCHW_shape[0];
out_NCHW_shape[1] = weight_OIHW_shape[0];
- PrimExpr numerator_h = input_h + padding_h - attrs->dilation[0] * (kernel_h
- 1) - 1;
- PrimExpr numerator_w = input_w + padding_w - attrs->dilation[1] * (kernel_w
- 1) - 1;
- out_NCHW_shape[2] = analyzer->Simplify(floordiv(numerator_h,
attrs->strides[0]) + 1);
- out_NCHW_shape[3] = analyzer->Simplify(floordiv(numerator_w,
attrs->strides[1]) + 1);
+ PrimExpr numerator_h = input_h + padding_h - Integer(attrs->dilation[0]) *
(kernel_h - 1) - 1;
+ PrimExpr numerator_w = input_w + padding_w - Integer(attrs->dilation[1]) *
(kernel_w - 1) - 1;
+ out_NCHW_shape[2] = analyzer->Simplify(floordiv(numerator_h,
Integer(attrs->strides[0])) + 1);
+ out_NCHW_shape[3] = analyzer->Simplify(floordiv(numerator_w,
Integer(attrs->strides[1])) + 1);
ffi::Array<PrimExpr> out_shape = out2NCHW.BackwardShape(out_NCHW_shape);
return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice);
@@ -409,8 +409,8 @@ TVM_REGISTER_OP("relax.nn.conv2d")
/* relax.nn.conv3d */
-Expr conv3d(Expr data, Expr weight, ffi::Array<IntImm> strides,
ffi::Array<IntImm> padding,
- ffi::Array<IntImm> dilation, int groups, ffi::String data_layout,
+Expr conv3d(Expr data, Expr weight, ffi::Array<int64_t> strides,
ffi::Array<int64_t> padding,
+ ffi::Array<int64_t> dilation, int groups, ffi::String data_layout,
ffi::String kernel_layout, ffi::Optional<ffi::String> out_layout,
ffi::Optional<DataType> out_dtype) {
padding = GetCompletePadding3D(std::move(padding));
@@ -506,21 +506,21 @@ StructInfo InferStructInfoConv3d(const Call& call, const
BlockBuilder& ctx) {
PrimExpr kernel_d = weight_OIDHW_shape[2];
PrimExpr kernel_h = weight_OIDHW_shape[3];
PrimExpr kernel_w = weight_OIDHW_shape[4];
- PrimExpr padding_d = attrs->padding[0] + attrs->padding[3];
- PrimExpr padding_h = attrs->padding[1] + attrs->padding[4];
- PrimExpr padding_w = attrs->padding[2] + attrs->padding[5];
+ PrimExpr padding_d = Integer(attrs->padding[0]) + Integer(attrs->padding[3]);
+ PrimExpr padding_h = Integer(attrs->padding[1]) + Integer(attrs->padding[4]);
+ PrimExpr padding_w = Integer(attrs->padding[2]) + Integer(attrs->padding[5]);
std::vector<PrimExpr> out_NCDHW_shape;
out_NCDHW_shape.resize(5);
out_NCDHW_shape[0] = data_NCDHW_shape[0];
out_NCDHW_shape[1] = weight_OIDHW_shape[0];
- PrimExpr numerator_d = input_d + padding_d - attrs->dilation[0] * (kernel_d
- 1) - 1;
- PrimExpr numerator_h = input_h + padding_h - attrs->dilation[1] * (kernel_h
- 1) - 1;
- PrimExpr numerator_w = input_w + padding_w - attrs->dilation[2] * (kernel_w
- 1) - 1;
- out_NCDHW_shape[2] = analyzer->Simplify(floordiv(numerator_d,
attrs->strides[0]) + 1);
- out_NCDHW_shape[3] = analyzer->Simplify(floordiv(numerator_h,
attrs->strides[1]) + 1);
- out_NCDHW_shape[4] = analyzer->Simplify(floordiv(numerator_w,
attrs->strides[2]) + 1);
+ PrimExpr numerator_d = input_d + padding_d - Integer(attrs->dilation[0]) *
(kernel_d - 1) - 1;
+ PrimExpr numerator_h = input_h + padding_h - Integer(attrs->dilation[1]) *
(kernel_h - 1) - 1;
+ PrimExpr numerator_w = input_w + padding_w - Integer(attrs->dilation[2]) *
(kernel_w - 1) - 1;
+ out_NCDHW_shape[2] = analyzer->Simplify(floordiv(numerator_d,
Integer(attrs->strides[0])) + 1);
+ out_NCDHW_shape[3] = analyzer->Simplify(floordiv(numerator_h,
Integer(attrs->strides[1])) + 1);
+ out_NCDHW_shape[4] = analyzer->Simplify(floordiv(numerator_w,
Integer(attrs->strides[2])) + 1);
ffi::Array<PrimExpr> out_shape = out2NCDHW.BackwardShape(out_NCDHW_shape);
return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice);
@@ -587,9 +587,9 @@ TVM_REGISTER_OP("relax.nn.conv3d")
.set_attr<FInferMixedPrecision>("FInferMixedPrecision",
InferMixedPrecisionConv3d)
.set_attr<Bool>("FPurity", Bool(true));
-Expr conv1d_transpose(Expr data, Expr weight, ffi::Array<IntImm> strides,
- ffi::Array<IntImm> padding, ffi::Array<IntImm>
output_padding,
- ffi::Array<IntImm> dilation, int groups, ffi::String
data_layout,
+Expr conv1d_transpose(Expr data, Expr weight, ffi::Array<int64_t> strides,
+ ffi::Array<int64_t> padding, ffi::Array<int64_t>
output_padding,
+ ffi::Array<int64_t> dilation, int groups, ffi::String
data_layout,
ffi::String kernel_layout, ffi::Optional<ffi::String>
out_layout,
ffi::Optional<DataType> out_dtype) {
padding = GetCompletePadding1D(std::move(padding));
@@ -607,10 +607,10 @@ Expr conv1d_transpose(Expr data, Expr weight,
ffi::Array<IntImm> strides,
<< dilation;
auto attrs = ffi::make_object<Conv1DTransposeAttrs>();
- attrs->strides = ConvertIntImmToInt64(strides);
- attrs->padding = ConvertIntImmToInt64(padding);
- attrs->output_padding = ConvertIntImmToInt64(output_padding);
- attrs->dilation = ConvertIntImmToInt64(dilation);
+ attrs->strides = std::move(strides);
+ attrs->padding = std::move(padding);
+ attrs->output_padding = std::move(output_padding);
+ attrs->dilation = std::move(dilation);
attrs->groups = groups;
attrs->data_layout = data_layout;
attrs->kernel_layout = std::move(kernel_layout);
@@ -680,27 +680,28 @@ StructInfo InferStructInfoConv1dTranspose(const Call&
call, const BlockBuilder&
// Todo(relax-team): Trust the input shape at this moment, and revisit
// this condition with runtime shape check
}
- if (analyzer->CanProve(attrs->output_padding[0]->value >=
attrs->strides[0]->value)) {
+ if (attrs->output_padding[0] >= attrs->strides[0]) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Conv1dTranspose expects the output padding less than
the strides, but the "
"output padding is"
<< attrs->output_padding << " while the strides are" <<
attrs->strides);
- } else if (!analyzer->CanProve(attrs->output_padding[0]->value <
attrs->strides[0]->value)) {
+ } else if (!(attrs->output_padding[0] < attrs->strides[0])) {
// Todo(relax-team): Trust the input padding at this moment, and revisit
// this condition with runtime shape check
}
PrimExpr input_w = data_NCW_shape[2];
PrimExpr kernel_w = weight_IOW_shape[2];
- PrimExpr padding_w = attrs->padding[0] + attrs->padding[1];
+ PrimExpr padding_w = Integer(attrs->padding[0]) + Integer(attrs->padding[1]);
std::vector<PrimExpr> out_NCW_shape;
out_NCW_shape.resize(3);
out_NCW_shape[0] = data_NCW_shape[0];
out_NCW_shape[1] = weight_IOW_shape[1] * attrs->groups;
- PrimExpr out_w = (input_w - 1) * attrs->strides[0] - padding_w +
- attrs->dilation[0] * (kernel_w - 1) +
attrs->output_padding[0] + 1;
+ PrimExpr out_w = (input_w - 1) * Integer(attrs->strides[0]) - padding_w +
+ Integer(attrs->dilation[0]) * (kernel_w - 1) +
+ Integer(attrs->output_padding[0]) + 1;
out_NCW_shape[2] = analyzer->Simplify(out_w);
ffi::Array<PrimExpr> out_shape = out2NCW.BackwardShape(out_NCW_shape);
@@ -767,9 +768,9 @@ TVM_REGISTER_OP("relax.nn.conv1d_transpose")
/* relax.nn.conv2d_transpose */
-Expr conv2d_transpose(Expr data, Expr weight, ffi::Array<IntImm> strides,
- ffi::Array<IntImm> padding, ffi::Array<IntImm>
output_padding,
- ffi::Array<IntImm> dilation, int groups, ffi::String
data_layout,
+Expr conv2d_transpose(Expr data, Expr weight, ffi::Array<int64_t> strides,
+ ffi::Array<int64_t> padding, ffi::Array<int64_t>
output_padding,
+ ffi::Array<int64_t> dilation, int groups, ffi::String
data_layout,
ffi::String kernel_layout, ffi::Optional<ffi::String>
out_layout,
ffi::Optional<DataType> out_dtype) {
padding = GetCompletePadding2D(std::move(padding));
@@ -796,10 +797,10 @@ Expr conv2d_transpose(Expr data, Expr weight,
ffi::Array<IntImm> strides,
<< dilation;
auto attrs = ffi::make_object<Conv2DTransposeAttrs>();
- attrs->strides = ConvertIntImmToInt64(strides);
- attrs->padding = ConvertIntImmToInt64(padding);
- attrs->output_padding = ConvertIntImmToInt64(output_padding);
- attrs->dilation = ConvertIntImmToInt64(dilation);
+ attrs->strides = std::move(strides);
+ attrs->padding = std::move(padding);
+ attrs->output_padding = std::move(output_padding);
+ attrs->dilation = std::move(dilation);
attrs->groups = groups;
attrs->data_layout = data_layout;
attrs->kernel_layout = std::move(kernel_layout);
@@ -870,14 +871,14 @@ StructInfo InferStructInfoConv2dTranspose(const Call&
call, const BlockBuilder&
// Todo(relax-team): Trust the input shape at this moment, and revisit
// this condition with runtime shape check
}
- if (analyzer->CanProve(attrs->output_padding[0]->value >=
attrs->strides[0]->value ||
- attrs->output_padding[1]->value >=
attrs->strides[1]->value)) {
+ if (attrs->output_padding[0] >= attrs->strides[0] ||
+ attrs->output_padding[1] >= attrs->strides[1]) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Conv2dTranspose expects the output padding less than
the strides, but the "
"output padding is"
<< attrs->output_padding << " while the strides are" <<
attrs->strides);
- } else if (!analyzer->CanProve(attrs->output_padding[0]->value <
attrs->strides[0]->value &&
- attrs->output_padding[1]->value <
attrs->strides[1]->value)) {
+ } else if (!(attrs->output_padding[0] < attrs->strides[0] &&
+ attrs->output_padding[1] < attrs->strides[1])) {
// Todo(relax-team): Trust the input padding at this moment, and revisit
// this condition with runtime shape check
}
@@ -886,18 +887,20 @@ StructInfo InferStructInfoConv2dTranspose(const Call&
call, const BlockBuilder&
PrimExpr input_w = data_NCHW_shape[3];
PrimExpr kernel_h = weight_IOHW_shape[2];
PrimExpr kernel_w = weight_IOHW_shape[3];
- PrimExpr padding_h = attrs->padding[0] + attrs->padding[2];
- PrimExpr padding_w = attrs->padding[1] + attrs->padding[3];
+ PrimExpr padding_h = Integer(attrs->padding[0]) + Integer(attrs->padding[2]);
+ PrimExpr padding_w = Integer(attrs->padding[1]) + Integer(attrs->padding[3]);
std::vector<PrimExpr> out_NCHW_shape;
out_NCHW_shape.resize(4);
out_NCHW_shape[0] = data_NCHW_shape[0];
out_NCHW_shape[1] = weight_IOHW_shape[1] * attrs->groups;
- PrimExpr out_h = (input_h - 1) * attrs->strides[0] - padding_h +
- attrs->dilation[0] * (kernel_h - 1) +
attrs->output_padding[0] + 1;
- PrimExpr out_w = (input_w - 1) * attrs->strides[1] - padding_w +
- attrs->dilation[1] * (kernel_w - 1) +
attrs->output_padding[1] + 1;
+ PrimExpr out_h = (input_h - 1) * Integer(attrs->strides[0]) - padding_h +
+ Integer(attrs->dilation[0]) * (kernel_h - 1) +
+ Integer(attrs->output_padding[0]) + 1;
+ PrimExpr out_w = (input_w - 1) * Integer(attrs->strides[1]) - padding_w +
+ Integer(attrs->dilation[1]) * (kernel_w - 1) +
+ Integer(attrs->output_padding[1]) + 1;
out_NCHW_shape[2] = analyzer->Simplify(out_h);
out_NCHW_shape[3] = analyzer->Simplify(out_w);
diff --git a/src/relax/op/nn/convolution.h b/src/relax/op/nn/convolution.h
index 4fc175b5aa..a5704d3f70 100644
--- a/src/relax/op/nn/convolution.h
+++ b/src/relax/op/nn/convolution.h
@@ -36,14 +36,14 @@ namespace tvm {
namespace relax {
template <typename T>
-inline Expr MakeConv(Expr data, Expr weight, ffi::Array<IntImm> strides,
ffi::Array<IntImm> padding,
- ffi::Array<IntImm> dilation, int groups, ffi::String
data_layout,
- ffi::String kernel_layout, ffi::String out_layout,
DataType out_dtype,
- std::string op_name) {
+inline Expr MakeConv(Expr data, Expr weight, ffi::Array<int64_t> strides,
+ ffi::Array<int64_t> padding, ffi::Array<int64_t>
dilation, int groups,
+ ffi::String data_layout, ffi::String kernel_layout,
ffi::String out_layout,
+ DataType out_dtype, std::string op_name) {
auto attrs = ffi::make_object<T>();
- attrs->strides = ConvertIntImmToInt64(strides);
- attrs->padding = ConvertIntImmToInt64(padding);
- attrs->dilation = ConvertIntImmToInt64(dilation);
+ attrs->strides = std::move(strides);
+ attrs->padding = std::move(padding);
+ attrs->dilation = std::move(dilation);
attrs->groups = groups;
attrs->data_layout = std::move(data_layout);
attrs->kernel_layout = std::move(kernel_layout);
@@ -54,20 +54,20 @@ inline Expr MakeConv(Expr data, Expr weight,
ffi::Array<IntImm> strides, ffi::Ar
}
/*! \brief 1D convolution */
-Expr conv1d(Expr data, Expr weight, ffi::Array<IntImm> strides,
ffi::Array<IntImm> padding,
- ffi::Array<IntImm> dilation, int groups, ffi::String data_layout,
+Expr conv1d(Expr data, Expr weight, ffi::Array<int64_t> strides,
ffi::Array<int64_t> padding,
+ ffi::Array<int64_t> dilation, int groups, ffi::String data_layout,
ffi::String kernel_layout, ffi::Optional<ffi::String> out_layout,
ffi::Optional<DataType> out_dtype);
/*! \brief 2D convolution */
-Expr conv2d(Expr data, Expr weight, ffi::Array<IntImm> strides,
ffi::Array<IntImm> padding,
- ffi::Array<IntImm> dilation, int groups, ffi::String data_layout,
+Expr conv2d(Expr data, Expr weight, ffi::Array<int64_t> strides,
ffi::Array<int64_t> padding,
+ ffi::Array<int64_t> dilation, int groups, ffi::String data_layout,
ffi::String kernel_layout, ffi::Optional<ffi::String> out_layout,
ffi::Optional<DataType> out_dtype);
/*! \brief 3D convolution */
-Expr conv3d(Expr data, Expr weight, ffi::Array<IntImm> strides,
ffi::Array<IntImm> padding,
- ffi::Array<IntImm> dilation, int groups, ffi::String data_layout,
+Expr conv3d(Expr data, Expr weight, ffi::Array<int64_t> strides,
ffi::Array<int64_t> padding,
+ ffi::Array<int64_t> dilation, int groups, ffi::String data_layout,
ffi::String kernel_layout, ffi::Optional<ffi::String> out_layout,
ffi::Optional<DataType> out_dtype);
@@ -77,9 +77,9 @@ Expr conv3d(Expr data, Expr weight, ffi::Array<IntImm>
strides, ffi::Array<IntIm
* This operator is intended to be the backward operator of conv1d. It can be
used to calculate the
* gradient of the result of conv1d w.r.t. the input of conv1d.
*/
-Expr conv1d_transpose(Expr data, Expr weight, ffi::Array<IntImm> strides,
- ffi::Array<IntImm> padding, ffi::Array<IntImm>
output_padding,
- ffi::Array<IntImm> dilation, int groups, ffi::String
data_layout,
+Expr conv1d_transpose(Expr data, Expr weight, ffi::Array<int64_t> strides,
+ ffi::Array<int64_t> padding, ffi::Array<int64_t>
output_padding,
+ ffi::Array<int64_t> dilation, int groups, ffi::String
data_layout,
ffi::String kernel_layout, ffi::Optional<ffi::String>
out_layout,
ffi::Optional<DataType> out_dtype);
@@ -89,9 +89,9 @@ Expr conv1d_transpose(Expr data, Expr weight,
ffi::Array<IntImm> strides,
* This operator is intended to be the backward operator of conv2d. It can be
used to calculate the
* gradient of the result of conv2d w.r.t. the input of conv2d.
*/
-Expr conv2d_transpose(Expr data, Expr weight, ffi::Array<IntImm> strides,
- ffi::Array<IntImm> padding, ffi::Array<IntImm>
output_padding,
- ffi::Array<IntImm> dilation, int groups, ffi::String
data_layout,
+Expr conv2d_transpose(Expr data, Expr weight, ffi::Array<int64_t> strides,
+ ffi::Array<int64_t> padding, ffi::Array<int64_t>
output_padding,
+ ffi::Array<int64_t> dilation, int groups, ffi::String
data_layout,
ffi::String kernel_layout, ffi::Optional<ffi::String>
out_layout,
ffi::Optional<DataType> out_dtype);
diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc
index 1a19872c27..2397bf0098 100644
--- a/src/relax/op/nn/pooling.cc
+++ b/src/relax/op/nn/pooling.cc
@@ -38,10 +38,10 @@ TVM_FFI_STATIC_INIT_BLOCK() {
/* relax.nn.max_pool1d */
-Expr MakePool1d(ffi::String op_name, Expr data, ffi::Array<IntImm> pool_size,
- ffi::Array<IntImm> strides, ffi::Array<IntImm> padding,
ffi::Array<IntImm> dilation,
- bool ceil_mode, bool count_include_pad, ffi::String layout,
- ffi::Optional<ffi::String> out_layout) {
+Expr MakePool1d(ffi::String op_name, Expr data, ffi::Array<int64_t> pool_size,
+ ffi::Array<int64_t> strides, ffi::Array<int64_t> padding,
+ ffi::Array<int64_t> dilation, bool ceil_mode, bool
count_include_pad,
+ ffi::String layout, ffi::Optional<ffi::String> out_layout) {
padding = GetCompletePadding1D(std::move(padding));
CHECK_EQ(pool_size.size(), 1)
@@ -54,10 +54,10 @@ Expr MakePool1d(ffi::String op_name, Expr data,
ffi::Array<IntImm> pool_size,
<< dilation;
auto attrs = ffi::make_object<Pool1DAttrs>();
- attrs->pool_size = ConvertIntImmToInt64(pool_size);
- attrs->strides = ConvertIntImmToInt64(strides);
- attrs->padding = ConvertIntImmToInt64(padding);
- attrs->dilation = ConvertIntImmToInt64(dilation);
+ attrs->pool_size = std::move(pool_size);
+ attrs->strides = std::move(strides);
+ attrs->padding = std::move(padding);
+ attrs->dilation = std::move(dilation);
attrs->ceil_mode = ceil_mode;
attrs->count_include_pad = count_include_pad;
attrs->layout = layout;
@@ -66,8 +66,8 @@ Expr MakePool1d(ffi::String op_name, Expr data,
ffi::Array<IntImm> pool_size,
return Call(op, {std::move(data)}, Attrs(attrs), {});
}
-Expr max_pool1d(Expr data, ffi::Array<IntImm> pool_size, ffi::Array<IntImm>
strides,
- ffi::Array<IntImm> padding, ffi::Array<IntImm> dilation, bool
ceil_mode,
+Expr max_pool1d(Expr data, ffi::Array<int64_t> pool_size, ffi::Array<int64_t>
strides,
+ ffi::Array<int64_t> padding, ffi::Array<int64_t> dilation,
bool ceil_mode,
bool count_include_pad, ffi::String layout,
ffi::Optional<ffi::String> out_layout) {
return MakePool1d("relax.nn.max_pool1d", data, pool_size, strides, padding,
dilation, ceil_mode,
count_include_pad, layout, out_layout);
@@ -98,8 +98,8 @@ StructInfo InferStructInfoPool1D(const Call& call, const
BlockBuilder& ctx) {
ffi::Array<PrimExpr> data_NCW_shape =
data2NCW.ForwardShape(data_shape.value()->values);
PrimExpr input_w = data_NCW_shape[2];
- PrimExpr kernel_w = attrs->pool_size[0];
- PrimExpr padding_w = attrs->padding[0] + attrs->padding[1];
+ PrimExpr kernel_w = Integer(attrs->pool_size[0]);
+ PrimExpr padding_w = Integer(attrs->padding[0]) + Integer(attrs->padding[1]);
arith::Analyzer* analyzer = ctx->GetAnalyzer();
std::vector<PrimExpr> out_NCW_shape;
@@ -107,13 +107,14 @@ StructInfo InferStructInfoPool1D(const Call& call, const
BlockBuilder& ctx) {
out_NCW_shape[0] = data_NCW_shape[0];
out_NCW_shape[1] = data_NCW_shape[1];
- PrimExpr numerator_w = input_w + padding_w - attrs->dilation[0] * (kernel_w
- 1) - 1;
+ PrimExpr numerator_w = input_w + padding_w - Integer(attrs->dilation[0]) *
(kernel_w - 1) - 1;
if (attrs->ceil_mode) {
- numerator_w += attrs->strides[0] - 1;
+ numerator_w += Integer(attrs->strides[0]) - 1;
}
- PrimExpr raw_out_w = floordiv(numerator_w, attrs->strides[0]) + 1;
+ PrimExpr raw_out_w = floordiv(numerator_w, Integer(attrs->strides[0])) + 1;
if (attrs->ceil_mode) {
- PrimExpr invalid_last_w = (raw_out_w - 1) * attrs->strides[0] >= input_w +
attrs->padding[0];
+ PrimExpr invalid_last_w =
+ (raw_out_w - 1) * Integer(attrs->strides[0]) >= input_w +
Integer(attrs->padding[0]);
out_NCW_shape[2] = analyzer->Simplify(if_then_else(invalid_last_w,
raw_out_w - 1, raw_out_w));
} else {
out_NCW_shape[2] = analyzer->Simplify(raw_out_w);
@@ -151,10 +152,10 @@ TVM_REGISTER_OP("relax.nn.max_pool1d")
/* relax.nn.max_pool2d */
-Expr MakePool2d(ffi::String op_name, Expr data, ffi::Array<IntImm> pool_size,
- ffi::Array<IntImm> strides, ffi::Array<IntImm> padding,
ffi::Array<IntImm> dilation,
- bool ceil_mode, bool count_include_pad, ffi::String layout,
- ffi::Optional<ffi::String> out_layout) {
+Expr MakePool2d(ffi::String op_name, Expr data, ffi::Array<int64_t> pool_size,
+ ffi::Array<int64_t> strides, ffi::Array<int64_t> padding,
+ ffi::Array<int64_t> dilation, bool ceil_mode, bool
count_include_pad,
+ ffi::String layout, ffi::Optional<ffi::String> out_layout) {
padding = GetCompletePadding2D(std::move(padding));
if (pool_size.size() == 1) {
pool_size.push_back(pool_size[0]);
@@ -176,10 +177,10 @@ Expr MakePool2d(ffi::String op_name, Expr data,
ffi::Array<IntImm> pool_size,
<< dilation;
auto attrs = ffi::make_object<Pool2DAttrs>();
- attrs->pool_size = ConvertIntImmToInt64(pool_size);
- attrs->strides = ConvertIntImmToInt64(strides);
- attrs->padding = ConvertIntImmToInt64(padding);
- attrs->dilation = ConvertIntImmToInt64(dilation);
+ attrs->pool_size = std::move(pool_size);
+ attrs->strides = std::move(strides);
+ attrs->padding = std::move(padding);
+ attrs->dilation = std::move(dilation);
attrs->ceil_mode = ceil_mode;
attrs->count_include_pad = count_include_pad;
attrs->layout = layout;
@@ -188,8 +189,8 @@ Expr MakePool2d(ffi::String op_name, Expr data,
ffi::Array<IntImm> pool_size,
return Call(op, {std::move(data)}, Attrs(attrs), {});
}
-Expr max_pool2d(Expr data, ffi::Array<IntImm> pool_size, ffi::Array<IntImm>
strides,
- ffi::Array<IntImm> padding, ffi::Array<IntImm> dilation, bool
ceil_mode,
+Expr max_pool2d(Expr data, ffi::Array<int64_t> pool_size, ffi::Array<int64_t>
strides,
+ ffi::Array<int64_t> padding, ffi::Array<int64_t> dilation,
bool ceil_mode,
bool count_include_pad, ffi::String layout,
ffi::Optional<ffi::String> out_layout) {
return MakePool2d("relax.nn.max_pool2d", data, pool_size, strides, padding,
dilation, ceil_mode,
count_include_pad, layout, out_layout);
@@ -221,10 +222,10 @@ StructInfo InferStructInfoPool2D(const Call& call, const
BlockBuilder& ctx) {
PrimExpr input_h = data_NCHW_shape[2];
PrimExpr input_w = data_NCHW_shape[3];
- PrimExpr kernel_h = attrs->pool_size[0];
- PrimExpr kernel_w = attrs->pool_size[1];
- PrimExpr padding_h = attrs->padding[0] + attrs->padding[2];
- PrimExpr padding_w = attrs->padding[1] + attrs->padding[3];
+ PrimExpr kernel_h = Integer(attrs->pool_size[0]);
+ PrimExpr kernel_w = Integer(attrs->pool_size[1]);
+ PrimExpr padding_h = Integer(attrs->padding[0]) + Integer(attrs->padding[2]);
+ PrimExpr padding_w = Integer(attrs->padding[1]) + Integer(attrs->padding[3]);
arith::Analyzer* analyzer = ctx->GetAnalyzer();
std::vector<PrimExpr> out_NCHW_shape;
@@ -232,17 +233,19 @@ StructInfo InferStructInfoPool2D(const Call& call, const
BlockBuilder& ctx) {
out_NCHW_shape[0] = data_NCHW_shape[0];
out_NCHW_shape[1] = data_NCHW_shape[1];
- PrimExpr numerator_h = input_h + padding_h - attrs->dilation[0] * (kernel_h
- 1) - 1;
- PrimExpr numerator_w = input_w + padding_w - attrs->dilation[1] * (kernel_w
- 1) - 1;
+ PrimExpr numerator_h = input_h + padding_h - Integer(attrs->dilation[0]) *
(kernel_h - 1) - 1;
+ PrimExpr numerator_w = input_w + padding_w - Integer(attrs->dilation[1]) *
(kernel_w - 1) - 1;
if (attrs->ceil_mode) {
- numerator_h += attrs->strides[0] - 1;
- numerator_w += attrs->strides[1] - 1;
+ numerator_h += Integer(attrs->strides[0]) - 1;
+ numerator_w += Integer(attrs->strides[1]) - 1;
}
- PrimExpr raw_out_h = floordiv(numerator_h, attrs->strides[0]) + 1;
- PrimExpr raw_out_w = floordiv(numerator_w, attrs->strides[1]) + 1;
+ PrimExpr raw_out_h = floordiv(numerator_h, Integer(attrs->strides[0])) + 1;
+ PrimExpr raw_out_w = floordiv(numerator_w, Integer(attrs->strides[1])) + 1;
if (attrs->ceil_mode) {
- PrimExpr invalid_last_h = (raw_out_h - 1) * attrs->strides[0] >= input_h +
attrs->padding[0];
- PrimExpr invalid_last_w = (raw_out_w - 1) * attrs->strides[1] >= input_w +
attrs->padding[1];
+ PrimExpr invalid_last_h =
+ (raw_out_h - 1) * Integer(attrs->strides[0]) >= input_h +
Integer(attrs->padding[0]);
+ PrimExpr invalid_last_w =
+ (raw_out_w - 1) * Integer(attrs->strides[1]) >= input_w +
Integer(attrs->padding[1]);
out_NCHW_shape[2] = analyzer->Simplify(if_then_else(invalid_last_h,
raw_out_h - 1, raw_out_h));
out_NCHW_shape[3] = analyzer->Simplify(if_then_else(invalid_last_w,
raw_out_w - 1, raw_out_w));
} else {
@@ -300,10 +303,10 @@ TVM_REGISTER_OP("relax.nn.max_pool2d")
/* relax.nn.max_pool3d */
-Expr MakePool3d(ffi::String op_name, Expr data, ffi::Array<IntImm> pool_size,
- ffi::Array<IntImm> strides, ffi::Array<IntImm> padding,
ffi::Array<IntImm> dilation,
- bool ceil_mode, bool count_include_pad, ffi::String layout,
- ffi::Optional<ffi::String> out_layout) {
+Expr MakePool3d(ffi::String op_name, Expr data, ffi::Array<int64_t> pool_size,
+ ffi::Array<int64_t> strides, ffi::Array<int64_t> padding,
+ ffi::Array<int64_t> dilation, bool ceil_mode, bool
count_include_pad,
+ ffi::String layout, ffi::Optional<ffi::String> out_layout) {
padding = GetCompletePadding3D(std::move(padding));
if (pool_size.size() == 1) {
pool_size.push_back(pool_size[0]);
@@ -328,10 +331,10 @@ Expr MakePool3d(ffi::String op_name, Expr data,
ffi::Array<IntImm> pool_size,
<< dilation;
auto attrs = ffi::make_object<Pool3DAttrs>();
- attrs->pool_size = ConvertIntImmToInt64(pool_size);
- attrs->strides = ConvertIntImmToInt64(strides);
- attrs->padding = ConvertIntImmToInt64(padding);
- attrs->dilation = ConvertIntImmToInt64(dilation);
+ attrs->pool_size = std::move(pool_size);
+ attrs->strides = std::move(strides);
+ attrs->padding = std::move(padding);
+ attrs->dilation = std::move(dilation);
attrs->ceil_mode = ceil_mode;
attrs->count_include_pad = count_include_pad;
attrs->layout = layout;
@@ -340,8 +343,8 @@ Expr MakePool3d(ffi::String op_name, Expr data,
ffi::Array<IntImm> pool_size,
return Call(op, {std::move(data)}, Attrs(attrs), {});
}
-Expr max_pool3d(Expr data, ffi::Array<IntImm> pool_size, ffi::Array<IntImm>
strides,
- ffi::Array<IntImm> padding, ffi::Array<IntImm> dilation, bool
ceil_mode,
+Expr max_pool3d(Expr data, ffi::Array<int64_t> pool_size, ffi::Array<int64_t>
strides,
+ ffi::Array<int64_t> padding, ffi::Array<int64_t> dilation,
bool ceil_mode,
bool count_include_pad, ffi::String layout,
ffi::Optional<ffi::String> out_layout) {
return MakePool3d("relax.nn.max_pool3d", data, pool_size, strides, padding,
dilation, ceil_mode,
count_include_pad, layout, out_layout);
@@ -374,12 +377,12 @@ StructInfo InferStructInfoPool3D(const Call& call, const
BlockBuilder& ctx) {
PrimExpr input_d = data_NCDHW_shape[2];
PrimExpr input_h = data_NCDHW_shape[3];
PrimExpr input_w = data_NCDHW_shape[4];
- PrimExpr kernel_d = attrs->pool_size[0];
- PrimExpr kernel_h = attrs->pool_size[1];
- PrimExpr kernel_w = attrs->pool_size[2];
- PrimExpr padding_d = attrs->padding[0] + attrs->padding[3];
- PrimExpr padding_h = attrs->padding[1] + attrs->padding[4];
- PrimExpr padding_w = attrs->padding[2] + attrs->padding[5];
+ PrimExpr kernel_d = Integer(attrs->pool_size[0]);
+ PrimExpr kernel_h = Integer(attrs->pool_size[1]);
+ PrimExpr kernel_w = Integer(attrs->pool_size[2]);
+ PrimExpr padding_d = Integer(attrs->padding[0]) + Integer(attrs->padding[3]);
+ PrimExpr padding_h = Integer(attrs->padding[1]) + Integer(attrs->padding[4]);
+ PrimExpr padding_w = Integer(attrs->padding[2]) + Integer(attrs->padding[5]);
arith::Analyzer* analyzer = ctx->GetAnalyzer();
std::vector<PrimExpr> out_NCDHW_shape;
@@ -387,21 +390,24 @@ StructInfo InferStructInfoPool3D(const Call& call, const
BlockBuilder& ctx) {
out_NCDHW_shape[0] = data_NCDHW_shape[0];
out_NCDHW_shape[1] = data_NCDHW_shape[1];
- PrimExpr numerator_d = input_d + padding_d - attrs->dilation[0] * (kernel_d
- 1) - 1;
- PrimExpr numerator_h = input_h + padding_h - attrs->dilation[1] * (kernel_h
- 1) - 1;
- PrimExpr numerator_w = input_w + padding_w - attrs->dilation[2] * (kernel_w
- 1) - 1;
+ PrimExpr numerator_d = input_d + padding_d - Integer(attrs->dilation[0]) *
(kernel_d - 1) - 1;
+ PrimExpr numerator_h = input_h + padding_h - Integer(attrs->dilation[1]) *
(kernel_h - 1) - 1;
+ PrimExpr numerator_w = input_w + padding_w - Integer(attrs->dilation[2]) *
(kernel_w - 1) - 1;
if (attrs->ceil_mode) {
- numerator_d += attrs->strides[0] - 1;
- numerator_h += attrs->strides[1] - 1;
- numerator_w += attrs->strides[2] - 1;
+ numerator_d += Integer(attrs->strides[0]) - 1;
+ numerator_h += Integer(attrs->strides[1]) - 1;
+ numerator_w += Integer(attrs->strides[2]) - 1;
}
- PrimExpr raw_out_d = floordiv(numerator_d, attrs->strides[0]) + 1;
- PrimExpr raw_out_h = floordiv(numerator_h, attrs->strides[1]) + 1;
- PrimExpr raw_out_w = floordiv(numerator_w, attrs->strides[2]) + 1;
+ PrimExpr raw_out_d = floordiv(numerator_d, Integer(attrs->strides[0])) + 1;
+ PrimExpr raw_out_h = floordiv(numerator_h, Integer(attrs->strides[1])) + 1;
+ PrimExpr raw_out_w = floordiv(numerator_w, Integer(attrs->strides[2])) + 1;
if (attrs->ceil_mode) {
- PrimExpr invalid_last_d = (raw_out_d - 1) * attrs->strides[0] >= input_d +
attrs->padding[0];
- PrimExpr invalid_last_h = (raw_out_h - 1) * attrs->strides[1] >= input_h +
attrs->padding[1];
- PrimExpr invalid_last_w = (raw_out_w - 1) * attrs->strides[2] >= input_w +
attrs->padding[2];
+ PrimExpr invalid_last_d =
+ (raw_out_d - 1) * Integer(attrs->strides[0]) >= input_d +
Integer(attrs->padding[0]);
+ PrimExpr invalid_last_h =
+ (raw_out_h - 1) * Integer(attrs->strides[1]) >= input_h +
Integer(attrs->padding[1]);
+ PrimExpr invalid_last_w =
+ (raw_out_w - 1) * Integer(attrs->strides[2]) >= input_w +
Integer(attrs->padding[2]);
out_NCDHW_shape[2] = analyzer->Simplify(if_then_else(invalid_last_d,
raw_out_d - 1, raw_out_d));
out_NCDHW_shape[3] = analyzer->Simplify(if_then_else(invalid_last_h,
raw_out_h - 1, raw_out_h));
out_NCDHW_shape[4] = analyzer->Simplify(if_then_else(invalid_last_w,
raw_out_w - 1, raw_out_w));
@@ -442,8 +448,8 @@ TVM_REGISTER_OP("relax.nn.max_pool3d")
.set_attr<Bool>("FPurity", Bool(true));
/* relax.nn.avg_pool1d */
-Expr avg_pool1d(Expr data, ffi::Array<IntImm> pool_size, ffi::Array<IntImm>
strides,
- ffi::Array<IntImm> padding, ffi::Array<IntImm> dilation, bool
ceil_mode,
+Expr avg_pool1d(Expr data, ffi::Array<int64_t> pool_size, ffi::Array<int64_t>
strides,
+ ffi::Array<int64_t> padding, ffi::Array<int64_t> dilation,
bool ceil_mode,
bool count_include_pad, ffi::String layout,
ffi::Optional<ffi::String> out_layout) {
return MakePool1d("relax.nn.avg_pool1d", data, pool_size, strides, padding,
dilation, ceil_mode,
count_include_pad, layout, out_layout);
@@ -464,8 +470,8 @@ TVM_REGISTER_OP("relax.nn.avg_pool1d")
.set_attr<Bool>("FPurity", Bool(true));
/* relax.nn.avg_pool2d */
-Expr avg_pool2d(Expr data, ffi::Array<IntImm> pool_size, ffi::Array<IntImm>
strides,
- ffi::Array<IntImm> padding, ffi::Array<IntImm> dilation, bool
ceil_mode,
+Expr avg_pool2d(Expr data, ffi::Array<int64_t> pool_size, ffi::Array<int64_t>
strides,
+ ffi::Array<int64_t> padding, ffi::Array<int64_t> dilation,
bool ceil_mode,
bool count_include_pad, ffi::String layout,
ffi::Optional<ffi::String> out_layout) {
return MakePool2d("relax.nn.avg_pool2d", data, pool_size, strides, padding,
dilation, ceil_mode,
count_include_pad, layout, out_layout);
@@ -486,8 +492,8 @@ TVM_REGISTER_OP("relax.nn.avg_pool2d")
.set_attr<Bool>("FPurity", Bool(true));
/* relax.nn.avg_pool3d */
-Expr avg_pool3d(Expr data, ffi::Array<IntImm> pool_size, ffi::Array<IntImm>
strides,
- ffi::Array<IntImm> padding, ffi::Array<IntImm> dilation, bool
ceil_mode,
+Expr avg_pool3d(Expr data, ffi::Array<int64_t> pool_size, ffi::Array<int64_t>
strides,
+ ffi::Array<int64_t> padding, ffi::Array<int64_t> dilation,
bool ceil_mode,
bool count_include_pad, ffi::String layout,
ffi::Optional<ffi::String> out_layout) {
return MakePool3d("relax.nn.avg_pool3d", data, pool_size, strides, padding,
dilation, ceil_mode,
count_include_pad, layout, out_layout);
@@ -509,13 +515,13 @@ TVM_REGISTER_OP("relax.nn.avg_pool3d")
/* relax.nn.adaptive_avg_pool1d */
-Expr adaptive_avg_pool1d(Expr data, ffi::Optional<ffi::Array<IntImm>>
output_size,
+Expr adaptive_avg_pool1d(Expr data, ffi::Optional<ffi::Array<int64_t>>
output_size,
ffi::String layout, ffi::Optional<ffi::String>
out_layout) {
ObjectPtr<AdaptivePool1DAttrs> attrs =
ffi::make_object<AdaptivePool1DAttrs>();
attrs->layout = layout;
attrs->out_layout = out_layout.value_or(layout);
if (output_size.defined()) {
- ffi::Array<IntImm> _output_size = output_size.value();
+ ffi::Array<int64_t> _output_size = output_size.value();
CHECK_EQ(_output_size.size(), 1)
<< "The output_size length is expected to be 1. However, the given
output_size is "
<< _output_size;
@@ -556,7 +562,7 @@ StructInfo InferStructInfoAdaptiveAvgPool1D(const Call&
call, const BlockBuilder
ffi::Array<PrimExpr> data_NCW_shape =
data2NCW.ForwardShape(data_shape.value()->values);
ffi::Array<PrimExpr> out_NCW_shape(data_NCW_shape);
if (attrs->output_size.defined()) {
- out_NCW_shape.Set(2, attrs->output_size.value()[0]);
+ out_NCW_shape.Set(2, Integer(attrs->output_size.value()[0]));
}
ffi::Array<PrimExpr> out_shape = out2NCW.BackwardShape(out_NCW_shape);
@@ -591,13 +597,13 @@ TVM_REGISTER_OP("relax.nn.adaptive_avg_pool1d")
/* relax.nn.adaptive_avg_pool2d */
-Expr adaptive_avg_pool2d(Expr data, ffi::Optional<ffi::Array<IntImm>>
output_size,
+Expr adaptive_avg_pool2d(Expr data, ffi::Optional<ffi::Array<int64_t>>
output_size,
ffi::String layout, ffi::Optional<ffi::String>
out_layout) {
ObjectPtr<AdaptivePool2DAttrs> attrs =
ffi::make_object<AdaptivePool2DAttrs>();
attrs->layout = layout;
attrs->out_layout = out_layout.value_or(layout);
if (output_size.defined()) {
- ffi::Array<IntImm> _output_size = output_size.value();
+ ffi::Array<int64_t> _output_size = output_size.value();
if (_output_size.size() == 1) {
_output_size.push_back(_output_size[0]);
}
@@ -641,8 +647,8 @@ StructInfo InferStructInfoAdaptiveAvgPool2D(const Call&
call, const BlockBuilder
ffi::Array<PrimExpr> data_NCHW_shape =
data2NCHW.ForwardShape(data_shape.value()->values);
ffi::Array<PrimExpr> out_NCHW_shape(data_NCHW_shape);
if (attrs->output_size.defined()) {
- out_NCHW_shape.Set(2, attrs->output_size.value()[0]);
- out_NCHW_shape.Set(3, attrs->output_size.value()[1]);
+ out_NCHW_shape.Set(2, Integer(attrs->output_size.value()[0]));
+ out_NCHW_shape.Set(3, Integer(attrs->output_size.value()[1]));
}
ffi::Array<PrimExpr> out_shape = out2NCHW.BackwardShape(out_NCHW_shape);
@@ -693,13 +699,13 @@ TVM_REGISTER_OP("relax.nn.adaptive_avg_pool2d")
/* relax.nn.adaptive_avg_pool3d */
-Expr adaptive_avg_pool3d(Expr data, ffi::Optional<ffi::Array<IntImm>>
output_size,
+Expr adaptive_avg_pool3d(Expr data, ffi::Optional<ffi::Array<int64_t>>
output_size,
ffi::String layout, ffi::Optional<ffi::String>
out_layout) {
ObjectPtr<AdaptivePool3DAttrs> attrs =
ffi::make_object<AdaptivePool3DAttrs>();
attrs->layout = layout;
attrs->out_layout = out_layout.value_or(layout);
if (output_size.defined()) {
- ffi::Array<IntImm> _output_size = output_size.value();
+ ffi::Array<int64_t> _output_size = output_size.value();
if (_output_size.size() == 1) {
_output_size.push_back(_output_size[0]);
}
@@ -743,9 +749,9 @@ StructInfo InferStructInfoAdaptiveAvgPool3D(const Call&
call, const BlockBuilder
ffi::Array<PrimExpr> data_NCDHW_shape =
data2NCDHW.ForwardShape(data_shape.value()->values);
ffi::Array<PrimExpr> out_NCDHW_shape(data_NCDHW_shape);
if (attrs->output_size.defined()) {
- out_NCDHW_shape.Set(2, attrs->output_size.value()[0]);
- out_NCDHW_shape.Set(3, attrs->output_size.value()[1]);
- out_NCDHW_shape.Set(4, attrs->output_size.value()[2]);
+ out_NCDHW_shape.Set(2, Integer(attrs->output_size.value()[0]));
+ out_NCDHW_shape.Set(3, Integer(attrs->output_size.value()[1]));
+ out_NCDHW_shape.Set(4, Integer(attrs->output_size.value()[2]));
}
ffi::Array<PrimExpr> out_shape = out2NCDHW.BackwardShape(out_NCDHW_shape);
diff --git a/src/relax/op/nn/pooling.h b/src/relax/op/nn/pooling.h
index c5435303e8..d1fbc834ee 100644
--- a/src/relax/op/nn/pooling.h
+++ b/src/relax/op/nn/pooling.h
@@ -33,17 +33,17 @@ namespace tvm {
namespace relax {
/*! \brief 2D maximum pooling operator. */
-Expr max_pool2d(Expr data, ffi::Array<IntImm> pool_size, ffi::Array<IntImm>
strides,
- ffi::Array<IntImm> padding, ffi::Array<IntImm> dilation, bool
ceil_mode,
+Expr max_pool2d(Expr data, ffi::Array<int64_t> pool_size, ffi::Array<int64_t>
strides,
+ ffi::Array<int64_t> padding, ffi::Array<int64_t> dilation,
bool ceil_mode,
bool count_include_pad, ffi::String layout,
ffi::Optional<ffi::String> out_layout);
/*! \brief 2D average pooling operator. */
-Expr avg_pool2d(Expr data, ffi::Array<IntImm> pool_size, ffi::Array<IntImm>
strides,
- ffi::Array<IntImm> padding, ffi::Array<IntImm> dilation, bool
ceil_mode,
+Expr avg_pool2d(Expr data, ffi::Array<int64_t> pool_size, ffi::Array<int64_t>
strides,
+ ffi::Array<int64_t> padding, ffi::Array<int64_t> dilation,
bool ceil_mode,
bool count_include_pad, ffi::String layout,
ffi::Optional<ffi::String> out_layout);
/*! \brief 2D adaptive average pooling operator. */
-Expr adaptive_avg_pool2d(Expr data, ffi::Optional<ffi::Array<IntImm>>
output_size,
+Expr adaptive_avg_pool2d(Expr data, ffi::Optional<ffi::Array<int64_t>>
output_size,
ffi::String layout, ffi::Optional<ffi::String>
out_layout);
} // namespace relax
diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h
index ee82f3eebc..cd5406e614 100644
--- a/src/relax/op/op_common.h
+++ b/src/relax/op/op_common.h
@@ -465,7 +465,7 @@ inline ffi::Array<IntImm> ConvertIntImmToInt64(const
ffi::Array<IntImm>& int_imm
* \return The completed padding.
* \throws Throws error if the input padding length is neither 1 or 2.
*/
-inline ffi::Array<IntImm> GetCompletePadding1D(ffi::Array<IntImm> padding) {
+inline ffi::Array<int64_t> GetCompletePadding1D(ffi::Array<int64_t> padding) {
if (padding.size() == 1) {
return {padding[0], padding[0]};
} else if (padding.size() == 2) {
@@ -486,7 +486,7 @@ inline ffi::Array<IntImm>
GetCompletePadding1D(ffi::Array<IntImm> padding) {
* \return The completed padding.
* \throws Throws error if the input padding length is neither 1, 2 or 4.
*/
-inline ffi::Array<IntImm> GetCompletePadding2D(ffi::Array<IntImm> padding) {
+inline ffi::Array<int64_t> GetCompletePadding2D(ffi::Array<int64_t> padding) {
if (padding.size() == 1) {
return {padding[0], padding[0], padding[0], padding[0]};
} else if (padding.size() == 2) {
@@ -511,7 +511,7 @@ inline ffi::Array<IntImm>
GetCompletePadding2D(ffi::Array<IntImm> padding) {
* \return The completed padding.
* \throws Throws error if the input padding length is neither 1, 3 or 6.
*/
-inline ffi::Array<IntImm> GetCompletePadding3D(ffi::Array<IntImm> padding) {
+inline ffi::Array<int64_t> GetCompletePadding3D(ffi::Array<int64_t> padding) {
if (padding.size() == 1) {
return {padding[0], padding[0], padding[0], padding[0], padding[0],
padding[0]};
} else if (padding.size() == 3) {
diff --git a/src/relax/op/tensor/grad.cc b/src/relax/op/tensor/grad.cc
index 52a218b730..0594ef75bd 100644
--- a/src/relax/op/tensor/grad.cc
+++ b/src/relax/op/tensor/grad.cc
@@ -141,15 +141,15 @@ TVM_REGISTER_OP("relax.grad.nll_loss_backward")
.set_attr<Bool>("FPurity", Bool(true));
/* relax.grad.max_pool2d_backward */
-Expr max_pool2d_backward(Expr output_grad, Expr data, ffi::Array<IntImm>
pool_size,
- ffi::Array<IntImm> strides, ffi::Array<IntImm>
padding,
- ffi::Array<IntImm> dilation, bool ceil_mode, bool
count_include_pad,
+Expr max_pool2d_backward(Expr output_grad, Expr data, ffi::Array<int64_t>
pool_size,
+ ffi::Array<int64_t> strides, ffi::Array<int64_t>
padding,
+ ffi::Array<int64_t> dilation, bool ceil_mode, bool
count_include_pad,
ffi::String layout, ffi::Optional<ffi::String>
out_layout) {
auto attrs = ffi::make_object<Pool2DAttrs>();
attrs->pool_size = std::move(pool_size);
- attrs->strides = ConvertIntImmToInt64(strides);
- attrs->padding = ConvertIntImmToInt64(padding);
- attrs->dilation = ConvertIntImmToInt64(dilation);
+ attrs->strides = std::move(strides);
+ attrs->padding = std::move(padding);
+ attrs->dilation = std::move(dilation);
attrs->ceil_mode = ceil_mode;
attrs->count_include_pad = count_include_pad;
attrs->layout = layout;
@@ -176,15 +176,15 @@ TVM_REGISTER_OP("relax.grad.max_pool2d_backward")
.set_attr<Bool>("FPurity", Bool(true));
/* relax.grad.avg_pool2d_backward */
-Expr avg_pool2d_backward(Expr output_grad, Expr data, ffi::Array<IntImm>
pool_size,
- ffi::Array<IntImm> strides, ffi::Array<IntImm>
padding,
- ffi::Array<IntImm> dilation, bool ceil_mode, bool
count_include_pad,
+Expr avg_pool2d_backward(Expr output_grad, Expr data, ffi::Array<int64_t>
pool_size,
+ ffi::Array<int64_t> strides, ffi::Array<int64_t>
padding,
+ ffi::Array<int64_t> dilation, bool ceil_mode, bool
count_include_pad,
ffi::String layout, ffi::Optional<ffi::String>
out_layout) {
auto attrs = ffi::make_object<Pool2DAttrs>();
attrs->pool_size = std::move(pool_size);
- attrs->strides = ConvertIntImmToInt64(strides);
- attrs->padding = ConvertIntImmToInt64(padding);
- attrs->dilation = ConvertIntImmToInt64(dilation);
+ attrs->strides = std::move(strides);
+ attrs->padding = std::move(padding);
+ attrs->dilation = std::move(dilation);
attrs->ceil_mode = ceil_mode;
attrs->count_include_pad = count_include_pad;
attrs->layout = layout;
diff --git a/src/relax/op/tensor/grad.h b/src/relax/op/tensor/grad.h
index 406d7a2f77..911049475d 100644
--- a/src/relax/op/tensor/grad.h
+++ b/src/relax/op/tensor/grad.h
@@ -46,16 +46,16 @@ Expr nll_loss_backward(Expr output_grad, Expr predictions,
Expr targets,
/*! \brief Backward operator of relax.max_pool2d. All parameters except
output_grad is the same as
* relax.max_pool2d. Returns the gradient w.r.t. data. */
-Expr max_pool2d_backward(Expr output_grad, Expr data, ffi::Array<IntImm>
pool_size,
- ffi::Array<IntImm> strides, ffi::Array<IntImm>
padding,
- ffi::Array<IntImm> dilation, bool ceil_mode, bool
count_include_pad,
+Expr max_pool2d_backward(Expr output_grad, Expr data, ffi::Array<int64_t>
pool_size,
+ ffi::Array<int64_t> strides, ffi::Array<int64_t>
padding,
+ ffi::Array<int64_t> dilation, bool ceil_mode, bool
count_include_pad,
ffi::String layout, ffi::Optional<ffi::String>
out_layout);
/*! \brief Backward operator of relax.avg_pool2d. All parameters except
output_grad is the same as
* relax.avg_pool2d. Returns the gradient w.r.t. data. */
-Expr avg_pool2d_backward(Expr output_grad, Expr data, ffi::Array<IntImm>
pool_size,
- ffi::Array<IntImm> strides, ffi::Array<IntImm>
padding,
- ffi::Array<IntImm> dilation, bool ceil_mode, bool
count_include_pad,
+Expr avg_pool2d_backward(Expr output_grad, Expr data, ffi::Array<int64_t>
pool_size,
+ ffi::Array<int64_t> strides, ffi::Array<int64_t>
padding,
+ ffi::Array<int64_t> dilation, bool ceil_mode, bool
count_include_pad,
ffi::String layout, ffi::Optional<ffi::String>
out_layout);
/*! \brief Backward operator of relax.take. All parameters except output_grad
is the same as
diff --git a/tests/python/relax/test_dataflow_pattern.py
b/tests/python/relax/test_dataflow_pattern.py
index ee21c14c6f..1d2a6e4676 100644
--- a/tests/python/relax/test_dataflow_pattern.py
+++ b/tests/python/relax/test_dataflow_pattern.py
@@ -277,10 +277,8 @@ def test_op_attr():
conv2d = rx.op.nn.conv2d(x, y, strides=(3, 3))
xp = is_var("x")
yp = is_var("y")
- # TODO(@yuchen): reenable the assert after figuring out why it fails
- # assert is_op("nn.conv2d")(xp, yp).has_attr({"strides": [3,
3]}).match(conv2d)
+ assert is_op("relax.nn.conv2d")(xp, yp).has_attr({"strides": [3,
3]}).match(conv2d)
assert not is_op("relax.nn.conv2d")(xp, yp).has_attr({"strides": [4,
3]}).match(conv2d)
- assert not is_op("relax.nn.conv2d")(xp, yp).has_attr({"strides": [3,
3]}).match(conv2d)
def test_match_call_attr():
diff --git a/tests/python/relax/test_op_nn_convolution.py
b/tests/python/relax/test_op_nn_convolution.py
index 9b913138df..4ee8226bc3 100644
--- a/tests/python/relax/test_op_nn_convolution.py
+++ b/tests/python/relax/test_op_nn_convolution.py
@@ -352,10 +352,10 @@ def test_conv1d_stride_padding_dilation_int64():
w = relax.Var("w", R.Tensor((4, 3, 3), "float32"))
conv1d = relax.op.nn.conv1d(x, w, strides=(1,), padding=(1, 1),
dilation=(1,))
- assert conv1d.attrs.strides[0].dtype == "int64"
- assert conv1d.attrs.padding[0].dtype == "int64"
- assert conv1d.attrs.padding[1].dtype == "int64"
- assert conv1d.attrs.dilation[0].dtype == "int64"
+ assert isinstance(conv1d.attrs.strides[0], int)
+ assert isinstance(conv1d.attrs.padding[0], int)
+ assert isinstance(conv1d.attrs.padding[1], int)
+ assert isinstance(conv1d.attrs.dilation[0], int)
def test_conv1d_wrong_strides_padding_dilation_length():
@@ -711,9 +711,9 @@ def test_conv1d_transpose_stride_padding_dilation_int64():
w = relax.Var("w", R.Tensor((3, 4, 3), "float32"))
conv1d = relax.op.nn.conv1d_transpose(x, w, strides=1, padding=1,
dilation=1)
- assert conv1d.attrs.strides[0].dtype == "int64"
- assert conv1d.attrs.padding[0].dtype == "int64"
- assert conv1d.attrs.dilation[0].dtype == "int64"
+ assert isinstance(conv1d.attrs.strides[0], int)
+ assert isinstance(conv1d.attrs.padding[0], int)
+ assert isinstance(conv1d.attrs.dilation[0], int)
def test_conv1d_transpose_wrong_strides_padding_dilation_length():
@@ -1122,14 +1122,14 @@ def test_conv2d_stride_padding_dilation_int64():
w = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32"))
conv2d = relax.op.nn.conv2d(x, w, strides=(1, 1), padding=(1, 1),
dilation=(1, 1))
- assert conv2d.attrs.strides[0].dtype == "int64"
- assert conv2d.attrs.strides[1].dtype == "int64"
- assert conv2d.attrs.padding[0].dtype == "int64"
- assert conv2d.attrs.padding[1].dtype == "int64"
- assert conv2d.attrs.padding[2].dtype == "int64"
- assert conv2d.attrs.padding[3].dtype == "int64"
- assert conv2d.attrs.dilation[0].dtype == "int64"
- assert conv2d.attrs.dilation[1].dtype == "int64"
+ assert isinstance(conv2d.attrs.strides[0], int)
+ assert isinstance(conv2d.attrs.strides[1], int)
+ assert isinstance(conv2d.attrs.padding[0], int)
+ assert isinstance(conv2d.attrs.padding[1], int)
+ assert isinstance(conv2d.attrs.padding[2], int)
+ assert isinstance(conv2d.attrs.padding[3], int)
+ assert isinstance(conv2d.attrs.dilation[0], int)
+ assert isinstance(conv2d.attrs.dilation[1], int)
def test_conv2d_wrong_strides_padding_dilation_length():
@@ -1510,16 +1510,16 @@ def
test_conv2d_transpose_stride_padding_dilation_int64():
x, w, strides=(1, 1), padding=(1, 1), output_padding=(1, 2),
dilation=(1, 1)
)
- assert conv2d_transpose.attrs.strides[0].dtype == "int64"
- assert conv2d_transpose.attrs.strides[1].dtype == "int64"
- assert conv2d_transpose.attrs.padding[0].dtype == "int64"
- assert conv2d_transpose.attrs.padding[1].dtype == "int64"
- assert conv2d_transpose.attrs.padding[2].dtype == "int64"
- assert conv2d_transpose.attrs.padding[3].dtype == "int64"
- assert conv2d_transpose.attrs.output_padding[0].dtype == "int64"
- assert conv2d_transpose.attrs.output_padding[1].dtype == "int64"
- assert conv2d_transpose.attrs.dilation[0].dtype == "int64"
- assert conv2d_transpose.attrs.dilation[1].dtype == "int64"
+ assert isinstance(conv2d_transpose.attrs.strides[0], int)
+ assert isinstance(conv2d_transpose.attrs.strides[1], int)
+ assert isinstance(conv2d_transpose.attrs.padding[0], int)
+ assert isinstance(conv2d_transpose.attrs.padding[1], int)
+ assert isinstance(conv2d_transpose.attrs.padding[2], int)
+ assert isinstance(conv2d_transpose.attrs.padding[3], int)
+ assert isinstance(conv2d_transpose.attrs.output_padding[0], int)
+ assert isinstance(conv2d_transpose.attrs.output_padding[1], int)
+ assert isinstance(conv2d_transpose.attrs.dilation[0], int)
+ assert isinstance(conv2d_transpose.attrs.dilation[1], int)
def test_conv2d_transpose_wrong_strides_padding_dilation_length():
diff --git a/tests/python/relax/test_op_nn_pooling.py
b/tests/python/relax/test_op_nn_pooling.py
index d4461a122d..12d099bb33 100644
--- a/tests/python/relax/test_op_nn_pooling.py
+++ b/tests/python/relax/test_op_nn_pooling.py
@@ -183,10 +183,10 @@ def test_max_pool1d_stride_padding_dilation_int64():
x = relax.Var("x", R.Tensor((2, 3, 28), "float32"))
max_pool1d = relax.op.nn.max_pool1d(x, pool_size=3, strides=1, padding=1,
dilation=1)
- assert max_pool1d.attrs.strides[0].dtype == "int64"
- assert max_pool1d.attrs.padding[0].dtype == "int64"
- assert max_pool1d.attrs.padding[1].dtype == "int64"
- assert max_pool1d.attrs.dilation[0].dtype == "int64"
+ assert isinstance(max_pool1d.attrs.strides[0], int)
+ assert isinstance(max_pool1d.attrs.padding[0], int)
+ assert isinstance(max_pool1d.attrs.padding[1], int)
+ assert isinstance(max_pool1d.attrs.dilation[0], int)
def test_max_pool1d_wrong_pool_size_strides_padding_dilation_length():
@@ -412,14 +412,14 @@ def test_max_pool2d_stride_padding_dilation_int64():
x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
max_pool2d = relax.op.nn.max_pool2d(x, (3, 3), strides=(1, 1), padding=(1,
1), dilation=(1, 1))
- assert max_pool2d.attrs.strides[0].dtype == "int64"
- assert max_pool2d.attrs.strides[1].dtype == "int64"
- assert max_pool2d.attrs.padding[0].dtype == "int64"
- assert max_pool2d.attrs.padding[1].dtype == "int64"
- assert max_pool2d.attrs.padding[2].dtype == "int64"
- assert max_pool2d.attrs.padding[3].dtype == "int64"
- assert max_pool2d.attrs.dilation[0].dtype == "int64"
- assert max_pool2d.attrs.dilation[1].dtype == "int64"
+ assert isinstance(max_pool2d.attrs.strides[0], int)
+ assert isinstance(max_pool2d.attrs.strides[1], int)
+ assert isinstance(max_pool2d.attrs.padding[0], int)
+ assert isinstance(max_pool2d.attrs.padding[1], int)
+ assert isinstance(max_pool2d.attrs.padding[2], int)
+ assert isinstance(max_pool2d.attrs.padding[3], int)
+ assert isinstance(max_pool2d.attrs.dilation[0], int)
+ assert isinstance(max_pool2d.attrs.dilation[1], int)
def test_max_pool2d_wrong_pool_size_strides_padding_dilation_length():
@@ -660,17 +660,17 @@ def test_max_pool3d_stride_padding_dilation_int64():
x, (3, 3, 3), strides=(1, 1, 1), padding=(1, 1, 1), dilation=(1, 1, 1)
)
- assert max_pool3d.attrs.strides[0].dtype == "int64"
- assert max_pool3d.attrs.strides[1].dtype == "int64"
- assert max_pool3d.attrs.strides[2].dtype == "int64"
- assert max_pool3d.attrs.padding[0].dtype == "int64"
- assert max_pool3d.attrs.padding[1].dtype == "int64"
- assert max_pool3d.attrs.padding[2].dtype == "int64"
- assert max_pool3d.attrs.padding[3].dtype == "int64"
- assert max_pool3d.attrs.padding[4].dtype == "int64"
- assert max_pool3d.attrs.dilation[0].dtype == "int64"
- assert max_pool3d.attrs.dilation[1].dtype == "int64"
- assert max_pool3d.attrs.dilation[2].dtype == "int64"
+ assert isinstance(max_pool3d.attrs.strides[0], int)
+ assert isinstance(max_pool3d.attrs.strides[1], int)
+ assert isinstance(max_pool3d.attrs.strides[2], int)
+ assert isinstance(max_pool3d.attrs.padding[0], int)
+ assert isinstance(max_pool3d.attrs.padding[1], int)
+ assert isinstance(max_pool3d.attrs.padding[2], int)
+ assert isinstance(max_pool3d.attrs.padding[3], int)
+ assert isinstance(max_pool3d.attrs.padding[4], int)
+ assert isinstance(max_pool3d.attrs.dilation[0], int)
+ assert isinstance(max_pool3d.attrs.dilation[1], int)
+ assert isinstance(max_pool3d.attrs.dilation[2], int)
def test_max_pool3d_wrong_pool_size_strides_padding_dilation_length():
@@ -875,10 +875,10 @@ def test_avg_pool1d_stride_padding_dilation_int64():
x = relax.Var("x", R.Tensor((2, 3, 28), "float32"))
avg_pool1d = relax.op.nn.avg_pool1d(x, 3, strides=1, padding=1, dilation=1)
- assert avg_pool1d.attrs.strides[0].dtype == "int64"
- assert avg_pool1d.attrs.padding[0].dtype == "int64"
- assert avg_pool1d.attrs.padding[1].dtype == "int64"
- assert avg_pool1d.attrs.dilation[0].dtype == "int64"
+ assert isinstance(avg_pool1d.attrs.strides[0], int)
+ assert isinstance(avg_pool1d.attrs.padding[0], int)
+ assert isinstance(avg_pool1d.attrs.padding[1], int)
+ assert isinstance(avg_pool1d.attrs.dilation[0], int)
def test_avg_pool1d_wrong_pool_size_strides_padding_dilation_length():
@@ -1101,14 +1101,14 @@ def test_avg_pool2d_stride_padding_dilation_int64():
x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
avg_pool2d = relax.op.nn.avg_pool2d(x, (3, 3), strides=(1, 1), padding=(1,
1), dilation=(1, 1))
- assert avg_pool2d.attrs.strides[0].dtype == "int64"
- assert avg_pool2d.attrs.strides[1].dtype == "int64"
- assert avg_pool2d.attrs.padding[0].dtype == "int64"
- assert avg_pool2d.attrs.padding[1].dtype == "int64"
- assert avg_pool2d.attrs.padding[2].dtype == "int64"
- assert avg_pool2d.attrs.padding[3].dtype == "int64"
- assert avg_pool2d.attrs.dilation[0].dtype == "int64"
- assert avg_pool2d.attrs.dilation[1].dtype == "int64"
+ assert isinstance(avg_pool2d.attrs.strides[0], int)
+ assert isinstance(avg_pool2d.attrs.strides[1], int)
+ assert isinstance(avg_pool2d.attrs.padding[0], int)
+ assert isinstance(avg_pool2d.attrs.padding[1], int)
+ assert isinstance(avg_pool2d.attrs.padding[2], int)
+ assert isinstance(avg_pool2d.attrs.padding[3], int)
+ assert isinstance(avg_pool2d.attrs.dilation[0], int)
+ assert isinstance(avg_pool2d.attrs.dilation[1], int)
def test_avg_pool2d_wrong_pool_size_strides_padding_dilation_length():
@@ -1356,15 +1356,15 @@ def test_avg_pool3d_stride_padding_dilation_int64():
x, (3, 3, 3), strides=(1, 1, 1), padding=(1, 1, 1), dilation=(1, 1, 1)
)
- assert avg_pool3d.attrs.strides[0].dtype == "int64"
- assert avg_pool3d.attrs.strides[1].dtype == "int64"
- assert avg_pool3d.attrs.strides[2].dtype == "int64"
- assert avg_pool3d.attrs.padding[0].dtype == "int64"
- assert avg_pool3d.attrs.padding[1].dtype == "int64"
- assert avg_pool3d.attrs.padding[2].dtype == "int64"
- assert avg_pool3d.attrs.dilation[0].dtype == "int64"
- assert avg_pool3d.attrs.dilation[1].dtype == "int64"
- assert avg_pool3d.attrs.dilation[2].dtype == "int64"
+ assert isinstance(avg_pool3d.attrs.strides[0], int)
+ assert isinstance(avg_pool3d.attrs.strides[1], int)
+ assert isinstance(avg_pool3d.attrs.strides[2], int)
+ assert isinstance(avg_pool3d.attrs.padding[0], int)
+ assert isinstance(avg_pool3d.attrs.padding[1], int)
+ assert isinstance(avg_pool3d.attrs.padding[2], int)
+ assert isinstance(avg_pool3d.attrs.dilation[0], int)
+ assert isinstance(avg_pool3d.attrs.dilation[1], int)
+ assert isinstance(avg_pool3d.attrs.dilation[2], int)
def test_avg_pool3d_wrong_pool_size_strides_padding_dilation_length():
diff --git a/tests/python/relax/test_transform_legalize_ops_grad.py
b/tests/python/relax/test_transform_legalize_ops_grad.py
index cf9361d7c7..294cea71de 100644
--- a/tests/python/relax/test_transform_legalize_ops_grad.py
+++ b/tests/python/relax/test_transform_legalize_ops_grad.py
@@ -272,17 +272,17 @@ def test_avg_pool2d_backward():
@I.ir_module
class Expected:
@T.prim_func(private=True)
- def avg_pool2d_backward(rxplaceholder: T.Buffer((T.int64(3),
T.int64(2), T.int64(6), T.int64(5)), "float32"), rxplaceholder_1:
T.Buffer((T.int64(3), T.int64(2), T.int64(10), T.int64(10)), "float32"),
T_pool_grad: T.Buffer((T.int64(3), T.int64(2), T.int64(10), T.int64(10)),
"float32")):
+ def avg_pool2d_backward(output_grad: T.Buffer((T.int64(3), T.int64(2),
T.int64(6), T.int64(5)), "float32"), data: T.Buffer((T.int64(3), T.int64(2),
T.int64(10), T.int64(10)), "float32"), T_pool_grad: T.Buffer((T.int64(3),
T.int64(2), T.int64(10), T.int64(10)), "float32")):
T.func_attr({"tir.noalias": True})
# with T.sblock("root"):
for ax0, ax1, ax2, ax3, wh, ww in T.grid(T.int64(3), T.int64(2),
T.int64(10), T.int64(10), T.int64(3), T.int64(3)):
with T.sblock("T_pool_grad"):
v_ax0, v_ax1, v_ax2, v_ax3, v_wh, v_ww =
T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, wh, ww])
- T.reads(rxplaceholder[v_ax0, v_ax1, T.Div((v_ax2 +
T.int64(2)), T.int64(2)) - v_wh, T.Div((v_ax3 + T.int64(1)), T.int64(2)) -
v_ww])
+ T.reads(output_grad[v_ax0, v_ax1, T.Div(v_ax2 +
T.int64(2), T.int64(2)) - v_wh, T.Div(v_ax3 + T.int64(1), T.int64(2)) - v_ww])
T.writes(T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3])
with T.init():
T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] = T.float32(0)
- T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] =
T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] + T.if_then_else(T.Select(v_ax2 <
T.int64(3), T.int64(0), T.Div(v_ax2 - T.int64(3), T.int64(2)) + T.int64(1)) <=
T.Div(v_ax2 + T.int64(2), T.int64(2)) - v_wh and T.Div(v_ax2 + T.int64(2),
T.int64(2)) - v_wh < T.int64(6) and T.Select(v_ax3 < T.int64(4), T.int64(0),
T.Div(v_ax3 - T.int64(4), T.int64(2)) + T.int64(1)) <= T.Div(v_ax3 +
T.int64(1), T.int64(2)) - v_ww and T.Div(v_ax3 + T.int64 [...]
+ T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] =
T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] + T.if_then_else(T.Select(v_ax2 <
T.int64(3), T.int64(0), T.Div(v_ax2 - T.int64(3), T.int64(2)) + T.int64(1)) <=
T.Div(v_ax2 + T.int64(2), T.int64(2)) - v_wh and T.Div(v_ax2 + T.int64(2),
T.int64(2)) - v_wh < T.int64(6) and T.Select(v_ax3 < T.int64(4), T.int64(0),
T.Div(v_ax3 - T.int64(4), T.int64(2)) + T.int64(1)) <= T.Div(v_ax3 +
T.int64(1), T.int64(2)) - v_ww and T.Div(v_ax3 + T.int64 [...]
@R.function
def main(output_grad: R.Tensor((3, 2, 6, 5), dtype="float32"), data:
R.Tensor((3, 2, 10, 10), dtype="float32")) -> R.Tensor((3, 2, 10, 10),
dtype="float32"):