csullivan commented on a change in pull request #7515:
URL: https://github.com/apache/tvm/pull/7515#discussion_r583864932
##########
File path: src/relay/transforms/simplify_expr.cc
##########
@@ -82,6 +82,119 @@ class SimplifyReshape : public SimplifyPattern {
DFPattern x_;
};
+/*!
+ * \brief SimplifyConvPad matches a pad followed by a
conv/convtranspose/pool/etc
+ * with a pad attribute and merges the padding into the kernel.
+ */
+class SimplifyConvPad : public SimplifyPattern {
+ public:
+ SimplifyConvPad() {
+ x_ = IsWildcard();
+ w_ = IsWildcard();
+ pad_ = IsOp("nn.pad")({x_});
+ conv1d_ = IsOp("nn.conv1d");
+ conv2d_ = IsOp("nn.conv2d");
+ conv3d_ = IsOp("nn.conv3d");
+ conv_ = (conv1d_ || conv2d_ || conv3d_)({pad_, w_});
+ pattern_ = conv_;
+ }
+ template <typename T>
+ Attrs MakeConvAttrs(const T* old_attrs, const Array<PrimExpr> padding) const
{
+ ICHECK(old_attrs);
+ ICHECK(padding.size() == old_attrs->padding.size())
+ << "Number of dimensions to pad and convolution padding attributes
should have the same "
+ "extent";
+
+ auto new_attrs = make_object<T>();
+ Array<PrimExpr> combined_padding;
+ for (size_t i = 0; i < padding.size(); ++i) {
+ combined_padding.push_back(padding[i] + old_attrs->padding[i]);
+ }
+ new_attrs->strides = old_attrs->strides;
+ new_attrs->padding = combined_padding;
+ new_attrs->dilation = old_attrs->dilation;
+ new_attrs->groups = old_attrs->groups;
+ new_attrs->channels = old_attrs->channels;
+ new_attrs->kernel_size = old_attrs->kernel_size;
+ new_attrs->data_layout = old_attrs->data_layout;
+ new_attrs->kernel_layout = old_attrs->kernel_layout;
+ new_attrs->out_layout = old_attrs->out_layout;
+ new_attrs->out_dtype = old_attrs->out_dtype;
+ return Attrs(new_attrs);
+ }
+ template <typename T>
+ Attrs GetAttrs(const PadAttrs* param, const T* attrs) const {
+ ICHECK(param);
+ ICHECK(attrs);
+ ICHECK(attrs->data_layout.size() == param->pad_width.size())
+ << "Data Layout and padding attributes should have the same extent";
+
+ std::string data_layout = attrs->data_layout;
+ std::set<char> image_dims({'H', 'W', 'D'});
+ Array<PrimExpr> padding;
Review comment:
Nice. Consider a comment here about bailing out if padding is not on one
of the spatial dimensions
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]