This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 5c87bfe09b [Unity][Relax][Op] Add Conv3D Operator (#16385)
5c87bfe09b is described below

commit 5c87bfe09b274b252c29a61d4a7c742689114886
Author: Josh Fromm <[email protected]>
AuthorDate: Sat Jan 13 18:06:04 2024 -0800

    [Unity][Relax][Op] Add Conv3D Operator (#16385)
---
 include/tvm/relax/attrs/nn.h                     |  45 ++++++
 python/tvm/relax/frontend/onnx/onnx_frontend.py  |  48 +++---
 python/tvm/relax/frontend/torch/fx_translator.py |  29 ++++
 python/tvm/relax/op/nn/__init__.py               |   1 +
 python/tvm/relax/op/nn/nn.py                     | 100 ++++++++++++
 python/tvm/relax/transform/legalize_ops/nn.py    |  41 +++++
 src/relax/op/nn/convolution.cc                   | 176 +++++++++++++++++++++
 src/relax/op/nn/convolution.h                    |   5 +
 src/relax/op/op_common.h                         |  25 +++
 tests/python/relax/test_frontend_from_fx.py      |  79 ++++++++++
 tests/python/relax/test_frontend_onnx.py         |  15 +-
 tests/python/relax/test_op_nn_convolution.py     | 187 +++++++++++++++++++++++
 12 files changed, 725 insertions(+), 26 deletions(-)

diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h
index 424874bd75..dd63a70bc4 100644
--- a/include/tvm/relax/attrs/nn.h
+++ b/include/tvm/relax/attrs/nn.h
@@ -117,6 +117,51 @@ struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
   }
 };  // struct Conv2dAttrs
 
