This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4de1f11344 [Relax] Add conv3d_transpose and ONNX ConvTranspose 3D 
support (#18948)
4de1f11344 is described below

commit 4de1f11344608d2305891c7fc585bc4f089158eb
Author: Dayuxiaoshui <[email protected]>
AuthorDate: Sun Mar 29 17:42:35 2026 +0800

    [Relax] Add conv3d_transpose and ONNX ConvTranspose 3D support (#18948)
    
    Introduce relax.nn.conv3d_transpose (attrs, C++ inference/layout, Python
    API) and lower it to TOPI group_conv3d_transpose_ncdhw when using
    NCDHW/IODHW with dilation 1, matching the conv2d_transpose legalization
    policy.
    
    Wire the Relax ONNX frontend to emit conv3d_transpose for 5D inputs.
    Extend tests for ONNX, struct info, LegalizeOps, and TVMScript
    round-trip; fix ConvTranspose test output spatial size to include
    output_padding.https://github.com/apache/tvm/issues/18945
---
 include/tvm/relax/attrs/nn.h                       |  52 +++++
 python/tvm/relax/frontend/onnx/onnx_frontend.py    |   4 +-
 python/tvm/relax/op/nn/__init__.py                 |   1 +
 python/tvm/relax/op/nn/nn.py                       | 106 +++++++++
 python/tvm/relax/op/op_attrs.py                    |   5 +
 python/tvm/relax/transform/legalize_ops/nn.py      |  38 +++-
 src/relax/op/nn/convolution.cc                     | 247 ++++++++++++++++++++-
 src/relax/op/nn/convolution.h                      |  12 +
 tests/python/relax/test_frontend_onnx.py           |   9 +-
 tests/python/relax/test_op_nn_convolution.py       |  74 ++++++
 .../python/relax/test_transform_legalize_ops_nn.py | 108 +++++++++
 tests/python/relax/test_tvmscript_parser_op_nn.py  |  38 ++++
 12 files changed, 687 insertions(+), 7 deletions(-)

diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h
index 2a2ac5fe07..5c1931c3ee 100644
--- a/include/tvm/relax/attrs/nn.h
+++ b/include/tvm/relax/attrs/nn.h
@@ -267,6 +267,58 @@ struct Conv2DTransposeAttrs : public 
AttrsNodeReflAdapter<Conv2DTransposeAttrs>
                                     BaseAttrsNode);
 };  // struct Conv2DTransposeAttrs
 
