ciyongch commented on a change in pull request #16141: [mkldnn-v1.0] Add
MKL-DNN Convolution
URL: https://github.com/apache/incubator-mxnet/pull/16141#discussion_r324465554
##########
File path: src/operator/nn/mkldnn/mkldnn_convolution-inl.h
##########
@@ -79,54 +79,63 @@ struct MKLDNNConvFullParam {
MKLDNNPostEltwiseParam postsum_act_param;
};
-mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(const
MKLDNNConvFullParam ¶m,
- const bool is_train,
- const NDArray &data,
- const NDArray
&weights,
- const NDArray *bias,
- const NDArray
&output);
+std::shared_ptr<mkldnn::convolution_forward::primitive_desc> GetConvFwdImpl(
+ const ConvolutionParam ¶m, const bool is_train, const NDArray &data,
const NDArray &weight,
+ const NDArray *bias, const NDArray &output);
class MKLDNNConvForward {
public:
- mkldnn::convolution_forward::primitive_desc fwd_pd;
-
MKLDNNConvForward(const MKLDNNConvFullParam ¶m, const bool is_train,
const NDArray &data,
- const NDArray &weights, const NDArray *bias, const NDArray
&output);
+ const NDArray &weight, const NDArray *bias, const NDArray
&output);
- void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight,
- const mkldnn::memory *bias, const mkldnn::memory &output);
+ const mkldnn::convolution_forward &GetFwd() const { return *fwd_; }
- void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &output) {
- this->data_->set_data_handle(data.get_data_handle());
- this->out_->set_data_handle(output.get_data_handle());
- }
-
- const mkldnn::convolution_forward &GetFwd() const {
- return *fwd_;
- }
+ const mkldnn::convolution_forward::primitive_desc &GetPd() const { return
*pd_; }
private:
std::shared_ptr<mkldnn::convolution_forward> fwd_;
- std::shared_ptr<mkldnn::memory> data_;
- std::shared_ptr<mkldnn::memory> weight_;
- std::shared_ptr<mkldnn::memory> bias_;
- std::shared_ptr<mkldnn::memory> out_;
+ std::shared_ptr<mkldnn::convolution_forward::primitive_desc> pd_;
};
typedef ParamOpSign<ConvolutionParam> MKLDNNConvSignature;
-MKLDNNConvForward &GetConvFwd(const ConvolutionParam ¶m,
- const bool is_train, const NDArray &data,
- const NDArray &weights, const NDArray *bias,
- const NDArray &output);
-
void MKLDNNConvolutionForwardFullFeature(const MKLDNNConvFullParam ¶m,
const OpContext &ctx,
MKLDNNConvForward *fwd,
const std::vector<NDArray> &in_data,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data);
+void MKLDNNConvolutionForward(const nnvm::NodeAttrs &attrs,
+ const OpContext &ctx,
+ const std::vector<NDArray> &in_data,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &out_data);
+
+class MKLDNNConvBackward {
+ public:
+ MKLDNNConvBackward(const MKLDNNConvFullParam ¶m, const NDArray &data,
const NDArray &weight,
+ const NDArray *bias, const NDArray &output);
+
+ const mkldnn::convolution_backward_data &GetBwdData() const { return
*bwd_data_; }
+
+ const mkldnn::convolution_backward_weights &GetBwdWeights() const { return
*bwd_weight_; }
+
+ const mkldnn::convolution_backward_data::primitive_desc &GetDataPd() const {
+ return *bwd_data_pd_;
+ }
+
+ const mkldnn::convolution_backward_weights::primitive_desc &GetWeightsPd()
const {
+ return *bwd_weights_pd_;
+ }
+
+ private:
+ std::shared_ptr<mkldnn::convolution_backward_data::primitive_desc>
bwd_data_pd_;
+ std::shared_ptr<mkldnn::convolution_backward_weights::primitive_desc>
bwd_weights_pd_;
+ std::shared_ptr<mkldnn::convolution_backward_data> bwd_data_;
+ std::shared_ptr<mkldnn::convolution_backward_weights> bwd_weight_;
+};
+
Review comment:
Please use `weight` or `weights` consistently.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services