+/*! \brief Attributes used in Conv3d operator */
+struct Conv3DAttrs : public tvm::AttrsNode<Conv3DAttrs> {
+  Array<IntImm> strides;
+  Array<IntImm> padding;
+  Array<IntImm> dilation;
+  int groups;
+  String data_layout;
+  String kernel_layout;
+  String out_layout;
+  DataType out_dtype;
+
+  TVM_DECLARE_ATTRS(Conv3DAttrs, "relax.attrs.Conv3DAttrs") {
+    TVM_ATTR_FIELD(strides).describe("Specifies the strides of the 
convolution.");
+    TVM_ATTR_FIELD(padding).describe(
+        "If padding is non-zero, then the input is implicitly zero-padded"
+        "Padding support both symmetric and asymmetric as"
+        "one int : same padding used on all sides"
+        "two int : bottom, right will use same padding as top, left"
+        "four int : padding width in the order of (forward, back, top, left, 
bottom, right)");
+    TVM_ATTR_FIELD(dilation).describe(
+        "Specifies the dilation rate to use for dilated convolution.");
+    TVM_ATTR_FIELD(groups).describe(
+        "Number of groups to split the input into for grouped convolution. The 
number of input and "
+        "output channels should be divisible by the number of groups.");
+    TVM_ATTR_FIELD(data_layout)
+        .describe(
+            "Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc."
+            "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, 
and width"
+            "dimensions respectively. Convolution is applied on the 'D', 'H', 
and"
+            "'W' dimensions.");
+    TVM_ATTR_FIELD(kernel_layout)
+        .describe(
+            "Dimension ordering of weight. Can be 'OIDHW', 'OIDHW16o16i', etc."
+            "'O', 'I', 'D', 'H', 'W' stands for num_filter, input_channel, 
depth, height, and width"
+            "dimensions respectively.");
+    TVM_ATTR_FIELD(out_layout)
+        .describe(
+            "Dimension ordering of output. Can be 'NCDHW', 'NDHWC', etc."
+            "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, 
and width"
+            "dimensions respectively. Default to be same as input layout.");
+    TVM_ATTR_FIELD(out_dtype).describe(
+        "Output data type, set to explicit type under mixed precision 
setting");
+  }
+};  // struct Conv3dAttrs
+
 /*! \brief Attributes used in Conv1DTranspose operator */
 struct Conv1DTransposeAttrs : public tvm::AttrsNode<Conv1DTransposeAttrs> {
   Array<IntImm> strides;
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 0b5aa4f7ec..702501ce16 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -685,32 +685,32 @@ class Conv(OnnxOpConverter):
             ndim = len(inputs[0].struct_info.shape)
 
         if ndim == 3:
-            conv_out = bb.emit_te(
-                topi.nn.conv1d,
-                inputs[0],
-                inputs[1],
-                attr.get("strides", 1),
-                attr.get("pads", 0),
-                attr.get("dilation", 1),
-                "NCHW",
-                "OIHW",
-            )
+            op = relax.op.nn.conv1d
+            data_layout = "NCW"
+            kernel_layout = "OIW"
         elif ndim == 4:
-            conv_out = bb.normalize(
-                relax.op.nn.conv2d(
-                    data=inputs[0],
-                    weight=inputs[1],
-                    strides=attr.get("strides", 1),
-                    padding=attr.get("pads", 0),
-                    dilation=attr.get("dilation", 1),
-                    groups=attr.get("group", 1),
-                    data_layout="NCHW",
-                    kernel_layout="OIHW",
-                )
-            )
+            op = relax.op.nn.conv2d
+            data_layout = "NCHW"
+            kernel_layout = "OIHW"
+        elif ndim == 5:
+            op = relax.op.nn.conv3d
+            data_layout = "NCDHW"
+            kernel_layout = "OIDHW"
         else:
-            raise NotImplementedError("Only 2d conv currently supported.")
-
+            raise NotImplementedError("Ndim > 5 not supported for 
convolution.")
+
+        conv_out = bb.normalize(
+            op(
+                data=inputs[0],
+                weight=inputs[1],
+                strides=attr.get("strides", 1),
+                padding=attr.get("pads", 0),
+                dilation=attr.get("dilation", 1),
+                groups=attr.get("group", 1),
+                data_layout=data_layout,
+                kernel_layout=kernel_layout,
+            )
+        )
         if inputs[2] is not None:
             bias = relax.op.reshape(
                 inputs[2],
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 1510df9548..5e581e81f3 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -695,6 +695,34 @@ class TorchFXImporter:
 
         return self.block_builder.emit(relax.op.add(conv1d, bias))
 
+    def _conv3d(self, node: fx.node.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        module = self.named_modules[node.target]
+        weight = self.params[module.weight]
+
+        conv3d = self.block_builder.emit(
+            relax.op.nn.conv3d(
+                x,
+                weight,
+                strides=module.stride,
+                padding=module.padding,
+                dilation=module.dilation,
+                groups=module.groups,
+                data_layout="NCDHW",
+                kernel_layout="OIDHW",
+                out_dtype="float32",
+            )
+        )
+
+        if module.bias is None:
+            return conv3d
+
+        bias = self.params[module.bias]
+        assert len(self.shape_of(bias)) == 1
+        bias = relax.op.reshape(bias, (1, -1, 1, 1, 1))
+
+        return self.block_builder.emit(relax.op.add(conv3d, bias))
+
     def _conv2d_impl(
         self,
         x: relax.Expr,
@@ -1313,6 +1341,7 @@ class TorchFXImporter:
             nn.Linear: self._linear,
             nn.Conv1d: self._conv1d,
             nn.Conv2d: self._conv2d,
+            nn.Conv3d: self._conv3d,
             nn.ConvTranspose1d: self._conv1d_transpose,
             nn.ConvTranspose2d: self._conv2d_transpose,
             nn.MaxPool2d: self._max_pool2d,
diff --git a/python/tvm/relax/op/nn/__init__.py 
b/python/tvm/relax/op/nn/__init__.py
index 9f01086a69..d90b207314 100644
--- a/python/tvm/relax/op/nn/__init__.py
+++ b/python/tvm/relax/op/nn/__init__.py
@@ -25,6 +25,7 @@ from .nn import (
     conv1d_transpose,
     conv2d,
     conv2d_transpose,
+    conv3d,
     cross_entropy_with_logits,
     dropout,
     gelu,
diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
index 5adf38d7d6..5c18d31bf2 100644
--- a/python/tvm/relax/op/nn/nn.py
+++ b/python/tvm/relax/op/nn/nn.py
@@ -220,6 +220,106 @@ def conv2d(
     )
 
 
+def conv3d(
+    data: Expr,
+    weight: Expr,
+    strides: Union[int, Tuple[int, int]] = (1, 1, 1),
+    padding: Union[int, Tuple[int, ...]] = (0, 0, 0),
+    dilation: Union[int, Tuple[int, int]] = (1, 1, 1),
+    groups: int = 1,
+    data_layout: str = "NCDHW",
+    kernel_layout: str = "OIDHW",
+    out_layout: Optional[str] = None,
+    out_dtype: Optional[Union[str, DataType]] = None,
+) -> Expr:
+    r"""3D convolution.
+
+    This operator takes the weight as the convolution kernel
+    and convolves it with data to produce an output.
+
+
+    In the default case, where the data_layout is `NCDHW`
+    and kernel_layout is `OIDHW`, conv3d takes in
+    a data Tensor with shape `(batch_size, in_channels, depth, height, width)`,
+    and a weight Tensor with shape `(channels, in_channels, kernel_d, 
kernel_h, kernel_w)`,
+    where `kernel_d`, `kernel_h`, and `kernel_w` are the lengths of the `D`, 
`H`,
+    and `W` kernel dimensions, to produce an output Tensor with the following 
rule:
+
+    .. math::
+
+        \mbox{out}[b, c, z, y, x] = \sum_{dz, dy, dx, k}
+           \mbox{data}[b, k, \mbox{strides}[0] * z + dz,
+           \mbox{strides}[1] * y  + dy,
+           \mbox{strides}[2] * x + dx] *
+           \mbox{weight}[c, k, dz, dy, dx]
+
+    Padding and dilation are applied to data and weight respectively before 
the computation.
+    This operator accepts data layout specification.
+    Semantically, the operator will convert the layout to the canonical layout
+    (`NCDHW` for data and `OIDHW` for weight), perform the computation,
+    then convert to the out_layout.
+
+    Parameters
+    ----------
+    data : relax.Expr
+        The input data to the operator.
+
+    weight : relax.Expr
+        The weight expressions.
+
+    strides : Union[int, Tuple[int, int, int]]
+        The strides of convolution. It is required to have length either 1 or 
3.
+
+    padding : Union[int, Tuple[int, ...]]
+        The padding of convolution on both sides of inputs before convolution.
+        It is required to have length either 1, 3 or 6.
+
+    dilation : Union[int, Tuple[int, int, int]]
+        Specifies the dilation rate to be used for dilated convolution.
+        It is required to have length either 1 or 3.
+
+    groups : int
+        Number of groups to split the input into for grouped convolution.
+        The number of input and output channels should be divisible by the 
number of groups.
+
+    data_layout : str
+        Layout of the input.
+
+    kernel_layout : str
+        Layout of the weight.
+
+    out_layout : Optional[str]
+        Layout of the output. If not specified, it is the same as data_layout
+
+    out_dtype : Optional[Union[str, DataType]]
+        Specifies the output data type for mixed precision conv2d.
+
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+    """
+    if isinstance(strides, int):
+        strides = (strides, strides, strides)
+    if isinstance(dilation, int):
+        dilation = (dilation, dilation, dilation)
+    if isinstance(padding, int):
+        padding = (padding, padding, padding, padding, padding, padding)
+
+    return _ffi_api.conv3d(  # type: ignore
+        data,
+        weight,
+        strides,
+        padding,
+        dilation,
+        groups,
+        data_layout,
+        kernel_layout,
+        out_layout,
+        out_dtype,
+    )
+
+
 def conv1d_transpose(
     data: Expr,
     weight: Expr,
diff --git a/python/tvm/relax/transform/legalize_ops/nn.py 
b/python/tvm/relax/transform/legalize_ops/nn.py
index f2453a67b6..186071f227 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -109,6 +109,47 @@ def _nn_conv2d(bb: BlockBuilder, call: Call) -> Expr:
     )
 
 
+@register_legalize("relax.nn.conv3d")
+def _nn_conv3d(bb: BlockBuilder, call: Call) -> Expr:
+    if call.attrs.out_layout != call.attrs.data_layout:
+        logging.info(
+            "TOPI conv3d does not support different input-output "
+            "layouts, and thus cannot be legalized by TOPI"
+        )
+        return call
+    if len(call.attrs.data_layout) != 5 or len(call.attrs.kernel_layout) != 5:
+        logging.info(
+            "Conv3D where data layout or kernel layout have channel chunk "
+            "cannot be legalized by TOPI at this moment."
+        )
+        return call
+    if call.attrs.groups != 1:
+        data_layout = tir.layout(call.attrs.data_layout)
+        kernel_layout = tir.layout(call.attrs.kernel_layout)
+        ic = call.args[0].struct_info.shape.values[data_layout.index_of("C")]
+        oc = call.args[1].struct_info.shape.values[kernel_layout.index_of("O")]
+        if not isinstance(ic, tir.IntImm) or not isinstance(oc, tir.IntImm):
+            logging.info(
+                "Conv3D where number of groups is more than one and input or 
output "
+                "channel size is symbolic cannot be legalized by TOPI at this 
moment."
+            )
+            return call
+
+    return bb.call_te(
+        topi.nn.conv,
+        inp=call.args[0],
+        filt=call.args[1],
+        stride=call.attrs.strides,
+        padding=call.attrs.padding,
+        dilation=call.attrs.dilation,
+        groups=call.attrs.groups,
+        data_layout=call.attrs.data_layout,
+        kernel_layout=call.attrs.kernel_layout,
+        out_dtype=call.attrs.out_dtype if call.attrs.out_dtype != "" else None,
+        primfunc_name_hint="conv3d",
+    )
+
+
 @register_legalize("relax.nn.conv1d_transpose")
 def _nn_conv1d_transpose(bb: BlockBuilder, call: Call) -> Expr:
     if call.attrs.out_layout != call.attrs.data_layout:
diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc
index cea234060e..7c7718b837 100644
--- a/src/relax/op/nn/convolution.cc
+++ b/src/relax/op/nn/convolution.cc
@@ -354,6 +354,182 @@ TVM_REGISTER_OP("relax.nn.conv2d")
     .set_attr<FInferMixedPrecision>("FInferMixedPrecision", 
InferMixedPrecisionConv2d)
     .set_attr<Bool>("FPurity", Bool(true));
 
+/* relax.nn.conv3d */
+TVM_REGISTER_NODE_TYPE(Conv3DAttrs);
+
+Expr conv3d(Expr data, Expr weight, Array<IntImm> strides, Array<IntImm> 
padding,
+            Array<IntImm> dilation, int groups, String data_layout, String 
kernel_layout,
+            Optional<String> out_layout, DataType out_dtype) {
+  padding = GetCompletePadding3D(std::move(padding));
+  if (strides.size() == 1) {
+    strides.push_back(strides[0]);
+    strides.push_back(strides[0]);
+  }
+  if (dilation.size() == 1) {
+    dilation.push_back(dilation[0]);
+    dilation.push_back(dilation[0]);
+  }
+
+  CHECK_GT(groups, 0) << "The number of groups in convolution is expected to 
be positive. However, "
+                         "the given number of groups is "
+                      << groups;
+  CHECK_EQ(strides.size(), 3)
+      << "The input strides length is expected to be 3. However, the given 
strides is " << strides;
+  CHECK_EQ(dilation.size(), 3)
+      << "The input dilation length is expected to be 3. However, the given 
dilation is "
+      << dilation;
+  return MakeConv<Conv3DAttrs>(std::move(data), std::move(weight), 
std::move(strides),
+                               std::move(padding), std::move(dilation), 
groups, data_layout,
+                               std::move(kernel_layout), 
out_layout.value_or(data_layout),
+                               out_dtype, /*op_name=*/"relax.nn.conv3d");
+}
+
+TVM_REGISTER_GLOBAL("relax.op.nn.conv3d").set_body_typed(conv3d);
+
+StructInfo InferStructInfoConv3d(const Call& call, const BlockBuilder& ctx) {
+  Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
+  TensorStructInfo data_sinfo = input_sinfo[0];
+  TensorStructInfo weight_sinfo = input_sinfo[1];
+
+  const auto* attrs = call->attrs.as<Conv3DAttrs>();
+  auto [data_layout, data2NCDHW] = CheckTensorLayout(call, ctx, 
attrs->data_layout,  //
+                                                     /*tgt_layout=*/"NCDHW",   
      //
+                                                     /*tensor_name=*/"data");
+  auto [weight_layout, weight2OIDHW] = CheckTensorLayout(call, ctx, 
attrs->kernel_layout,  //
+                                                         
/*tgt_layout=*/"OIDHW",           //
+                                                         
/*tensor_name=*/"kernel");
+  auto [out_layout, out2NCDHW] = CheckTensorLayout(call, ctx, 
attrs->out_layout,  //
+                                                   /*tgt_layout=*/"NCDHW",     
   //
+                                                   /*tensor_name=*/"output");
+
+  Optional<ShapeExpr> data_shape =
+      CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout);
+  Optional<ShapeExpr> weight_shape =
+      CheckNdimPerLayoutAndGetShape(call, ctx, weight_sinfo, weight_layout);
+
+  DataType out_dtype = attrs->out_dtype.is_void()
+                           ? InferBinaryArithOpOutDtype(call, ctx, data_sinfo, 
weight_sinfo)
+                           : attrs->out_dtype;
+  Optional<VDevice> vdevice = InferBinaryArithOpOutVDevice(call, ctx, 
data_sinfo, weight_sinfo);
+  if (!data_shape.defined() || !weight_shape.defined()) {
+    return TensorStructInfo(out_dtype, out_layout.ndim(), vdevice);
+  }
+
+  Array<PrimExpr> data_NCDHW_shape = 
data2NCDHW.ForwardShape(data_shape.value()->values);
+  Array<PrimExpr> weight_OIDHW_shape = 
weight2OIDHW.ForwardShape(weight_shape.value()->values);
+
+  arith::Analyzer* analyzer = ctx->GetAnalyzer();
+  PrimExpr input_channel_data = data_NCDHW_shape[1];
+  PrimExpr input_channel_kernel = weight_OIDHW_shape[1];
+  if (analyzer->CanProve(input_channel_data != input_channel_kernel * 
attrs->groups)) {
+    ctx->ReportFatal(
+        Diagnostic::Error(call)
+        << "The channel size of the data should equal to the product of input 
channel size of the "
+           "weight and the number of groups. However, the data channel size is 
"
+        << input_channel_data << " while the weight input channel size and 
number of groups are "
+        << input_channel_kernel << " and " << attrs->groups);
+  } else if (!analyzer->CanProveEqual(input_channel_data, input_channel_kernel 
* attrs->groups)) {
+    // Todo(relax-team): Trust the input shape at this moment, and revisit
+    // this condition with runtime shape check
+  }
+  if (analyzer->CanProve(floormod(weight_OIDHW_shape[0], attrs->groups) != 0)) 
{
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "Conv3d expects the number of output channels to be 
divisible by the "
+                        "number of groups. However, the number of output 
channels is "
+                     << weight_OIDHW_shape[0] << " while the number of groups 
is "
+                     << attrs->groups);
+  } else if (!analyzer->CanProveEqual(floormod(weight_OIDHW_shape[0], 
attrs->groups), 0)) {
+    // Todo(relax-team): Trust the input shape at this moment, and revisit
+    // this condition with runtime shape check
+  }
+
+  PrimExpr input_d = data_NCDHW_shape[2];
+  PrimExpr input_h = data_NCDHW_shape[3];
+  PrimExpr input_w = data_NCDHW_shape[4];
+  PrimExpr kernel_d = weight_OIDHW_shape[2];
+  PrimExpr kernel_h = weight_OIDHW_shape[3];
+  PrimExpr kernel_w = weight_OIDHW_shape[4];
+  PrimExpr padding_d = attrs->padding[0] + attrs->padding[3];
+  PrimExpr padding_h = attrs->padding[1] + attrs->padding[4];
+  PrimExpr padding_w = attrs->padding[2] + attrs->padding[5];
+
+  std::vector<PrimExpr> out_NCDHW_shape;
+  out_NCDHW_shape.resize(5);
+  out_NCDHW_shape[0] = data_NCDHW_shape[0];
+  out_NCDHW_shape[1] = weight_OIDHW_shape[0];
+
+  PrimExpr numerator_d = input_d + padding_d - attrs->dilation[0] * (kernel_d 
- 1) - 1;
+  PrimExpr numerator_h = input_h + padding_h - attrs->dilation[1] * (kernel_h 
- 1) - 1;
+  PrimExpr numerator_w = input_w + padding_w - attrs->dilation[2] * (kernel_w 
- 1) - 1;
+  out_NCDHW_shape[2] = analyzer->Simplify(floordiv(numerator_d, 
attrs->strides[0]) + 1);
+  out_NCDHW_shape[3] = analyzer->Simplify(floordiv(numerator_h, 
attrs->strides[1]) + 1);
+  out_NCDHW_shape[4] = analyzer->Simplify(floordiv(numerator_w, 
attrs->strides[2]) + 1);
+
+  Array<PrimExpr> out_shape = out2NCDHW.BackwardShape(out_NCDHW_shape);
+  return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice);
+}
+
+InferLayoutOutput InferLayoutConv3d(const Call& call,
+                                    const Map<String, Array<String>>& 
desired_layouts,
+                                    const VarLayoutMap& var_layout_map) {
+  const auto& it = desired_layouts.find("relax.nn.conv3d");
+  const auto* attrs = call->attrs.as<Conv3DAttrs>();
+  ICHECK(attrs) << "Invalid Call";
+
+  LayoutDecision data_layout, weight_layout, output_layout;
+  ObjectPtr<Conv3DAttrs> new_attrs = make_object<Conv3DAttrs>(*attrs);
+
+  if (it != desired_layouts.end()) {
+    // We have a desired layout for conv3d.
+    Layout desired_data_layout = (*it).second[0];
+    Layout desired_weight_layout = (*it).second[1];
+    Layout desired_output_layout = (*it).second.size() == 3 ? (*it).second[2] 
: (*it).second[0];
+    ICHECK_EQ(desired_data_layout.ndim(), desired_data_layout.ndim_primal()) 
<< "Axis swap only";
+    ICHECK_EQ(desired_weight_layout.ndim(), 
desired_weight_layout.ndim_primal())
+        << "Axis swap only";
+    ICHECK_EQ(desired_output_layout.ndim(), 
desired_output_layout.ndim_primal())
+        << "Axis swap only";
+    data_layout = TransposeLike(InitialLayout(5), attrs->data_layout, 
desired_data_layout);
+    weight_layout = TransposeLike(InitialLayout(5), attrs->kernel_layout, 
desired_weight_layout);
+    output_layout = TransposeLike(InitialLayout(5), attrs->out_layout, 
desired_output_layout);
+    new_attrs->data_layout = (*it).second[0];
+    new_attrs->kernel_layout = (*it).second[1];
+    new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] : 
(*it).second[0];
+  } else {
+    // We don't have a desired layout for conv2d.
+    // We can just propagate the layout from the input.
+    data_layout = GetLayoutDecision(var_layout_map, call->args[0]);
+    weight_layout = GetLayoutDecision(var_layout_map, call->args[1]);
+    output_layout = data_layout;
+    new_attrs->data_layout =
+        TransposeLike(attrs->data_layout, InitialLayout(5), 
data_layout->layout).name();
+    new_attrs->kernel_layout =
+        TransposeLike(attrs->kernel_layout, InitialLayout(5), 
weight_layout->layout).name();
+    new_attrs->out_layout =
+        TransposeLike(attrs->out_layout, InitialLayout(5), 
output_layout->layout).name();
+  }
+  return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, 
Attrs(new_attrs));
+}
+
+Call InferMixedPrecisionConv3d(const Call& call, const DataType& out_dtype) {
+  const auto* conv3d_attrs = call->attrs.as<Conv3DAttrs>();
+  return Downcast<Call>(conv3d(call->args[0], call->args[1], 
conv3d_attrs->strides,
+                               conv3d_attrs->padding, conv3d_attrs->dilation, 
conv3d_attrs->groups,
+                               conv3d_attrs->data_layout, 
conv3d_attrs->kernel_layout,
+                               conv3d_attrs->out_layout, out_dtype));
+}
+
+TVM_REGISTER_OP("relax.nn.conv3d")
+    .set_num_inputs(2)
+    .add_argument("data", "Tensor", "The input tensor.")
+    .add_argument("weight", "Tensor", "The weight tensor.")
+    .set_attrs_type<Conv3DAttrs>()
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoConv3d)
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutConv3d)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kAlways)
+    .set_attr<FInferMixedPrecision>("FInferMixedPrecision", 
InferMixedPrecisionConv3d)
+    .set_attr<Bool>("FPurity", Bool(true));
+
 TVM_REGISTER_NODE_TYPE(Conv1DTransposeAttrs);
 
 Expr conv1d_transpose(Expr data, Expr weight, Array<IntImm> strides, 
Array<IntImm> padding,
diff --git a/src/relax/op/nn/convolution.h b/src/relax/op/nn/convolution.h
index 536d2af371..e4ddd44877 100644
--- a/src/relax/op/nn/convolution.h
+++ b/src/relax/op/nn/convolution.h
@@ -62,6 +62,11 @@ Expr conv2d(Expr data, Expr weight, Array<IntImm> strides, 
Array<IntImm> padding
             Array<IntImm> dilation, int groups, String data_layout, String 
kernel_layout,
             Optional<String> out_layout, DataType out_dtype);
 
+/*! \brief 3D convolution */
+Expr conv3d(Expr data, Expr weight, Array<IntImm> strides, Array<IntImm> 
padding,
+            Array<IntImm> dilation, int groups, String data_layout, String 
kernel_layout,
+            Optional<String> out_layout, DataType out_dtype);
+
 /*!
  * \brief One dimensional transposed convolution operator.
  *
diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h
index 290cdef0d5..f5eed7af06 100644
--- a/src/relax/op/op_common.h
+++ b/src/relax/op/op_common.h
@@ -395,6 +395,31 @@ inline Array<IntImm> GetCompletePadding2D(Array<IntImm> 
padding) {
   throw;
 }
 
+/*!
+ * \brief Complete the padding to a 6-length array.
+ * - If the padding length is 1, the same padding is used on all 
front/top/left/back/bottom/right
+ * sides
+ * - If the padding length is 3, front/back sides use padding[0], top/bottom 
sides use padding[1]
+ * and left/right use padding[2]
+ * - If the padding length is 6, padding is in the order of (front, top, left, 
back, bottom, right)
+ * \param padding The given padding to be completed
+ * \return The completed padding.
+ * \throws Throws error if the input padding length is neither 1, 3 or 6.
+ */
+inline Array<IntImm> GetCompletePadding3D(Array<IntImm> padding) {
+  if (padding.size() == 1) {
+    return {padding[0], padding[0], padding[0], padding[0], padding[0], 
padding[0]};
+  } else if (padding.size() == 3) {
+    return {padding[0], padding[1], padding[2], padding[0], padding[1], 
padding[2]};
+  } else if (padding.size() == 6) {
+    return padding;
+  }
+  LOG(FATAL) << "The input padding length is expected to be either 1, 3 or 6. 
However, the given "
+                "padding is "
+             << padding;
+  throw;
+}
+
 /*!
  * \brief Check if the given tensor layout can be converted to the given 
target layout.
  * If convertible, return the tensor layout and the bijective conversion in 
tir::Layout and
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 937214dca6..dfa5cad4a5 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -368,6 +368,85 @@ def test_conv2d_transpose():
     verify_model(model, input_info, binding, expected2)
 
 
+def test_conv3d():
+    class Conv3D1(Module):
+        def __init__(self):
+            super().__init__()
+            self.conv = torch.nn.Conv3d(3, 6, 7, bias=True)
+
+        def forward(self, input):
+            return self.conv(input)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10, 10), dtype="float32"),
+            w1: R.Tensor((6, 3, 7, 7, 7), dtype="float32"),
+            w2: R.Tensor((6,), dtype="float32"),
+        ) -> R.Tensor((1, 6, 4, 4, 4), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv1: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = R.nn.conv3d(
+                    input_1,
+                    w1,
+                    strides=[1],
+                    padding=[0, 0, 0],
+                    dilation=[1],
+                    data_layout="NCDHW",
+                    kernel_layout="OIDHW",
+                    out_layout="NCDHW",
+                    out_dtype="float32",
+                )
+                lv2: R.Tensor((1, 6, 1, 1, 1)) = R.reshape(w2, [1, 6, 1, 1, 1])
+                lv3: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = R.add(lv1, 
lv2)
+                gv: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = lv3
+                R.output(gv)
+            return gv
+
+    class Conv3D2(Module):
+        def __init__(self):
+            super().__init__()
+            self.conv = torch.nn.Conv3d(3, 6, 7, bias=False)
+
+        def forward(self, input):
+            return self.conv(input)
+
+    @tvm.script.ir_module
+    class expected2:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10, 10), dtype="float32"),
+            w1: R.Tensor((6, 3, 7, 7, 7), dtype="float32"),
+        ) -> R.Tensor((1, 6, 4, 4, 4), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv1: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = R.nn.conv3d(
+                    input_1,
+                    w1,
+                    strides=[1],
+                    padding=[0, 0, 0],
+                    dilation=[1],
+                    data_layout="NCDHW",
+                    kernel_layout="OIDHW",
+                    out_layout="NCDHW",
+                    out_dtype="float32",
+                )
+                gv: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = lv1
+                R.output(gv)
+            return gv
+
+    input_info = [([1, 3, 10, 10, 10], "float32")]
+
+    model = Conv3D1()
+    binding = {"w1": model.conv.weight.detach().numpy(), "w2": 
model.conv.bias.detach().numpy()}
+    verify_model(model, input_info, binding, expected1)
+
+    model = Conv3D2()
+    binding = {"w1": model.conv.weight.detach().numpy()}
+    verify_model(model, input_info, binding, expected2)
+
+
 def test_linear():
     # nn.Linear
     class Dense1(Module):
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index 748119f6f9..f9a7643aa5 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -74,7 +74,8 @@ def generate_random_inputs(
 def check_correctness(
     model: ModelProto,
     inputs: Optional[Dict[str, np.ndarray]] = None,
-    opset: int = None,
+    ir_version: int = 8,
+    opset: int = 14,
     atol: float = 1e-5,
 ) -> None:
     """Run an onnx model in both onnxruntime and TVM through our importer
@@ -86,12 +87,17 @@ def check_correctness(
         The input onnx model that should be tested.
     inputs: Optional[Dict[str, np.ndarray]]
         An optional dictionary containing values for each input in the onnx 
model.
+    ir_version: int
+        Which version of the onnx IR to use.
     opset: int
         The opset version to use for the onnx importer.
     atol: float
         Set the tolerance of correctness checking. Some ops may be show more
         arithmetic variance than others.
     """
+    # Configure model format.
+    if ir_version is not None:
+        model.ir_version = ir_version
     if opset is not None:
         model.opset_import[0].version = opset
 
@@ -563,9 +569,14 @@ def test_conv():
         )
 
         model = helper.make_model(graph, producer_name="conv_test")
-        check_correctness(model)
+        check_correctness(model, atol=1e-4)
 
+    # Conv1D
+    _verify_conv([3, 12, 32], [4, 12, 3], [3, 4, 30])
+    # Conv2D
     _verify_conv([3, 12, 32, 32], [4, 12, 3, 3], [3, 4, 30, 30])
+    # Conv3D
+    _verify_conv([3, 12, 32, 32, 32], [4, 12, 3, 3, 3], [3, 4, 30, 30, 30])
 
 
 def test_pow():
diff --git a/tests/python/relax/test_op_nn_convolution.py 
b/tests/python/relax/test_op_nn_convolution.py
index 6be1245fe2..55e35ee203 100644
--- a/tests/python/relax/test_op_nn_convolution.py
+++ b/tests/python/relax/test_op_nn_convolution.py
@@ -37,6 +37,12 @@ def test_conv2d_op_correctness():
     assert relax.op.nn.conv2d_transpose(x, w).op == 
Op.get("relax.nn.conv2d_transpose")
 
 
+def test_conv3d_op_correctness():
+    x = relax.Var("x", R.Tensor((2, 3, 28, 28, 28), "float32"))
+    w = relax.Var("w", R.Tensor((4, 3, 3, 3, 3), "float32"))
+    assert relax.op.nn.conv3d(x, w).op == Op.get("relax.nn.conv3d")
+
+
 def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: 
relax.StructInfo):
     ret = bb.normalize(call)
     tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo)
@@ -1565,5 +1571,186 @@ def 
test_conv2d_transpose_infer_struct_info_wrong_input_type():
         bb.normalize(relax.op.nn.conv2d_transpose(x1, w0))
 
 
+def test_conv3d_infer_struct_info():
+    bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
+    x0 = relax.Var("x", R.Tensor((2, 3, 28, 28, 28), "float32"))
+    x1 = relax.Var("x", R.Tensor((2, 28, 28, 28, 3), "float32"))
+    x2 = relax.Var("x", R.Tensor("float32", ndim=5))
+    x3 = relax.Var("x", R.Tensor("float32"))
+    x4 = relax.Var("x", R.Tensor())
+    x5 = relax.Var("x", R.Tensor((2, 4, 28, 28, 28, 16), "float32"))
+    x6 = relax.Var("x", R.Tensor((2, 3, 28, 28, 28), "float32", vdev0))
+    w0 = relax.Var("w", R.Tensor((4, 3, 3, 3, 3), "float32"))
+    w1 = relax.Var("w", R.Tensor((3, 4, 3, 3, 3), "float32"))
+    w2 = relax.Var("w", R.Tensor("float32", ndim=5))
+    w3 = relax.Var("w", R.Tensor("float32"))
+    w4 = relax.Var("w", R.Tensor((48, 4, 3, 3, 3, 16), "float32"))
+    w5 = relax.Var("w", R.Tensor((4, 3, 3, 3, 3), "float32", vdev0))
+
+    _check_inference(
+        bb, relax.op.nn.conv3d(x0, w0), relax.TensorStructInfo((2, 4, 26, 26, 
26), "float32")
+    )
+    _check_inference(
+        bb, relax.op.nn.conv3d(x6, w5), relax.TensorStructInfo((2, 4, 26, 26, 
26), "float32", vdev0)
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv3d(x0, w0, out_dtype="float16"),
+        relax.TensorStructInfo((2, 4, 26, 26, 26), "float16"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv3d(x0, w0, padding=1),
+        relax.TensorStructInfo((2, 4, 28, 28, 28), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv3d(x0, w0, padding=[1, 2, 3]),
+        relax.TensorStructInfo((2, 4, 28, 30, 32), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv3d(x0, w0, padding=[1, 2, 3, 4, 5, 6]),
+        relax.TensorStructInfo((2, 4, 31, 33, 35), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv3d(x0, w0, strides=2),
+        relax.TensorStructInfo((2, 4, 13, 13, 13), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv3d(x0, w0, strides=(2, 3, 4)),
+        relax.TensorStructInfo((2, 4, 13, 9, 7), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv3d(x0, w0, dilation=2),
+        relax.TensorStructInfo((2, 4, 24, 24, 24), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv3d(x0, w0, dilation=(3, 2, 1)),
+        relax.TensorStructInfo((2, 4, 22, 24, 26), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv3d(x1, w0, data_layout="NDHWC"),
+        relax.TensorStructInfo((2, 26, 26, 26, 4), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv3d(x0, w0, out_layout="NDHWC"),
+        relax.TensorStructInfo((2, 26, 26, 26, 4), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv3d(x0, w1, kernel_layout="IODHW"),
+        relax.TensorStructInfo((2, 4, 26, 26, 26), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv3d(
+            x5, w4, data_layout="NCDHW16c", kernel_layout="OIDHW16i", 
out_layout="NDHWC16c"
+        ),
+        relax.TensorStructInfo((2, 26, 26, 26, 3, 16), "float32"),
+    )
+    _check_inference(
+        bb, relax.op.nn.conv3d(x2, w0), 
relax.TensorStructInfo(dtype="float32", ndim=5)
+    )
+    _check_inference(
+        bb, relax.op.nn.conv3d(x3, w0), 
relax.TensorStructInfo(dtype="float32", ndim=5)
+    )
+    _check_inference(
+        bb, relax.op.nn.conv3d(x0, w2), 
relax.TensorStructInfo(dtype="float32", ndim=5)
+    )
+    _check_inference(
+        bb, relax.op.nn.conv3d(x0, w3), 
relax.TensorStructInfo(dtype="float32", ndim=5)
+    )
+    _check_inference(bb, relax.op.nn.conv3d(x4, w0), 
relax.TensorStructInfo(dtype="", ndim=5))
+
+
+def test_conv3d_infer_struct_info_shape_symbolic():
+    bb = relax.BlockBuilder()
+    n = tir.Var("n", "int64")
+    c = tir.Var("c", "int64")
+    c16 = tir.Var("c16", "int64")
+    id = tir.Var("id", "int64")
+    ih = tir.Var("ih", "int64")
+    iw = tir.Var("iw", "int64")
+    ki = tir.Var("ki", "int64")
+    ko = tir.Var("ko", "int64")
+    kd = tir.Var("kd", "int64")
+    kh = tir.Var("kh", "int64")
+    kw = tir.Var("kw", "int64")
+    x0 = relax.Var("x", R.Tensor((n, c, id, ih, iw), "float32"))
+    x1 = relax.Var("x", R.Tensor((n, c, id, ih, iw, c16), "float32"))
+    w0 = relax.Var("w", R.Tensor((ko, ki, kd, kh, kw), "float32"))
+    w1 = relax.Var("w", R.Tensor((ko, c, kd, kh, kw), "float32"))
+    w2 = relax.Var("w", R.Tensor((ko, c, kd, kh, kw, c16), "float32"))
+
+    _check_inference(
+        bb,
+        relax.op.nn.conv3d(x0, w0),
+        relax.TensorStructInfo((n, ko, id + 1 - kd, ih + 1 - kh, iw + 1 - kw), 
"float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv3d(x0, w1),
+        relax.TensorStructInfo((n, ko, id + 1 - kd, ih + 1 - kh, iw + 1 - kw), 
"float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv3d(
+            x1, w2, data_layout="NCDHW16c", kernel_layout="OIDHW16i", 
out_layout="NCDHW"
+        ),
+        relax.TensorStructInfo((n, ko, id + 1 - kd, ih + 1 - kh, iw + 1 - kw), 
"float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv3d(x0, w0, strides=(2, 2, 2), padding=(1, 1, 1), 
dilation=(2, 2, 2)),
+        relax.TensorStructInfo(
+            (
+                n,
+                ko,
+                tvm.tir.floordiv(id + 3, 2) + 1 - kd,
+                tvm.tir.floordiv(ih + 3, 2) + 1 - kh,
+                tvm.tir.floordiv(iw + 3, 2) + 1 - kw,
+            ),
+            "float32",
+        ),
+    )
+
+
+def test_conv3d_infer_struct_info_shape_var():
+    bb = relax.BlockBuilder()
+    s0 = relax.Var("s", relax.ShapeStructInfo(ndim=5))
+    s1 = relax.Var("s", relax.ShapeStructInfo(ndim=6))
+    s2 = relax.Var("s", relax.ShapeStructInfo(ndim=5))
+    s3 = relax.Var("s", relax.ShapeStructInfo())
+    x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+    x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
+    x2 = relax.Var("x", relax.TensorStructInfo(s3, "float32"))
+    w = relax.Var("w", relax.TensorStructInfo(s2, "float32"))
+
+    _check_inference(bb, relax.op.nn.conv3d(x0, w), 
relax.TensorStructInfo(dtype="float32", ndim=5))
+    _check_inference(
+        bb,
+        relax.op.nn.conv3d(x1, w, data_layout="NCDHW16c"),
+        relax.TensorStructInfo(dtype="float32", ndim=6),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv3d(x0, w, out_layout="NCDHW16c"),
+        relax.TensorStructInfo(dtype="float32", ndim=6),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv3d(x2, w),
+        relax.TensorStructInfo(dtype="float32", ndim=5),
+    )
+
+
 if __name__ == "__main__":
     tvm.testing.main()


Reply via email to