+/*! \brief Attributes used in Conv3dTranspose operator */
+struct Conv3DTransposeAttrs : public 
AttrsNodeReflAdapter<Conv3DTransposeAttrs> {
+  ffi::Array<int64_t> strides;
+  ffi::Array<int64_t> padding;
+  ffi::Array<int64_t> output_padding;
+  ffi::Array<int64_t> dilation;
+  int groups;
+  ffi::String data_layout;
+  ffi::String kernel_layout;
+  ffi::String out_layout;
+  DataType out_dtype;
+
+  static void RegisterReflection() {
+    namespace refl = tvm::ffi::reflection;
+    refl::ObjectDef<Conv3DTransposeAttrs>()
+        .def_ro("strides", &Conv3DTransposeAttrs::strides,
+                "Specifies the strides of the convolution.")
+        .def_ro("padding", &Conv3DTransposeAttrs::padding,
+                "If padding is non-zero, then the input is implicitly 
zero-padded"
+                "Padding support both symmetric and asymmetric as"
+                "one int : same padding used on all sides"
+                "three int : back/bottom/right will use same padding as 
front/top/left"
+                "six int : padding width in the order of (front, top, left, 
back, bottom, right)")
+        .def_ro("output_padding", &Conv3DTransposeAttrs::output_padding,
+                "Used to disambiguate the output shape.")
+        .def_ro("dilation", &Conv3DTransposeAttrs::dilation,
+                "Specifies the dilation rate to use for dilated convolution.")
+        .def_ro("groups", &Conv3DTransposeAttrs::groups,
+                "Number of groups to split the input into for grouped 
convolution. The number of "
+                "input and "
+                "output channels should be divisible by the number of groups.")
+        .def_ro("data_layout", &Conv3DTransposeAttrs::data_layout,
+                "Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', 
etc."
+                "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, 
height, and width"
+                "dimensions respectively. Convolution is applied on the 'D', 
'H', and"
+                "'W' dimensions.")
+        .def_ro("kernel_layout", &Conv3DTransposeAttrs::kernel_layout,
+                "Dimension ordering of weight. Can be 'IODHW', etc."
+                "'I', 'O', 'D', 'H', 'W' stands for input_channel, 
output_channel, depth, height, and "
+                "width"
+                "dimensions respectively.")
+        .def_ro("out_layout", &Conv3DTransposeAttrs::out_layout,
+                "Dimension ordering of output. Can be 'NCDHW', 'NDHWC', etc."
+                "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, 
height, and width"
+                "dimensions respectively. Default to be same as input layout.")
+        .def_ro("out_dtype", &Conv3DTransposeAttrs::out_dtype,
+                "Output data type, set to explicit type under mixed precision 
setting");
+  }
+  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Conv3DTransposeAttrs", 
Conv3DTransposeAttrs,
+                                    BaseAttrsNode);
+};  // struct Conv3DTransposeAttrs
+
 /*! \brief Attributes used in max_pool1d and avg_pool1d operator */
 struct Pool1DAttrs : public AttrsNodeReflAdapter<Pool1DAttrs> {
   ffi::Array<int64_t> pool_size;
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index e56f975c62..74c8bfe690 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -1364,7 +1364,9 @@ class ConvTranspose(OnnxOpConverter):
             data_layout = "NCHW"
             kernel_layout = "IOHW"
         elif ndim == 5:
-            raise NotImplementedError("Relax ConvTranspose3d not supported 
yet")
+            op = relax.op.nn.conv3d_transpose
+            data_layout = "NCDHW"
+            kernel_layout = "IODHW"
         else:
             raise NotImplementedError("Ndim > 5 not supported for 
convolution.")
 
diff --git a/python/tvm/relax/op/nn/__init__.py 
b/python/tvm/relax/op/nn/__init__.py
index 00481245d0..57128282b9 100644
--- a/python/tvm/relax/op/nn/__init__.py
+++ b/python/tvm/relax/op/nn/__init__.py
@@ -34,6 +34,7 @@ from .nn import (
     conv2d,
     conv2d_transpose,
     conv3d,
+    conv3d_transpose,
     cross_entropy_with_logits,
     dropout,
     gelu,
diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
index e30ba550c7..c31a974402 100644
--- a/python/tvm/relax/op/nn/nn.py
+++ b/python/tvm/relax/op/nn/nn.py
@@ -293,6 +293,10 @@ def conv3d(
     out_dtype : Optional[Union[str, DataType]]
         Specifies the output data type for mixed precision conv2d.
 
+    See Also
+    --------
+    conv3d_transpose : Transposed 3D convolution; paired layouts default to 
``NCDHW`` / ``IODHW``.
+
     Returns
     -------
     result : relax.Expr
@@ -512,6 +516,108 @@ def conv2d_transpose(
     )
 
 
+def conv3d_transpose(
+    data: Expr,
+    weight: Expr,
+    strides: int | tuple[int, int, int] = (1, 1, 1),
+    padding: int | tuple[int, ...] = (0, 0, 0),
+    output_padding: int | tuple[int, int, int] = (0, 0, 0),
+    dilation: int | tuple[int, int, int] = (1, 1, 1),
+    groups: int = 1,
+    data_layout: str = "NCDHW",
+    kernel_layout: str = "IODHW",
+    out_layout: str | None = None,
+    out_dtype: str | DataType | None = None,
+) -> Expr:
+    r"""Three dimensional transposed convolution operator.
+
+    This operator is intended to be the gradient operator of conv3d. That 
means, if
+
+    `out = conv3d(data, weight, strides, padding, dilation)`,
+
+    The gradient w.r.t. data can be calculated as follows:
+
+    `data_grad = conv3d_transpose(out_grad, weight, strides, padding, 
output_padding, dilation)`,
+
+    where `output_padding` is a parameter used to determine the output shape.
+
+    In the default case, where `data_layout == "NCDHW"` and `kernel_layout == 
"IODHW"`, `data` has
+    shape `(N, in_channel, in_d, in_h, in_w)`, `weight` has shape
+    `(in_channel, out_channel, weight_d, weight_h, weight_w)`, with 
`in_channel % groups == 0`.
+    The output shape is `(N, out_channel * groups, out_d, out_h, out_w)`.
+
+    Parameters
+    ----------
+    data : relax.Expr
+        The input data to the operator.
+
+    weight : relax.Expr
+        The weight expressions.
+
+    strides : Union[int, Tuple[int, int, int]]
+        The strides of convolution. It is required to have length either 1 or 
3.
+
+    padding : Union[int, Tuple[int, ...]]
+        The padding of convolution on both sides of inputs before convolution.
+        It is required to have length either 1, 3 or 6.
+
+    output_padding : Union[int, Tuple[int, ...]], optional
+        Used to disambiguate the output shape.
+
+    dilation : Union[int, Tuple[int, int, int]]
+        Specifies the dilation rate to be used for dilated convolution.
+        It is required to have length either 1 or 3.
+
+    groups : int
+        Number of groups to split the input into for grouped convolution.
+        The number of input and output channels should be divisible by the 
number of groups.
+
+    data_layout : str
+        Layout of the input.
+
+    kernel_layout : str
+        Layout of the weight.
+
+    out_layout : Optional[str]
+        Layout of the output. If not specified, it is the same as data_layout
+
+    out_dtype : Optional[Union[str, DataType]]
+        Specifies the output data type for mixed precision conv3d_transpose.
+
+    See Also
+    --------
+    conv3d : Forward 3D convolution (default ``OIDHW`` weights vs. ``IODHW`` 
here).
+    conv2d_transpose : 2D analogue; legalization supports the same TOPI subset 
(canonical layout, dilation 1).
+
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+    """
+    if isinstance(strides, int):
+        strides = (strides, strides, strides)
+    if isinstance(dilation, int):
+        dilation = (dilation, dilation, dilation)
+    if isinstance(padding, int):
+        padding = (padding, padding, padding, padding, padding, padding)
+    if isinstance(output_padding, int):
+        output_padding = (output_padding, output_padding, output_padding)
+
+    return _ffi_api.conv3d_transpose(  # type: ignore
+        data,
+        weight,
+        strides,
+        padding,
+        output_padding,
+        dilation,
+        groups,
+        data_layout,
+        kernel_layout,
+        out_layout,
+        out_dtype,
+    )
+
+
 def pad(
     data: Expr,
     pad_width: list[int] | tuple[int, ...],
diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py
index d85c439d3a..7602af7e58 100644
--- a/python/tvm/relax/op/op_attrs.py
+++ b/python/tvm/relax/op/op_attrs.py
@@ -71,6 +71,11 @@ class Conv2DTransposeAttrs(Attrs):
     """Attributes for nn.conv2d_transpose"""
 
 
+@tvm_ffi.register_object("relax.attrs.Conv3DTransposeAttrs")
+class Conv3DTransposeAttrs(Attrs):
+    """Attributes for nn.conv3d_transpose"""
+
+
 @tvm_ffi.register_object("relax.attrs.Pool2DAttrs")
 class Pool2DAttrs(Attrs):
     """Attributes for nn.max_pool2d"""
diff --git a/python/tvm/relax/transform/legalize_ops/nn.py 
b/python/tvm/relax/transform/legalize_ops/nn.py
index 4234aa831e..157ec8b148 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -200,7 +200,7 @@ def _nn_conv2d_transpose(bb: BlockBuilder, call: Call) -> 
Expr:
         )
         return call
     dilation = call.attrs.dilation
-    if len(dilation) != 2 or dilation[0] != 1 or dilation[1] != 1:
+    if len(dilation) != 2 or any(d != 1 for d in dilation):
         logging.info(
             "TOPI conv2d_transpose does not support dilations other than 1, "
             "and thus cannot be legalized by TOPI"
@@ -220,6 +220,42 @@ def _nn_conv2d_transpose(bb: BlockBuilder, call: Call) -> 
Expr:
     )
 
 
+@register_legalize("relax.nn.conv3d_transpose")
+def _nn_conv3d_transpose(bb: BlockBuilder, call: Call) -> Expr:
+    # Keep policy in sync with _nn_conv2d_transpose: only lower when TOPI 
supports the layout/dilation.
+    if call.attrs.out_layout != call.attrs.data_layout:
+        logging.info(
+            "TOPI conv3d_transpose does not support different input-output "
+            "layouts, and thus cannot be legalized by TOPI"
+        )
+        return call
+    if call.attrs.data_layout != "NCDHW" or call.attrs.kernel_layout != 
"IODHW":
+        logging.info(
+            "TOPI conv3d_transpose does not support input layout other than 
NCDHW, "
+            "and kernel layout other than IODHW, so cannot be legalized by 
TOPI"
+        )
+        return call
+    dilation = call.attrs.dilation
+    if len(dilation) != 3 or any(d != 1 for d in dilation):
+        logging.info(
+            "TOPI conv3d_transpose does not support dilations other than 1, "
+            "and thus cannot be legalized by TOPI"
+        )
+        return call
+
+    return bb.call_te(
+        topi.nn.group_conv3d_transpose_ncdhw,
+        call.args[0],
+        call.args[1],
+        strides=call.attrs.strides,
+        padding=call.attrs.padding,
+        out_dtype=call.struct_info.dtype,
+        output_padding=call.attrs.output_padding,
+        groups=call.attrs.groups,
+        primfunc_name_hint="conv3d_transpose",
+    )
+
+
 @register_legalize("relax.nn.pad")
 def _nn_pad(bb: BlockBuilder, call: Call) -> Expr:
     pad_mode = call.attrs.pad_mode
diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc
index 81c9fdf313..dcda15d39d 100644
--- a/src/relax/op/nn/convolution.cc
+++ b/src/relax/op/nn/convolution.cc
@@ -37,6 +37,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
   Conv3DAttrs::RegisterReflection();
   Conv1DTransposeAttrs::RegisterReflection();
   Conv2DTransposeAttrs::RegisterReflection();
+  Conv3DTransposeAttrs::RegisterReflection();
 }
 
 /* relax.nn.conv1d */
@@ -887,10 +888,6 @@ StructInfo InferStructInfoConv2dTranspose(const Call& 
call, const BlockBuilder&
                      << "Conv2dTranspose expects the output padding less than 
the strides, but the "
                         "output padding is"
                      << attrs->output_padding << " while the strides are" << 
attrs->strides);
-  } else if (!(attrs->output_padding[0] < attrs->strides[0] &&
-               attrs->output_padding[1] < attrs->strides[1])) {
-    // Todo(relax-team): Trust the input padding at this moment, and revisit
-    // this condition with runtime shape check
   }
 
   PrimExpr input_h = data_NCHW_shape[2];
@@ -1009,5 +1006,247 @@ TVM_REGISTER_OP("relax.nn.conv2d_transpose")
     .set_attr<FInferMixedPrecision>("FInferMixedPrecision", 
InferMixedPrecisionConv2dTranspose)
     .set_attr<Bool>("FPurity", Bool(true));
 
+/* relax.nn.conv3d_transpose */
+
+Expr conv3d_transpose(Expr data, Expr weight, ffi::Array<int64_t> strides,
+                      ffi::Array<int64_t> padding, ffi::Array<int64_t> 
output_padding,
+                      ffi::Array<int64_t> dilation, int groups, ffi::String 
data_layout,
+                      ffi::String kernel_layout, ffi::Optional<ffi::String> 
out_layout,
+                      ffi::Optional<DataType> out_dtype) {
+  padding = GetCompletePadding3D(std::move(padding));
+  if (output_padding.size() == 1) {
+    output_padding.push_back(output_padding[0]);
+    output_padding.push_back(output_padding[0]);
+  }
+  if (strides.size() == 1) {
+    strides.push_back(strides[0]);
+    strides.push_back(strides[0]);
+  }
+  if (dilation.size() == 1) {
+    dilation.push_back(dilation[0]);
+    dilation.push_back(dilation[0]);
+  }
+
+  TVM_FFI_ICHECK_GT(groups, 0)
+      << "The number of groups in convolution is expected to be positive. 
However, "
+         "the given number of groups is "
+      << groups;
+  TVM_FFI_ICHECK_EQ(output_padding.size(), 3)
+      << "The input output_padding length is expected to be 3. "
+         "However, the given output_padding is "
+      << output_padding;
+  TVM_FFI_ICHECK_EQ(strides.size(), 3)
+      << "The input strides length is expected to be 3. However, the given 
strides is " << strides;
+  TVM_FFI_ICHECK_EQ(dilation.size(), 3)
+      << "The input dilation length is expected to be 3. However, the given 
dilation is "
+      << dilation;
+
+  auto attrs = ffi::make_object<Conv3DTransposeAttrs>();
+  attrs->strides = std::move(strides);
+  attrs->padding = std::move(padding);
+  attrs->output_padding = std::move(output_padding);
+  attrs->dilation = std::move(dilation);
+  attrs->groups = groups;
+  attrs->data_layout = data_layout;
+  attrs->kernel_layout = std::move(kernel_layout);
+  attrs->out_layout = out_layout.value_or(data_layout);
+  attrs->out_dtype = std::move(out_dtype.value_or(DataType::Void()));
+  const Op& op = Op::Get("relax.nn.conv3d_transpose");
+  return Call(op, {data, weight}, Attrs(attrs), {});
+}
+
+TVM_FFI_STATIC_INIT_BLOCK() {
+  namespace refl = tvm::ffi::reflection;
+  refl::GlobalDef().def("relax.op.nn.conv3d_transpose", conv3d_transpose);
+}
+
+StructInfo InferStructInfoConv3dTranspose(const Call& call, const 
BlockBuilder& ctx) {
+  ffi::Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, 
ctx);
+  TensorStructInfo data_sinfo = input_sinfo[0];
+  TensorStructInfo weight_sinfo = input_sinfo[1];
+
+  const auto* attrs = call->attrs.as<Conv3DTransposeAttrs>();
+  auto [data_layout, data2NCDHW] = CheckTensorLayout(call, ctx, 
attrs->data_layout,  //
+                                                     /*tgt_layout=*/"NCDHW",   
      //
+                                                     /*tensor_name=*/"data");
+  auto [weight_layout, weight2IODHW] = CheckTensorLayout(call, ctx, 
attrs->kernel_layout,  //
+                                                         
/*tgt_layout=*/"IODHW",           //
+                                                         
/*tensor_name=*/"kernel");
+  auto [out_layout, out2NCDHW] = CheckTensorLayout(call, ctx, 
attrs->out_layout,  //
+                                                  /*tgt_layout=*/"NCDHW",      
  //
+                                                  /*tensor_name=*/"output");
+
+  ffi::Optional<ShapeExpr> data_shape =
+      CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout);
+  ffi::Optional<ShapeExpr> weight_shape =
+      CheckNdimPerLayoutAndGetShape(call, ctx, weight_sinfo, weight_layout);
+
+  DataType out_dtype = attrs->out_dtype.is_void()
+                           ? InferBinaryArithOpOutDtype(call, ctx, data_sinfo, 
weight_sinfo)
+                           : attrs->out_dtype;
+  ffi::Optional<VDevice> vdevice =
+      InferBinaryArithOpOutVDevice(call, ctx, data_sinfo, weight_sinfo);
+  if (!data_shape.defined() || !weight_shape.defined()) {
+    return TensorStructInfo(out_dtype, out_layout.ndim(), vdevice);
+  }
+
+  ffi::Array<PrimExpr> data_NCDHW_shape = 
data2NCDHW.ForwardShape(data_shape.value()->values);
+  ffi::Array<PrimExpr> weight_IODHW_shape = 
weight2IODHW.ForwardShape(weight_shape.value()->values);
+
+  arith::Analyzer* analyzer = ctx->GetAnalyzer();
+  PrimExpr input_channel_data = data_NCDHW_shape[1];
+  PrimExpr input_channel_kernel = weight_IODHW_shape[0];
+  if (analyzer->CanProve(input_channel_data != input_channel_kernel)) {
+    ctx->ReportFatal(
+        Diagnostic::Error(call)
+        << "Conv3dTranspose expects the channel size of the data should equal 
to the input channel "
+           "size of the weight. However, the data channel size is "
+        << input_channel_data << " while the weight input channel size is "
+        << input_channel_kernel);
+  } else if (!analyzer->CanProveEqual(input_channel_data, 
input_channel_kernel)) {
+    // Todo(relax-team): Trust the input shape at this moment, and revisit
+    // this condition with runtime shape check
+  }
+  if (analyzer->CanProve(floormod(input_channel_kernel, attrs->groups) != 0)) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "Conv3dTranspose expects the number of input channels 
to be divisible by "
+                        "the number of groups. However, the number of input 
channels is "
+                     << input_channel_kernel << " while the number of groups 
is " << attrs->groups);
+  } else if (!analyzer->CanProveEqual(floormod(input_channel_kernel, 
attrs->groups), 0)) {
+    // Todo(relax-team): Trust the input shape at this moment, and revisit
+    // this condition with runtime shape check
+  }
+  if (attrs->output_padding[0] >= attrs->strides[0] ||
+      attrs->output_padding[1] >= attrs->strides[1] ||
+      attrs->output_padding[2] >= attrs->strides[2]) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "Conv3dTranspose expects the output padding less than 
the strides, but the "
+                        "output padding is"
+                     << attrs->output_padding << " while the strides are" << 
attrs->strides);
+  }
+
+  PrimExpr input_d = data_NCDHW_shape[2];
+  PrimExpr input_h = data_NCDHW_shape[3];
+  PrimExpr input_w = data_NCDHW_shape[4];
+  PrimExpr kernel_d = weight_IODHW_shape[2];
+  PrimExpr kernel_h = weight_IODHW_shape[3];
+  PrimExpr kernel_w = weight_IODHW_shape[4];
+  PrimExpr padding_d = Integer(attrs->padding[0]) + Integer(attrs->padding[3]);
+  PrimExpr padding_h = Integer(attrs->padding[1]) + Integer(attrs->padding[4]);
+  PrimExpr padding_w = Integer(attrs->padding[2]) + Integer(attrs->padding[5]);
+
+  std::vector<PrimExpr> out_NCDHW_shape;
+  out_NCDHW_shape.resize(5);
+  out_NCDHW_shape[0] = data_NCDHW_shape[0];
+  out_NCDHW_shape[1] = weight_IODHW_shape[1] * attrs->groups;
+
+  PrimExpr out_d = (input_d - 1) * Integer(attrs->strides[0]) - padding_d +
+                   Integer(attrs->dilation[0]) * (kernel_d - 1) +
+                   Integer(attrs->output_padding[0]) + 1;
+  PrimExpr out_h = (input_h - 1) * Integer(attrs->strides[1]) - padding_h +
+                   Integer(attrs->dilation[1]) * (kernel_h - 1) +
+                   Integer(attrs->output_padding[1]) + 1;
+  PrimExpr out_w = (input_w - 1) * Integer(attrs->strides[2]) - padding_w +
+                   Integer(attrs->dilation[2]) * (kernel_w - 1) +
+                   Integer(attrs->output_padding[2]) + 1;
+  out_NCDHW_shape[2] = analyzer->Simplify(out_d);
+  out_NCDHW_shape[3] = analyzer->Simplify(out_h);
+  out_NCDHW_shape[4] = analyzer->Simplify(out_w);
+
+  ffi::Array<PrimExpr> out_shape = out2NCDHW.BackwardShape(out_NCDHW_shape);
+  return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice);
+}
+
+InferLayoutOutput InferLayoutConv3dTranspose(
+    const Call& call, const ffi::Map<ffi::String, ffi::Array<ffi::String>>& 
desired_layouts,
+    const VarLayoutMap& var_layout_map) {
+  const auto* attrs = call->attrs.as<Conv3DTransposeAttrs>();
+  LayoutDecision data_layout = GetLayoutDecision(var_layout_map, 
call->args[0]);
+  LayoutDecision weight_layout = GetLayoutDecision(var_layout_map, 
call->args[1]);
+  LayoutDecision output_layout;
+  ObjectPtr<Conv3DTransposeAttrs> new_attrs = 
ffi::make_object<Conv3DTransposeAttrs>(*attrs);
+
+  auto it = desired_layouts.find("relax.nn.conv3d_transpose");
+  if (it != desired_layouts.end()) {
+    Layout desired_data_layout = (*it).second[0];
+    Layout desired_weight_layout = (*it).second[1];
+    Layout desired_output_layout = (*it).second.size() == 3 ? (*it).second[2] 
: (*it).second[0];
+
+    Layout input_layout = Layout(attrs->data_layout);
+    Layout kernel_layout = Layout(attrs->kernel_layout);
+    Layout out_layout = Layout(attrs->out_layout);
+
+    if (desired_data_layout.ndim_primal() == input_layout.ndim() &&
+        desired_weight_layout.ndim_primal() == kernel_layout.ndim() &&
+        desired_output_layout.ndim_primal() == out_layout.ndim()) {
+      data_layout = TransposeLike(InitialLayout(5), attrs->data_layout, 
desired_data_layout);
+      weight_layout = TransposeLike(InitialLayout(5), attrs->kernel_layout, 
desired_weight_layout);
+      output_layout = TransposeLike(InitialLayout(5), attrs->out_layout, 
desired_output_layout);
+      new_attrs->data_layout = (*it).second[0];
+      new_attrs->kernel_layout = (*it).second[1];
+      new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] : 
(*it).second[0];
+      return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, 
Attrs(new_attrs));
+    } else {
+      auto data_si = GetStructInfo(call->args[0]);
+      auto kernel_si = GetStructInfo(call->args[1]);
+      TensorStructInfo data_sinfo = data_si.as<TensorStructInfo>().value();
+      TensorStructInfo kernel_sinfo = kernel_si.as<TensorStructInfo>().value();
+      ffi::Optional<ShapeExpr> data_shape =
+          ffi::GetRef<ShapeExpr>(data_sinfo->shape.as<ShapeExprNode>());
+      ffi::Optional<ShapeExpr> kernel_shape =
+          ffi::GetRef<ShapeExpr>(kernel_sinfo->shape.as<ShapeExprNode>());
+
+      bool can_data_proved =
+          CanProveLayoutTransform(input_layout, desired_data_layout, 
data_shape.value()->values);
+      bool can_kernel_proved = CanProveLayoutTransform(kernel_layout, 
desired_weight_layout,
+                                                     
kernel_shape.value()->values);
+
+      if (can_data_proved && can_kernel_proved) {
+        data_layout = TransposeSubLayoutLike(InitialLayout(5), input_layout, 
desired_data_layout);
+        weight_layout =
+            TransposeSubLayoutLike(InitialLayout(5), kernel_layout, 
desired_weight_layout);
+        output_layout = TransposeSubLayoutLike(InitialLayout(5), out_layout, 
desired_output_layout);
+        new_attrs->data_layout = (*it).second[0];
+        new_attrs->kernel_layout = (*it).second[1];
+        new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] : 
(*it).second[0];
+        return InferLayoutOutput({data_layout, weight_layout}, 
{output_layout}, Attrs(new_attrs));
+      } else {
+        data_layout = LayoutDecision(InitialLayout(5));
+        weight_layout = LayoutDecision(InitialLayout(5));
+      }
+    }
+  }
+
+  output_layout = data_layout;
+  new_attrs->data_layout =
+      TransposeLike(attrs->data_layout, InitialLayout(5), 
data_layout->layout).name();
+  new_attrs->kernel_layout =
+      TransposeLike(attrs->kernel_layout, InitialLayout(5), 
weight_layout->layout).name();
+  new_attrs->out_layout =
+      TransposeLike(attrs->out_layout, InitialLayout(5), 
output_layout->layout).name();
+  return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, 
Attrs(new_attrs));
+}
+
+Call InferMixedPrecisionConv3dTranspose(const Call& call, const DataType& 
out_dtype) {
+  const auto* conv3d_transpose_attrs = call->attrs.as<Conv3DTransposeAttrs>();
+  return Downcast<Call>(
+      conv3d_transpose(call->args[0], call->args[1], 
conv3d_transpose_attrs->strides,
+                       conv3d_transpose_attrs->padding, 
conv3d_transpose_attrs->output_padding,
+                       conv3d_transpose_attrs->dilation, 
conv3d_transpose_attrs->groups,
+                       conv3d_transpose_attrs->data_layout, 
conv3d_transpose_attrs->kernel_layout,
+                       conv3d_transpose_attrs->out_layout, out_dtype));
+}
+
+TVM_REGISTER_OP("relax.nn.conv3d_transpose")
+    .set_num_inputs(2)
+    .add_argument("data", "Tensor", "The input tensor.")
+    .add_argument("weight", "Tensor", "The weight tensor.")
+    .set_attrs_type<Conv3DTransposeAttrs>()
+    .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoConv3dTranspose)
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", 
InferLayoutConv3dTranspose)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kAlways)
+    .set_attr<FInferMixedPrecision>("FInferMixedPrecision", 
InferMixedPrecisionConv3dTranspose)
+    .set_attr<Bool>("FPurity", Bool(true));
+
 }  // namespace relax
 }  // namespace tvm
