masahi commented on a change in pull request #5186: [Relay][Topi][AutoTVM]
Winograd support for Conv3D
URL: https://github.com/apache/incubator-tvm/pull/5186#discussion_r403412011
##########
File path: src/relay/op/nn/convolution.cc
##########
@@ -198,141 +302,32 @@ with the layer input to produce a tensor of outputs.
.add_type_rel("Conv3D", Conv3DRel<Conv3DAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ConvInferCorrectLayout<Conv3DAttrs>);
+
// relay.nn.conv2d_transpose
TVM_REGISTER_NODE_TYPE(Conv2DTransposeAttrs);
-bool Conv2DTransposeRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
- const TypeReporter& reporter) {
- CHECK_EQ(types.size(), 3);
- const auto* data = types[0].as<TensorTypeNode>();
- const auto* weight = types[1].as<TensorTypeNode>();
- if (data == nullptr) return false;
-
- static const Layout kNCHW("NCHW");
- static const Layout kOIHW("OIHW");
-
- const Conv2DTransposeAttrs* param = attrs.as<Conv2DTransposeAttrs>();
- CHECK(param != nullptr);
- const Layout in_layout(param->data_layout);
- const Layout kernel_layout(param->kernel_layout);
-
- const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW);
- CHECK(trans_in_layout.defined())
- << "Conv only support input layouts that are convertible from NCHW."
- << " But got " << in_layout;
-
- const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW);
- CHECK(trans_kernel_layout.defined())
- << "Conv only support kernel layouts that are convertible from OIHW."
- << " But got "<< kernel_layout;
-
- Layout out_layout(param->out_layout == "" ? param->data_layout :
param->out_layout);
- const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW);
- CHECK(trans_out_layout.defined())
- << "Conv only support output layouts that are convertible from NCHW."
- << " But got " << out_layout;
-
- IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
-
- auto dshape_nchw = trans_in_layout.ForwardShape(data->shape);
-
- // infer weight if the kernel_size and channels are defined
- if (param->kernel_size.defined() && param->channels.defined()) {
- CHECK_EQ(param->kernel_size.size(), 2);
- CHECK_EQ(param->dilation.size(), 2);
-
- Array<IndexExpr> wshape({dshape_nchw[1],
- indexdiv(param->channels, param->groups),
- param->kernel_size[0],
- param->kernel_size[1]});
-
- wshape = trans_kernel_layout.BackwardShape(wshape);
- dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
- dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
- channels = param->channels;
-
- // assign result to reporter
- reporter->Assign(types[1], TensorType(wshape, data->dtype));
- } else {
- // use weight to infer the conv shape.
- if (weight == nullptr) return false;
- auto wshape = trans_kernel_layout.ForwardShape(weight->shape);
- if (param->kernel_size.defined()) {
- CHECK_EQ(param->kernel_size.size(), 2);
- // check the size
- CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) &&
- reporter->AssertEQ(param->kernel_size[1], wshape[3]))
- << "Conv2D: shape of weight is inconsistent with kernel_size, "
- << " kernel_size=" << param->kernel_size
- << " wshape=" << Array<IndexExpr>(wshape);
- }
- if (param->channels.defined()) {
- CHECK(reporter->AssertEQ(param->channels, wshape[1]))
- << "Conv2D: shape of weight is inconsistent with channels, "
- << " channels=" << param->channels
- << " wshape=" << Array<IndexExpr>(wshape);
- }
- CHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups),
wshape[0]));
- channels = wshape[1];
- dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0];
- dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
- }
- // dilation
- Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
- IndexExpr pad_h, pad_w;
- GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
- oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y -
- pad_h + param->output_padding[0]));
- oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x -
- pad_w + param->output_padding[1]));
-
- DataType out_dtype = param->out_dtype;
- if (out_dtype.bits() == 0) {
- out_dtype = data->dtype;
- }
- oshape = trans_out_layout.BackwardShape(oshape);
- reporter->Assign(types[2], TensorType(oshape, out_dtype));
- return true;
-}
-
-
-Expr MakeConv2DTranspose(Expr data,
- Expr weight,
- Array<IndexExpr> strides,
- Array<IndexExpr> padding,
- Array<IndexExpr> dilation,
- int groups,
- IndexExpr channels,
- Array<IndexExpr> kernel_size,
- std::string data_layout,
- std::string kernel_layout,
- std::string out_layout,
- Array<IndexExpr> output_padding,
- DataType out_dtype) {
- auto attrs = make_object<Conv2DTransposeAttrs>();
- attrs->channels = std::move(channels);
- attrs->kernel_size = std::move(kernel_size);
- attrs->strides = std::move(strides);
- attrs->padding = std::move(padding);
- attrs->output_padding = std::move(output_padding);
- attrs->dilation = std::move(dilation);
- attrs->groups = groups;
- attrs->data_layout = std::move(data_layout);
- attrs->kernel_layout = std::move(kernel_layout);
- attrs->out_layout = std::move(out_layout);
- attrs->out_dtype = std::move(out_dtype);
- static const Op& op = Op::Get("nn.conv2d_transpose");
- return Call(op, {data, weight}, Attrs(attrs), {});
-}
-
-
TVM_REGISTER_GLOBAL("relay.op.nn._make.conv2d_transpose")
-.set_body_typed(MakeConv2DTranspose);
+.set_body_typed([](Expr data,
+ Expr weight,
+ Array<IndexExpr> strides,
+ Array<IndexExpr> padding,
+ Array<IndexExpr> dilation,
+ int groups,
+ IndexExpr channels,
+ Array<IndexExpr> kernel_size,
+ std::string data_layout,
+ std::string kernel_layout,
+ std::string out_layout,
+ Array<IndexExpr> output_padding,
+ DataType out_dtype) {
+ return MakeConvTranspose<Conv2DTransposeAttrs>(
+ data, weight, strides, padding, dilation,
+ groups, channels, kernel_size, data_layout,
+ kernel_layout, out_layout, output_padding, out_dtype,
"nn.conv2d_transpose");
+});
RELAY_REGISTER_OP("nn.conv2d_transpose")
-.describe(R"code(Transposed 2D convolution layer (sometimes called
Deconvolution).
+ .describe(R"code(Transposed 2D convolution layer (sometimes called
Deconvolution).
Review comment:
indent
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services