icemelon9 commented on a change in pull request #5186: [Relay][Topi][AutoTVM] 
Winograd support for Conv3D
URL: https://github.com/apache/incubator-tvm/pull/5186#discussion_r403299155
 
 

 ##########
 File path: src/relay/op/nn/convolution.cc
 ##########
 @@ -662,96 +454,101 @@ 
RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_without_weight_transform")
         ConvInferCorrectLayout<Conv2DWinogradAttrs>);
 
 // relay.nn.contrib_conv2d_winograd_weight_transform
-TVM_REGISTER_NODE_TYPE(Conv2DWinogradWeightTransformAttrs);
-
-bool Conv2DWinogradWeightTransformRel(const Array<Type>& types,
-                                      int num_inputs,
-                                      const Attrs& attrs,
-                                      const TypeReporter& reporter) {
-  CHECK_EQ(types.size(), 2);
-  const auto* data = types[0].as<TensorTypeNode>();
-  if (data == nullptr) return false;
-
-  const Conv2DWinogradWeightTransformAttrs* param = 
attrs.as<Conv2DWinogradWeightTransformAttrs>();
-  CHECK(param != nullptr);
-
-  CHECK_EQ(data->shape.size(), 4) << "Only support NCHW normal kernel layout";
-
-  // each pad width element should be a pair of positive integers
-  std::vector<IndexExpr> oshape {
-      param->tile_size + data->shape[2] - 1,
-      param->tile_size + data->shape[3] - 1,
-      data->shape[0],
-      data->shape[1],
-  };
-
-  reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape),
-                                                  data->dtype));
-  return true;
-}
-
-Expr MakeConv2DWinogradWeightTransform(Expr weight,
-                                       int tile_size) {
-  auto attrs = make_object<Conv2DWinogradWeightTransformAttrs>();
-  attrs->tile_size = tile_size;
-  static const Op& op = Op::Get("nn.contrib_conv2d_winograd_weight_transform");
-  return Call(op, {weight}, Attrs(attrs), {});
-}
-
+TVM_REGISTER_NODE_TYPE(ConvWinogradWeightTransformAttrs);
 
 
TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_winograd_weight_transform")
-.set_body_typed(MakeConv2DWinogradWeightTransform);
-
+.set_body_typed([](Expr weight,
+                   int tile_size) {
+  return MakeConvWinogradWeightTransform(
+    weight, tile_size, "nn.contrib_conv2d_winograd_weight_transform");
+});
 
 RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_weight_transform")
-.describe(R"code(Weight transformation of winograd fast convolution algorithm.
+    .describe(R"code(Weight transformation of winograd fast convolution 
algorithm.
 
 Separate this into another operator in order to enable Precompute Pass to 
compute the
 weight transformation in advance.
 
 - **weight**: (channels, in_channels, kernel_size[0], kernel_size[1])
 )code" TVM_ADD_FILELINE)
-.set_attrs_type<Conv2DWinogradWeightTransformAttrs>()
-.set_num_inputs(1)
-.add_argument("weight", "Tensor", "The weight tensor.")
-.set_support_level(10)
-.add_type_rel("Conv2DWinogradWeightTransform", 
Conv2DWinogradWeightTransformRel);
+    .set_attrs_type<ConvWinogradWeightTransformAttrs>()
+    .set_num_inputs(1)
+    .add_argument("weight", "Tensor", "The weight tensor.")
+    .set_support_level(10)
+    .add_type_rel("Conv2DWinogradWeightTransform",
+                  
Conv2DWinogradWeightTransformRel<ConvWinogradWeightTransformAttrs>);
+
+// relay.nn.contrib_conv3d_winograd_without_weight_transform
+TVM_REGISTER_NODE_TYPE(Conv3DWinogradAttrs);
+
+TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv3d_winograd_without_weight_transform")
+.set_body_typed([](Expr data,
+                   Expr weight,
+                   int tile_size,
+                   Array<IndexExpr> strides,
+                   Array<IndexExpr> padding,
+                   Array<IndexExpr> dilation,
+                   int groups,
+                   IndexExpr channels,
+                   Array<IndexExpr> kernel_size,
+                   std::string data_layout,
+                   std::string kernel_layout,
+                   std::string out_layout,
+                   DataType out_dtype) {
+  return MakeConvWinograd<Conv3DWinogradAttrs>(
+    data, weight, tile_size, strides, padding, dilation,
+    groups, channels, kernel_size, data_layout,
+    kernel_layout, out_layout, out_dtype, 
"nn.contrib_conv3d_winograd_without_weight_transform");
+});
+
+RELAY_REGISTER_OP("nn.contrib_conv3d_winograd_without_weight_transform")
+    .describe(R"code(Compute conv3d with winograd algorithm. Only supports 
NCDHW layout.
 
 Review comment:
   ditto

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to