diff --git a/src/relax/op/nn/convolution.h b/src/relax/op/nn/convolution.h
index a5704d3f70..b08eb8a83f 100644
--- a/src/relax/op/nn/convolution.h
+++ b/src/relax/op/nn/convolution.h
@@ -95,6 +95,18 @@ Expr conv2d_transpose(Expr data, Expr weight, 
ffi::Array<int64_t> strides,
                       ffi::String kernel_layout, ffi::Optional<ffi::String> 
out_layout,
                       ffi::Optional<DataType> out_dtype);
 
+/*!
+ * \brief Three dimensional transposed convolution operator.
+ *
+ * This operator is intended to be the backward operator of conv3d. It can be 
used to calculate the
+ * gradient of the result of conv3d w.r.t. the input of conv3d.
+ */
+Expr conv3d_transpose(Expr data, Expr weight, ffi::Array<int64_t> strides,
+                      ffi::Array<int64_t> padding, ffi::Array<int64_t> 
output_padding,
+                      ffi::Array<int64_t> dilation, int groups, ffi::String 
data_layout,
+                      ffi::String kernel_layout, ffi::Optional<ffi::String> 
out_layout,
+                      ffi::Optional<DataType> out_dtype);
+
 }  // namespace relax
 }  // namespace tvm
 
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index e2067bad23..c6b4df6aaa 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -1223,7 +1223,11 @@ def test_conv_transpose(stride: int, dilation: int, pad: 
int, bias: bool, output
     def _verify_conv_transpose(input_shape, weight_shape):
         nd = len(weight_shape) - 2
         output_shape = [input_shape[0], weight_shape[0]] + [
-            (input_shape[i] - 1) * stride - 2 * pad + dilation * 
(weight_shape[i] - 1) + 1
+            (input_shape[i] - 1) * stride
+            - 2 * pad
+            + dilation * (weight_shape[i] - 1)
+            + output_pad
+            + 1
             for i in range(2, len(input_shape))
         ]
         bias_shape = [output_shape[1]]
@@ -1257,6 +1261,9 @@ def test_conv_transpose(stride: int, dilation: int, pad: 
int, bias: bool, output
     # ConvTranspose2D
     _verify_conv_transpose([3, 4, 32, 32], [4, 4, 3, 3])
     _verify_conv_transpose([3, 4, 32, 32], [4, 2, 3, 3])  # group=2
+    # ConvTranspose3D
+    _verify_conv_transpose([3, 4, 12, 12, 12], [4, 4, 3, 3, 3])
+    _verify_conv_transpose([3, 4, 12, 12, 12], [4, 2, 3, 3, 3])  # group=2
 
 
 def test_pow():
diff --git a/tests/python/relax/test_op_nn_convolution.py 
b/tests/python/relax/test_op_nn_convolution.py
index 07b9469abf..bf0abb09b0 100644
--- a/tests/python/relax/test_op_nn_convolution.py
+++ b/tests/python/relax/test_op_nn_convolution.py
@@ -41,6 +41,8 @@ def test_conv3d_op_correctness():
     x = relax.Var("x", R.Tensor((2, 3, 28, 28, 28), "float32"))
     w = relax.Var("w", R.Tensor((4, 3, 3, 3, 3), "float32"))
     assert relax.op.nn.conv3d(x, w).op == Op.get("relax.nn.conv3d")
+    wt = relax.Var("wt", R.Tensor((3, 4, 3, 3, 3), "float32"))
+    assert relax.op.nn.conv3d_transpose(x, wt).op == 
Op.get("relax.nn.conv3d_transpose")
 
 
 def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: 
relax.StructInfo):
@@ -1609,6 +1611,78 @@ def 
test_conv2d_transpose_infer_struct_info_mixed_precision():
     )
 
 
