icemelon9 commented on a change in pull request #5186: [Relay][Topi][AutoTVM]
Winograd support for Conv3D
URL: https://github.com/apache/incubator-tvm/pull/5186#discussion_r403298871
##########
File path: src/relay/op/nn/convolution.cc
##########
@@ -662,96 +454,101 @@
RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_without_weight_transform")
ConvInferCorrectLayout<Conv2DWinogradAttrs>);
// relay.nn.contrib_conv2d_winograd_weight_transform
-TVM_REGISTER_NODE_TYPE(Conv2DWinogradWeightTransformAttrs);
-
-bool Conv2DWinogradWeightTransformRel(const Array<Type>& types,
- int num_inputs,
- const Attrs& attrs,
- const TypeReporter& reporter) {
- CHECK_EQ(types.size(), 2);
- const auto* data = types[0].as<TensorTypeNode>();
- if (data == nullptr) return false;
-
- const Conv2DWinogradWeightTransformAttrs* param =
attrs.as<Conv2DWinogradWeightTransformAttrs>();
- CHECK(param != nullptr);
-
- CHECK_EQ(data->shape.size(), 4) << "Only support NCHW normal kernel layout";
-
- // each pad width element should be a pair of positive integers
- std::vector<IndexExpr> oshape {
- param->tile_size + data->shape[2] - 1,
- param->tile_size + data->shape[3] - 1,
- data->shape[0],
- data->shape[1],
- };
-
- reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape),
- data->dtype));
- return true;
-}
-
-Expr MakeConv2DWinogradWeightTransform(Expr weight,
- int tile_size) {
- auto attrs = make_object<Conv2DWinogradWeightTransformAttrs>();
- attrs->tile_size = tile_size;
- static const Op& op = Op::Get("nn.contrib_conv2d_winograd_weight_transform");
- return Call(op, {weight}, Attrs(attrs), {});
-}
-
+TVM_REGISTER_NODE_TYPE(ConvWinogradWeightTransformAttrs);
TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_winograd_weight_transform")
-.set_body_typed(MakeConv2DWinogradWeightTransform);
-
+.set_body_typed([](Expr weight,
+ int tile_size) {
+ return MakeConvWinogradWeightTransform(
+ weight, tile_size, "nn.contrib_conv2d_winograd_weight_transform");
+});
RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_weight_transform")
-.describe(R"code(Weight transformation of winograd fast convolution algorithm.
+ .describe(R"code(Weight transformation of winograd fast convolution
algorithm.
Review comment:
ditto
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services