This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 262c6d2e04 [Relax] Migrate NN conv/pooling/grad attrs from 
Array<IntImm> to Array<int64_t> (#18733)
262c6d2e04 is described below

commit 262c6d2e041e6a5d5775049b32b166b526b49fa9
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Wed Feb 11 22:38:33 2026 +0800

    [Relax] Migrate NN conv/pooling/grad attrs from Array<IntImm> to 
Array<int64_t> (#18733)
    
    ## Why
    has_attr pattern matching failed for integer attributes because
    StructuralEqual requires exact dtype match for IntImm,
    causing mismatches between Python-specified pattern values and actual
    operator attributes.
    ## How
    
    - Migrate NN conv/pooling/grad attrs from Array<IntImm> to
    Array<int64_t>
    
    ---------
    
    Signed-off-by: Guan-Ming Chiu <[email protected]>
---
 include/tvm/relax/attrs/nn.h                       |  64 ++++----
 python/tvm/relax/dpl/pattern.py                    |   4 +-
 python/tvm/relax/frontend/onnx/onnx_frontend.py    |   2 +-
 python/tvm/relax/op/_op_gradient.py                |   2 +-
 src/contrib/msc/core/utils.cc                      |   4 +-
 .../msc/framework/tensorrt/transform_tensorrt.cc   |  13 +-
 .../backend/contrib/codegen_json/codegen_json.h    |   6 +-
 src/relax/backend/contrib/nnapi/codegen.cc         |  22 +--
 src/relax/op/nn/convolution.cc                     | 109 ++++++-------
 src/relax/op/nn/convolution.h                      |  38 ++---
 src/relax/op/nn/pooling.cc                         | 174 +++++++++++----------
 src/relax/op/nn/pooling.h                          |  10 +-
 src/relax/op/op_common.h                           |   6 +-
 src/relax/op/tensor/grad.cc                        |  24 +--
 src/relax/op/tensor/grad.h                         |  12 +-
 tests/python/relax/test_dataflow_pattern.py        |   4 +-
 tests/python/relax/test_op_nn_convolution.py       |  50 +++---
 tests/python/relax/test_op_nn_pooling.py           |  88 +++++------
 .../relax/test_transform_legalize_ops_grad.py      |   6 +-
 19 files changed, 319 insertions(+), 319 deletions(-)

diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h
index 13a54a16b3..2a2ac5fe07 100644
--- a/include/tvm/relax/attrs/nn.h
+++ b/include/tvm/relax/attrs/nn.h
@@ -31,9 +31,9 @@ namespace relax {
 
 /*! \brief Attributes used in Conv1d operator */
 struct Conv1DAttrs : public AttrsNodeReflAdapter<Conv1DAttrs> {
-  ffi::Array<IntImm> strides;
-  ffi::Array<IntImm> padding;
-  ffi::Array<IntImm> dilation;
+  ffi::Array<int64_t> strides;
+  ffi::Array<int64_t> padding;
+  ffi::Array<int64_t> dilation;
   int groups;
   ffi::String data_layout;
   ffi::String kernel_layout;
@@ -75,9 +75,9 @@ struct Conv1DAttrs : public AttrsNodeReflAdapter<Conv1DAttrs> 
{
 
 /*! \brief Attributes used in Conv2d operator */
 struct Conv2DAttrs : public AttrsNodeReflAdapter<Conv2DAttrs> {
-  ffi::Array<IntImm> strides;
-  ffi::Array<IntImm> padding;
-  ffi::Array<IntImm> dilation;
+  ffi::Array<int64_t> strides;
+  ffi::Array<int64_t> padding;
+  ffi::Array<int64_t> dilation;
   int groups;
   ffi::String data_layout;
   ffi::String kernel_layout;
@@ -121,9 +121,9 @@ struct Conv2DAttrs : public 
AttrsNodeReflAdapter<Conv2DAttrs> {
 
 /*! \brief Attributes used in Conv3d operator */
 struct Conv3DAttrs : public AttrsNodeReflAdapter<Conv3DAttrs> {
-  ffi::Array<IntImm> strides;
-  ffi::Array<IntImm> padding;
-  ffi::Array<IntImm> dilation;
+  ffi::Array<int64_t> strides;
+  ffi::Array<int64_t> padding;
+  ffi::Array<int64_t> dilation;
   int groups;
   ffi::String data_layout;
   ffi::String kernel_layout;
@@ -169,10 +169,10 @@ struct Conv3DAttrs : public 
AttrsNodeReflAdapter<Conv3DAttrs> {
 
 /*! \brief Attributes used in Conv1DTranspose operator */
 struct Conv1DTransposeAttrs : public 
AttrsNodeReflAdapter<Conv1DTransposeAttrs> {
-  ffi::Array<IntImm> strides;
-  ffi::Array<IntImm> padding;
-  ffi::Array<IntImm> output_padding;
-  ffi::Array<IntImm> dilation;
+  ffi::Array<int64_t> strides;
+  ffi::Array<int64_t> padding;
+  ffi::Array<int64_t> output_padding;
+  ffi::Array<int64_t> dilation;
   int groups;
   ffi::String data_layout;
   ffi::String kernel_layout;
@@ -218,10 +218,10 @@ struct Conv1DTransposeAttrs : public 
AttrsNodeReflAdapter<Conv1DTransposeAttrs>
 
 /*! \brief Attributes used in Conv2d operator */
 struct Conv2DTransposeAttrs : public 
AttrsNodeReflAdapter<Conv2DTransposeAttrs> {
-  ffi::Array<IntImm> strides;
-  ffi::Array<IntImm> padding;
-  ffi::Array<IntImm> output_padding;
-  ffi::Array<IntImm> dilation;
+  ffi::Array<int64_t> strides;
+  ffi::Array<int64_t> padding;
+  ffi::Array<int64_t> output_padding;
+  ffi::Array<int64_t> dilation;
   int groups;
   ffi::String data_layout;
   ffi::String kernel_layout;
@@ -269,10 +269,10 @@ struct Conv2DTransposeAttrs : public 
AttrsNodeReflAdapter<Conv2DTransposeAttrs>
 
 /*! \brief Attributes used in max_pool1d and avg_pool1d operator */
 struct Pool1DAttrs : public AttrsNodeReflAdapter<Pool1DAttrs> {
-  ffi::Array<IntImm> pool_size;
-  ffi::Array<IntImm> strides;
-  ffi::Array<IntImm> padding;
-  ffi::Array<IntImm> dilation;
+  ffi::Array<int64_t> pool_size;
+  ffi::Array<int64_t> strides;
+  ffi::Array<int64_t> padding;
+  ffi::Array<int64_t> dilation;
   bool ceil_mode;
   bool count_include_pad;
   ffi::String layout;
@@ -310,10 +310,10 @@ struct Pool1DAttrs : public 
AttrsNodeReflAdapter<Pool1DAttrs> {
 
 /*! \brief Attributes used in max_pool2d and avg_pool2d operator */
 struct Pool2DAttrs : public AttrsNodeReflAdapter<Pool2DAttrs> {
-  ffi::Array<IntImm> pool_size;
-  ffi::Array<IntImm> strides;
-  ffi::Array<IntImm> padding;
-  ffi::Array<IntImm> dilation;
+  ffi::Array<int64_t> pool_size;
+  ffi::Array<int64_t> strides;
+  ffi::Array<int64_t> padding;
+  ffi::Array<int64_t> dilation;
   bool ceil_mode;
   bool count_include_pad;
   ffi::String layout;
@@ -353,10 +353,10 @@ struct Pool2DAttrs : public 
AttrsNodeReflAdapter<Pool2DAttrs> {
 
 /*! \brief Attributes used in max_pool3d and avg_pool3d operator */
 struct Pool3DAttrs : public AttrsNodeReflAdapter<Pool3DAttrs> {
-  ffi::Array<IntImm> pool_size;
-  ffi::Array<IntImm> strides;
-  ffi::Array<IntImm> padding;
-  ffi::Array<IntImm> dilation;
+  ffi::Array<int64_t> pool_size;
+  ffi::Array<int64_t> strides;
+  ffi::Array<int64_t> padding;
+  ffi::Array<int64_t> dilation;
   bool ceil_mode;
   bool count_include_pad;
   ffi::String layout;
@@ -396,7 +396,7 @@ struct Pool3DAttrs : public 
AttrsNodeReflAdapter<Pool3DAttrs> {
 
 /*! \brief Attributes for 1d adaptive pool operator */
 struct AdaptivePool1DAttrs : public AttrsNodeReflAdapter<AdaptivePool1DAttrs> {
-  ffi::Optional<ffi::Array<IntImm>> output_size;
+  ffi::Optional<ffi::Array<int64_t>> output_size;
   ffi::String layout;
   ffi::String out_layout;
 
@@ -421,7 +421,7 @@ struct AdaptivePool1DAttrs : public 
AttrsNodeReflAdapter<AdaptivePool1DAttrs> {
 
 /*! \brief Attributes for 2d adaptive pool operator */
 struct AdaptivePool2DAttrs : public AttrsNodeReflAdapter<AdaptivePool2DAttrs> {
-  ffi::Optional<ffi::Array<IntImm>> output_size;
+  ffi::Optional<ffi::Array<int64_t>> output_size;
   ffi::String layout;
   ffi::String out_layout;
 
@@ -446,7 +446,7 @@ struct AdaptivePool2DAttrs : public 
AttrsNodeReflAdapter<AdaptivePool2DAttrs> {
 
 /*! \brief Attributes for 3d adaptive pool operator */
 struct AdaptivePool3DAttrs : public AttrsNodeReflAdapter<AdaptivePool3DAttrs> {
-  ffi::Optional<ffi::Array<IntImm>> output_size;
+  ffi::Optional<ffi::Array<int64_t>> output_size;
   ffi::String layout;
   ffi::String out_layout;
 
diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py
index ef7516f31f..1a08b66f2a 100644
--- a/python/tvm/relax/dpl/pattern.py
+++ b/python/tvm/relax/dpl/pattern.py
@@ -656,8 +656,8 @@ class PrimArrPattern(DFPattern):
 
 @register_df_node
 class AttrPattern(DFPattern):
-    """Get match an expression with a certain attributes.
-    Currently only supports Op Attributes, not call Attributes.
+    """Match an expression with certain attributes.
+    Supports Op attributes, Call attributes, and Function attributes.
 
     Parameters
     ----------
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 61ab45d308..5f7c2ab752 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -2567,7 +2567,7 @@ class Pool(OnnxOpConverter):
             pads = []
             if cls.name == "avg_pool":
                 for axis in range(len(input_shape) - 2):
-                    axis_shape = input_shape[2 + axis]
+                    axis_shape = int(input_shape[2 + axis])
                     stride = strides[axis]
                     kernel = kernel_shape[axis]
                     pad = cls.get_pad_pair(axis_shape, kernel, stride, 
auto_pad)
diff --git a/python/tvm/relax/op/_op_gradient.py 
b/python/tvm/relax/op/_op_gradient.py
index fd80f1e313..bf267473f8 100644
--- a/python/tvm/relax/op/_op_gradient.py
+++ b/python/tvm/relax/op/_op_gradient.py
@@ -1202,7 +1202,7 @@ kernel_layout, out_layout, out_dtype)`
     out_h = (grad_h - 1) * stride_h - pad_top - pad_bottom + filter_h
     out_w = (grad_w - 1) * stride_w - pad_left - pad_right + filter_w
 
-    output_padding = (in_h - out_h, in_w - out_w)
+    output_padding = (int(in_h - out_h), int(in_w - out_w))
 
     data_grad = conv2d_transpose(  # type: ignore
         output_grad,
diff --git a/src/contrib/msc/core/utils.cc b/src/contrib/msc/core/utils.cc
index bc70c809af..b917bc47a2 100644
--- a/src/contrib/msc/core/utils.cc
+++ b/src/contrib/msc/core/utils.cc
@@ -275,11 +275,13 @@ const ffi::String StringUtils::ToString(const ffi::Any& 
obj) {
     obj_string = *opt_str;
   } else if (const auto* n = obj.as<IntImmNode>()) {
     obj_string = std::to_string(n->value);
+  } else if (obj.type_index() == kTVMFFIInt) {
+    obj_string = std::to_string(obj.cast<int64_t>());
   } else if (const auto* n = obj.as<FloatImmNode>()) {
     obj_string = std::to_string(n->value);
   } else if (const auto* n = obj.as<ffi::ArrayObj>()) {
     for (size_t i = 0; i < n->size(); i++) {
-      obj_string = obj_string + ToString((*n)[i].cast<ObjectRef>());
+      obj_string = obj_string + ToString((*n)[i]);
       if (n->size() == 1 || i < n->size() - 1) {
         obj_string = obj_string + ",";
       }
diff --git a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc 
b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc
index e3579ec7ef..8b20d59574 100644
--- a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc
+++ b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc
@@ -429,10 +429,9 @@ Expr RewriteConv1d(BlockBuilder builder, const Var& var, 
const Call& src_call,
     // change to conv2d
     static const Op& conv2d_op = Op::Get("relax.nn.conv2d");
     auto conv_attrs = ffi::make_object<Conv2DAttrs>();
-    conv_attrs->strides = ffi::Array<IntImm>{src_attrs->strides[0], 
Integer(1)};
-    conv_attrs->padding =
-        ffi::Array<IntImm>{Integer(0), src_attrs->padding[0], Integer(0), 
src_attrs->padding[1]};
-    conv_attrs->dilation = ffi::Array<IntImm>{src_attrs->dilation[0], 
Integer(1)};
+    conv_attrs->strides = ffi::Array<int64_t>{src_attrs->strides[0], 1};
+    conv_attrs->padding = ffi::Array<int64_t>{0, src_attrs->padding[0], 0, 
src_attrs->padding[1]};
+    conv_attrs->dilation = ffi::Array<int64_t>{src_attrs->dilation[0], 1};
     conv_attrs->groups = src_attrs->groups;
     conv_attrs->data_layout = "NCHW";
     conv_attrs->kernel_layout = "OIHW";
@@ -706,9 +705,9 @@ Expr RewriteMatmul(BlockBuilder builder, const Var& var, 
const Call& src_call,
     // to conv2d
     static const Op& conv2d_op = Op::Get("relax.nn.conv2d");
     auto conv_attrs = ffi::make_object<Conv2DAttrs>();
-    conv_attrs->strides = ffi::Array<IntImm>{Integer(1), Integer(1)};
-    conv_attrs->padding = ffi::Array<IntImm>{Integer(0), Integer(0), 
Integer(0), Integer(0)};
-    conv_attrs->dilation = ffi::Array<IntImm>{Integer(1), Integer(1)};
+    conv_attrs->strides = ffi::Array<int64_t>{1, 1};
+    conv_attrs->padding = ffi::Array<int64_t>{0, 0, 0, 0};
+    conv_attrs->dilation = ffi::Array<int64_t>{1, 1};
     conv_attrs->groups = 1;
     conv_attrs->data_layout = "NCHW";
     conv_attrs->kernel_layout = "OIHW";
diff --git a/src/relax/backend/contrib/codegen_json/codegen_json.h 
b/src/relax/backend/contrib/codegen_json/codegen_json.h
index 5056962542..6076de75bc 100644
--- a/src/relax/backend/contrib/codegen_json/codegen_json.h
+++ b/src/relax/backend/contrib/codegen_json/codegen_json.h
@@ -115,8 +115,12 @@ class OpAttrExtractor {
     if (const auto* an = (*value).as<ffi::ArrayObj>()) {
       std::vector<std::string> attr;
       for (size_t i = 0; i < an->size(); ++i) {
-        if (const auto* im = (*an)[i].as<IntImmNode>()) {
+        if (auto opt_int = (*an)[i].try_cast<int64_t>()) {
+          attr.push_back(std::to_string(opt_int.value()));
+        } else if (const auto* im = (*an)[i].as<IntImmNode>()) {
           attr.push_back(std::to_string(im->value));
+        } else if (auto opt_float = (*an)[i].try_cast<double>()) {
+          attr.push_back(Fp2String(opt_float.value()));
         } else if (const auto* fm = (*an)[i].as<FloatImmNode>()) {
           attr.push_back(Fp2String(fm->value));
         } else if (auto opt_str = (*an)[i].as<ffi::String>()) {
diff --git a/src/relax/backend/contrib/nnapi/codegen.cc 
b/src/relax/backend/contrib/nnapi/codegen.cc
index 92933ba070..0ea05b9863 100644
--- a/src/relax/backend/contrib/nnapi/codegen.cc
+++ b/src/relax/backend/contrib/nnapi/codegen.cc
@@ -107,10 +107,7 @@ class CollectFromCompositeFunctionBody : public 
ExprVisitor {
     std::vector<std::string> strides;
     if (!conv2d_attr->strides.empty()) {
       for (auto stride : conv2d_attr->strides) {
-        const auto* stride_val = stride.as<IntImmNode>();
-        ICHECK(stride_val) << "convertion failed";
-
-        strides.push_back(std::to_string(stride_val->value));
+        strides.push_back(std::to_string(stride));
       }
     } else {
       strides = {"1", "1"};
@@ -118,9 +115,7 @@ class CollectFromCompositeFunctionBody : public ExprVisitor 
{
 
     std::vector<std::string> padding;
     for (auto pad : conv2d_attr->padding) {
-      const auto* padding_val = pad.as<IntImmNode>();
-
-      padding.push_back(std::to_string(padding_val->value));
+      padding.push_back(std::to_string(pad));
     }
 
     std::vector<std::string> groups;
@@ -147,10 +142,7 @@ class CollectFromCompositeFunctionBody : public 
ExprVisitor {
     std::vector<std::string> strides;
     if (!max_pool_2d_attr->strides.empty()) {
       for (auto stride : max_pool_2d_attr->strides) {
-        const auto* stride_val = stride.as<IntImmNode>();
-        ICHECK(stride_val) << "convertion failed";
-
-        strides.push_back(std::to_string(stride_val->value));
+        strides.push_back(std::to_string(stride));
       }
     } else {
       strides.push_back("1");
@@ -159,16 +151,12 @@ class CollectFromCompositeFunctionBody : public 
ExprVisitor {
 
     std::vector<std::string> padding;
     for (auto pad : max_pool_2d_attr->padding) {
-      const auto* padding_val = pad.as<IntImmNode>();
-
-      padding.push_back(std::to_string(padding_val->value));
+      padding.push_back(std::to_string(pad));
     }
 
     std::vector<std::string> pool_size;
     for (auto size : max_pool_2d_attr->pool_size) {
-      const auto* pooling_val = size.as<IntImmNode>();
-
-      pool_size.push_back(std::to_string(pooling_val->value));
+      pool_size.push_back(std::to_string(size));
     }
 
     std::vector<dmlc::any> strides_attr;
diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc
index 3fba58ede2..5368db79d2 100644
--- a/src/relax/op/nn/convolution.cc
+++ b/src/relax/op/nn/convolution.cc
@@ -41,8 +41,8 @@ TVM_FFI_STATIC_INIT_BLOCK() {
 
 /* relax.nn.conv1d */
 
-Expr conv1d(Expr data, Expr weight, ffi::Array<IntImm> strides, 
ffi::Array<IntImm> padding,
-            ffi::Array<IntImm> dilation, int groups, ffi::String data_layout,
+Expr conv1d(Expr data, Expr weight, ffi::Array<int64_t> strides, 
ffi::Array<int64_t> padding,
+            ffi::Array<int64_t> dilation, int groups, ffi::String data_layout,
             ffi::String kernel_layout, ffi::Optional<ffi::String> out_layout,
             ffi::Optional<DataType> out_dtype) {
   padding = GetCompletePadding1D(std::move(padding));
@@ -125,15 +125,15 @@ StructInfo InferStructInfoConv1d(const Call& call, const 
BlockBuilder& ctx) {
 
   PrimExpr input_w = data_NCW_shape[2];
   PrimExpr kernel_w = weight_OIW_shape[2];
-  PrimExpr padding_w = attrs->padding[0] + attrs->padding[1];
+  PrimExpr padding_w = Integer(attrs->padding[0]) + Integer(attrs->padding[1]);
 
   std::vector<PrimExpr> out_NCW_shape;
   out_NCW_shape.resize(3);
   out_NCW_shape[0] = data_NCW_shape[0];
   out_NCW_shape[1] = weight_OIW_shape[0];
 
-  PrimExpr numerator_w = input_w + padding_w - attrs->dilation[0] * (kernel_w 
- 1) - 1;
-  out_NCW_shape[2] = analyzer->Simplify(floordiv(numerator_w, 
attrs->strides[0]) + 1);
+  PrimExpr numerator_w = input_w + padding_w - Integer(attrs->dilation[0]) * 
(kernel_w - 1) - 1;
+  out_NCW_shape[2] = analyzer->Simplify(floordiv(numerator_w, 
Integer(attrs->strides[0])) + 1);
 
   ffi::Array<PrimExpr> out_shape = out2NCW.BackwardShape(out_NCW_shape);
   return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice);
@@ -202,8 +202,8 @@ TVM_REGISTER_OP("relax.nn.conv1d")
 
 /* relax.nn.conv2d */
 
-Expr conv2d(Expr data, Expr weight, ffi::Array<IntImm> strides, 
ffi::Array<IntImm> padding,
-            ffi::Array<IntImm> dilation, int groups, ffi::String data_layout,
+Expr conv2d(Expr data, Expr weight, ffi::Array<int64_t> strides, 
ffi::Array<int64_t> padding,
+            ffi::Array<int64_t> dilation, int groups, ffi::String data_layout,
             ffi::String kernel_layout, ffi::Optional<ffi::String> out_layout,
             ffi::Optional<DataType> out_dtype) {
   padding = GetCompletePadding2D(std::move(padding));
@@ -294,18 +294,18 @@ StructInfo InferStructInfoConv2d(const Call& call, const 
BlockBuilder& ctx) {
   PrimExpr input_w = data_NCHW_shape[3];
   PrimExpr kernel_h = weight_OIHW_shape[2];
   PrimExpr kernel_w = weight_OIHW_shape[3];
-  PrimExpr padding_h = attrs->padding[0] + attrs->padding[2];
-  PrimExpr padding_w = attrs->padding[1] + attrs->padding[3];
+  PrimExpr padding_h = Integer(attrs->padding[0]) + Integer(attrs->padding[2]);
+  PrimExpr padding_w = Integer(attrs->padding[1]) + Integer(attrs->padding[3]);
 
   std::vector<PrimExpr> out_NCHW_shape;
   out_NCHW_shape.resize(4);
   out_NCHW_shape[0] = data_NCHW_shape[0];
   out_NCHW_shape[1] = weight_OIHW_shape[0];
 
-  PrimExpr numerator_h = input_h + padding_h - attrs->dilation[0] * (kernel_h 
- 1) - 1;
-  PrimExpr numerator_w = input_w + padding_w - attrs->dilation[1] * (kernel_w 
- 1) - 1;
-  out_NCHW_shape[2] = analyzer->Simplify(floordiv(numerator_h, 
attrs->strides[0]) + 1);
-  out_NCHW_shape[3] = analyzer->Simplify(floordiv(numerator_w, 
attrs->strides[1]) + 1);
+  PrimExpr numerator_h = input_h + padding_h - Integer(attrs->dilation[0]) * 
(kernel_h - 1) - 1;
+  PrimExpr numerator_w = input_w + padding_w - Integer(attrs->dilation[1]) * 
(kernel_w - 1) - 1;
+  out_NCHW_shape[2] = analyzer->Simplify(floordiv(numerator_h, 
Integer(attrs->strides[0])) + 1);
+  out_NCHW_shape[3] = analyzer->Simplify(floordiv(numerator_w, 
Integer(attrs->strides[1])) + 1);
 
   ffi::Array<PrimExpr> out_shape = out2NCHW.BackwardShape(out_NCHW_shape);
   return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice);
@@ -409,8 +409,8 @@ TVM_REGISTER_OP("relax.nn.conv2d")
 
 /* relax.nn.conv3d */
 
-Expr conv3d(Expr data, Expr weight, ffi::Array<IntImm> strides, 
ffi::Array<IntImm> padding,
-            ffi::Array<IntImm> dilation, int groups, ffi::String data_layout,
+Expr conv3d(Expr data, Expr weight, ffi::Array<int64_t> strides, 
ffi::Array<int64_t> padding,
+            ffi::Array<int64_t> dilation, int groups, ffi::String data_layout,
             ffi::String kernel_layout, ffi::Optional<ffi::String> out_layout,
             ffi::Optional<DataType> out_dtype) {
   padding = GetCompletePadding3D(std::move(padding));
@@ -506,21 +506,21 @@ StructInfo InferStructInfoConv3d(const Call& call, const 
BlockBuilder& ctx) {
   PrimExpr kernel_d = weight_OIDHW_shape[2];
   PrimExpr kernel_h = weight_OIDHW_shape[3];
   PrimExpr kernel_w = weight_OIDHW_shape[4];
-  PrimExpr padding_d = attrs->padding[0] + attrs->padding[3];
-  PrimExpr padding_h = attrs->padding[1] + attrs->padding[4];
-  PrimExpr padding_w = attrs->padding[2] + attrs->padding[5];
+  PrimExpr padding_d = Integer(attrs->padding[0]) + Integer(attrs->padding[3]);
+  PrimExpr padding_h = Integer(attrs->padding[1]) + Integer(attrs->padding[4]);
+  PrimExpr padding_w = Integer(attrs->padding[2]) + Integer(attrs->padding[5]);
 
   std::vector<PrimExpr> out_NCDHW_shape;
   out_NCDHW_shape.resize(5);
   out_NCDHW_shape[0] = data_NCDHW_shape[0];
   out_NCDHW_shape[1] = weight_OIDHW_shape[0];
 
-  PrimExpr numerator_d = input_d + padding_d - attrs->dilation[0] * (kernel_d 
- 1) - 1;
-  PrimExpr numerator_h = input_h + padding_h - attrs->dilation[1] * (kernel_h 
- 1) - 1;
-  PrimExpr numerator_w = input_w + padding_w - attrs->dilation[2] * (kernel_w 
- 1) - 1;
-  out_NCDHW_shape[2] = analyzer->Simplify(floordiv(numerator_d, 
attrs->strides[0]) + 1);
-  out_NCDHW_shape[3] = analyzer->Simplify(floordiv(numerator_h, 
attrs->strides[1]) + 1);
-  out_NCDHW_shape[4] = analyzer->Simplify(floordiv(numerator_w, 
attrs->strides[2]) + 1);
+  PrimExpr numerator_d = input_d + padding_d - Integer(attrs->dilation[0]) * 
(kernel_d - 1) - 1;
+  PrimExpr numerator_h = input_h + padding_h - Integer(attrs->dilation[1]) * 
(kernel_h - 1) - 1;
+  PrimExpr numerator_w = input_w + padding_w - Integer(attrs->dilation[2]) * 
(kernel_w - 1) - 1;
+  out_NCDHW_shape[2] = analyzer->Simplify(floordiv(numerator_d, 
Integer(attrs->strides[0])) + 1);
+  out_NCDHW_shape[3] = analyzer->Simplify(floordiv(numerator_h, 
Integer(attrs->strides[1])) + 1);
+  out_NCDHW_shape[4] = analyzer->Simplify(floordiv(numerator_w, 
Integer(attrs->strides[2])) + 1);
 
   ffi::Array<PrimExpr> out_shape = out2NCDHW.BackwardShape(out_NCDHW_shape);
   return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice);
@@ -587,9 +587,9 @@ TVM_REGISTER_OP("relax.nn.conv3d")
     .set_attr<FInferMixedPrecision>("FInferMixedPrecision", 
InferMixedPrecisionConv3d)
     .set_attr<Bool>("FPurity", Bool(true));
 
-Expr conv1d_transpose(Expr data, Expr weight, ffi::Array<IntImm> strides,
-                      ffi::Array<IntImm> padding, ffi::Array<IntImm> 
output_padding,
-                      ffi::Array<IntImm> dilation, int groups, ffi::String 
data_layout,
+Expr conv1d_transpose(Expr data, Expr weight, ffi::Array<int64_t> strides,
+                      ffi::Array<int64_t> padding, ffi::Array<int64_t> 
output_padding,
+                      ffi::Array<int64_t> dilation, int groups, ffi::String 
data_layout,
                       ffi::String kernel_layout, ffi::Optional<ffi::String> 
out_layout,
                       ffi::Optional<DataType> out_dtype) {
   padding = GetCompletePadding1D(std::move(padding));
@@ -607,10 +607,10 @@ Expr conv1d_transpose(Expr data, Expr weight, 
ffi::Array<IntImm> strides,
       << dilation;
 
   auto attrs = ffi::make_object<Conv1DTransposeAttrs>();
-  attrs->strides = ConvertIntImmToInt64(strides);
-  attrs->padding = ConvertIntImmToInt64(padding);
-  attrs->output_padding = ConvertIntImmToInt64(output_padding);
-  attrs->dilation = ConvertIntImmToInt64(dilation);
+  attrs->strides = std::move(strides);
+  attrs->padding = std::move(padding);
+  attrs->output_padding = std::move(output_padding);
+  attrs->dilation = std::move(dilation);
   attrs->groups = groups;
   attrs->data_layout = data_layout;
   attrs->kernel_layout = std::move(kernel_layout);
@@ -680,27 +680,28 @@ StructInfo InferStructInfoConv1dTranspose(const Call& 
call, const BlockBuilder&
     // Todo(relax-team): Trust the input shape at this moment, and revisit
     // this condition with runtime shape check
   }
-  if (analyzer->CanProve(attrs->output_padding[0]->value >= 
attrs->strides[0]->value)) {
+  if (attrs->output_padding[0] >= attrs->strides[0]) {
     ctx->ReportFatal(Diagnostic::Error(call)
                      << "Conv1dTranspose expects the output padding less than 
the strides, but the "
                         "output padding is"
                      << attrs->output_padding << " while the strides are" << 
attrs->strides);
-  } else if (!analyzer->CanProve(attrs->output_padding[0]->value < 
attrs->strides[0]->value)) {
+  } else if (!(attrs->output_padding[0] < attrs->strides[0])) {
     // Todo(relax-team): Trust the input padding at this moment, and revisit
     // this condition with runtime shape check
   }
 
   PrimExpr input_w = data_NCW_shape[2];
   PrimExpr kernel_w = weight_IOW_shape[2];
-  PrimExpr padding_w = attrs->padding[0] + attrs->padding[1];
+  PrimExpr padding_w = Integer(attrs->padding[0]) + Integer(attrs->padding[1]);
 
   std::vector<PrimExpr> out_NCW_shape;
   out_NCW_shape.resize(3);
   out_NCW_shape[0] = data_NCW_shape[0];
   out_NCW_shape[1] = weight_IOW_shape[1] * attrs->groups;
 
-  PrimExpr out_w = (input_w - 1) * attrs->strides[0] - padding_w +
-                   attrs->dilation[0] * (kernel_w - 1) + 
attrs->output_padding[0] + 1;
+  PrimExpr out_w = (input_w - 1) * Integer(attrs->strides[0]) - padding_w +
+                   Integer(attrs->dilation[0]) * (kernel_w - 1) +
+                   Integer(attrs->output_padding[0]) + 1;
   out_NCW_shape[2] = analyzer->Simplify(out_w);
 
   ffi::Array<PrimExpr> out_shape = out2NCW.BackwardShape(out_NCW_shape);
@@ -767,9 +768,9 @@ TVM_REGISTER_OP("relax.nn.conv1d_transpose")
 
 /* relax.nn.conv2d_transpose */
 
-Expr conv2d_transpose(Expr data, Expr weight, ffi::Array<IntImm> strides,
-                      ffi::Array<IntImm> padding, ffi::Array<IntImm> 
output_padding,
-                      ffi::Array<IntImm> dilation, int groups, ffi::String 
data_layout,
+Expr conv2d_transpose(Expr data, Expr weight, ffi::Array<int64_t> strides,
+                      ffi::Array<int64_t> padding, ffi::Array<int64_t> 
output_padding,
+                      ffi::Array<int64_t> dilation, int groups, ffi::String 
data_layout,
                       ffi::String kernel_layout, ffi::Optional<ffi::String> 
out_layout,
                       ffi::Optional<DataType> out_dtype) {
   padding = GetCompletePadding2D(std::move(padding));
@@ -796,10 +797,10 @@ Expr conv2d_transpose(Expr data, Expr weight, 
ffi::Array<IntImm> strides,
       << dilation;
 
   auto attrs = ffi::make_object<Conv2DTransposeAttrs>();
-  attrs->strides = ConvertIntImmToInt64(strides);
-  attrs->padding = ConvertIntImmToInt64(padding);
-  attrs->output_padding = ConvertIntImmToInt64(output_padding);
-  attrs->dilation = ConvertIntImmToInt64(dilation);
+  attrs->strides = std::move(strides);
+  attrs->padding = std::move(padding);
+  attrs->output_padding = std::move(output_padding);
+  attrs->dilation = std::move(dilation);
   attrs->groups = groups;
   attrs->data_layout = data_layout;
   attrs->kernel_layout = std::move(kernel_layout);
@@ -870,14 +871,14 @@ StructInfo InferStructInfoConv2dTranspose(const Call& 
call, const BlockBuilder&
     // Todo(relax-team): Trust the input shape at this moment, and revisit
     // this condition with runtime shape check
   }
-  if (analyzer->CanProve(attrs->output_padding[0]->value >= 
attrs->strides[0]->value ||
-                         attrs->output_padding[1]->value >= 
attrs->strides[1]->value)) {
+  if (attrs->output_padding[0] >= attrs->strides[0] ||
+      attrs->output_padding[1] >= attrs->strides[1]) {
     ctx->ReportFatal(Diagnostic::Error(call)
                      << "Conv2dTranspose expects the output padding less than 
the strides, but the "
                         "output padding is"
                      << attrs->output_padding << " while the strides are" << 
attrs->strides);
-  } else if (!analyzer->CanProve(attrs->output_padding[0]->value < 
attrs->strides[0]->value &&
-                                 attrs->output_padding[1]->value < 
attrs->strides[1]->value)) {
+  } else if (!(attrs->output_padding[0] < attrs->strides[0] &&
+               attrs->output_padding[1] < attrs->strides[1])) {
     // Todo(relax-team): Trust the input padding at this moment, and revisit
     // this condition with runtime shape check
   }
@@ -886,18 +887,20 @@ StructInfo InferStructInfoConv2dTranspose(const Call& 
call, const BlockBuilder&
   PrimExpr input_w = data_NCHW_shape[3];
   PrimExpr kernel_h = weight_IOHW_shape[2];
   PrimExpr kernel_w = weight_IOHW_shape[3];
-  PrimExpr padding_h = attrs->padding[0] + attrs->padding[2];
-  PrimExpr padding_w = attrs->padding[1] + attrs->padding[3];
+  PrimExpr padding_h = Integer(attrs->padding[0]) + Integer(attrs->padding[2]);
+  PrimExpr padding_w = Integer(attrs->padding[1]) + Integer(attrs->padding[3]);
 
   std::vector<PrimExpr> out_NCHW_shape;
   out_NCHW_shape.resize(4);
   out_NCHW_shape[0] = data_NCHW_shape[0];
   out_NCHW_shape[1] = weight_IOHW_shape[1] * attrs->groups;
 
-  PrimExpr out_h = (input_h - 1) * attrs->strides[0] - padding_h +
-                   attrs->dilation[0] * (kernel_h - 1) + 
attrs->output_padding[0] + 1;
-  PrimExpr out_w = (input_w - 1) * attrs->strides[1] - padding_w +
-                   attrs->dilation[1] * (kernel_w - 1) + 
attrs->output_padding[1] + 1;
+  PrimExpr out_h = (input_h - 1) * Integer(attrs->strides[0]) - padding_h +
+                   Integer(attrs->dilation[0]) * (kernel_h - 1) +
+                   Integer(attrs->output_padding[0]) + 1;
+  PrimExpr out_w = (input_w - 1) * Integer(attrs->strides[1]) - padding_w +
+                   Integer(attrs->dilation[1]) * (kernel_w - 1) +
+                   Integer(attrs->output_padding[1]) + 1;
   out_NCHW_shape[2] = analyzer->Simplify(out_h);
   out_NCHW_shape[3] = analyzer->Simplify(out_w);
 
diff --git a/src/relax/op/nn/convolution.h b/src/relax/op/nn/convolution.h
index 4fc175b5aa..a5704d3f70 100644
--- a/src/relax/op/nn/convolution.h
+++ b/src/relax/op/nn/convolution.h
@@ -36,14 +36,14 @@ namespace tvm {
 namespace relax {
 
 template <typename T>
-inline Expr MakeConv(Expr data, Expr weight, ffi::Array<IntImm> strides, 
ffi::Array<IntImm> padding,
-                     ffi::Array<IntImm> dilation, int groups, ffi::String 
data_layout,
-                     ffi::String kernel_layout, ffi::String out_layout, 
DataType out_dtype,
-                     std::string op_name) {
+inline Expr MakeConv(Expr data, Expr weight, ffi::Array<int64_t> strides,
+                     ffi::Array<int64_t> padding, ffi::Array<int64_t> 
dilation, int groups,
+                     ffi::String data_layout, ffi::String kernel_layout, 
ffi::String out_layout,
+                     DataType out_dtype, std::string op_name) {
   auto attrs = ffi::make_object<T>();
-  attrs->strides = ConvertIntImmToInt64(strides);
-  attrs->padding = ConvertIntImmToInt64(padding);
-  attrs->dilation = ConvertIntImmToInt64(dilation);
+  attrs->strides = std::move(strides);
+  attrs->padding = std::move(padding);
+  attrs->dilation = std::move(dilation);
   attrs->groups = groups;
   attrs->data_layout = std::move(data_layout);
   attrs->kernel_layout = std::move(kernel_layout);
@@ -54,20 +54,20 @@ inline Expr MakeConv(Expr data, Expr weight, 
ffi::Array<IntImm> strides, ffi::Ar
 }
 
 /*! \brief 1D convolution */
-Expr conv1d(Expr data, Expr weight, ffi::Array<IntImm> strides, 
ffi::Array<IntImm> padding,
-            ffi::Array<IntImm> dilation, int groups, ffi::String data_layout,
+Expr conv1d(Expr data, Expr weight, ffi::Array<int64_t> strides, 
ffi::Array<int64_t> padding,
+            ffi::Array<int64_t> dilation, int groups, ffi::String data_layout,
             ffi::String kernel_layout, ffi::Optional<ffi::String> out_layout,
             ffi::Optional<DataType> out_dtype);
 
 /*! \brief 2D convolution */
-Expr conv2d(Expr data, Expr weight, ffi::Array<IntImm> strides, 
ffi::Array<IntImm> padding,
-            ffi::Array<IntImm> dilation, int groups, ffi::String data_layout,
+Expr conv2d(Expr data, Expr weight, ffi::Array<int64_t> strides, 
ffi::Array<int64_t> padding,
+            ffi::Array<int64_t> dilation, int groups, ffi::String data_layout,
             ffi::String kernel_layout, ffi::Optional<ffi::String> out_layout,
             ffi::Optional<DataType> out_dtype);
 
 /*! \brief 3D convolution */
-Expr conv3d(Expr data, Expr weight, ffi::Array<IntImm> strides, 
ffi::Array<IntImm> padding,
-            ffi::Array<IntImm> dilation, int groups, ffi::String data_layout,
+Expr conv3d(Expr data, Expr weight, ffi::Array<int64_t> strides, 
ffi::Array<int64_t> padding,
+            ffi::Array<int64_t> dilation, int groups, ffi::String data_layout,
             ffi::String kernel_layout, ffi::Optional<ffi::String> out_layout,
             ffi::Optional<DataType> out_dtype);
 
@@ -77,9 +77,9 @@ Expr conv3d(Expr data, Expr weight, ffi::Array<IntImm> 
strides, ffi::Array<IntIm
  * This operator is intended to be the backward operator of conv1d. It can be 
used to calculate the
  * gradient of the result of conv1d w.r.t. the input of conv1d.
  */
-Expr conv1d_transpose(Expr data, Expr weight, ffi::Array<IntImm> strides,
-                      ffi::Array<IntImm> padding, ffi::Array<IntImm> 
output_padding,
-                      ffi::Array<IntImm> dilation, int groups, ffi::String 
data_layout,
+Expr conv1d_transpose(Expr data, Expr weight, ffi::Array<int64_t> strides,
+                      ffi::Array<int64_t> padding, ffi::Array<int64_t> 
output_padding,
+                      ffi::Array<int64_t> dilation, int groups, ffi::String 
data_layout,
                       ffi::String kernel_layout, ffi::Optional<ffi::String> 
out_layout,
                       ffi::Optional<DataType> out_dtype);
 
@@ -89,9 +89,9 @@ Expr conv1d_transpose(Expr data, Expr weight, 
ffi::Array<IntImm> strides,
  * This operator is intended to be the backward operator of conv2d. It can be 
used to calculate the
  * gradient of the result of conv2d w.r.t. the input of conv2d.
  */
-Expr conv2d_transpose(Expr data, Expr weight, ffi::Array<IntImm> strides,
-                      ffi::Array<IntImm> padding, ffi::Array<IntImm> 
output_padding,
-                      ffi::Array<IntImm> dilation, int groups, ffi::String 
data_layout,
+Expr conv2d_transpose(Expr data, Expr weight, ffi::Array<int64_t> strides,
+                      ffi::Array<int64_t> padding, ffi::Array<int64_t> 
output_padding,
+                      ffi::Array<int64_t> dilation, int groups, ffi::String 
data_layout,
                       ffi::String kernel_layout, ffi::Optional<ffi::String> 
out_layout,
                       ffi::Optional<DataType> out_dtype);
 
diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc
index 1a19872c27..2397bf0098 100644
--- a/src/relax/op/nn/pooling.cc
+++ b/src/relax/op/nn/pooling.cc
@@ -38,10 +38,10 @@ TVM_FFI_STATIC_INIT_BLOCK() {
 
 /* relax.nn.max_pool1d */
 
-Expr MakePool1d(ffi::String op_name, Expr data, ffi::Array<IntImm> pool_size,
-                ffi::Array<IntImm> strides, ffi::Array<IntImm> padding, 
ffi::Array<IntImm> dilation,
-                bool ceil_mode, bool count_include_pad, ffi::String layout,
-                ffi::Optional<ffi::String> out_layout) {
+Expr MakePool1d(ffi::String op_name, Expr data, ffi::Array<int64_t> pool_size,
+                ffi::Array<int64_t> strides, ffi::Array<int64_t> padding,
+                ffi::Array<int64_t> dilation, bool ceil_mode, bool 
count_include_pad,
+                ffi::String layout, ffi::Optional<ffi::String> out_layout) {
   padding = GetCompletePadding1D(std::move(padding));
 
   CHECK_EQ(pool_size.size(), 1)
@@ -54,10 +54,10 @@ Expr MakePool1d(ffi::String op_name, Expr data, 
ffi::Array<IntImm> pool_size,
       << dilation;
 
   auto attrs = ffi::make_object<Pool1DAttrs>();
-  attrs->pool_size = ConvertIntImmToInt64(pool_size);
-  attrs->strides = ConvertIntImmToInt64(strides);
-  attrs->padding = ConvertIntImmToInt64(padding);
-  attrs->dilation = ConvertIntImmToInt64(dilation);
+  attrs->pool_size = std::move(pool_size);
+  attrs->strides = std::move(strides);
+  attrs->padding = std::move(padding);
+  attrs->dilation = std::move(dilation);
   attrs->ceil_mode = ceil_mode;
   attrs->count_include_pad = count_include_pad;
   attrs->layout = layout;
@@ -66,8 +66,8 @@ Expr MakePool1d(ffi::String op_name, Expr data, 
ffi::Array<IntImm> pool_size,
   return Call(op, {std::move(data)}, Attrs(attrs), {});
 }
 
-Expr max_pool1d(Expr data, ffi::Array<IntImm> pool_size, ffi::Array<IntImm> 
strides,
-                ffi::Array<IntImm> padding, ffi::Array<IntImm> dilation, bool 
ceil_mode,
+Expr max_pool1d(Expr data, ffi::Array<int64_t> pool_size, ffi::Array<int64_t> 
strides,
+                ffi::Array<int64_t> padding, ffi::Array<int64_t> dilation, 
bool ceil_mode,
                 bool count_include_pad, ffi::String layout, 
ffi::Optional<ffi::String> out_layout) {
   return MakePool1d("relax.nn.max_pool1d", data, pool_size, strides, padding, 
dilation, ceil_mode,
                     count_include_pad, layout, out_layout);
@@ -98,8 +98,8 @@ StructInfo InferStructInfoPool1D(const Call& call, const 
BlockBuilder& ctx) {
   ffi::Array<PrimExpr> data_NCW_shape = 
data2NCW.ForwardShape(data_shape.value()->values);
 
   PrimExpr input_w = data_NCW_shape[2];
-  PrimExpr kernel_w = attrs->pool_size[0];
-  PrimExpr padding_w = attrs->padding[0] + attrs->padding[1];
+  PrimExpr kernel_w = Integer(attrs->pool_size[0]);
+  PrimExpr padding_w = Integer(attrs->padding[0]) + Integer(attrs->padding[1]);
 
   arith::Analyzer* analyzer = ctx->GetAnalyzer();
   std::vector<PrimExpr> out_NCW_shape;
@@ -107,13 +107,14 @@ StructInfo InferStructInfoPool1D(const Call& call, const 
BlockBuilder& ctx) {
   out_NCW_shape[0] = data_NCW_shape[0];
   out_NCW_shape[1] = data_NCW_shape[1];
 
-  PrimExpr numerator_w = input_w + padding_w - attrs->dilation[0] * (kernel_w 
- 1) - 1;
+  PrimExpr numerator_w = input_w + padding_w - Integer(attrs->dilation[0]) * 
(kernel_w - 1) - 1;
   if (attrs->ceil_mode) {
-    numerator_w += attrs->strides[0] - 1;
+    numerator_w += Integer(attrs->strides[0]) - 1;
   }
-  PrimExpr raw_out_w = floordiv(numerator_w, attrs->strides[0]) + 1;
+  PrimExpr raw_out_w = floordiv(numerator_w, Integer(attrs->strides[0])) + 1;
   if (attrs->ceil_mode) {
-    PrimExpr invalid_last_w = (raw_out_w - 1) * attrs->strides[0] >= input_w + 
attrs->padding[0];
+    PrimExpr invalid_last_w =
+        (raw_out_w - 1) * Integer(attrs->strides[0]) >= input_w + 
Integer(attrs->padding[0]);
     out_NCW_shape[2] = analyzer->Simplify(if_then_else(invalid_last_w, 
raw_out_w - 1, raw_out_w));
   } else {
     out_NCW_shape[2] = analyzer->Simplify(raw_out_w);
@@ -151,10 +152,10 @@ TVM_REGISTER_OP("relax.nn.max_pool1d")
 
 /* relax.nn.max_pool2d */
 
-Expr MakePool2d(ffi::String op_name, Expr data, ffi::Array<IntImm> pool_size,
-                ffi::Array<IntImm> strides, ffi::Array<IntImm> padding, 
ffi::Array<IntImm> dilation,
-                bool ceil_mode, bool count_include_pad, ffi::String layout,
-                ffi::Optional<ffi::String> out_layout) {
+Expr MakePool2d(ffi::String op_name, Expr data, ffi::Array<int64_t> pool_size,
+                ffi::Array<int64_t> strides, ffi::Array<int64_t> padding,
+                ffi::Array<int64_t> dilation, bool ceil_mode, bool 
count_include_pad,
+                ffi::String layout, ffi::Optional<ffi::String> out_layout) {
   padding = GetCompletePadding2D(std::move(padding));
   if (pool_size.size() == 1) {
     pool_size.push_back(pool_size[0]);
@@ -176,10 +177,10 @@ Expr MakePool2d(ffi::String op_name, Expr data, 
ffi::Array<IntImm> pool_size,
       << dilation;
 
   auto attrs = ffi::make_object<Pool2DAttrs>();
-  attrs->pool_size = ConvertIntImmToInt64(pool_size);
-  attrs->strides = ConvertIntImmToInt64(strides);
-  attrs->padding = ConvertIntImmToInt64(padding);
-  attrs->dilation = ConvertIntImmToInt64(dilation);
+  attrs->pool_size = std::move(pool_size);
+  attrs->strides = std::move(strides);
+  attrs->padding = std::move(padding);
+  attrs->dilation = std::move(dilation);
   attrs->ceil_mode = ceil_mode;
   attrs->count_include_pad = count_include_pad;
   attrs->layout = layout;
@@ -188,8 +189,8 @@ Expr MakePool2d(ffi::String op_name, Expr data, 
ffi::Array<IntImm> pool_size,
   return Call(op, {std::move(data)}, Attrs(attrs), {});
 }
 
-Expr max_pool2d(Expr data, ffi::Array<IntImm> pool_size, ffi::Array<IntImm> 
strides,
-                ffi::Array<IntImm> padding, ffi::Array<IntImm> dilation, bool 
ceil_mode,
+Expr max_pool2d(Expr data, ffi::Array<int64_t> pool_size, ffi::Array<int64_t> 
strides,
+                ffi::Array<int64_t> padding, ffi::Array<int64_t> dilation, 
bool ceil_mode,
                 bool count_include_pad, ffi::String layout, 
ffi::Optional<ffi::String> out_layout) {
   return MakePool2d("relax.nn.max_pool2d", data, pool_size, strides, padding, 
dilation, ceil_mode,
                     count_include_pad, layout, out_layout);
@@ -221,10 +222,10 @@ StructInfo InferStructInfoPool2D(const Call& call, const 
BlockBuilder& ctx) {
 
   PrimExpr input_h = data_NCHW_shape[2];
   PrimExpr input_w = data_NCHW_shape[3];
-  PrimExpr kernel_h = attrs->pool_size[0];
-  PrimExpr kernel_w = attrs->pool_size[1];
-  PrimExpr padding_h = attrs->padding[0] + attrs->padding[2];
-  PrimExpr padding_w = attrs->padding[1] + attrs->padding[3];
+  PrimExpr kernel_h = Integer(attrs->pool_size[0]);
+  PrimExpr kernel_w = Integer(attrs->pool_size[1]);
+  PrimExpr padding_h = Integer(attrs->padding[0]) + Integer(attrs->padding[2]);
+  PrimExpr padding_w = Integer(attrs->padding[1]) + Integer(attrs->padding[3]);
 
   arith::Analyzer* analyzer = ctx->GetAnalyzer();
   std::vector<PrimExpr> out_NCHW_shape;
@@ -232,17 +233,19 @@ StructInfo InferStructInfoPool2D(const Call& call, const 
BlockBuilder& ctx) {
   out_NCHW_shape[0] = data_NCHW_shape[0];
   out_NCHW_shape[1] = data_NCHW_shape[1];
 
-  PrimExpr numerator_h = input_h + padding_h - attrs->dilation[0] * (kernel_h 
- 1) - 1;
-  PrimExpr numerator_w = input_w + padding_w - attrs->dilation[1] * (kernel_w 
- 1) - 1;
+  PrimExpr numerator_h = input_h + padding_h - Integer(attrs->dilation[0]) * 
(kernel_h - 1) - 1;
+  PrimExpr numerator_w = input_w + padding_w - Integer(attrs->dilation[1]) * 
(kernel_w - 1) - 1;
   if (attrs->ceil_mode) {
-    numerator_h += attrs->strides[0] - 1;
-    numerator_w += attrs->strides[1] - 1;
+    numerator_h += Integer(attrs->strides[0]) - 1;
+    numerator_w += Integer(attrs->strides[1]) - 1;
   }
-  PrimExpr raw_out_h = floordiv(numerator_h, attrs->strides[0]) + 1;
-  PrimExpr raw_out_w = floordiv(numerator_w, attrs->strides[1]) + 1;
+  PrimExpr raw_out_h = floordiv(numerator_h, Integer(attrs->strides[0])) + 1;
+  PrimExpr raw_out_w = floordiv(numerator_w, Integer(attrs->strides[1])) + 1;
   if (attrs->ceil_mode) {
-    PrimExpr invalid_last_h = (raw_out_h - 1) * attrs->strides[0] >= input_h + 
attrs->padding[0];
-    PrimExpr invalid_last_w = (raw_out_w - 1) * attrs->strides[1] >= input_w + 
attrs->padding[1];
+    PrimExpr invalid_last_h =
+        (raw_out_h - 1) * Integer(attrs->strides[0]) >= input_h + 
Integer(attrs->padding[0]);
+    PrimExpr invalid_last_w =
+        (raw_out_w - 1) * Integer(attrs->strides[1]) >= input_w + 
Integer(attrs->padding[1]);
     out_NCHW_shape[2] = analyzer->Simplify(if_then_else(invalid_last_h, 
raw_out_h - 1, raw_out_h));
     out_NCHW_shape[3] = analyzer->Simplify(if_then_else(invalid_last_w, 
raw_out_w - 1, raw_out_w));
   } else {
@@ -300,10 +303,10 @@ TVM_REGISTER_OP("relax.nn.max_pool2d")
 
 /* relax.nn.max_pool3d */
 
-Expr MakePool3d(ffi::String op_name, Expr data, ffi::Array<IntImm> pool_size,
-                ffi::Array<IntImm> strides, ffi::Array<IntImm> padding, 
ffi::Array<IntImm> dilation,
-                bool ceil_mode, bool count_include_pad, ffi::String layout,
-                ffi::Optional<ffi::String> out_layout) {
+Expr MakePool3d(ffi::String op_name, Expr data, ffi::Array<int64_t> pool_size,
+                ffi::Array<int64_t> strides, ffi::Array<int64_t> padding,
+                ffi::Array<int64_t> dilation, bool ceil_mode, bool 
count_include_pad,
+                ffi::String layout, ffi::Optional<ffi::String> out_layout) {
   padding = GetCompletePadding3D(std::move(padding));
   if (pool_size.size() == 1) {
     pool_size.push_back(pool_size[0]);
@@ -328,10 +331,10 @@ Expr MakePool3d(ffi::String op_name, Expr data, 
ffi::Array<IntImm> pool_size,
       << dilation;
 
   auto attrs = ffi::make_object<Pool3DAttrs>();
-  attrs->pool_size = ConvertIntImmToInt64(pool_size);
-  attrs->strides = ConvertIntImmToInt64(strides);
-  attrs->padding = ConvertIntImmToInt64(padding);
-  attrs->dilation = ConvertIntImmToInt64(dilation);
+  attrs->pool_size = std::move(pool_size);
+  attrs->strides = std::move(strides);
+  attrs->padding = std::move(padding);
+  attrs->dilation = std::move(dilation);
   attrs->ceil_mode = ceil_mode;
   attrs->count_include_pad = count_include_pad;
   attrs->layout = layout;
@@ -340,8 +343,8 @@ Expr MakePool3d(ffi::String op_name, Expr data, 
ffi::Array<IntImm> pool_size,
   return Call(op, {std::move(data)}, Attrs(attrs), {});
 }
 
-Expr max_pool3d(Expr data, ffi::Array<IntImm> pool_size, ffi::Array<IntImm> 
strides,
-                ffi::Array<IntImm> padding, ffi::Array<IntImm> dilation, bool 
ceil_mode,
+Expr max_pool3d(Expr data, ffi::Array<int64_t> pool_size, ffi::Array<int64_t> 
strides,
+                ffi::Array<int64_t> padding, ffi::Array<int64_t> dilation, 
bool ceil_mode,
                 bool count_include_pad, ffi::String layout, 
ffi::Optional<ffi::String> out_layout) {
   return MakePool3d("relax.nn.max_pool3d", data, pool_size, strides, padding, 
dilation, ceil_mode,
                     count_include_pad, layout, out_layout);
@@ -374,12 +377,12 @@ StructInfo InferStructInfoPool3D(const Call& call, const 
BlockBuilder& ctx) {
   PrimExpr input_d = data_NCDHW_shape[2];
   PrimExpr input_h = data_NCDHW_shape[3];
   PrimExpr input_w = data_NCDHW_shape[4];
-  PrimExpr kernel_d = attrs->pool_size[0];
-  PrimExpr kernel_h = attrs->pool_size[1];
-  PrimExpr kernel_w = attrs->pool_size[2];
-  PrimExpr padding_d = attrs->padding[0] + attrs->padding[3];
-  PrimExpr padding_h = attrs->padding[1] + attrs->padding[4];
-  PrimExpr padding_w = attrs->padding[2] + attrs->padding[5];
+  PrimExpr kernel_d = Integer(attrs->pool_size[0]);
+  PrimExpr kernel_h = Integer(attrs->pool_size[1]);
+  PrimExpr kernel_w = Integer(attrs->pool_size[2]);
+  PrimExpr padding_d = Integer(attrs->padding[0]) + Integer(attrs->padding[3]);
+  PrimExpr padding_h = Integer(attrs->padding[1]) + Integer(attrs->padding[4]);
+  PrimExpr padding_w = Integer(attrs->padding[2]) + Integer(attrs->padding[5]);
 
   arith::Analyzer* analyzer = ctx->GetAnalyzer();
   std::vector<PrimExpr> out_NCDHW_shape;
@@ -387,21 +390,24 @@ StructInfo InferStructInfoPool3D(const Call& call, const 
BlockBuilder& ctx) {
   out_NCDHW_shape[0] = data_NCDHW_shape[0];
   out_NCDHW_shape[1] = data_NCDHW_shape[1];
 
-  PrimExpr numerator_d = input_d + padding_d - attrs->dilation[0] * (kernel_d 
- 1) - 1;
-  PrimExpr numerator_h = input_h + padding_h - attrs->dilation[1] * (kernel_h 
- 1) - 1;
-  PrimExpr numerator_w = input_w + padding_w - attrs->dilation[2] * (kernel_w 
- 1) - 1;
+  PrimExpr numerator_d = input_d + padding_d - Integer(attrs->dilation[0]) * 
(kernel_d - 1) - 1;
+  PrimExpr numerator_h = input_h + padding_h - Integer(attrs->dilation[1]) * 
(kernel_h - 1) - 1;
+  PrimExpr numerator_w = input_w + padding_w - Integer(attrs->dilation[2]) * 
(kernel_w - 1) - 1;
   if (attrs->ceil_mode) {
-    numerator_d += attrs->strides[0] - 1;
-    numerator_h += attrs->strides[1] - 1;
-    numerator_w += attrs->strides[2] - 1;
+    numerator_d += Integer(attrs->strides[0]) - 1;
+    numerator_h += Integer(attrs->strides[1]) - 1;
+    numerator_w += Integer(attrs->strides[2]) - 1;
   }
-  PrimExpr raw_out_d = floordiv(numerator_d, attrs->strides[0]) + 1;
-  PrimExpr raw_out_h = floordiv(numerator_h, attrs->strides[1]) + 1;
-  PrimExpr raw_out_w = floordiv(numerator_w, attrs->strides[2]) + 1;
+  PrimExpr raw_out_d = floordiv(numerator_d, Integer(attrs->strides[0])) + 1;
+  PrimExpr raw_out_h = floordiv(numerator_h, Integer(attrs->strides[1])) + 1;
+  PrimExpr raw_out_w = floordiv(numerator_w, Integer(attrs->strides[2])) + 1;
   if (attrs->ceil_mode) {
-    PrimExpr invalid_last_d = (raw_out_d - 1) * attrs->strides[0] >= input_d + 
attrs->padding[0];
-    PrimExpr invalid_last_h = (raw_out_h - 1) * attrs->strides[1] >= input_h + 
attrs->padding[1];
-    PrimExpr invalid_last_w = (raw_out_w - 1) * attrs->strides[2] >= input_w + 
attrs->padding[2];
+    PrimExpr invalid_last_d =
+        (raw_out_d - 1) * Integer(attrs->strides[0]) >= input_d + 
Integer(attrs->padding[0]);
+    PrimExpr invalid_last_h =
+        (raw_out_h - 1) * Integer(attrs->strides[1]) >= input_h + 
Integer(attrs->padding[1]);
+    PrimExpr invalid_last_w =
+        (raw_out_w - 1) * Integer(attrs->strides[2]) >= input_w + 
Integer(attrs->padding[2]);
     out_NCDHW_shape[2] = analyzer->Simplify(if_then_else(invalid_last_d, 
raw_out_d - 1, raw_out_d));
     out_NCDHW_shape[3] = analyzer->Simplify(if_then_else(invalid_last_h, 
raw_out_h - 1, raw_out_h));
     out_NCDHW_shape[4] = analyzer->Simplify(if_then_else(invalid_last_w, 
raw_out_w - 1, raw_out_w));
@@ -442,8 +448,8 @@ TVM_REGISTER_OP("relax.nn.max_pool3d")
     .set_attr<Bool>("FPurity", Bool(true));
 
 /* relax.nn.avg_pool1d */
-Expr avg_pool1d(Expr data, ffi::Array<IntImm> pool_size, ffi::Array<IntImm> 
strides,
-                ffi::Array<IntImm> padding, ffi::Array<IntImm> dilation, bool 
ceil_mode,
+Expr avg_pool1d(Expr data, ffi::Array<int64_t> pool_size, ffi::Array<int64_t> 
strides,
+                ffi::Array<int64_t> padding, ffi::Array<int64_t> dilation, 
bool ceil_mode,
                 bool count_include_pad, ffi::String layout, 
ffi::Optional<ffi::String> out_layout) {
   return MakePool1d("relax.nn.avg_pool1d", data, pool_size, strides, padding, 
dilation, ceil_mode,
                     count_include_pad, layout, out_layout);
@@ -464,8 +470,8 @@ TVM_REGISTER_OP("relax.nn.avg_pool1d")
     .set_attr<Bool>("FPurity", Bool(true));
 
 /* relax.nn.avg_pool2d */
-Expr avg_pool2d(Expr data, ffi::Array<IntImm> pool_size, ffi::Array<IntImm> 
strides,
-                ffi::Array<IntImm> padding, ffi::Array<IntImm> dilation, bool 
ceil_mode,
+Expr avg_pool2d(Expr data, ffi::Array<int64_t> pool_size, ffi::Array<int64_t> 
strides,
+                ffi::Array<int64_t> padding, ffi::Array<int64_t> dilation, 
bool ceil_mode,
                 bool count_include_pad, ffi::String layout, 
ffi::Optional<ffi::String> out_layout) {
   return MakePool2d("relax.nn.avg_pool2d", data, pool_size, strides, padding, 
dilation, ceil_mode,
                     count_include_pad, layout, out_layout);
@@ -486,8 +492,8 @@ TVM_REGISTER_OP("relax.nn.avg_pool2d")
     .set_attr<Bool>("FPurity", Bool(true));
 
 /* relax.nn.avg_pool3d */
-Expr avg_pool3d(Expr data, ffi::Array<IntImm> pool_size, ffi::Array<IntImm> 
strides,
-                ffi::Array<IntImm> padding, ffi::Array<IntImm> dilation, bool 
ceil_mode,
+Expr avg_pool3d(Expr data, ffi::Array<int64_t> pool_size, ffi::Array<int64_t> 
strides,
+                ffi::Array<int64_t> padding, ffi::Array<int64_t> dilation, 
bool ceil_mode,
                 bool count_include_pad, ffi::String layout, 
ffi::Optional<ffi::String> out_layout) {
   return MakePool3d("relax.nn.avg_pool3d", data, pool_size, strides, padding, 
dilation, ceil_mode,
                     count_include_pad, layout, out_layout);
@@ -509,13 +515,13 @@ TVM_REGISTER_OP("relax.nn.avg_pool3d")
 
 /* relax.nn.adaptive_avg_pool1d */
 
-Expr adaptive_avg_pool1d(Expr data, ffi::Optional<ffi::Array<IntImm>> 
output_size,
+Expr adaptive_avg_pool1d(Expr data, ffi::Optional<ffi::Array<int64_t>> 
output_size,
                          ffi::String layout, ffi::Optional<ffi::String> 
out_layout) {
   ObjectPtr<AdaptivePool1DAttrs> attrs = 
ffi::make_object<AdaptivePool1DAttrs>();
   attrs->layout = layout;
   attrs->out_layout = out_layout.value_or(layout);
   if (output_size.defined()) {
-    ffi::Array<IntImm> _output_size = output_size.value();
+    ffi::Array<int64_t> _output_size = output_size.value();
     CHECK_EQ(_output_size.size(), 1)
         << "The output_size length is expected to be 1. However, the given 
output_size is "
         << _output_size;
@@ -556,7 +562,7 @@ StructInfo InferStructInfoAdaptiveAvgPool1D(const Call& 
call, const BlockBuilder
   ffi::Array<PrimExpr> data_NCW_shape = 
data2NCW.ForwardShape(data_shape.value()->values);
   ffi::Array<PrimExpr> out_NCW_shape(data_NCW_shape);
   if (attrs->output_size.defined()) {
-    out_NCW_shape.Set(2, attrs->output_size.value()[0]);
+    out_NCW_shape.Set(2, Integer(attrs->output_size.value()[0]));
   }
 
   ffi::Array<PrimExpr> out_shape = out2NCW.BackwardShape(out_NCW_shape);
@@ -591,13 +597,13 @@ TVM_REGISTER_OP("relax.nn.adaptive_avg_pool1d")
 
 /* relax.nn.adaptive_avg_pool2d */
 
-Expr adaptive_avg_pool2d(Expr data, ffi::Optional<ffi::Array<IntImm>> 
output_size,
+Expr adaptive_avg_pool2d(Expr data, ffi::Optional<ffi::Array<int64_t>> 
output_size,
                          ffi::String layout, ffi::Optional<ffi::String> 
out_layout) {
   ObjectPtr<AdaptivePool2DAttrs> attrs = 
ffi::make_object<AdaptivePool2DAttrs>();
   attrs->layout = layout;
   attrs->out_layout = out_layout.value_or(layout);
   if (output_size.defined()) {
-    ffi::Array<IntImm> _output_size = output_size.value();
+    ffi::Array<int64_t> _output_size = output_size.value();
     if (_output_size.size() == 1) {
       _output_size.push_back(_output_size[0]);
     }
@@ -641,8 +647,8 @@ StructInfo InferStructInfoAdaptiveAvgPool2D(const Call& 
call, const BlockBuilder
   ffi::Array<PrimExpr> data_NCHW_shape = 
data2NCHW.ForwardShape(data_shape.value()->values);
   ffi::Array<PrimExpr> out_NCHW_shape(data_NCHW_shape);
   if (attrs->output_size.defined()) {
-    out_NCHW_shape.Set(2, attrs->output_size.value()[0]);
-    out_NCHW_shape.Set(3, attrs->output_size.value()[1]);
+    out_NCHW_shape.Set(2, Integer(attrs->output_size.value()[0]));
+    out_NCHW_shape.Set(3, Integer(attrs->output_size.value()[1]));
   }
 
   ffi::Array<PrimExpr> out_shape = out2NCHW.BackwardShape(out_NCHW_shape);
@@ -693,13 +699,13 @@ TVM_REGISTER_OP("relax.nn.adaptive_avg_pool2d")
 
 /* relax.nn.adaptive_avg_pool3d */
 
-Expr adaptive_avg_pool3d(Expr data, ffi::Optional<ffi::Array<IntImm>> 
output_size,
+Expr adaptive_avg_pool3d(Expr data, ffi::Optional<ffi::Array<int64_t>> 
output_size,
                          ffi::String layout, ffi::Optional<ffi::String> 
out_layout) {
   ObjectPtr<AdaptivePool3DAttrs> attrs = 
ffi::make_object<AdaptivePool3DAttrs>();
   attrs->layout = layout;
   attrs->out_layout = out_layout.value_or(layout);
   if (output_size.defined()) {
-    ffi::Array<IntImm> _output_size = output_size.value();
+    ffi::Array<int64_t> _output_size = output_size.value();
     if (_output_size.size() == 1) {
       _output_size.push_back(_output_size[0]);
     }
@@ -743,9 +749,9 @@ StructInfo InferStructInfoAdaptiveAvgPool3D(const Call& 
call, const BlockBuilder
   ffi::Array<PrimExpr> data_NCDHW_shape = 
data2NCDHW.ForwardShape(data_shape.value()->values);
   ffi::Array<PrimExpr> out_NCDHW_shape(data_NCDHW_shape);
   if (attrs->output_size.defined()) {
-    out_NCDHW_shape.Set(2, attrs->output_size.value()[0]);
-    out_NCDHW_shape.Set(3, attrs->output_size.value()[1]);
-    out_NCDHW_shape.Set(4, attrs->output_size.value()[2]);
+    out_NCDHW_shape.Set(2, Integer(attrs->output_size.value()[0]));
+    out_NCDHW_shape.Set(3, Integer(attrs->output_size.value()[1]));
+    out_NCDHW_shape.Set(4, Integer(attrs->output_size.value()[2]));
   }
 
   ffi::Array<PrimExpr> out_shape = out2NCDHW.BackwardShape(out_NCDHW_shape);
diff --git a/src/relax/op/nn/pooling.h b/src/relax/op/nn/pooling.h
index c5435303e8..d1fbc834ee 100644
--- a/src/relax/op/nn/pooling.h
+++ b/src/relax/op/nn/pooling.h
@@ -33,17 +33,17 @@ namespace tvm {
 namespace relax {
 
 /*! \brief 2D maximum pooling operator. */
-Expr max_pool2d(Expr data, ffi::Array<IntImm> pool_size, ffi::Array<IntImm> 
strides,
-                ffi::Array<IntImm> padding, ffi::Array<IntImm> dilation, bool 
ceil_mode,
+Expr max_pool2d(Expr data, ffi::Array<int64_t> pool_size, ffi::Array<int64_t> 
strides,
+                ffi::Array<int64_t> padding, ffi::Array<int64_t> dilation, 
bool ceil_mode,
                 bool count_include_pad, ffi::String layout, 
ffi::Optional<ffi::String> out_layout);
 
 /*! \brief 2D average pooling operator. */
-Expr avg_pool2d(Expr data, ffi::Array<IntImm> pool_size, ffi::Array<IntImm> 
strides,
-                ffi::Array<IntImm> padding, ffi::Array<IntImm> dilation, bool 
ceil_mode,
+Expr avg_pool2d(Expr data, ffi::Array<int64_t> pool_size, ffi::Array<int64_t> 
strides,
+                ffi::Array<int64_t> padding, ffi::Array<int64_t> dilation, 
bool ceil_mode,
                 bool count_include_pad, ffi::String layout, 
ffi::Optional<ffi::String> out_layout);
 
 /*! \brief 2D adaptive average pooling operator. */
-Expr adaptive_avg_pool2d(Expr data, ffi::Optional<ffi::Array<IntImm>> 
output_size,
+Expr adaptive_avg_pool2d(Expr data, ffi::Optional<ffi::Array<int64_t>> 
output_size,
                          ffi::String layout, ffi::Optional<ffi::String> 
out_layout);
 
 }  // namespace relax
diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h
index ee82f3eebc..cd5406e614 100644
--- a/src/relax/op/op_common.h
+++ b/src/relax/op/op_common.h
@@ -465,7 +465,7 @@ inline ffi::Array<IntImm> ConvertIntImmToInt64(const 
ffi::Array<IntImm>& int_imm
  * \return The completed padding.
  * \throws Throws error if the input padding length is neither 1 or 2.
  */
-inline ffi::Array<IntImm> GetCompletePadding1D(ffi::Array<IntImm> padding) {
+inline ffi::Array<int64_t> GetCompletePadding1D(ffi::Array<int64_t> padding) {
   if (padding.size() == 1) {
     return {padding[0], padding[0]};
   } else if (padding.size() == 2) {
@@ -486,7 +486,7 @@ inline ffi::Array<IntImm> 
GetCompletePadding1D(ffi::Array<IntImm> padding) {
  * \return The completed padding.
  * \throws Throws error if the input padding length is neither 1, 2 or 4.
  */
-inline ffi::Array<IntImm> GetCompletePadding2D(ffi::Array<IntImm> padding) {
+inline ffi::Array<int64_t> GetCompletePadding2D(ffi::Array<int64_t> padding) {
   if (padding.size() == 1) {
     return {padding[0], padding[0], padding[0], padding[0]};
   } else if (padding.size() == 2) {
@@ -511,7 +511,7 @@ inline ffi::Array<IntImm> 
GetCompletePadding2D(ffi::Array<IntImm> padding) {
  * \return The completed padding.
  * \throws Throws error if the input padding length is neither 1, 3 or 6.
  */
-inline ffi::Array<IntImm> GetCompletePadding3D(ffi::Array<IntImm> padding) {
+inline ffi::Array<int64_t> GetCompletePadding3D(ffi::Array<int64_t> padding) {
   if (padding.size() == 1) {
     return {padding[0], padding[0], padding[0], padding[0], padding[0], 
padding[0]};
   } else if (padding.size() == 3) {
diff --git a/src/relax/op/tensor/grad.cc b/src/relax/op/tensor/grad.cc
index 52a218b730..0594ef75bd 100644
--- a/src/relax/op/tensor/grad.cc
+++ b/src/relax/op/tensor/grad.cc
@@ -141,15 +141,15 @@ TVM_REGISTER_OP("relax.grad.nll_loss_backward")
     .set_attr<Bool>("FPurity", Bool(true));
 
 /* relax.grad.max_pool2d_backward */
-Expr max_pool2d_backward(Expr output_grad, Expr data, ffi::Array<IntImm> 
pool_size,
-                         ffi::Array<IntImm> strides, ffi::Array<IntImm> 
padding,
-                         ffi::Array<IntImm> dilation, bool ceil_mode, bool 
count_include_pad,
+Expr max_pool2d_backward(Expr output_grad, Expr data, ffi::Array<int64_t> 
pool_size,
+                         ffi::Array<int64_t> strides, ffi::Array<int64_t> 
padding,
+                         ffi::Array<int64_t> dilation, bool ceil_mode, bool 
count_include_pad,
                          ffi::String layout, ffi::Optional<ffi::String> 
out_layout) {
   auto attrs = ffi::make_object<Pool2DAttrs>();
   attrs->pool_size = std::move(pool_size);
-  attrs->strides = ConvertIntImmToInt64(strides);
-  attrs->padding = ConvertIntImmToInt64(padding);
-  attrs->dilation = ConvertIntImmToInt64(dilation);
+  attrs->strides = std::move(strides);
+  attrs->padding = std::move(padding);
+  attrs->dilation = std::move(dilation);
   attrs->ceil_mode = ceil_mode;
   attrs->count_include_pad = count_include_pad;
   attrs->layout = layout;
@@ -176,15 +176,15 @@ TVM_REGISTER_OP("relax.grad.max_pool2d_backward")
     .set_attr<Bool>("FPurity", Bool(true));
 
 /* relax.grad.avg_pool2d_backward */
-Expr avg_pool2d_backward(Expr output_grad, Expr data, ffi::Array<IntImm> 
pool_size,
-                         ffi::Array<IntImm> strides, ffi::Array<IntImm> 
padding,
-                         ffi::Array<IntImm> dilation, bool ceil_mode, bool 
count_include_pad,
+Expr avg_pool2d_backward(Expr output_grad, Expr data, ffi::Array<int64_t> 
pool_size,
+                         ffi::Array<int64_t> strides, ffi::Array<int64_t> 
padding,
+                         ffi::Array<int64_t> dilation, bool ceil_mode, bool 
count_include_pad,
                          ffi::String layout, ffi::Optional<ffi::String> 
out_layout) {
   auto attrs = ffi::make_object<Pool2DAttrs>();
   attrs->pool_size = std::move(pool_size);
-  attrs->strides = ConvertIntImmToInt64(strides);
-  attrs->padding = ConvertIntImmToInt64(padding);
-  attrs->dilation = ConvertIntImmToInt64(dilation);
+  attrs->strides = std::move(strides);
+  attrs->padding = std::move(padding);
+  attrs->dilation = std::move(dilation);
   attrs->ceil_mode = ceil_mode;
   attrs->count_include_pad = count_include_pad;
   attrs->layout = layout;
diff --git a/src/relax/op/tensor/grad.h b/src/relax/op/tensor/grad.h
index 406d7a2f77..911049475d 100644
--- a/src/relax/op/tensor/grad.h
+++ b/src/relax/op/tensor/grad.h
@@ -46,16 +46,16 @@ Expr nll_loss_backward(Expr output_grad, Expr predictions, 
Expr targets,
 
 /*! \brief Backward operator of relax.max_pool2d. All parameters except 
output_grad is the same as
  * relax.max_pool2d. Returns the gradient w.r.t. data. */
-Expr max_pool2d_backward(Expr output_grad, Expr data, ffi::Array<IntImm> 
pool_size,
-                         ffi::Array<IntImm> strides, ffi::Array<IntImm> 
padding,
-                         ffi::Array<IntImm> dilation, bool ceil_mode, bool 
count_include_pad,
+Expr max_pool2d_backward(Expr output_grad, Expr data, ffi::Array<int64_t> 
pool_size,
+                         ffi::Array<int64_t> strides, ffi::Array<int64_t> 
padding,
+                         ffi::Array<int64_t> dilation, bool ceil_mode, bool 
count_include_pad,
                          ffi::String layout, ffi::Optional<ffi::String> 
out_layout);
 
 /*! \brief Backward operator of relax.avg_pool2d. All parameters except 
output_grad is the same as
  * relax.avg_pool2d. Returns the gradient w.r.t. data. */
-Expr avg_pool2d_backward(Expr output_grad, Expr data, ffi::Array<IntImm> 
pool_size,
-                         ffi::Array<IntImm> strides, ffi::Array<IntImm> 
padding,
-                         ffi::Array<IntImm> dilation, bool ceil_mode, bool 
count_include_pad,
+Expr avg_pool2d_backward(Expr output_grad, Expr data, ffi::Array<int64_t> 
pool_size,
+                         ffi::Array<int64_t> strides, ffi::Array<int64_t> 
padding,
+                         ffi::Array<int64_t> dilation, bool ceil_mode, bool 
count_include_pad,
                          ffi::String layout, ffi::Optional<ffi::String> 
out_layout);
 
 /*! \brief Backward operator of relax.take. All parameters except output_grad 
is the same as
diff --git a/tests/python/relax/test_dataflow_pattern.py 
b/tests/python/relax/test_dataflow_pattern.py
index ee21c14c6f..1d2a6e4676 100644
--- a/tests/python/relax/test_dataflow_pattern.py
+++ b/tests/python/relax/test_dataflow_pattern.py
@@ -277,10 +277,8 @@ def test_op_attr():
     conv2d = rx.op.nn.conv2d(x, y, strides=(3, 3))
     xp = is_var("x")
     yp = is_var("y")
-    # TODO(@yuchen): reenable the assert after figuring out why it fails
-    # assert is_op("nn.conv2d")(xp, yp).has_attr({"strides": [3, 
3]}).match(conv2d)
+    assert is_op("relax.nn.conv2d")(xp, yp).has_attr({"strides": [3, 
3]}).match(conv2d)
     assert not is_op("relax.nn.conv2d")(xp, yp).has_attr({"strides": [4, 
3]}).match(conv2d)
-    assert not is_op("relax.nn.conv2d")(xp, yp).has_attr({"strides": [3, 
3]}).match(conv2d)
 
 
 def test_match_call_attr():
diff --git a/tests/python/relax/test_op_nn_convolution.py 
b/tests/python/relax/test_op_nn_convolution.py
index 9b913138df..4ee8226bc3 100644
--- a/tests/python/relax/test_op_nn_convolution.py
+++ b/tests/python/relax/test_op_nn_convolution.py
@@ -352,10 +352,10 @@ def test_conv1d_stride_padding_dilation_int64():
     w = relax.Var("w", R.Tensor((4, 3, 3), "float32"))
     conv1d = relax.op.nn.conv1d(x, w, strides=(1,), padding=(1, 1), 
dilation=(1,))
 
-    assert conv1d.attrs.strides[0].dtype == "int64"
-    assert conv1d.attrs.padding[0].dtype == "int64"
-    assert conv1d.attrs.padding[1].dtype == "int64"
-    assert conv1d.attrs.dilation[0].dtype == "int64"
+    assert isinstance(conv1d.attrs.strides[0], int)
+    assert isinstance(conv1d.attrs.padding[0], int)
+    assert isinstance(conv1d.attrs.padding[1], int)
+    assert isinstance(conv1d.attrs.dilation[0], int)
 
 
 def test_conv1d_wrong_strides_padding_dilation_length():
@@ -711,9 +711,9 @@ def test_conv1d_transpose_stride_padding_dilation_int64():
     w = relax.Var("w", R.Tensor((3, 4, 3), "float32"))
     conv1d = relax.op.nn.conv1d_transpose(x, w, strides=1, padding=1, 
dilation=1)
 
-    assert conv1d.attrs.strides[0].dtype == "int64"
-    assert conv1d.attrs.padding[0].dtype == "int64"
-    assert conv1d.attrs.dilation[0].dtype == "int64"
+    assert isinstance(conv1d.attrs.strides[0], int)
+    assert isinstance(conv1d.attrs.padding[0], int)
+    assert isinstance(conv1d.attrs.dilation[0], int)
 
 
 def test_conv1d_transpose_wrong_strides_padding_dilation_length():
@@ -1122,14 +1122,14 @@ def test_conv2d_stride_padding_dilation_int64():
     w = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32"))
     conv2d = relax.op.nn.conv2d(x, w, strides=(1, 1), padding=(1, 1), 
dilation=(1, 1))
 
-    assert conv2d.attrs.strides[0].dtype == "int64"
-    assert conv2d.attrs.strides[1].dtype == "int64"
-    assert conv2d.attrs.padding[0].dtype == "int64"
-    assert conv2d.attrs.padding[1].dtype == "int64"
-    assert conv2d.attrs.padding[2].dtype == "int64"
-    assert conv2d.attrs.padding[3].dtype == "int64"
-    assert conv2d.attrs.dilation[0].dtype == "int64"
-    assert conv2d.attrs.dilation[1].dtype == "int64"
+    assert isinstance(conv2d.attrs.strides[0], int)
+    assert isinstance(conv2d.attrs.strides[1], int)
+    assert isinstance(conv2d.attrs.padding[0], int)
+    assert isinstance(conv2d.attrs.padding[1], int)
+    assert isinstance(conv2d.attrs.padding[2], int)
+    assert isinstance(conv2d.attrs.padding[3], int)
+    assert isinstance(conv2d.attrs.dilation[0], int)
+    assert isinstance(conv2d.attrs.dilation[1], int)
 
 
 def test_conv2d_wrong_strides_padding_dilation_length():
@@ -1510,16 +1510,16 @@ def 
test_conv2d_transpose_stride_padding_dilation_int64():
         x, w, strides=(1, 1), padding=(1, 1), output_padding=(1, 2), 
dilation=(1, 1)
     )
 
-    assert conv2d_transpose.attrs.strides[0].dtype == "int64"
-    assert conv2d_transpose.attrs.strides[1].dtype == "int64"
-    assert conv2d_transpose.attrs.padding[0].dtype == "int64"
-    assert conv2d_transpose.attrs.padding[1].dtype == "int64"
-    assert conv2d_transpose.attrs.padding[2].dtype == "int64"
-    assert conv2d_transpose.attrs.padding[3].dtype == "int64"
-    assert conv2d_transpose.attrs.output_padding[0].dtype == "int64"
-    assert conv2d_transpose.attrs.output_padding[1].dtype == "int64"
-    assert conv2d_transpose.attrs.dilation[0].dtype == "int64"
-    assert conv2d_transpose.attrs.dilation[1].dtype == "int64"
+    assert isinstance(conv2d_transpose.attrs.strides[0], int)
+    assert isinstance(conv2d_transpose.attrs.strides[1], int)
+    assert isinstance(conv2d_transpose.attrs.padding[0], int)
+    assert isinstance(conv2d_transpose.attrs.padding[1], int)
+    assert isinstance(conv2d_transpose.attrs.padding[2], int)
+    assert isinstance(conv2d_transpose.attrs.padding[3], int)
+    assert isinstance(conv2d_transpose.attrs.output_padding[0], int)
+    assert isinstance(conv2d_transpose.attrs.output_padding[1], int)
+    assert isinstance(conv2d_transpose.attrs.dilation[0], int)
+    assert isinstance(conv2d_transpose.attrs.dilation[1], int)
 
 
 def test_conv2d_transpose_wrong_strides_padding_dilation_length():
diff --git a/tests/python/relax/test_op_nn_pooling.py 
b/tests/python/relax/test_op_nn_pooling.py
index d4461a122d..12d099bb33 100644
--- a/tests/python/relax/test_op_nn_pooling.py
+++ b/tests/python/relax/test_op_nn_pooling.py
@@ -183,10 +183,10 @@ def test_max_pool1d_stride_padding_dilation_int64():
     x = relax.Var("x", R.Tensor((2, 3, 28), "float32"))
     max_pool1d = relax.op.nn.max_pool1d(x, pool_size=3, strides=1, padding=1, 
dilation=1)
 
-    assert max_pool1d.attrs.strides[0].dtype == "int64"
-    assert max_pool1d.attrs.padding[0].dtype == "int64"
-    assert max_pool1d.attrs.padding[1].dtype == "int64"
-    assert max_pool1d.attrs.dilation[0].dtype == "int64"
+    assert isinstance(max_pool1d.attrs.strides[0], int)
+    assert isinstance(max_pool1d.attrs.padding[0], int)
+    assert isinstance(max_pool1d.attrs.padding[1], int)
+    assert isinstance(max_pool1d.attrs.dilation[0], int)
 
 
 def test_max_pool1d_wrong_pool_size_strides_padding_dilation_length():
@@ -412,14 +412,14 @@ def test_max_pool2d_stride_padding_dilation_int64():
     x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
     max_pool2d = relax.op.nn.max_pool2d(x, (3, 3), strides=(1, 1), padding=(1, 
1), dilation=(1, 1))
 
-    assert max_pool2d.attrs.strides[0].dtype == "int64"
-    assert max_pool2d.attrs.strides[1].dtype == "int64"
-    assert max_pool2d.attrs.padding[0].dtype == "int64"
-    assert max_pool2d.attrs.padding[1].dtype == "int64"
-    assert max_pool2d.attrs.padding[2].dtype == "int64"
-    assert max_pool2d.attrs.padding[3].dtype == "int64"
-    assert max_pool2d.attrs.dilation[0].dtype == "int64"
-    assert max_pool2d.attrs.dilation[1].dtype == "int64"
+    assert isinstance(max_pool2d.attrs.strides[0], int)
+    assert isinstance(max_pool2d.attrs.strides[1], int)
+    assert isinstance(max_pool2d.attrs.padding[0], int)
+    assert isinstance(max_pool2d.attrs.padding[1], int)
+    assert isinstance(max_pool2d.attrs.padding[2], int)
+    assert isinstance(max_pool2d.attrs.padding[3], int)
+    assert isinstance(max_pool2d.attrs.dilation[0], int)
+    assert isinstance(max_pool2d.attrs.dilation[1], int)
 
 
 def test_max_pool2d_wrong_pool_size_strides_padding_dilation_length():
@@ -660,17 +660,17 @@ def test_max_pool3d_stride_padding_dilation_int64():
         x, (3, 3, 3), strides=(1, 1, 1), padding=(1, 1, 1), dilation=(1, 1, 1)
     )
 
-    assert max_pool3d.attrs.strides[0].dtype == "int64"
-    assert max_pool3d.attrs.strides[1].dtype == "int64"
-    assert max_pool3d.attrs.strides[2].dtype == "int64"
-    assert max_pool3d.attrs.padding[0].dtype == "int64"
-    assert max_pool3d.attrs.padding[1].dtype == "int64"
-    assert max_pool3d.attrs.padding[2].dtype == "int64"
-    assert max_pool3d.attrs.padding[3].dtype == "int64"
-    assert max_pool3d.attrs.padding[4].dtype == "int64"
-    assert max_pool3d.attrs.dilation[0].dtype == "int64"
-    assert max_pool3d.attrs.dilation[1].dtype == "int64"
-    assert max_pool3d.attrs.dilation[2].dtype == "int64"
+    assert isinstance(max_pool3d.attrs.strides[0], int)
+    assert isinstance(max_pool3d.attrs.strides[1], int)
+    assert isinstance(max_pool3d.attrs.strides[2], int)
+    assert isinstance(max_pool3d.attrs.padding[0], int)
+    assert isinstance(max_pool3d.attrs.padding[1], int)
+    assert isinstance(max_pool3d.attrs.padding[2], int)
+    assert isinstance(max_pool3d.attrs.padding[3], int)
+    assert isinstance(max_pool3d.attrs.padding[4], int)
+    assert isinstance(max_pool3d.attrs.dilation[0], int)
+    assert isinstance(max_pool3d.attrs.dilation[1], int)
+    assert isinstance(max_pool3d.attrs.dilation[2], int)
 
 
 def test_max_pool3d_wrong_pool_size_strides_padding_dilation_length():
@@ -875,10 +875,10 @@ def test_avg_pool1d_stride_padding_dilation_int64():
     x = relax.Var("x", R.Tensor((2, 3, 28), "float32"))
     avg_pool1d = relax.op.nn.avg_pool1d(x, 3, strides=1, padding=1, dilation=1)
 
-    assert avg_pool1d.attrs.strides[0].dtype == "int64"
-    assert avg_pool1d.attrs.padding[0].dtype == "int64"
-    assert avg_pool1d.attrs.padding[1].dtype == "int64"
-    assert avg_pool1d.attrs.dilation[0].dtype == "int64"
+    assert isinstance(avg_pool1d.attrs.strides[0], int)
+    assert isinstance(avg_pool1d.attrs.padding[0], int)
+    assert isinstance(avg_pool1d.attrs.padding[1], int)
+    assert isinstance(avg_pool1d.attrs.dilation[0], int)
 
 
 def test_avg_pool1d_wrong_pool_size_strides_padding_dilation_length():
@@ -1101,14 +1101,14 @@ def test_avg_pool2d_stride_padding_dilation_int64():
     x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
     avg_pool2d = relax.op.nn.avg_pool2d(x, (3, 3), strides=(1, 1), padding=(1, 
1), dilation=(1, 1))
 
-    assert avg_pool2d.attrs.strides[0].dtype == "int64"
-    assert avg_pool2d.attrs.strides[1].dtype == "int64"
-    assert avg_pool2d.attrs.padding[0].dtype == "int64"
-    assert avg_pool2d.attrs.padding[1].dtype == "int64"
-    assert avg_pool2d.attrs.padding[2].dtype == "int64"
-    assert avg_pool2d.attrs.padding[3].dtype == "int64"
-    assert avg_pool2d.attrs.dilation[0].dtype == "int64"
-    assert avg_pool2d.attrs.dilation[1].dtype == "int64"
+    assert isinstance(avg_pool2d.attrs.strides[0], int)
+    assert isinstance(avg_pool2d.attrs.strides[1], int)
+    assert isinstance(avg_pool2d.attrs.padding[0], int)
+    assert isinstance(avg_pool2d.attrs.padding[1], int)
+    assert isinstance(avg_pool2d.attrs.padding[2], int)
+    assert isinstance(avg_pool2d.attrs.padding[3], int)
+    assert isinstance(avg_pool2d.attrs.dilation[0], int)
+    assert isinstance(avg_pool2d.attrs.dilation[1], int)
 
 
 def test_avg_pool2d_wrong_pool_size_strides_padding_dilation_length():
@@ -1356,15 +1356,15 @@ def test_avg_pool3d_stride_padding_dilation_int64():
         x, (3, 3, 3), strides=(1, 1, 1), padding=(1, 1, 1), dilation=(1, 1, 1)
     )
 
-    assert avg_pool3d.attrs.strides[0].dtype == "int64"
-    assert avg_pool3d.attrs.strides[1].dtype == "int64"
-    assert avg_pool3d.attrs.strides[2].dtype == "int64"
-    assert avg_pool3d.attrs.padding[0].dtype == "int64"
-    assert avg_pool3d.attrs.padding[1].dtype == "int64"
-    assert avg_pool3d.attrs.padding[2].dtype == "int64"
-    assert avg_pool3d.attrs.dilation[0].dtype == "int64"
-    assert avg_pool3d.attrs.dilation[1].dtype == "int64"
-    assert avg_pool3d.attrs.dilation[2].dtype == "int64"
+    assert isinstance(avg_pool3d.attrs.strides[0], int)
+    assert isinstance(avg_pool3d.attrs.strides[1], int)
+    assert isinstance(avg_pool3d.attrs.strides[2], int)
+    assert isinstance(avg_pool3d.attrs.padding[0], int)
+    assert isinstance(avg_pool3d.attrs.padding[1], int)
+    assert isinstance(avg_pool3d.attrs.padding[2], int)
+    assert isinstance(avg_pool3d.attrs.dilation[0], int)
+    assert isinstance(avg_pool3d.attrs.dilation[1], int)
+    assert isinstance(avg_pool3d.attrs.dilation[2], int)
 
 
 def test_avg_pool3d_wrong_pool_size_strides_padding_dilation_length():
diff --git a/tests/python/relax/test_transform_legalize_ops_grad.py 
b/tests/python/relax/test_transform_legalize_ops_grad.py
index cf9361d7c7..294cea71de 100644
--- a/tests/python/relax/test_transform_legalize_ops_grad.py
+++ b/tests/python/relax/test_transform_legalize_ops_grad.py
@@ -272,17 +272,17 @@ def test_avg_pool2d_backward():
     @I.ir_module
     class Expected:
         @T.prim_func(private=True)
-        def avg_pool2d_backward(rxplaceholder: T.Buffer((T.int64(3), 
T.int64(2), T.int64(6), T.int64(5)), "float32"), rxplaceholder_1: 
T.Buffer((T.int64(3), T.int64(2), T.int64(10), T.int64(10)), "float32"), 
T_pool_grad: T.Buffer((T.int64(3), T.int64(2), T.int64(10), T.int64(10)), 
"float32")):
+        def avg_pool2d_backward(output_grad: T.Buffer((T.int64(3), T.int64(2), 
T.int64(6), T.int64(5)), "float32"), data: T.Buffer((T.int64(3), T.int64(2), 
T.int64(10), T.int64(10)), "float32"), T_pool_grad: T.Buffer((T.int64(3), 
T.int64(2), T.int64(10), T.int64(10)), "float32")):
             T.func_attr({"tir.noalias": True})
             # with T.sblock("root"):
             for ax0, ax1, ax2, ax3, wh, ww in T.grid(T.int64(3), T.int64(2), 
T.int64(10), T.int64(10), T.int64(3), T.int64(3)):
                 with T.sblock("T_pool_grad"):
                     v_ax0, v_ax1, v_ax2, v_ax3, v_wh, v_ww = 
T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, wh, ww])
-                    T.reads(rxplaceholder[v_ax0, v_ax1, T.Div((v_ax2 + 
T.int64(2)), T.int64(2)) - v_wh, T.Div((v_ax3 + T.int64(1)), T.int64(2)) - 
v_ww])
+                    T.reads(output_grad[v_ax0, v_ax1, T.Div(v_ax2 + 
T.int64(2), T.int64(2)) - v_wh, T.Div(v_ax3 + T.int64(1), T.int64(2)) - v_ww])
                     T.writes(T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3])
                     with T.init():
                         T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] = T.float32(0)
-                    T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] + T.if_then_else(T.Select(v_ax2 < 
T.int64(3), T.int64(0), T.Div(v_ax2 - T.int64(3), T.int64(2)) + T.int64(1)) <= 
T.Div(v_ax2 + T.int64(2), T.int64(2)) - v_wh and T.Div(v_ax2 + T.int64(2), 
T.int64(2)) - v_wh < T.int64(6) and T.Select(v_ax3 < T.int64(4), T.int64(0), 
T.Div(v_ax3 - T.int64(4), T.int64(2)) + T.int64(1)) <= T.Div(v_ax3 + 
T.int64(1), T.int64(2)) - v_ww and T.Div(v_ax3 + T.int64 [...]
+                    T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] + T.if_then_else(T.Select(v_ax2 < 
T.int64(3), T.int64(0), T.Div(v_ax2 - T.int64(3), T.int64(2)) + T.int64(1)) <= 
T.Div(v_ax2 + T.int64(2), T.int64(2)) - v_wh and T.Div(v_ax2 + T.int64(2), 
T.int64(2)) - v_wh < T.int64(6) and T.Select(v_ax3 < T.int64(4), T.int64(0), 
T.Div(v_ax3 - T.int64(4), T.int64(2)) + T.int64(1)) <= T.Div(v_ax3 + 
T.int64(1), T.int64(2)) - v_ww and T.Div(v_ax3 + T.int64 [...]
 
         @R.function
         def main(output_grad: R.Tensor((3, 2, 6, 5), dtype="float32"), data: 
R.Tensor((3, 2, 10, 10), dtype="float32")) -> R.Tensor((3, 2, 10, 10), 
dtype="float32"):

Reply via email to