+def test_conv3d_transpose_infer_struct_info():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 28, 28, 28), "float32"))
+    w0 = relax.Var("w", R.Tensor((3, 4, 3, 3, 3), "float32"))
+    _check_inference(
+        bb,
+        relax.op.nn.conv3d_transpose(x0, w0),
+        relax.TensorStructInfo((2, 4, 30, 30, 30), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv3d_transpose(x0, w0, padding=1),
+        relax.TensorStructInfo((2, 4, 28, 28, 28), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv3d_transpose(x0, w0, strides=2, output_padding=1),
+        relax.TensorStructInfo((2, 4, 58, 58, 58), "float32"),
+    )
+
+
+def test_conv3d_transpose_infer_struct_info_ndhwc_out_layout():
+    bb = relax.BlockBuilder()
+    x_ndhwc = relax.Var("x_nd", R.Tensor((2, 28, 28, 28, 3), "float32"))
+    x_ncdhw = relax.Var("x_nc", R.Tensor((2, 3, 28, 28, 28), "float32"))
+    w0 = relax.Var("w", R.Tensor((3, 4, 3, 3, 3), "float32"))
+    _check_inference(
+        bb,
+        relax.op.nn.conv3d_transpose(x_ndhwc, w0, data_layout="NDHWC"),
+        relax.TensorStructInfo((2, 30, 30, 30, 4), "float32"),
+    )
+    # Default data_layout is NCDHW; use NCDHW-shaped input when only 
out_layout is NDHWC.
+    _check_inference(
+        bb,
+        relax.op.nn.conv3d_transpose(x_ncdhw, w0, out_layout="NDHWC"),
+        relax.TensorStructInfo((2, 30, 30, 30, 4), "float32"),
+    )
+
+
+def test_conv3d_transpose_infer_struct_info_groups():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 128, 28, 28, 28), "float32"))
+    w0 = relax.Var("w", R.Tensor((128, 16, 3, 3, 3), "float32"))
+    _check_inference(
+        bb,
+        relax.op.nn.conv3d_transpose(x0, w0, groups=8),
+        relax.TensorStructInfo((2, 128, 30, 30, 30), "float32"),
+    )
+
+
+def test_conv3d_transpose_wrong_output_padding():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 28, 28, 28), "float32"))
+    w0 = relax.Var("w", R.Tensor((3, 4, 3, 3, 3), "float32"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.conv3d_transpose(x0, w0, strides=2, 
output_padding=2))
+    with pytest.raises(TVMError):
+        bb.normalize(
+            relax.op.nn.conv3d_transpose(
+                x0, w0, strides=(2, 2, 2), output_padding=(2, 2, 2)
+            )
+        )
+
+
+def test_conv3d_transpose_unequal_input_channel():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 28, 28, 28), "float32"))
+    w0 = relax.Var("w", R.Tensor((4, 4, 3, 3, 3), "float32"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.conv3d_transpose(x0, w0))
+
+
 def test_conv3d_infer_struct_info():
     bb = relax.BlockBuilder()
     vdev0 = VDevice("llvm")
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py 
b/tests/python/relax/test_transform_legalize_ops_nn.py
index 53a4fa7b1c..603da2b48c 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -489,6 +489,114 @@ def test_conv2d_transpose():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_conv3d_transpose():
+    # fmt: off
+    @tvm.script.ir_module
+    class Conv3dTranspose:
+        @R.function
+        def main(x: R.Tensor((2, 3, 4, 4, 4), "float32"), w: R.Tensor((3, 4, 
3, 3, 3), "float32")):
+            gv = R.nn.conv3d_transpose(x, w)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((2, 3, 4, 4, 4), dtype="float32"), w: 
R.Tensor((3, 4, 3, 3, 3), dtype="float32")) -> R.Tensor((2, 4, 6, 6, 6), 
dtype="float32"):
+            gv = R.call_tir(Expected.conv3d_transpose, (x, w), 
out_sinfo=R.Tensor((2, 4, 6, 6, 6), dtype="float32"))
+            return gv
+
+        @T.prim_func(private=True)
+        def conv3d_transpose(x: T.Buffer((T.int64(2), T.int64(3), T.int64(4), 
T.int64(4), T.int64(4)), "float32"), w: T.Buffer((T.int64(3), T.int64(4), 
T.int64(3), T.int64(3), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), 
T.int64(4), T.int64(6), T.int64(6), T.int64(6)), "float32")):
+            T.func_attr({"tirx.noalias": True})
+            data_dilate = T.sblock_alloc_buffer((T.int64(2), T.int64(3), 
T.int64(4), T.int64(4), T.int64(4)))
+            data_pad = T.sblock_alloc_buffer((T.int64(2), T.int64(3), 
T.int64(8), T.int64(8), T.int64(8)))
+            kernel_transform = T.sblock_alloc_buffer((T.int64(4), T.int64(3), 
T.int64(3), T.int64(3), T.int64(3)))
+            for i0, i1, i2, i3, i4 in T.grid(T.int64(2), T.int64(3), 
T.int64(4), T.int64(4), T.int64(4)):
+                with T.sblock("data_dilate"):
+                    v_i0, v_i1, v_i2, v_i3, v_i4 = T.axis.remap("SSSSS", [i0, 
i1, i2, i3, i4])
+                    T.reads(x[v_i0, v_i1, v_i2, v_i3, v_i4])
+                    T.writes(data_dilate[v_i0, v_i1, v_i2, v_i3, v_i4])
+                    data_dilate[v_i0, v_i1, v_i2, v_i3, v_i4] = x[v_i0, v_i1, 
v_i2, v_i3, v_i4]
+            for i0, i1, i2, i3, i4 in T.grid(T.int64(2), T.int64(3), 
T.int64(8), T.int64(8), T.int64(8)):
+                with T.sblock("data_pad"):
+                    v_i0, v_i1, v_i2, v_i3, v_i4 = T.axis.remap("SSSSS", [i0, 
i1, i2, i3, i4])
+                    T.reads(data_dilate[v_i0, v_i1, v_i2 - T.int64(2), v_i3 - 
T.int64(2), v_i4 - T.int64(2)])
+                    T.writes(data_pad[v_i0, v_i1, v_i2, v_i3, v_i4])
+                    data_pad[v_i0, v_i1, v_i2, v_i3, v_i4] = 
T.if_then_else(T.int64(2) <= v_i2 and v_i2 < T.int64(6) and T.int64(2) <= v_i3 
and v_i3 < T.int64(6) and T.int64(2) <= v_i4 and v_i4 < T.int64(6), 
data_dilate[v_i0, v_i1, v_i2 - T.int64(2), v_i3 - T.int64(2), v_i4 - 
T.int64(2)], T.float32(0.0))
+            for o, i, d, h, w_1 in T.grid(T.int64(4), T.int64(3), T.int64(3), 
T.int64(3), T.int64(3)):
+                with T.sblock("kernel_transform"):
+                    v_o, v_i, v_d, v_h, v_w = T.axis.remap("SSSSS", [o, i, d, 
h, w_1])
+                    T.reads(w[v_i, v_o, T.int64(2) - v_d, T.int64(2) - v_h, 
T.int64(2) - v_w])
+                    T.writes(kernel_transform[v_o, v_i, v_d, v_h, v_w])
+                    kernel_transform[v_o, v_i, v_d, v_h, v_w] = w[v_i, v_o, 
T.int64(2) - v_d, T.int64(2) - v_h, T.int64(2) - v_w]
+            for b, c, d, h, w_1, dc, dd, dh, dw in T.grid(T.int64(2), 
T.int64(4), T.int64(6), T.int64(6), T.int64(6), T.int64(3), T.int64(3), 
T.int64(3), T.int64(3)):
+                with T.sblock("compute"):
+                    v_b, v_c, v_d, v_h, v_w, v_dc, v_dd, v_dh, v_dw = 
T.axis.remap("SSSSSRRRR", [b, c, d, h, w_1, dc, dd, dh, dw])
+                    T.reads(data_pad[v_b, v_dc, v_d + v_dd, v_h + v_dh, v_w + 
v_dw], kernel_transform[v_c, v_dc, v_dd, v_dh, v_dw])
+                    T.writes(compute[v_b, v_c, v_d, v_h, v_w])
+                    with T.init():
+                        compute[v_b, v_c, v_d, v_h, v_w] = T.float32(0.0)
+                    compute[v_b, v_c, v_d, v_h, v_w] = compute[v_b, v_c, v_d, 
v_h, v_w] + data_pad[v_b, v_dc, v_d + v_dd, v_h + v_dh, v_w + v_dw] * 
kernel_transform[v_c, v_dc, v_dd, v_dh, v_dw]
+    # fmt: on
+
+    mod = LegalizeOps()(Conv3dTranspose)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_conv3d_transpose_with_out_dtype():
+    # fmt: off
+    @tvm.script.ir_module
+    class Conv3dTranspose:
+        @R.function
+        def main(x: R.Tensor((2, 3, 4, 4, 4), "float32"), w: R.Tensor((3, 4, 
3, 3, 3), "float32")):
+            gv = R.nn.conv3d_transpose(x, w, out_dtype="float16")
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((2, 3, 4, 4, 4), dtype="float32"), w: 
R.Tensor((3, 4, 3, 3, 3), dtype="float32")) -> R.Tensor((2, 4, 6, 6, 6), 
dtype="float16"):
+            gv = R.call_tir(Expected.conv3d_transpose, (x, w), 
out_sinfo=R.Tensor((2, 4, 6, 6, 6), dtype="float16"))
+            return gv
+
+        @T.prim_func(private=True)
+        def conv3d_transpose(x: T.Buffer((T.int64(2), T.int64(3), T.int64(4), 
T.int64(4), T.int64(4)), "float32"), w: T.Buffer((T.int64(3), T.int64(4), 
T.int64(3), T.int64(3), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), 
T.int64(4), T.int64(6), T.int64(6), T.int64(6)), "float16")):
+            T.func_attr({"tirx.noalias": True})
+            data_dilate = T.sblock_alloc_buffer((T.int64(2), T.int64(3), 
T.int64(4), T.int64(4), T.int64(4)))
+            data_pad = T.sblock_alloc_buffer((T.int64(2), T.int64(3), 
T.int64(8), T.int64(8), T.int64(8)))
+            kernel_transform = T.sblock_alloc_buffer((T.int64(4), T.int64(3), 
T.int64(3), T.int64(3), T.int64(3)))
+            for i0, i1, i2, i3, i4 in T.grid(T.int64(2), T.int64(3), 
T.int64(4), T.int64(4), T.int64(4)):
+                with T.sblock("data_dilate"):
+                    v_i0, v_i1, v_i2, v_i3, v_i4 = T.axis.remap("SSSSS", [i0, 
i1, i2, i3, i4])
+                    T.reads(x[v_i0, v_i1, v_i2, v_i3, v_i4])
+                    T.writes(data_dilate[v_i0, v_i1, v_i2, v_i3, v_i4])
+                    data_dilate[v_i0, v_i1, v_i2, v_i3, v_i4] = x[v_i0, v_i1, 
v_i2, v_i3, v_i4]
+            for i0, i1, i2, i3, i4 in T.grid(T.int64(2), T.int64(3), 
T.int64(8), T.int64(8), T.int64(8)):
+                with T.sblock("data_pad"):
+                    v_i0, v_i1, v_i2, v_i3, v_i4 = T.axis.remap("SSSSS", [i0, 
i1, i2, i3, i4])
+                    T.reads(data_dilate[v_i0, v_i1, v_i2 - T.int64(2), v_i3 - 
T.int64(2), v_i4 - T.int64(2)])
+                    T.writes(data_pad[v_i0, v_i1, v_i2, v_i3, v_i4])
+                    data_pad[v_i0, v_i1, v_i2, v_i3, v_i4] = 
T.if_then_else(T.int64(2) <= v_i2 and v_i2 < T.int64(6) and T.int64(2) <= v_i3 
and v_i3 < T.int64(6) and T.int64(2) <= v_i4 and v_i4 < T.int64(6), 
data_dilate[v_i0, v_i1, v_i2 - T.int64(2), v_i3 - T.int64(2), v_i4 - 
T.int64(2)], T.float32(0.0))
+            for o, i, d, h, w_1 in T.grid(T.int64(4), T.int64(3), T.int64(3), 
T.int64(3), T.int64(3)):
+                with T.sblock("kernel_transform"):
+                    v_o, v_i, v_d, v_h, v_w = T.axis.remap("SSSSS", [o, i, d, 
h, w_1])
+                    T.reads(w[v_i, v_o, T.int64(2) - v_d, T.int64(2) - v_h, 
T.int64(2) - v_w])
+                    T.writes(kernel_transform[v_o, v_i, v_d, v_h, v_w])
+                    kernel_transform[v_o, v_i, v_d, v_h, v_w] = w[v_i, v_o, 
T.int64(2) - v_d, T.int64(2) - v_h, T.int64(2) - v_w]
+            for b, c, d, h, w_1, dc, dd, dh, dw in T.grid(T.int64(2), 
T.int64(4), T.int64(6), T.int64(6), T.int64(6), T.int64(3), T.int64(3), 
T.int64(3), T.int64(3)):
+                with T.sblock("compute"):
+                    v_b, v_c, v_d, v_h, v_w, v_dc, v_dd, v_dh, v_dw = 
T.axis.remap("SSSSSRRRR", [b, c, d, h, w_1, dc, dd, dh, dw])
+                    T.reads(data_pad[v_b, v_dc, v_d + v_dd, v_h + v_dh, v_w + 
v_dw], kernel_transform[v_c, v_dc, v_dd, v_dh, v_dw])
+                    T.writes(compute[v_b, v_c, v_d, v_h, v_w])
+                    with T.init():
+                        compute[v_b, v_c, v_d, v_h, v_w] = T.float16(0.0)
+                    compute[v_b, v_c, v_d, v_h, v_w] = compute[v_b, v_c, v_d, 
v_h, v_w] + T.Cast("float16", data_pad[v_b, v_dc, v_d + v_dd, v_h + v_dh, v_w + 
v_dw]) * T.Cast("float16", kernel_transform[v_c, v_dc, v_dd, v_dh, v_dw])
+    # fmt: on
+
+    mod = LegalizeOps()(Conv3dTranspose)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 def test_conv2d_transpose_with_out_dtype():
     # fmt: off
     @tvm.script.ir_module
