This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4de1f11344 [Relax] Add conv3d_transpose and ONNX ConvTranspose 3D
support (#18948)
4de1f11344 is described below
commit 4de1f11344608d2305891c7fc585bc4f089158eb
Author: Dayuxiaoshui <[email protected]>
AuthorDate: Sun Mar 29 17:42:35 2026 +0800
[Relax] Add conv3d_transpose and ONNX ConvTranspose 3D support (#18948)
Introduce relax.nn.conv3d_transpose (attrs, C++ inference/layout, Python
API) and lower it to TOPI group_conv3d_transpose_ncdhw when using
NCDHW/IODHW with dilation 1, matching the conv2d_transpose legalization
policy.
Wire the Relax ONNX frontend to emit conv3d_transpose for 5D inputs.
Extend tests for ONNX, struct info, LegalizeOps, and TVMScript
round-trip; fix ConvTranspose test output spatial size to include
output_padding.https://github.com/apache/tvm/issues/18945
---
include/tvm/relax/attrs/nn.h | 52 +++++
python/tvm/relax/frontend/onnx/onnx_frontend.py | 4 +-
python/tvm/relax/op/nn/__init__.py | 1 +
python/tvm/relax/op/nn/nn.py | 106 +++++++++
python/tvm/relax/op/op_attrs.py | 5 +
python/tvm/relax/transform/legalize_ops/nn.py | 38 +++-
src/relax/op/nn/convolution.cc | 247 ++++++++++++++++++++-
src/relax/op/nn/convolution.h | 12 +
tests/python/relax/test_frontend_onnx.py | 9 +-
tests/python/relax/test_op_nn_convolution.py | 74 ++++++
.../python/relax/test_transform_legalize_ops_nn.py | 108 +++++++++
tests/python/relax/test_tvmscript_parser_op_nn.py | 38 ++++
12 files changed, 687 insertions(+), 7 deletions(-)
diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h
index 2a2ac5fe07..5c1931c3ee 100644
--- a/include/tvm/relax/attrs/nn.h
+++ b/include/tvm/relax/attrs/nn.h
@@ -267,6 +267,58 @@ struct Conv2DTransposeAttrs : public
AttrsNodeReflAdapter<Conv2DTransposeAttrs>
BaseAttrsNode);
}; // struct Conv2DTransposeAttrs
+/*! \brief Attributes used in Conv3dTranspose operator */
+struct Conv3DTransposeAttrs : public
AttrsNodeReflAdapter<Conv3DTransposeAttrs> {
+ ffi::Array<int64_t> strides;
+ ffi::Array<int64_t> padding;
+ ffi::Array<int64_t> output_padding;
+ ffi::Array<int64_t> dilation;
+ int groups;
+ ffi::String data_layout;
+ ffi::String kernel_layout;
+ ffi::String out_layout;
+ DataType out_dtype;
+
+ static void RegisterReflection() {
+ namespace refl = tvm::ffi::reflection;
+ refl::ObjectDef<Conv3DTransposeAttrs>()
+ .def_ro("strides", &Conv3DTransposeAttrs::strides,
+ "Specifies the strides of the convolution.")
+ .def_ro("padding", &Conv3DTransposeAttrs::padding,
+ "If padding is non-zero, then the input is implicitly
zero-padded"
+ "Padding support both symmetric and asymmetric as"
+ "one int : same padding used on all sides"
+ "three int : back/bottom/right will use same padding as
front/top/left"
+ "six int : padding width in the order of (front, top, left,
back, bottom, right)")
+ .def_ro("output_padding", &Conv3DTransposeAttrs::output_padding,
+ "Used to disambiguate the output shape.")
+ .def_ro("dilation", &Conv3DTransposeAttrs::dilation,
+ "Specifies the dilation rate to use for dilated convolution.")
+ .def_ro("groups", &Conv3DTransposeAttrs::groups,
+ "Number of groups to split the input into for grouped
convolution. The number of "
+ "input and "
+ "output channels should be divisible by the number of groups.")
+ .def_ro("data_layout", &Conv3DTransposeAttrs::data_layout,
+ "Dimension ordering of input data. Can be 'NCDHW', 'NDHWC',
etc."
+ "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth,
height, and width"
+ "dimensions respectively. Convolution is applied on the 'D',
'H', and"
+ "'W' dimensions.")
+ .def_ro("kernel_layout", &Conv3DTransposeAttrs::kernel_layout,
+ "Dimension ordering of weight. Can be 'IODHW', etc."
+ "'I', 'O', 'D', 'H', 'W' stands for input_channel,
output_channel, depth, height, and "
+ "width"
+ "dimensions respectively.")
+ .def_ro("out_layout", &Conv3DTransposeAttrs::out_layout,
+ "Dimension ordering of output. Can be 'NCDHW', 'NDHWC', etc."
+ "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth,
height, and width"
+ "dimensions respectively. Default to be same as input layout.")
+ .def_ro("out_dtype", &Conv3DTransposeAttrs::out_dtype,
+ "Output data type, set to explicit type under mixed precision
setting");
+ }
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Conv3DTransposeAttrs",
Conv3DTransposeAttrs,
+ BaseAttrsNode);
+}; // struct Conv3DTransposeAttrs
+
/*! \brief Attributes used in max_pool1d and avg_pool1d operator */
struct Pool1DAttrs : public AttrsNodeReflAdapter<Pool1DAttrs> {
ffi::Array<int64_t> pool_size;
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index e56f975c62..74c8bfe690 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -1364,7 +1364,9 @@ class ConvTranspose(OnnxOpConverter):
data_layout = "NCHW"
kernel_layout = "IOHW"
elif ndim == 5:
- raise NotImplementedError("Relax ConvTranspose3d not supported
yet")
+ op = relax.op.nn.conv3d_transpose
+ data_layout = "NCDHW"
+ kernel_layout = "IODHW"
else:
raise NotImplementedError("Ndim > 5 not supported for
convolution.")
diff --git a/python/tvm/relax/op/nn/__init__.py
b/python/tvm/relax/op/nn/__init__.py
index 00481245d0..57128282b9 100644
--- a/python/tvm/relax/op/nn/__init__.py
+++ b/python/tvm/relax/op/nn/__init__.py
@@ -34,6 +34,7 @@ from .nn import (
conv2d,
conv2d_transpose,
conv3d,
+ conv3d_transpose,
cross_entropy_with_logits,
dropout,
gelu,
diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
index e30ba550c7..c31a974402 100644
--- a/python/tvm/relax/op/nn/nn.py
+++ b/python/tvm/relax/op/nn/nn.py
@@ -293,6 +293,10 @@ def conv3d(
out_dtype : Optional[Union[str, DataType]]
Specifies the output data type for mixed precision conv2d.
+ See Also
+ --------
+ conv3d_transpose : Transposed 3D convolution; paired layouts default to
``NCDHW`` / ``IODHW``.
+
Returns
-------
result : relax.Expr
@@ -512,6 +516,108 @@ def conv2d_transpose(
)
+def conv3d_transpose(
+ data: Expr,
+ weight: Expr,
+ strides: int | tuple[int, int, int] = (1, 1, 1),
+ padding: int | tuple[int, ...] = (0, 0, 0),
+ output_padding: int | tuple[int, int, int] = (0, 0, 0),
+ dilation: int | tuple[int, int, int] = (1, 1, 1),
+ groups: int = 1,
+ data_layout: str = "NCDHW",
+ kernel_layout: str = "IODHW",
+ out_layout: str | None = None,
+ out_dtype: str | DataType | None = None,
+) -> Expr:
+ r"""Three dimensional transposed convolution operator.
+
+ This operator is intended to be the gradient operator of conv3d. That
means, if
+
+ `out = conv3d(data, weight, strides, padding, dilation)`,
+
+ The gradient w.r.t. data can be calculated as follows:
+
+ `data_grad = conv3d_transpose(out_grad, weight, strides, padding,
output_padding, dilation)`,
+
+ where `output_padding` is a parameter used to determine the output shape.
+
+ In the default case, where `data_layout == "NCDHW"` and `kernel_layout ==
"IODHW"`, `data` has
+ shape `(N, in_channel, in_d, in_h, in_w)`, `weight` has shape
+ `(in_channel, out_channel, weight_d, weight_h, weight_w)`, with
`in_channel % groups == 0`.
+ The output shape is `(N, out_channel * groups, out_d, out_h, out_w)`.
+
+ Parameters
+ ----------
+ data : relax.Expr
+ The input data to the operator.
+
+ weight : relax.Expr
+ The weight expressions.
+
+ strides : Union[int, Tuple[int, int, int]]
+ The strides of convolution. It is required to have length either 1 or
3.
+
+ padding : Union[int, Tuple[int, ...]]
+ The padding of convolution on both sides of inputs before convolution.
+ It is required to have length either 1, 3 or 6.
+
+ output_padding : Union[int, Tuple[int, ...]], optional
+ Used to disambiguate the output shape.
+
+ dilation : Union[int, Tuple[int, int, int]]
+ Specifies the dilation rate to be used for dilated convolution.
+ It is required to have length either 1 or 3.
+
+ groups : int
+ Number of groups to split the input into for grouped convolution.
+ The number of input and output channels should be divisible by the
number of groups.
+
+ data_layout : str
+ Layout of the input.
+
+ kernel_layout : str
+ Layout of the weight.
+
+ out_layout : Optional[str]
+ Layout of the output. If not specified, it is the same as data_layout
+
+ out_dtype : Optional[Union[str, DataType]]
+ Specifies the output data type for mixed precision conv3d_transpose.
+
+ See Also
+ --------
+ conv3d : Forward 3D convolution (default ``OIDHW`` weights vs. ``IODHW``
here).
+ conv2d_transpose : 2D analogue; legalization supports the same TOPI subset
(canonical layout, dilation 1).
+
+ Returns
+ -------
+ result : relax.Expr
+ The computed result.
+ """
+ if isinstance(strides, int):
+ strides = (strides, strides, strides)
+ if isinstance(dilation, int):
+ dilation = (dilation, dilation, dilation)
+ if isinstance(padding, int):
+ padding = (padding, padding, padding, padding, padding, padding)
+ if isinstance(output_padding, int):
+ output_padding = (output_padding, output_padding, output_padding)
+
+ return _ffi_api.conv3d_transpose( # type: ignore
+ data,
+ weight,
+ strides,
+ padding,
+ output_padding,
+ dilation,
+ groups,
+ data_layout,
+ kernel_layout,
+ out_layout,
+ out_dtype,
+ )
+
+
def pad(
data: Expr,
pad_width: list[int] | tuple[int, ...],
diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py
index d85c439d3a..7602af7e58 100644
--- a/python/tvm/relax/op/op_attrs.py
+++ b/python/tvm/relax/op/op_attrs.py
@@ -71,6 +71,11 @@ class Conv2DTransposeAttrs(Attrs):
"""Attributes for nn.conv2d_transpose"""
+@tvm_ffi.register_object("relax.attrs.Conv3DTransposeAttrs")
+class Conv3DTransposeAttrs(Attrs):
+ """Attributes for nn.conv3d_transpose"""
+
+
@tvm_ffi.register_object("relax.attrs.Pool2DAttrs")
class Pool2DAttrs(Attrs):
"""Attributes for nn.max_pool2d"""
diff --git a/python/tvm/relax/transform/legalize_ops/nn.py
b/python/tvm/relax/transform/legalize_ops/nn.py
index 4234aa831e..157ec8b148 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -200,7 +200,7 @@ def _nn_conv2d_transpose(bb: BlockBuilder, call: Call) ->
Expr:
)
return call
dilation = call.attrs.dilation
- if len(dilation) != 2 or dilation[0] != 1 or dilation[1] != 1:
+ if len(dilation) != 2 or any(d != 1 for d in dilation):
logging.info(
"TOPI conv2d_transpose does not support dilations other than 1, "
"and thus cannot be legalized by TOPI"
@@ -220,6 +220,42 @@ def _nn_conv2d_transpose(bb: BlockBuilder, call: Call) ->
Expr:
)
+@register_legalize("relax.nn.conv3d_transpose")
+def _nn_conv3d_transpose(bb: BlockBuilder, call: Call) -> Expr:
+ # Keep policy in sync with _nn_conv2d_transpose: only lower when TOPI
supports the layout/dilation.
+ if call.attrs.out_layout != call.attrs.data_layout:
+ logging.info(
+ "TOPI conv3d_transpose does not support different input-output "
+ "layouts, and thus cannot be legalized by TOPI"
+ )
+ return call
+ if call.attrs.data_layout != "NCDHW" or call.attrs.kernel_layout !=
"IODHW":
+ logging.info(
+ "TOPI conv3d_transpose does not support input layout other than
NCDHW, "
+ "and kernel layout other than IODHW, so cannot be legalized by
TOPI"
+ )
+ return call
+ dilation = call.attrs.dilation
+ if len(dilation) != 3 or any(d != 1 for d in dilation):
+ logging.info(
+ "TOPI conv3d_transpose does not support dilations other than 1, "
+ "and thus cannot be legalized by TOPI"
+ )
+ return call
+
+ return bb.call_te(
+ topi.nn.group_conv3d_transpose_ncdhw,
+ call.args[0],
+ call.args[1],
+ strides=call.attrs.strides,
+ padding=call.attrs.padding,
+ out_dtype=call.struct_info.dtype,
+ output_padding=call.attrs.output_padding,
+ groups=call.attrs.groups,
+ primfunc_name_hint="conv3d_transpose",
+ )
+
+
@register_legalize("relax.nn.pad")
def _nn_pad(bb: BlockBuilder, call: Call) -> Expr:
pad_mode = call.attrs.pad_mode
diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc
index 81c9fdf313..dcda15d39d 100644
--- a/src/relax/op/nn/convolution.cc
+++ b/src/relax/op/nn/convolution.cc
@@ -37,6 +37,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
Conv3DAttrs::RegisterReflection();
Conv1DTransposeAttrs::RegisterReflection();
Conv2DTransposeAttrs::RegisterReflection();
+ Conv3DTransposeAttrs::RegisterReflection();
}
/* relax.nn.conv1d */
@@ -887,10 +888,6 @@ StructInfo InferStructInfoConv2dTranspose(const Call&
call, const BlockBuilder&
<< "Conv2dTranspose expects the output padding less than
the strides, but the "
"output padding is"
<< attrs->output_padding << " while the strides are" <<
attrs->strides);
- } else if (!(attrs->output_padding[0] < attrs->strides[0] &&
- attrs->output_padding[1] < attrs->strides[1])) {
- // Todo(relax-team): Trust the input padding at this moment, and revisit
- // this condition with runtime shape check
}
PrimExpr input_h = data_NCHW_shape[2];
@@ -1009,5 +1006,247 @@ TVM_REGISTER_OP("relax.nn.conv2d_transpose")
.set_attr<FInferMixedPrecision>("FInferMixedPrecision",
InferMixedPrecisionConv2dTranspose)
.set_attr<Bool>("FPurity", Bool(true));
+/* relax.nn.conv3d_transpose */
+
+Expr conv3d_transpose(Expr data, Expr weight, ffi::Array<int64_t> strides,
+ ffi::Array<int64_t> padding, ffi::Array<int64_t>
output_padding,
+ ffi::Array<int64_t> dilation, int groups, ffi::String
data_layout,
+ ffi::String kernel_layout, ffi::Optional<ffi::String>
out_layout,
+ ffi::Optional<DataType> out_dtype) {
+ padding = GetCompletePadding3D(std::move(padding));
+ if (output_padding.size() == 1) {
+ output_padding.push_back(output_padding[0]);
+ output_padding.push_back(output_padding[0]);
+ }
+ if (strides.size() == 1) {
+ strides.push_back(strides[0]);
+ strides.push_back(strides[0]);
+ }
+ if (dilation.size() == 1) {
+ dilation.push_back(dilation[0]);
+ dilation.push_back(dilation[0]);
+ }
+
+ TVM_FFI_ICHECK_GT(groups, 0)
+ << "The number of groups in convolution is expected to be positive.
However, "
+ "the given number of groups is "
+ << groups;
+ TVM_FFI_ICHECK_EQ(output_padding.size(), 3)
+ << "The input output_padding length is expected to be 3. "
+ "However, the given output_padding is "
+ << output_padding;
+ TVM_FFI_ICHECK_EQ(strides.size(), 3)
+ << "The input strides length is expected to be 3. However, the given
strides is " << strides;
+ TVM_FFI_ICHECK_EQ(dilation.size(), 3)
+ << "The input dilation length is expected to be 3. However, the given
dilation is "
+ << dilation;
+
+ auto attrs = ffi::make_object<Conv3DTransposeAttrs>();
+ attrs->strides = std::move(strides);
+ attrs->padding = std::move(padding);
+ attrs->output_padding = std::move(output_padding);
+ attrs->dilation = std::move(dilation);
+ attrs->groups = groups;
+ attrs->data_layout = data_layout;
+ attrs->kernel_layout = std::move(kernel_layout);
+ attrs->out_layout = out_layout.value_or(data_layout);
+ attrs->out_dtype = std::move(out_dtype.value_or(DataType::Void()));
+ const Op& op = Op::Get("relax.nn.conv3d_transpose");
+ return Call(op, {data, weight}, Attrs(attrs), {});
+}
+
+TVM_FFI_STATIC_INIT_BLOCK() {
+ namespace refl = tvm::ffi::reflection;
+ refl::GlobalDef().def("relax.op.nn.conv3d_transpose", conv3d_transpose);
+}
+
+StructInfo InferStructInfoConv3dTranspose(const Call& call, const
BlockBuilder& ctx) {
+ ffi::Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call,
ctx);
+ TensorStructInfo data_sinfo = input_sinfo[0];
+ TensorStructInfo weight_sinfo = input_sinfo[1];
+
+ const auto* attrs = call->attrs.as<Conv3DTransposeAttrs>();
+ auto [data_layout, data2NCDHW] = CheckTensorLayout(call, ctx,
attrs->data_layout, //
+ /*tgt_layout=*/"NCDHW",
//
+ /*tensor_name=*/"data");
+ auto [weight_layout, weight2IODHW] = CheckTensorLayout(call, ctx,
attrs->kernel_layout, //
+
/*tgt_layout=*/"IODHW", //
+
/*tensor_name=*/"kernel");
+ auto [out_layout, out2NCDHW] = CheckTensorLayout(call, ctx,
attrs->out_layout, //
+ /*tgt_layout=*/"NCDHW",
//
+ /*tensor_name=*/"output");
+
+ ffi::Optional<ShapeExpr> data_shape =
+ CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout);
+ ffi::Optional<ShapeExpr> weight_shape =
+ CheckNdimPerLayoutAndGetShape(call, ctx, weight_sinfo, weight_layout);
+
+ DataType out_dtype = attrs->out_dtype.is_void()
+ ? InferBinaryArithOpOutDtype(call, ctx, data_sinfo,
weight_sinfo)
+ : attrs->out_dtype;
+ ffi::Optional<VDevice> vdevice =
+ InferBinaryArithOpOutVDevice(call, ctx, data_sinfo, weight_sinfo);
+ if (!data_shape.defined() || !weight_shape.defined()) {
+ return TensorStructInfo(out_dtype, out_layout.ndim(), vdevice);
+ }
+
+ ffi::Array<PrimExpr> data_NCDHW_shape =
data2NCDHW.ForwardShape(data_shape.value()->values);
+ ffi::Array<PrimExpr> weight_IODHW_shape =
weight2IODHW.ForwardShape(weight_shape.value()->values);
+
+ arith::Analyzer* analyzer = ctx->GetAnalyzer();
+ PrimExpr input_channel_data = data_NCDHW_shape[1];
+ PrimExpr input_channel_kernel = weight_IODHW_shape[0];
+ if (analyzer->CanProve(input_channel_data != input_channel_kernel)) {
+ ctx->ReportFatal(
+ Diagnostic::Error(call)
+ << "Conv3dTranspose expects the channel size of the data should equal
to the input channel "
+ "size of the weight. However, the data channel size is "
+ << input_channel_data << " while the weight input channel size is "
+ << input_channel_kernel);
+ } else if (!analyzer->CanProveEqual(input_channel_data,
input_channel_kernel)) {
+ // Todo(relax-team): Trust the input shape at this moment, and revisit
+ // this condition with runtime shape check
+ }
+ if (analyzer->CanProve(floormod(input_channel_kernel, attrs->groups) != 0)) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "Conv3dTranspose expects the number of input channels
to be divisible by "
+ "the number of groups. However, the number of input
channels is "
+ << input_channel_kernel << " while the number of groups
is " << attrs->groups);
+ } else if (!analyzer->CanProveEqual(floormod(input_channel_kernel,
attrs->groups), 0)) {
+ // Todo(relax-team): Trust the input shape at this moment, and revisit
+ // this condition with runtime shape check
+ }
+ if (attrs->output_padding[0] >= attrs->strides[0] ||
+ attrs->output_padding[1] >= attrs->strides[1] ||
+ attrs->output_padding[2] >= attrs->strides[2]) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "Conv3dTranspose expects the output padding less than
the strides, but the "
+ "output padding is"
+ << attrs->output_padding << " while the strides are" <<
attrs->strides);
+ }
+
+ PrimExpr input_d = data_NCDHW_shape[2];
+ PrimExpr input_h = data_NCDHW_shape[3];
+ PrimExpr input_w = data_NCDHW_shape[4];
+ PrimExpr kernel_d = weight_IODHW_shape[2];
+ PrimExpr kernel_h = weight_IODHW_shape[3];
+ PrimExpr kernel_w = weight_IODHW_shape[4];
+ PrimExpr padding_d = Integer(attrs->padding[0]) + Integer(attrs->padding[3]);
+ PrimExpr padding_h = Integer(attrs->padding[1]) + Integer(attrs->padding[4]);
+ PrimExpr padding_w = Integer(attrs->padding[2]) + Integer(attrs->padding[5]);
+
+ std::vector<PrimExpr> out_NCDHW_shape;
+ out_NCDHW_shape.resize(5);
+ out_NCDHW_shape[0] = data_NCDHW_shape[0];
+ out_NCDHW_shape[1] = weight_IODHW_shape[1] * attrs->groups;
+
+ PrimExpr out_d = (input_d - 1) * Integer(attrs->strides[0]) - padding_d +
+ Integer(attrs->dilation[0]) * (kernel_d - 1) +
+ Integer(attrs->output_padding[0]) + 1;
+ PrimExpr out_h = (input_h - 1) * Integer(attrs->strides[1]) - padding_h +
+ Integer(attrs->dilation[1]) * (kernel_h - 1) +
+ Integer(attrs->output_padding[1]) + 1;
+ PrimExpr out_w = (input_w - 1) * Integer(attrs->strides[2]) - padding_w +
+ Integer(attrs->dilation[2]) * (kernel_w - 1) +
+ Integer(attrs->output_padding[2]) + 1;
+ out_NCDHW_shape[2] = analyzer->Simplify(out_d);
+ out_NCDHW_shape[3] = analyzer->Simplify(out_h);
+ out_NCDHW_shape[4] = analyzer->Simplify(out_w);
+
+ ffi::Array<PrimExpr> out_shape = out2NCDHW.BackwardShape(out_NCDHW_shape);
+ return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice);
+}
+
+InferLayoutOutput InferLayoutConv3dTranspose(
+ const Call& call, const ffi::Map<ffi::String, ffi::Array<ffi::String>>&
desired_layouts,
+ const VarLayoutMap& var_layout_map) {
+ const auto* attrs = call->attrs.as<Conv3DTransposeAttrs>();
+ LayoutDecision data_layout = GetLayoutDecision(var_layout_map,
call->args[0]);
+ LayoutDecision weight_layout = GetLayoutDecision(var_layout_map,
call->args[1]);
+ LayoutDecision output_layout;
+ ObjectPtr<Conv3DTransposeAttrs> new_attrs =
ffi::make_object<Conv3DTransposeAttrs>(*attrs);
+
+ auto it = desired_layouts.find("relax.nn.conv3d_transpose");
+ if (it != desired_layouts.end()) {
+ Layout desired_data_layout = (*it).second[0];
+ Layout desired_weight_layout = (*it).second[1];
+ Layout desired_output_layout = (*it).second.size() == 3 ? (*it).second[2]
: (*it).second[0];
+
+ Layout input_layout = Layout(attrs->data_layout);
+ Layout kernel_layout = Layout(attrs->kernel_layout);
+ Layout out_layout = Layout(attrs->out_layout);
+
+ if (desired_data_layout.ndim_primal() == input_layout.ndim() &&
+ desired_weight_layout.ndim_primal() == kernel_layout.ndim() &&
+ desired_output_layout.ndim_primal() == out_layout.ndim()) {
+ data_layout = TransposeLike(InitialLayout(5), attrs->data_layout,
desired_data_layout);
+ weight_layout = TransposeLike(InitialLayout(5), attrs->kernel_layout,
desired_weight_layout);
+ output_layout = TransposeLike(InitialLayout(5), attrs->out_layout,
desired_output_layout);
+ new_attrs->data_layout = (*it).second[0];
+ new_attrs->kernel_layout = (*it).second[1];
+ new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] :
(*it).second[0];
+ return InferLayoutOutput({data_layout, weight_layout}, {output_layout},
Attrs(new_attrs));
+ } else {
+ auto data_si = GetStructInfo(call->args[0]);
+ auto kernel_si = GetStructInfo(call->args[1]);
+ TensorStructInfo data_sinfo = data_si.as<TensorStructInfo>().value();
+ TensorStructInfo kernel_sinfo = kernel_si.as<TensorStructInfo>().value();
+ ffi::Optional<ShapeExpr> data_shape =
+ ffi::GetRef<ShapeExpr>(data_sinfo->shape.as<ShapeExprNode>());
+ ffi::Optional<ShapeExpr> kernel_shape =
+ ffi::GetRef<ShapeExpr>(kernel_sinfo->shape.as<ShapeExprNode>());
+
+ bool can_data_proved =
+ CanProveLayoutTransform(input_layout, desired_data_layout,
data_shape.value()->values);
+ bool can_kernel_proved = CanProveLayoutTransform(kernel_layout,
desired_weight_layout,
+
kernel_shape.value()->values);
+
+ if (can_data_proved && can_kernel_proved) {
+ data_layout = TransposeSubLayoutLike(InitialLayout(5), input_layout,
desired_data_layout);
+ weight_layout =
+ TransposeSubLayoutLike(InitialLayout(5), kernel_layout,
desired_weight_layout);
+ output_layout = TransposeSubLayoutLike(InitialLayout(5), out_layout,
desired_output_layout);
+ new_attrs->data_layout = (*it).second[0];
+ new_attrs->kernel_layout = (*it).second[1];
+ new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] :
(*it).second[0];
+ return InferLayoutOutput({data_layout, weight_layout},
{output_layout}, Attrs(new_attrs));
+ } else {
+ data_layout = LayoutDecision(InitialLayout(5));
+ weight_layout = LayoutDecision(InitialLayout(5));
+ }
+ }
+ }
+
+ output_layout = data_layout;
+ new_attrs->data_layout =
+ TransposeLike(attrs->data_layout, InitialLayout(5),
data_layout->layout).name();
+ new_attrs->kernel_layout =
+ TransposeLike(attrs->kernel_layout, InitialLayout(5),
weight_layout->layout).name();
+ new_attrs->out_layout =
+ TransposeLike(attrs->out_layout, InitialLayout(5),
output_layout->layout).name();
+ return InferLayoutOutput({data_layout, weight_layout}, {output_layout},
Attrs(new_attrs));
+}
+
+Call InferMixedPrecisionConv3dTranspose(const Call& call, const DataType&
out_dtype) {
+ const auto* conv3d_transpose_attrs = call->attrs.as<Conv3DTransposeAttrs>();
+ return Downcast<Call>(
+ conv3d_transpose(call->args[0], call->args[1],
conv3d_transpose_attrs->strides,
+ conv3d_transpose_attrs->padding,
conv3d_transpose_attrs->output_padding,
+ conv3d_transpose_attrs->dilation,
conv3d_transpose_attrs->groups,
+ conv3d_transpose_attrs->data_layout,
conv3d_transpose_attrs->kernel_layout,
+ conv3d_transpose_attrs->out_layout, out_dtype));
+}
+
+TVM_REGISTER_OP("relax.nn.conv3d_transpose")
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("weight", "Tensor", "The weight tensor.")
+ .set_attrs_type<Conv3DTransposeAttrs>()
+ .set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoConv3dTranspose)
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout",
InferLayoutConv3dTranspose)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kAlways)
+ .set_attr<FInferMixedPrecision>("FInferMixedPrecision",
InferMixedPrecisionConv3dTranspose)
+ .set_attr<Bool>("FPurity", Bool(true));
+
} // namespace relax
} // namespace tvm
diff --git a/src/relax/op/nn/convolution.h b/src/relax/op/nn/convolution.h
index a5704d3f70..b08eb8a83f 100644
--- a/src/relax/op/nn/convolution.h
+++ b/src/relax/op/nn/convolution.h
@@ -95,6 +95,18 @@ Expr conv2d_transpose(Expr data, Expr weight,
ffi::Array<int64_t> strides,
ffi::String kernel_layout, ffi::Optional<ffi::String>
out_layout,
ffi::Optional<DataType> out_dtype);
+/*!
+ * \brief Three dimensional transposed convolution operator.
+ *
+ * This operator is intended to be the backward operator of conv3d. It can be
used to calculate the
+ * gradient of the result of conv3d w.r.t. the input of conv3d.
+ */
+Expr conv3d_transpose(Expr data, Expr weight, ffi::Array<int64_t> strides,
+ ffi::Array<int64_t> padding, ffi::Array<int64_t>
output_padding,
+ ffi::Array<int64_t> dilation, int groups, ffi::String
data_layout,
+ ffi::String kernel_layout, ffi::Optional<ffi::String>
out_layout,
+ ffi::Optional<DataType> out_dtype);
+
} // namespace relax
} // namespace tvm
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index e2067bad23..c6b4df6aaa 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -1223,7 +1223,11 @@ def test_conv_transpose(stride: int, dilation: int, pad:
int, bias: bool, output
def _verify_conv_transpose(input_shape, weight_shape):
nd = len(weight_shape) - 2
output_shape = [input_shape[0], weight_shape[0]] + [
- (input_shape[i] - 1) * stride - 2 * pad + dilation *
(weight_shape[i] - 1) + 1
+ (input_shape[i] - 1) * stride
+ - 2 * pad
+ + dilation * (weight_shape[i] - 1)
+ + output_pad
+ + 1
for i in range(2, len(input_shape))
]
bias_shape = [output_shape[1]]
@@ -1257,6 +1261,9 @@ def test_conv_transpose(stride: int, dilation: int, pad:
int, bias: bool, output
# ConvTranspose2D
_verify_conv_transpose([3, 4, 32, 32], [4, 4, 3, 3])
_verify_conv_transpose([3, 4, 32, 32], [4, 2, 3, 3]) # group=2
+ # ConvTranspose3D
+ _verify_conv_transpose([3, 4, 12, 12, 12], [4, 4, 3, 3, 3])
+ _verify_conv_transpose([3, 4, 12, 12, 12], [4, 2, 3, 3, 3]) # group=2
def test_pow():
diff --git a/tests/python/relax/test_op_nn_convolution.py
b/tests/python/relax/test_op_nn_convolution.py
index 07b9469abf..bf0abb09b0 100644
--- a/tests/python/relax/test_op_nn_convolution.py
+++ b/tests/python/relax/test_op_nn_convolution.py
@@ -41,6 +41,8 @@ def test_conv3d_op_correctness():
x = relax.Var("x", R.Tensor((2, 3, 28, 28, 28), "float32"))
w = relax.Var("w", R.Tensor((4, 3, 3, 3, 3), "float32"))
assert relax.op.nn.conv3d(x, w).op == Op.get("relax.nn.conv3d")
+ wt = relax.Var("wt", R.Tensor((3, 4, 3, 3, 3), "float32"))
+ assert relax.op.nn.conv3d_transpose(x, wt).op ==
Op.get("relax.nn.conv3d_transpose")
def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo:
relax.StructInfo):
@@ -1609,6 +1611,78 @@ def
test_conv2d_transpose_infer_struct_info_mixed_precision():
)
+def test_conv3d_transpose_infer_struct_info():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3, 28, 28, 28), "float32"))
+ w0 = relax.Var("w", R.Tensor((3, 4, 3, 3, 3), "float32"))
+ _check_inference(
+ bb,
+ relax.op.nn.conv3d_transpose(x0, w0),
+ relax.TensorStructInfo((2, 4, 30, 30, 30), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv3d_transpose(x0, w0, padding=1),
+ relax.TensorStructInfo((2, 4, 28, 28, 28), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv3d_transpose(x0, w0, strides=2, output_padding=1),
+ relax.TensorStructInfo((2, 4, 58, 58, 58), "float32"),
+ )
+
+
+def test_conv3d_transpose_infer_struct_info_ndhwc_out_layout():
+ bb = relax.BlockBuilder()
+ x_ndhwc = relax.Var("x_nd", R.Tensor((2, 28, 28, 28, 3), "float32"))
+ x_ncdhw = relax.Var("x_nc", R.Tensor((2, 3, 28, 28, 28), "float32"))
+ w0 = relax.Var("w", R.Tensor((3, 4, 3, 3, 3), "float32"))
+ _check_inference(
+ bb,
+ relax.op.nn.conv3d_transpose(x_ndhwc, w0, data_layout="NDHWC"),
+ relax.TensorStructInfo((2, 30, 30, 30, 4), "float32"),
+ )
+ # Default data_layout is NCDHW; use NCDHW-shaped input when only
out_layout is NDHWC.
+ _check_inference(
+ bb,
+ relax.op.nn.conv3d_transpose(x_ncdhw, w0, out_layout="NDHWC"),
+ relax.TensorStructInfo((2, 30, 30, 30, 4), "float32"),
+ )
+
+
+def test_conv3d_transpose_infer_struct_info_groups():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 128, 28, 28, 28), "float32"))
+ w0 = relax.Var("w", R.Tensor((128, 16, 3, 3, 3), "float32"))
+ _check_inference(
+ bb,
+ relax.op.nn.conv3d_transpose(x0, w0, groups=8),
+ relax.TensorStructInfo((2, 128, 30, 30, 30), "float32"),
+ )
+
+
+def test_conv3d_transpose_wrong_output_padding():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3, 28, 28, 28), "float32"))
+ w0 = relax.Var("w", R.Tensor((3, 4, 3, 3, 3), "float32"))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv3d_transpose(x0, w0, strides=2,
output_padding=2))
+ with pytest.raises(TVMError):
+ bb.normalize(
+ relax.op.nn.conv3d_transpose(
+ x0, w0, strides=(2, 2, 2), output_padding=(2, 2, 2)
+ )
+ )
+
+
+def test_conv3d_transpose_unequal_input_channel():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3, 28, 28, 28), "float32"))
+ w0 = relax.Var("w", R.Tensor((4, 4, 3, 3, 3), "float32"))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv3d_transpose(x0, w0))
+
+
def test_conv3d_infer_struct_info():
bb = relax.BlockBuilder()
vdev0 = VDevice("llvm")
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py
b/tests/python/relax/test_transform_legalize_ops_nn.py
index 53a4fa7b1c..603da2b48c 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -489,6 +489,114 @@ def test_conv2d_transpose():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_conv3d_transpose():
+ # fmt: off
+ @tvm.script.ir_module
+ class Conv3dTranspose:
+ @R.function
+ def main(x: R.Tensor((2, 3, 4, 4, 4), "float32"), w: R.Tensor((3, 4,
3, 3, 3), "float32")):
+ gv = R.nn.conv3d_transpose(x, w)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((2, 3, 4, 4, 4), dtype="float32"), w:
R.Tensor((3, 4, 3, 3, 3), dtype="float32")) -> R.Tensor((2, 4, 6, 6, 6),
dtype="float32"):
+ gv = R.call_tir(Expected.conv3d_transpose, (x, w),
out_sinfo=R.Tensor((2, 4, 6, 6, 6), dtype="float32"))
+ return gv
+
+ @T.prim_func(private=True)
+ def conv3d_transpose(x: T.Buffer((T.int64(2), T.int64(3), T.int64(4),
T.int64(4), T.int64(4)), "float32"), w: T.Buffer((T.int64(3), T.int64(4),
T.int64(3), T.int64(3), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2),
T.int64(4), T.int64(6), T.int64(6), T.int64(6)), "float32")):
+ T.func_attr({"tirx.noalias": True})
+ data_dilate = T.sblock_alloc_buffer((T.int64(2), T.int64(3),
T.int64(4), T.int64(4), T.int64(4)))
+ data_pad = T.sblock_alloc_buffer((T.int64(2), T.int64(3),
T.int64(8), T.int64(8), T.int64(8)))
+ kernel_transform = T.sblock_alloc_buffer((T.int64(4), T.int64(3),
T.int64(3), T.int64(3), T.int64(3)))
+ for i0, i1, i2, i3, i4 in T.grid(T.int64(2), T.int64(3),
T.int64(4), T.int64(4), T.int64(4)):
+ with T.sblock("data_dilate"):
+ v_i0, v_i1, v_i2, v_i3, v_i4 = T.axis.remap("SSSSS", [i0,
i1, i2, i3, i4])
+ T.reads(x[v_i0, v_i1, v_i2, v_i3, v_i4])
+ T.writes(data_dilate[v_i0, v_i1, v_i2, v_i3, v_i4])
+ data_dilate[v_i0, v_i1, v_i2, v_i3, v_i4] = x[v_i0, v_i1,
v_i2, v_i3, v_i4]
+ for i0, i1, i2, i3, i4 in T.grid(T.int64(2), T.int64(3),
T.int64(8), T.int64(8), T.int64(8)):
+ with T.sblock("data_pad"):
+ v_i0, v_i1, v_i2, v_i3, v_i4 = T.axis.remap("SSSSS", [i0,
i1, i2, i3, i4])
+ T.reads(data_dilate[v_i0, v_i1, v_i2 - T.int64(2), v_i3 -
T.int64(2), v_i4 - T.int64(2)])
+ T.writes(data_pad[v_i0, v_i1, v_i2, v_i3, v_i4])
+ data_pad[v_i0, v_i1, v_i2, v_i3, v_i4] =
T.if_then_else(T.int64(2) <= v_i2 and v_i2 < T.int64(6) and T.int64(2) <= v_i3
and v_i3 < T.int64(6) and T.int64(2) <= v_i4 and v_i4 < T.int64(6),
data_dilate[v_i0, v_i1, v_i2 - T.int64(2), v_i3 - T.int64(2), v_i4 -
T.int64(2)], T.float32(0.0))
+ for o, i, d, h, w_1 in T.grid(T.int64(4), T.int64(3), T.int64(3),
T.int64(3), T.int64(3)):
+ with T.sblock("kernel_transform"):
+ v_o, v_i, v_d, v_h, v_w = T.axis.remap("SSSSS", [o, i, d,
h, w_1])
+ T.reads(w[v_i, v_o, T.int64(2) - v_d, T.int64(2) - v_h,
T.int64(2) - v_w])
+ T.writes(kernel_transform[v_o, v_i, v_d, v_h, v_w])
+ kernel_transform[v_o, v_i, v_d, v_h, v_w] = w[v_i, v_o,
T.int64(2) - v_d, T.int64(2) - v_h, T.int64(2) - v_w]
+ for b, c, d, h, w_1, dc, dd, dh, dw in T.grid(T.int64(2),
T.int64(4), T.int64(6), T.int64(6), T.int64(6), T.int64(3), T.int64(3),
T.int64(3), T.int64(3)):
+ with T.sblock("compute"):
+ v_b, v_c, v_d, v_h, v_w, v_dc, v_dd, v_dh, v_dw =
T.axis.remap("SSSSSRRRR", [b, c, d, h, w_1, dc, dd, dh, dw])
+ T.reads(data_pad[v_b, v_dc, v_d + v_dd, v_h + v_dh, v_w +
v_dw], kernel_transform[v_c, v_dc, v_dd, v_dh, v_dw])
+ T.writes(compute[v_b, v_c, v_d, v_h, v_w])
+ with T.init():
+ compute[v_b, v_c, v_d, v_h, v_w] = T.float32(0.0)
+ compute[v_b, v_c, v_d, v_h, v_w] = compute[v_b, v_c, v_d,
v_h, v_w] + data_pad[v_b, v_dc, v_d + v_dd, v_h + v_dh, v_w + v_dw] *
kernel_transform[v_c, v_dc, v_dd, v_dh, v_dw]
+ # fmt: on
+
+ mod = LegalizeOps()(Conv3dTranspose)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_conv3d_transpose_with_out_dtype():
+ # fmt: off
+ @tvm.script.ir_module
+ class Conv3dTranspose:
+ @R.function
+ def main(x: R.Tensor((2, 3, 4, 4, 4), "float32"), w: R.Tensor((3, 4,
3, 3, 3), "float32")):
+ gv = R.nn.conv3d_transpose(x, w, out_dtype="float16")
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((2, 3, 4, 4, 4), dtype="float32"), w:
R.Tensor((3, 4, 3, 3, 3), dtype="float32")) -> R.Tensor((2, 4, 6, 6, 6),
dtype="float16"):
+ gv = R.call_tir(Expected.conv3d_transpose, (x, w),
out_sinfo=R.Tensor((2, 4, 6, 6, 6), dtype="float16"))
+ return gv
+
+ @T.prim_func(private=True)
+ def conv3d_transpose(x: T.Buffer((T.int64(2), T.int64(3), T.int64(4),
T.int64(4), T.int64(4)), "float32"), w: T.Buffer((T.int64(3), T.int64(4),
T.int64(3), T.int64(3), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2),
T.int64(4), T.int64(6), T.int64(6), T.int64(6)), "float16")):
+ T.func_attr({"tirx.noalias": True})
+ data_dilate = T.sblock_alloc_buffer((T.int64(2), T.int64(3),
T.int64(4), T.int64(4), T.int64(4)))
+ data_pad = T.sblock_alloc_buffer((T.int64(2), T.int64(3),
T.int64(8), T.int64(8), T.int64(8)))
+ kernel_transform = T.sblock_alloc_buffer((T.int64(4), T.int64(3),
T.int64(3), T.int64(3), T.int64(3)))
+ for i0, i1, i2, i3, i4 in T.grid(T.int64(2), T.int64(3),
T.int64(4), T.int64(4), T.int64(4)):
+ with T.sblock("data_dilate"):
+ v_i0, v_i1, v_i2, v_i3, v_i4 = T.axis.remap("SSSSS", [i0,
i1, i2, i3, i4])
+ T.reads(x[v_i0, v_i1, v_i2, v_i3, v_i4])
+ T.writes(data_dilate[v_i0, v_i1, v_i2, v_i3, v_i4])
+ data_dilate[v_i0, v_i1, v_i2, v_i3, v_i4] = x[v_i0, v_i1,
v_i2, v_i3, v_i4]
+ for i0, i1, i2, i3, i4 in T.grid(T.int64(2), T.int64(3),
T.int64(8), T.int64(8), T.int64(8)):
+ with T.sblock("data_pad"):
+ v_i0, v_i1, v_i2, v_i3, v_i4 = T.axis.remap("SSSSS", [i0,
i1, i2, i3, i4])
+ T.reads(data_dilate[v_i0, v_i1, v_i2 - T.int64(2), v_i3 -
T.int64(2), v_i4 - T.int64(2)])
+ T.writes(data_pad[v_i0, v_i1, v_i2, v_i3, v_i4])
+ data_pad[v_i0, v_i1, v_i2, v_i3, v_i4] =
T.if_then_else(T.int64(2) <= v_i2 and v_i2 < T.int64(6) and T.int64(2) <= v_i3
and v_i3 < T.int64(6) and T.int64(2) <= v_i4 and v_i4 < T.int64(6),
data_dilate[v_i0, v_i1, v_i2 - T.int64(2), v_i3 - T.int64(2), v_i4 -
T.int64(2)], T.float32(0.0))
+ for o, i, d, h, w_1 in T.grid(T.int64(4), T.int64(3), T.int64(3),
T.int64(3), T.int64(3)):
+ with T.sblock("kernel_transform"):
+ v_o, v_i, v_d, v_h, v_w = T.axis.remap("SSSSS", [o, i, d,
h, w_1])
+ T.reads(w[v_i, v_o, T.int64(2) - v_d, T.int64(2) - v_h,
T.int64(2) - v_w])
+ T.writes(kernel_transform[v_o, v_i, v_d, v_h, v_w])
+ kernel_transform[v_o, v_i, v_d, v_h, v_w] = w[v_i, v_o,
T.int64(2) - v_d, T.int64(2) - v_h, T.int64(2) - v_w]
+ for b, c, d, h, w_1, dc, dd, dh, dw in T.grid(T.int64(2),
T.int64(4), T.int64(6), T.int64(6), T.int64(6), T.int64(3), T.int64(3),
T.int64(3), T.int64(3)):
+ with T.sblock("compute"):
+ v_b, v_c, v_d, v_h, v_w, v_dc, v_dd, v_dh, v_dw =
T.axis.remap("SSSSSRRRR", [b, c, d, h, w_1, dc, dd, dh, dw])
+ T.reads(data_pad[v_b, v_dc, v_d + v_dd, v_h + v_dh, v_w +
v_dw], kernel_transform[v_c, v_dc, v_dd, v_dh, v_dw])
+ T.writes(compute[v_b, v_c, v_d, v_h, v_w])
+ with T.init():
+ compute[v_b, v_c, v_d, v_h, v_w] = T.float16(0.0)
+ compute[v_b, v_c, v_d, v_h, v_w] = compute[v_b, v_c, v_d,
v_h, v_w] + T.Cast("float16", data_pad[v_b, v_dc, v_d + v_dd, v_h + v_dh, v_w +
v_dw]) * T.Cast("float16", kernel_transform[v_c, v_dc, v_dd, v_dh, v_dw])
+ # fmt: on
+
+ mod = LegalizeOps()(Conv3dTranspose)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
def test_conv2d_transpose_with_out_dtype():
# fmt: off
@tvm.script.ir_module
diff --git a/tests/python/relax/test_tvmscript_parser_op_nn.py
b/tests/python/relax/test_tvmscript_parser_op_nn.py
index b9aeb9efe5..88707ff98d 100644
--- a/tests/python/relax/test_tvmscript_parser_op_nn.py
+++ b/tests/python/relax/test_tvmscript_parser_op_nn.py
@@ -108,6 +108,44 @@ def test_conv2d_transpose():
_check(foo, bb.get()["foo"])
+def test_conv3d():
+ @R.function
+ def foo(
+ x: R.Tensor((2, 3, 8, 8, 8), "float16"), w: R.Tensor((6, 3, 3, 3, 3),
"float16")
+ ) -> R.Tensor((2, 6, 6, 6, 6), "float16"):
+ gv: R.Tensor((2, 6, 6, 6, 6), "float16") = R.nn.conv3d(x, w,
out_dtype="float16")
+ return gv
+
+ x = relax.Var("x", R.Tensor([2, 3, 8, 8, 8], "float16"))
+ w = relax.Var("w", R.Tensor([6, 3, 3, 3, 3], "float16"))
+ bb = relax.BlockBuilder()
+ with bb.function("foo", [x, w]):
+ gv = bb.emit(relax.op.nn.conv3d(x, w, out_dtype="float16"))
+ bb.emit_func_output(gv)
+
+ _check(foo, bb.get()["foo"])
+
+
+def test_conv3d_transpose():
+ @R.function
+ def foo(
+ x: R.Tensor((2, 3, 8, 8, 8), "float16"), w: R.Tensor((3, 6, 3, 3, 3),
"float16")
+ ) -> R.Tensor((2, 6, 10, 10, 10), "float16"):
+ gv: R.Tensor((2, 6, 10, 10, 10), "float16") = R.nn.conv3d_transpose(
+ x, w, out_dtype="float16"
+ )
+ return gv
+
+ x = relax.Var("x", R.Tensor([2, 3, 8, 8, 8], "float16"))
+ w = relax.Var("w", R.Tensor([3, 6, 3, 3, 3], "float16"))
+ bb = relax.BlockBuilder()
+ with bb.function("foo", [x, w]):
+ gv = bb.emit(relax.op.nn.conv3d_transpose(x, w, out_dtype="float16"))
+ bb.emit_func_output(gv)
+
+ _check(foo, bb.get()["foo"])
+
+
def test_max_pool2d():
@R.function
def foo(x: R.Tensor((1, 1, 32, 32), dtype="float32")) -> R.Tensor(