This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 5c87bfe09b [Unity][Relax][Op] Add Conv3D Operator (#16385)
5c87bfe09b is described below
commit 5c87bfe09b274b252c29a61d4a7c742689114886
Author: Josh Fromm <[email protected]>
AuthorDate: Sat Jan 13 18:06:04 2024 -0800
[Unity][Relax][Op] Add Conv3D Operator (#16385)
---
include/tvm/relax/attrs/nn.h | 45 ++++++
python/tvm/relax/frontend/onnx/onnx_frontend.py | 48 +++---
python/tvm/relax/frontend/torch/fx_translator.py | 29 ++++
python/tvm/relax/op/nn/__init__.py | 1 +
python/tvm/relax/op/nn/nn.py | 100 ++++++++++++
python/tvm/relax/transform/legalize_ops/nn.py | 41 +++++
src/relax/op/nn/convolution.cc | 176 +++++++++++++++++++++
src/relax/op/nn/convolution.h | 5 +
src/relax/op/op_common.h | 25 +++
tests/python/relax/test_frontend_from_fx.py | 79 ++++++++++
tests/python/relax/test_frontend_onnx.py | 15 +-
tests/python/relax/test_op_nn_convolution.py | 187 +++++++++++++++++++++++
12 files changed, 725 insertions(+), 26 deletions(-)
diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h
index 424874bd75..dd63a70bc4 100644
--- a/include/tvm/relax/attrs/nn.h
+++ b/include/tvm/relax/attrs/nn.h
@@ -117,6 +117,51 @@ struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
}
}; // struct Conv2dAttrs
+/*! \brief Attributes used in Conv3d operator */
+struct Conv3DAttrs : public tvm::AttrsNode<Conv3DAttrs> {
+ Array<IntImm> strides;
+ Array<IntImm> padding;
+ Array<IntImm> dilation;
+ int groups;
+ String data_layout;
+ String kernel_layout;
+ String out_layout;
+ DataType out_dtype;
+
+ TVM_DECLARE_ATTRS(Conv3DAttrs, "relax.attrs.Conv3DAttrs") {
+ TVM_ATTR_FIELD(strides).describe("Specifies the strides of the
convolution.");
+ TVM_ATTR_FIELD(padding).describe(
+ "If padding is non-zero, then the input is implicitly zero-padded"
+ "Padding support both symmetric and asymmetric as"
+ "one int : same padding used on all sides"
+ "two int : bottom, right will use same padding as top, left"
+ "four int : padding width in the order of (forward, back, top, left,
bottom, right)");
+ TVM_ATTR_FIELD(dilation).describe(
+ "Specifies the dilation rate to use for dilated convolution.");
+ TVM_ATTR_FIELD(groups).describe(
+ "Number of groups to split the input into for grouped convolution. The
number of input and "
+ "output channels should be divisible by the number of groups.");
+ TVM_ATTR_FIELD(data_layout)
+ .describe(
+ "Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc."
+ "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height,
and width"
+ "dimensions respectively. Convolution is applied on the 'D', 'H',
and"
+ "'W' dimensions.");
+ TVM_ATTR_FIELD(kernel_layout)
+ .describe(
+ "Dimension ordering of weight. Can be 'OIDHW', 'OIDHW16o16i', etc."
+ "'O', 'I', 'D', 'H', 'W' stands for num_filter, input_channel,
depth, height, and width"
+ "dimensions respectively.");
+ TVM_ATTR_FIELD(out_layout)
+ .describe(
+ "Dimension ordering of output. Can be 'NCDHW', 'NDHWC', etc."
+ "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height,
and width"
+ "dimensions respectively. Default to be same as input layout.");
+ TVM_ATTR_FIELD(out_dtype).describe(
+ "Output data type, set to explicit type under mixed precision
setting");
+ }
+}; // struct Conv3dAttrs
+
/*! \brief Attributes used in Conv1DTranspose operator */
struct Conv1DTransposeAttrs : public tvm::AttrsNode<Conv1DTransposeAttrs> {
Array<IntImm> strides;
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 0b5aa4f7ec..702501ce16 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -685,32 +685,32 @@ class Conv(OnnxOpConverter):
ndim = len(inputs[0].struct_info.shape)
if ndim == 3:
- conv_out = bb.emit_te(
- topi.nn.conv1d,
- inputs[0],
- inputs[1],
- attr.get("strides", 1),
- attr.get("pads", 0),
- attr.get("dilation", 1),
- "NCHW",
- "OIHW",
- )
+ op = relax.op.nn.conv1d
+ data_layout = "NCW"
+ kernel_layout = "OIW"
elif ndim == 4:
- conv_out = bb.normalize(
- relax.op.nn.conv2d(
- data=inputs[0],
- weight=inputs[1],
- strides=attr.get("strides", 1),
- padding=attr.get("pads", 0),
- dilation=attr.get("dilation", 1),
- groups=attr.get("group", 1),
- data_layout="NCHW",
- kernel_layout="OIHW",
- )
- )
+ op = relax.op.nn.conv2d
+ data_layout = "NCHW"
+ kernel_layout = "OIHW"
+ elif ndim == 5:
+ op = relax.op.nn.conv3d
+ data_layout = "NCDHW"
+ kernel_layout = "OIDHW"
else:
- raise NotImplementedError("Only 2d conv currently supported.")
-
+ raise NotImplementedError("Ndim > 5 not supported for
convolution.")
+
+ conv_out = bb.normalize(
+ op(
+ data=inputs[0],
+ weight=inputs[1],
+ strides=attr.get("strides", 1),
+ padding=attr.get("pads", 0),
+ dilation=attr.get("dilation", 1),
+ groups=attr.get("group", 1),
+ data_layout=data_layout,
+ kernel_layout=kernel_layout,
+ )
+ )
if inputs[2] is not None:
bias = relax.op.reshape(
inputs[2],
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index 1510df9548..5e581e81f3 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -695,6 +695,34 @@ class TorchFXImporter:
return self.block_builder.emit(relax.op.add(conv1d, bias))
+ def _conv3d(self, node: fx.node.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ module = self.named_modules[node.target]
+ weight = self.params[module.weight]
+
+ conv3d = self.block_builder.emit(
+ relax.op.nn.conv3d(
+ x,
+ weight,
+ strides=module.stride,
+ padding=module.padding,
+ dilation=module.dilation,
+ groups=module.groups,
+ data_layout="NCDHW",
+ kernel_layout="OIDHW",
+ out_dtype="float32",
+ )
+ )
+
+ if module.bias is None:
+ return conv3d
+
+ bias = self.params[module.bias]
+ assert len(self.shape_of(bias)) == 1
+ bias = relax.op.reshape(bias, (1, -1, 1, 1, 1))
+
+ return self.block_builder.emit(relax.op.add(conv3d, bias))
+
def _conv2d_impl(
self,
x: relax.Expr,
@@ -1313,6 +1341,7 @@ class TorchFXImporter:
nn.Linear: self._linear,
nn.Conv1d: self._conv1d,
nn.Conv2d: self._conv2d,
+ nn.Conv3d: self._conv3d,
nn.ConvTranspose1d: self._conv1d_transpose,
nn.ConvTranspose2d: self._conv2d_transpose,
nn.MaxPool2d: self._max_pool2d,
diff --git a/python/tvm/relax/op/nn/__init__.py
b/python/tvm/relax/op/nn/__init__.py
index 9f01086a69..d90b207314 100644
--- a/python/tvm/relax/op/nn/__init__.py
+++ b/python/tvm/relax/op/nn/__init__.py
@@ -25,6 +25,7 @@ from .nn import (
conv1d_transpose,
conv2d,
conv2d_transpose,
+ conv3d,
cross_entropy_with_logits,
dropout,
gelu,
diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
index 5adf38d7d6..5c18d31bf2 100644
--- a/python/tvm/relax/op/nn/nn.py
+++ b/python/tvm/relax/op/nn/nn.py
@@ -220,6 +220,106 @@ def conv2d(
)
+def conv3d(
+ data: Expr,
+ weight: Expr,
+ strides: Union[int, Tuple[int, int]] = (1, 1, 1),
+ padding: Union[int, Tuple[int, ...]] = (0, 0, 0),
+ dilation: Union[int, Tuple[int, int]] = (1, 1, 1),
+ groups: int = 1,
+ data_layout: str = "NCDHW",
+ kernel_layout: str = "OIDHW",
+ out_layout: Optional[str] = None,
+ out_dtype: Optional[Union[str, DataType]] = None,
+) -> Expr:
+ r"""3D convolution.
+
+ This operator takes the weight as the convolution kernel
+ and convolves it with data to produce an output.
+
+
+ In the default case, where the data_layout is `NCDHW`
+ and kernel_layout is `OIDHW`, conv3d takes in
+ a data Tensor with shape `(batch_size, in_channels, depth, height, width)`,
+ and a weight Tensor with shape `(channels, in_channels, kernel_d,
kernel_h, kernel_w)`,
+ where `kernel_d`, `kernel_h`, and `kernel_w` are the lengths of the `D`,
`H`,
+ and `W` kernel dimensions, to produce an output Tensor with the following
rule:
+
+ .. math::
+
+ \mbox{out}[b, c, z, y, x] = \sum_{dz, dy, dx, k}
+ \mbox{data}[b, k, \mbox{strides}[0] * z + dz,
+ \mbox{strides}[1] * y + dy,
+ \mbox{strides}[2] * x + dx] *
+ \mbox{weight}[c, k, dz, dy, dx]
+
+ Padding and dilation are applied to data and weight respectively before
the computation.
+ This operator accepts data layout specification.
+ Semantically, the operator will convert the layout to the canonical layout
+ (`NCDHW` for data and `OIDHW` for weight), perform the computation,
+ then convert to the out_layout.
+
+ Parameters
+ ----------
+ data : relax.Expr
+ The input data to the operator.
+
+ weight : relax.Expr
+ The weight expressions.
+
+ strides : Union[int, Tuple[int, int, int]]
+ The strides of convolution. It is required to have length either 1 or
3.
+
+ padding : Union[int, Tuple[int, ...]]
+ The padding of convolution on both sides of inputs before convolution.
+ It is required to have length either 1, 3 or 6.
+
+ dilation : Union[int, Tuple[int, int, int]]
+ Specifies the dilation rate to be used for dilated convolution.
+ It is required to have length either 1 or 3.
+
+ groups : int
+ Number of groups to split the input into for grouped convolution.
+ The number of input and output channels should be divisible by the
number of groups.
+
+ data_layout : str
+ Layout of the input.
+
+ kernel_layout : str
+ Layout of the weight.
+
+ out_layout : Optional[str]
+ Layout of the output. If not specified, it is the same as data_layout
+
+ out_dtype : Optional[Union[str, DataType]]
+ Specifies the output data type for mixed precision conv2d.
+
+ Returns
+ -------
+ result : relax.Expr
+ The computed result.
+ """
+ if isinstance(strides, int):
+ strides = (strides, strides, strides)
+ if isinstance(dilation, int):
+ dilation = (dilation, dilation, dilation)
+ if isinstance(padding, int):
+ padding = (padding, padding, padding, padding, padding, padding)
+
+ return _ffi_api.conv3d( # type: ignore
+ data,
+ weight,
+ strides,
+ padding,
+ dilation,
+ groups,
+ data_layout,
+ kernel_layout,
+ out_layout,
+ out_dtype,
+ )
+
+
def conv1d_transpose(
data: Expr,
weight: Expr,
diff --git a/python/tvm/relax/transform/legalize_ops/nn.py
b/python/tvm/relax/transform/legalize_ops/nn.py
index f2453a67b6..186071f227 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -109,6 +109,47 @@ def _nn_conv2d(bb: BlockBuilder, call: Call) -> Expr:
)
+@register_legalize("relax.nn.conv3d")
+def _nn_conv3d(bb: BlockBuilder, call: Call) -> Expr:
+ if call.attrs.out_layout != call.attrs.data_layout:
+ logging.info(
+ "TOPI conv3d does not support different input-output "
+ "layouts, and thus cannot be legalized by TOPI"
+ )
+ return call
+ if len(call.attrs.data_layout) != 5 or len(call.attrs.kernel_layout) != 5:
+ logging.info(
+ "Conv3D where data layout or kernel layout have channel chunk "
+ "cannot be legalized by TOPI at this moment."
+ )
+ return call
+ if call.attrs.groups != 1:
+ data_layout = tir.layout(call.attrs.data_layout)
+ kernel_layout = tir.layout(call.attrs.kernel_layout)
+ ic = call.args[0].struct_info.shape.values[data_layout.index_of("C")]
+ oc = call.args[1].struct_info.shape.values[kernel_layout.index_of("O")]
+ if not isinstance(ic, tir.IntImm) or not isinstance(oc, tir.IntImm):
+ logging.info(
+ "Conv3D where number of groups is more than one and input or
output "
+ "channel size is symbolic cannot be legalized by TOPI at this
moment."
+ )
+ return call
+
+ return bb.call_te(
+ topi.nn.conv,
+ inp=call.args[0],
+ filt=call.args[1],
+ stride=call.attrs.strides,
+ padding=call.attrs.padding,
+ dilation=call.attrs.dilation,
+ groups=call.attrs.groups,
+ data_layout=call.attrs.data_layout,
+ kernel_layout=call.attrs.kernel_layout,
+ out_dtype=call.attrs.out_dtype if call.attrs.out_dtype != "" else None,
+ primfunc_name_hint="conv3d",
+ )
+
+
@register_legalize("relax.nn.conv1d_transpose")
def _nn_conv1d_transpose(bb: BlockBuilder, call: Call) -> Expr:
if call.attrs.out_layout != call.attrs.data_layout:
diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc
index cea234060e..7c7718b837 100644
--- a/src/relax/op/nn/convolution.cc
+++ b/src/relax/op/nn/convolution.cc
@@ -354,6 +354,182 @@ TVM_REGISTER_OP("relax.nn.conv2d")
.set_attr<FInferMixedPrecision>("FInferMixedPrecision",
InferMixedPrecisionConv2d)
.set_attr<Bool>("FPurity", Bool(true));
+/* relax.nn.conv3d */
+TVM_REGISTER_NODE_TYPE(Conv3DAttrs);
+
+Expr conv3d(Expr data, Expr weight, Array<IntImm> strides, Array<IntImm>
padding,
+ Array<IntImm> dilation, int groups, String data_layout, String
kernel_layout,
+ Optional<String> out_layout, DataType out_dtype) {
+ padding = GetCompletePadding3D(std::move(padding));
+ if (strides.size() == 1) {
+ strides.push_back(strides[0]);
+ strides.push_back(strides[0]);
+ }
+ if (dilation.size() == 1) {
+ dilation.push_back(dilation[0]);
+ dilation.push_back(dilation[0]);
+ }
+
+ CHECK_GT(groups, 0) << "The number of groups in convolution is expected to
be positive. However, "
+ "the given number of groups is "
+ << groups;
+ CHECK_EQ(strides.size(), 3)
+ << "The input strides length is expected to be 3. However, the given
strides is " << strides;
+ CHECK_EQ(dilation.size(), 3)
+ << "The input dilation length is expected to be 3. However, the given
dilation is "
+ << dilation;
+ return MakeConv<Conv3DAttrs>(std::move(data), std::move(weight),
std::move(strides),
+ std::move(padding), std::move(dilation),
groups, data_layout,
+ std::move(kernel_layout),
out_layout.value_or(data_layout),
+ out_dtype, /*op_name=*/"relax.nn.conv3d");
+}
+
+TVM_REGISTER_GLOBAL("relax.op.nn.conv3d").set_body_typed(conv3d);
+
+StructInfo InferStructInfoConv3d(const Call& call, const BlockBuilder& ctx) {
+ Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
+ TensorStructInfo data_sinfo = input_sinfo[0];
+ TensorStructInfo weight_sinfo = input_sinfo[1];
+
+ const auto* attrs = call->attrs.as<Conv3DAttrs>();
+ auto [data_layout, data2NCDHW] = CheckTensorLayout(call, ctx,
attrs->data_layout, //
+ /*tgt_layout=*/"NCDHW",
//
+ /*tensor_name=*/"data");
+ auto [weight_layout, weight2OIDHW] = CheckTensorLayout(call, ctx,
attrs->kernel_layout, //
+
/*tgt_layout=*/"OIDHW", //
+
/*tensor_name=*/"kernel");
+ auto [out_layout, out2NCDHW] = CheckTensorLayout(call, ctx,
attrs->out_layout, //
+ /*tgt_layout=*/"NCDHW",
//
+ /*tensor_name=*/"output");
+
+ Optional<ShapeExpr> data_shape =
+ CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout);
+ Optional<ShapeExpr> weight_shape =
+ CheckNdimPerLayoutAndGetShape(call, ctx, weight_sinfo, weight_layout);
+
+ DataType out_dtype = attrs->out_dtype.is_void()
+ ? InferBinaryArithOpOutDtype(call, ctx, data_sinfo,
weight_sinfo)
+ : attrs->out_dtype;
+ Optional<VDevice> vdevice = InferBinaryArithOpOutVDevice(call, ctx,
data_sinfo, weight_sinfo);
+ if (!data_shape.defined() || !weight_shape.defined()) {
+ return TensorStructInfo(out_dtype, out_layout.ndim(), vdevice);
+ }
+
+ Array<PrimExpr> data_NCDHW_shape =
data2NCDHW.ForwardShape(data_shape.value()->values);
+ Array<PrimExpr> weight_OIDHW_shape =
weight2OIDHW.ForwardShape(weight_shape.value()->values);
+
+ arith::Analyzer* analyzer = ctx->GetAnalyzer();
+ PrimExpr input_channel_data = data_NCDHW_shape[1];
+ PrimExpr input_channel_kernel = weight_OIDHW_shape[1];
+ if (analyzer->CanProve(input_channel_data != input_channel_kernel *
attrs->groups)) {
+ ctx->ReportFatal(
+ Diagnostic::Error(call)
+ << "The channel size of the data should equal to the product of input
channel size of the "
+ "weight and the number of groups. However, the data channel size is
"
+ << input_channel_data << " while the weight input channel size and
number of groups are "
+ << input_channel_kernel << " and " << attrs->groups);
+ } else if (!analyzer->CanProveEqual(input_channel_data, input_channel_kernel
* attrs->groups)) {
+ // Todo(relax-team): Trust the input shape at this moment, and revisit
+ // this condition with runtime shape check
+ }
+ if (analyzer->CanProve(floormod(weight_OIDHW_shape[0], attrs->groups) != 0))
{
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "Conv3d expects the number of output channels to be
divisible by the "
+ "number of groups. However, the number of output
channels is "
+ << weight_OIDHW_shape[0] << " while the number of groups
is "
+ << attrs->groups);
+ } else if (!analyzer->CanProveEqual(floormod(weight_OIDHW_shape[0],
attrs->groups), 0)) {
+ // Todo(relax-team): Trust the input shape at this moment, and revisit
+ // this condition with runtime shape check
+ }
+
+ PrimExpr input_d = data_NCDHW_shape[2];
+ PrimExpr input_h = data_NCDHW_shape[3];
+ PrimExpr input_w = data_NCDHW_shape[4];
+ PrimExpr kernel_d = weight_OIDHW_shape[2];
+ PrimExpr kernel_h = weight_OIDHW_shape[3];
+ PrimExpr kernel_w = weight_OIDHW_shape[4];
+ PrimExpr padding_d = attrs->padding[0] + attrs->padding[3];
+ PrimExpr padding_h = attrs->padding[1] + attrs->padding[4];
+ PrimExpr padding_w = attrs->padding[2] + attrs->padding[5];
+
+ std::vector<PrimExpr> out_NCDHW_shape;
+ out_NCDHW_shape.resize(5);
+ out_NCDHW_shape[0] = data_NCDHW_shape[0];
+ out_NCDHW_shape[1] = weight_OIDHW_shape[0];
+
+ PrimExpr numerator_d = input_d + padding_d - attrs->dilation[0] * (kernel_d
- 1) - 1;
+ PrimExpr numerator_h = input_h + padding_h - attrs->dilation[1] * (kernel_h
- 1) - 1;
+ PrimExpr numerator_w = input_w + padding_w - attrs->dilation[2] * (kernel_w
- 1) - 1;
+ out_NCDHW_shape[2] = analyzer->Simplify(floordiv(numerator_d,
attrs->strides[0]) + 1);
+ out_NCDHW_shape[3] = analyzer->Simplify(floordiv(numerator_h,
attrs->strides[1]) + 1);
+ out_NCDHW_shape[4] = analyzer->Simplify(floordiv(numerator_w,
attrs->strides[2]) + 1);
+
+ Array<PrimExpr> out_shape = out2NCDHW.BackwardShape(out_NCDHW_shape);
+ return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice);
+}
+
+InferLayoutOutput InferLayoutConv3d(const Call& call,
+ const Map<String, Array<String>>&
desired_layouts,
+ const VarLayoutMap& var_layout_map) {
+ const auto& it = desired_layouts.find("relax.nn.conv3d");
+ const auto* attrs = call->attrs.as<Conv3DAttrs>();
+ ICHECK(attrs) << "Invalid Call";
+
+ LayoutDecision data_layout, weight_layout, output_layout;
+ ObjectPtr<Conv3DAttrs> new_attrs = make_object<Conv3DAttrs>(*attrs);
+
+ if (it != desired_layouts.end()) {
+ // We have a desired layout for conv3d.
+ Layout desired_data_layout = (*it).second[0];
+ Layout desired_weight_layout = (*it).second[1];
+ Layout desired_output_layout = (*it).second.size() == 3 ? (*it).second[2]
: (*it).second[0];
+ ICHECK_EQ(desired_data_layout.ndim(), desired_data_layout.ndim_primal())
<< "Axis swap only";
+ ICHECK_EQ(desired_weight_layout.ndim(),
desired_weight_layout.ndim_primal())
+ << "Axis swap only";
+ ICHECK_EQ(desired_output_layout.ndim(),
desired_output_layout.ndim_primal())
+ << "Axis swap only";
+ data_layout = TransposeLike(InitialLayout(5), attrs->data_layout,
desired_data_layout);
+ weight_layout = TransposeLike(InitialLayout(5), attrs->kernel_layout,
desired_weight_layout);
+ output_layout = TransposeLike(InitialLayout(5), attrs->out_layout,
desired_output_layout);
+ new_attrs->data_layout = (*it).second[0];
+ new_attrs->kernel_layout = (*it).second[1];
+ new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] :
(*it).second[0];
+ } else {
+ // We don't have a desired layout for conv2d.
+ // We can just propagate the layout from the input.
+ data_layout = GetLayoutDecision(var_layout_map, call->args[0]);
+ weight_layout = GetLayoutDecision(var_layout_map, call->args[1]);
+ output_layout = data_layout;
+ new_attrs->data_layout =
+ TransposeLike(attrs->data_layout, InitialLayout(5),
data_layout->layout).name();
+ new_attrs->kernel_layout =
+ TransposeLike(attrs->kernel_layout, InitialLayout(5),
weight_layout->layout).name();
+ new_attrs->out_layout =
+ TransposeLike(attrs->out_layout, InitialLayout(5),
output_layout->layout).name();
+ }
+ return InferLayoutOutput({data_layout, weight_layout}, {output_layout},
Attrs(new_attrs));
+}
+
+Call InferMixedPrecisionConv3d(const Call& call, const DataType& out_dtype) {
+ const auto* conv3d_attrs = call->attrs.as<Conv3DAttrs>();
+ return Downcast<Call>(conv3d(call->args[0], call->args[1],
conv3d_attrs->strides,
+ conv3d_attrs->padding, conv3d_attrs->dilation,
conv3d_attrs->groups,
+ conv3d_attrs->data_layout,
conv3d_attrs->kernel_layout,
+ conv3d_attrs->out_layout, out_dtype));
+}
+
+TVM_REGISTER_OP("relax.nn.conv3d")
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("weight", "Tensor", "The weight tensor.")
+ .set_attrs_type<Conv3DAttrs>()
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoConv3d)
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutConv3d)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kAlways)
+ .set_attr<FInferMixedPrecision>("FInferMixedPrecision",
InferMixedPrecisionConv3d)
+ .set_attr<Bool>("FPurity", Bool(true));
+
TVM_REGISTER_NODE_TYPE(Conv1DTransposeAttrs);
Expr conv1d_transpose(Expr data, Expr weight, Array<IntImm> strides,
Array<IntImm> padding,
diff --git a/src/relax/op/nn/convolution.h b/src/relax/op/nn/convolution.h
index 536d2af371..e4ddd44877 100644
--- a/src/relax/op/nn/convolution.h
+++ b/src/relax/op/nn/convolution.h
@@ -62,6 +62,11 @@ Expr conv2d(Expr data, Expr weight, Array<IntImm> strides,
Array<IntImm> padding
Array<IntImm> dilation, int groups, String data_layout, String
kernel_layout,
Optional<String> out_layout, DataType out_dtype);
+/*! \brief 3D convolution */
+Expr conv3d(Expr data, Expr weight, Array<IntImm> strides, Array<IntImm>
padding,
+ Array<IntImm> dilation, int groups, String data_layout, String
kernel_layout,
+ Optional<String> out_layout, DataType out_dtype);
+
/*!
* \brief One dimensional transposed convolution operator.
*
diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h
index 290cdef0d5..f5eed7af06 100644
--- a/src/relax/op/op_common.h
+++ b/src/relax/op/op_common.h
@@ -395,6 +395,31 @@ inline Array<IntImm> GetCompletePadding2D(Array<IntImm>
padding) {
throw;
}
+/*!
+ * \brief Complete the padding to a 6-length array.
+ * - If the padding length is 1, the same padding is used on all
front/top/left/back/bottom/right
+ * sides
+ * - If the padding length is 3, front/back sides use padding[0], top/bottom
sides use padding[1]
+ * and left/right use padding[2]
+ * - If the padding length is 6, padding is in the order of (front, top, left,
back, bottom, right)
+ * \param padding The given padding to be completed
+ * \return The completed padding.
+ * \throws Throws error if the input padding length is neither 1, 3 or 6.
+ */
+inline Array<IntImm> GetCompletePadding3D(Array<IntImm> padding) {
+ if (padding.size() == 1) {
+ return {padding[0], padding[0], padding[0], padding[0], padding[0],
padding[0]};
+ } else if (padding.size() == 3) {
+ return {padding[0], padding[1], padding[2], padding[0], padding[1],
padding[2]};
+ } else if (padding.size() == 6) {
+ return padding;
+ }
+ LOG(FATAL) << "The input padding length is expected to be either 1, 3 or 6.
However, the given "
+ "padding is "
+ << padding;
+ throw;
+}
+
/*!
* \brief Check if the given tensor layout can be converted to the given
target layout.
* If convertible, return the tensor layout and the bijective conversion in
tir::Layout and
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 937214dca6..dfa5cad4a5 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -368,6 +368,85 @@ def test_conv2d_transpose():
verify_model(model, input_info, binding, expected2)
+def test_conv3d():
+ class Conv3D1(Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = torch.nn.Conv3d(3, 6, 7, bias=True)
+
+ def forward(self, input):
+ return self.conv(input)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10, 10), dtype="float32"),
+ w1: R.Tensor((6, 3, 7, 7, 7), dtype="float32"),
+ w2: R.Tensor((6,), dtype="float32"),
+ ) -> R.Tensor((1, 6, 4, 4, 4), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv1: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = R.nn.conv3d(
+ input_1,
+ w1,
+ strides=[1],
+ padding=[0, 0, 0],
+ dilation=[1],
+ data_layout="NCDHW",
+ kernel_layout="OIDHW",
+ out_layout="NCDHW",
+ out_dtype="float32",
+ )
+ lv2: R.Tensor((1, 6, 1, 1, 1)) = R.reshape(w2, [1, 6, 1, 1, 1])
+ lv3: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = R.add(lv1,
lv2)
+ gv: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = lv3
+ R.output(gv)
+ return gv
+
+ class Conv3D2(Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = torch.nn.Conv3d(3, 6, 7, bias=False)
+
+ def forward(self, input):
+ return self.conv(input)
+
+ @tvm.script.ir_module
+ class expected2:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10, 10), dtype="float32"),
+ w1: R.Tensor((6, 3, 7, 7, 7), dtype="float32"),
+ ) -> R.Tensor((1, 6, 4, 4, 4), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv1: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = R.nn.conv3d(
+ input_1,
+ w1,
+ strides=[1],
+ padding=[0, 0, 0],
+ dilation=[1],
+ data_layout="NCDHW",
+ kernel_layout="OIDHW",
+ out_layout="NCDHW",
+ out_dtype="float32",
+ )
+ gv: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = lv1
+ R.output(gv)
+ return gv
+
+ input_info = [([1, 3, 10, 10, 10], "float32")]
+
+ model = Conv3D1()
+ binding = {"w1": model.conv.weight.detach().numpy(), "w2":
model.conv.bias.detach().numpy()}
+ verify_model(model, input_info, binding, expected1)
+
+ model = Conv3D2()
+ binding = {"w1": model.conv.weight.detach().numpy()}
+ verify_model(model, input_info, binding, expected2)
+
+
def test_linear():
# nn.Linear
class Dense1(Module):
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index 748119f6f9..f9a7643aa5 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -74,7 +74,8 @@ def generate_random_inputs(
def check_correctness(
model: ModelProto,
inputs: Optional[Dict[str, np.ndarray]] = None,
- opset: int = None,
+ ir_version: int = 8,
+ opset: int = 14,
atol: float = 1e-5,
) -> None:
"""Run an onnx model in both onnxruntime and TVM through our importer
@@ -86,12 +87,17 @@ def check_correctness(
The input onnx model that should be tested.
inputs: Optional[Dict[str, np.ndarray]]
An optional dictionary containing values for each input in the onnx
model.
+ ir_version: int
+ Which version of the onnx IR to use.
opset: int
The opset version to use for the onnx importer.
atol: float
Set the tolerance of correctness checking. Some ops may be show more
arithmetic variance than others.
"""
+ # Configure model format.
+ if ir_version is not None:
+ model.ir_version = ir_version
if opset is not None:
model.opset_import[0].version = opset
@@ -563,9 +569,14 @@ def test_conv():
)
model = helper.make_model(graph, producer_name="conv_test")
- check_correctness(model)
+ check_correctness(model, atol=1e-4)
+ # Conv1D
+ _verify_conv([3, 12, 32], [4, 12, 3], [3, 4, 30])
+ # Conv2D
_verify_conv([3, 12, 32, 32], [4, 12, 3, 3], [3, 4, 30, 30])
+ # Conv3D
+ _verify_conv([3, 12, 32, 32, 32], [4, 12, 3, 3, 3], [3, 4, 30, 30, 30])
def test_pow():
diff --git a/tests/python/relax/test_op_nn_convolution.py
b/tests/python/relax/test_op_nn_convolution.py
index 6be1245fe2..55e35ee203 100644
--- a/tests/python/relax/test_op_nn_convolution.py
+++ b/tests/python/relax/test_op_nn_convolution.py
@@ -37,6 +37,12 @@ def test_conv2d_op_correctness():
assert relax.op.nn.conv2d_transpose(x, w).op ==
Op.get("relax.nn.conv2d_transpose")
+def test_conv3d_op_correctness():
+ x = relax.Var("x", R.Tensor((2, 3, 28, 28, 28), "float32"))
+ w = relax.Var("w", R.Tensor((4, 3, 3, 3, 3), "float32"))
+ assert relax.op.nn.conv3d(x, w).op == Op.get("relax.nn.conv3d")
+
+
def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo:
relax.StructInfo):
ret = bb.normalize(call)
tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo)
@@ -1565,5 +1571,186 @@ def
test_conv2d_transpose_infer_struct_info_wrong_input_type():
bb.normalize(relax.op.nn.conv2d_transpose(x1, w0))
+def test_conv3d_infer_struct_info():
+ bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
+ x0 = relax.Var("x", R.Tensor((2, 3, 28, 28, 28), "float32"))
+ x1 = relax.Var("x", R.Tensor((2, 28, 28, 28, 3), "float32"))
+ x2 = relax.Var("x", R.Tensor("float32", ndim=5))
+ x3 = relax.Var("x", R.Tensor("float32"))
+ x4 = relax.Var("x", R.Tensor())
+ x5 = relax.Var("x", R.Tensor((2, 4, 28, 28, 28, 16), "float32"))
+ x6 = relax.Var("x", R.Tensor((2, 3, 28, 28, 28), "float32", vdev0))
+ w0 = relax.Var("w", R.Tensor((4, 3, 3, 3, 3), "float32"))
+ w1 = relax.Var("w", R.Tensor((3, 4, 3, 3, 3), "float32"))
+ w2 = relax.Var("w", R.Tensor("float32", ndim=5))
+ w3 = relax.Var("w", R.Tensor("float32"))
+ w4 = relax.Var("w", R.Tensor((48, 4, 3, 3, 3, 16), "float32"))
+ w5 = relax.Var("w", R.Tensor((4, 3, 3, 3, 3), "float32", vdev0))
+
+ _check_inference(
+ bb, relax.op.nn.conv3d(x0, w0), relax.TensorStructInfo((2, 4, 26, 26,
26), "float32")
+ )
+ _check_inference(
+ bb, relax.op.nn.conv3d(x6, w5), relax.TensorStructInfo((2, 4, 26, 26,
26), "float32", vdev0)
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv3d(x0, w0, out_dtype="float16"),
+ relax.TensorStructInfo((2, 4, 26, 26, 26), "float16"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv3d(x0, w0, padding=1),
+ relax.TensorStructInfo((2, 4, 28, 28, 28), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv3d(x0, w0, padding=[1, 2, 3]),
+ relax.TensorStructInfo((2, 4, 28, 30, 32), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv3d(x0, w0, padding=[1, 2, 3, 4, 5, 6]),
+ relax.TensorStructInfo((2, 4, 31, 33, 35), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv3d(x0, w0, strides=2),
+ relax.TensorStructInfo((2, 4, 13, 13, 13), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv3d(x0, w0, strides=(2, 3, 4)),
+ relax.TensorStructInfo((2, 4, 13, 9, 7), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv3d(x0, w0, dilation=2),
+ relax.TensorStructInfo((2, 4, 24, 24, 24), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv3d(x0, w0, dilation=(3, 2, 1)),
+ relax.TensorStructInfo((2, 4, 22, 24, 26), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv3d(x1, w0, data_layout="NDHWC"),
+ relax.TensorStructInfo((2, 26, 26, 26, 4), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv3d(x0, w0, out_layout="NDHWC"),
+ relax.TensorStructInfo((2, 26, 26, 26, 4), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv3d(x0, w1, kernel_layout="IODHW"),
+ relax.TensorStructInfo((2, 4, 26, 26, 26), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv3d(
+ x5, w4, data_layout="NCDHW16c", kernel_layout="OIDHW16i",
out_layout="NDHWC16c"
+ ),
+ relax.TensorStructInfo((2, 26, 26, 26, 3, 16), "float32"),
+ )
+ _check_inference(
+ bb, relax.op.nn.conv3d(x2, w0),
relax.TensorStructInfo(dtype="float32", ndim=5)
+ )
+ _check_inference(
+ bb, relax.op.nn.conv3d(x3, w0),
relax.TensorStructInfo(dtype="float32", ndim=5)
+ )
+ _check_inference(
+ bb, relax.op.nn.conv3d(x0, w2),
relax.TensorStructInfo(dtype="float32", ndim=5)
+ )
+ _check_inference(
+ bb, relax.op.nn.conv3d(x0, w3),
relax.TensorStructInfo(dtype="float32", ndim=5)
+ )
+ _check_inference(bb, relax.op.nn.conv3d(x4, w0),
relax.TensorStructInfo(dtype="", ndim=5))
+
+
+def test_conv3d_infer_struct_info_shape_symbolic():
+ bb = relax.BlockBuilder()
+ n = tir.Var("n", "int64")
+ c = tir.Var("c", "int64")
+ c16 = tir.Var("c16", "int64")
+ id = tir.Var("id", "int64")
+ ih = tir.Var("ih", "int64")
+ iw = tir.Var("iw", "int64")
+ ki = tir.Var("ki", "int64")
+ ko = tir.Var("ko", "int64")
+ kd = tir.Var("kd", "int64")
+ kh = tir.Var("kh", "int64")
+ kw = tir.Var("kw", "int64")
+ x0 = relax.Var("x", R.Tensor((n, c, id, ih, iw), "float32"))
+ x1 = relax.Var("x", R.Tensor((n, c, id, ih, iw, c16), "float32"))
+ w0 = relax.Var("w", R.Tensor((ko, ki, kd, kh, kw), "float32"))
+ w1 = relax.Var("w", R.Tensor((ko, c, kd, kh, kw), "float32"))
+ w2 = relax.Var("w", R.Tensor((ko, c, kd, kh, kw, c16), "float32"))
+
+ _check_inference(
+ bb,
+ relax.op.nn.conv3d(x0, w0),
+ relax.TensorStructInfo((n, ko, id + 1 - kd, ih + 1 - kh, iw + 1 - kw),
"float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv3d(x0, w1),
+ relax.TensorStructInfo((n, ko, id + 1 - kd, ih + 1 - kh, iw + 1 - kw),
"float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv3d(
+ x1, w2, data_layout="NCDHW16c", kernel_layout="OIDHW16i",
out_layout="NCDHW"
+ ),
+ relax.TensorStructInfo((n, ko, id + 1 - kd, ih + 1 - kh, iw + 1 - kw),
"float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv3d(x0, w0, strides=(2, 2, 2), padding=(1, 1, 1),
dilation=(2, 2, 2)),
+ relax.TensorStructInfo(
+ (
+ n,
+ ko,
+ tvm.tir.floordiv(id + 3, 2) + 1 - kd,
+ tvm.tir.floordiv(ih + 3, 2) + 1 - kh,
+ tvm.tir.floordiv(iw + 3, 2) + 1 - kw,
+ ),
+ "float32",
+ ),
+ )
+
+
+def test_conv3d_infer_struct_info_shape_var():
+ bb = relax.BlockBuilder()
+ s0 = relax.Var("s", relax.ShapeStructInfo(ndim=5))
+ s1 = relax.Var("s", relax.ShapeStructInfo(ndim=6))
+ s2 = relax.Var("s", relax.ShapeStructInfo(ndim=5))
+ s3 = relax.Var("s", relax.ShapeStructInfo())
+ x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+ x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
+ x2 = relax.Var("x", relax.TensorStructInfo(s3, "float32"))
+ w = relax.Var("w", relax.TensorStructInfo(s2, "float32"))
+
+ _check_inference(bb, relax.op.nn.conv3d(x0, w),
relax.TensorStructInfo(dtype="float32", ndim=5))
+ _check_inference(
+ bb,
+ relax.op.nn.conv3d(x1, w, data_layout="NCDHW16c"),
+ relax.TensorStructInfo(dtype="float32", ndim=6),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv3d(x0, w, out_layout="NCDHW16c"),
+ relax.TensorStructInfo(dtype="float32", ndim=6),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv3d(x2, w),
+ relax.TensorStructInfo(dtype="float32", ndim=5),
+ )
+
+
if __name__ == "__main__":
tvm.testing.main()