diff --git a/tests/python/relax/test_tvmscript_parser_op_nn.py 
b/tests/python/relax/test_tvmscript_parser_op_nn.py
index b9aeb9efe5..88707ff98d 100644
--- a/tests/python/relax/test_tvmscript_parser_op_nn.py
+++ b/tests/python/relax/test_tvmscript_parser_op_nn.py
@@ -108,6 +108,44 @@ def test_conv2d_transpose():
     _check(foo, bb.get()["foo"])
 
 
+def test_conv3d():
+    @R.function
+    def foo(
+        x: R.Tensor((2, 3, 8, 8, 8), "float16"), w: R.Tensor((6, 3, 3, 3, 3), 
"float16")
+    ) -> R.Tensor((2, 6, 6, 6, 6), "float16"):
+        gv: R.Tensor((2, 6, 6, 6, 6), "float16") = R.nn.conv3d(x, w, 
out_dtype="float16")
+        return gv
+
+    x = relax.Var("x", R.Tensor([2, 3, 8, 8, 8], "float16"))
+    w = relax.Var("w", R.Tensor([6, 3, 3, 3, 3], "float16"))
+    bb = relax.BlockBuilder()
+    with bb.function("foo", [x, w]):
+        gv = bb.emit(relax.op.nn.conv3d(x, w, out_dtype="float16"))
+        bb.emit_func_output(gv)
+
+    _check(foo, bb.get()["foo"])
+
+
+def test_conv3d_transpose():
+    @R.function
+    def foo(
+        x: R.Tensor((2, 3, 8, 8, 8), "float16"), w: R.Tensor((3, 6, 3, 3, 3), 
"float16")
+    ) -> R.Tensor((2, 6, 10, 10, 10), "float16"):
+        gv: R.Tensor((2, 6, 10, 10, 10), "float16") = R.nn.conv3d_transpose(
+            x, w, out_dtype="float16"
+        )
+        return gv
+
+    x = relax.Var("x", R.Tensor([2, 3, 8, 8, 8], "float16"))
+    w = relax.Var("w", R.Tensor([3, 6, 3, 3, 3], "float16"))
+    bb = relax.BlockBuilder()
+    with bb.function("foo", [x, w]):
+        gv = bb.emit(relax.op.nn.conv3d_transpose(x, w, out_dtype="float16"))
+        bb.emit_func_output(gv)
+
+    _check(foo, bb.get()["foo"])
+
+
 def test_max_pool2d():
     @R.function
     def foo(x: R.Tensor((1, 1, 32, 32), dtype="float32")) -> R.Tensor(


Reply via email to