bartekkuncer commented on a change in pull request #20825:
URL: https://github.com/apache/incubator-mxnet/pull/20825#discussion_r787896195
##########
File path: src/operator/nn/dnnl/dnnl_pooling.cc
##########
@@ -233,17 +234,30 @@ void InitPoolingPrimitiveParams(const PoolingParam& param,
dnnl::pooling_forward::primitive_desc GetPoolingFwdPdesc(const PoolingParam&
param,
const bool is_train,
const
dnnl::memory::desc& data_md,
- const
dnnl::memory::desc& out_md) {
- CHECK(param.kernel.ndim() == 1 || param.kernel.ndim() == 2 ||
param.kernel.ndim() == 3)
- << "Not Implemented";
+ const
dnnl::memory::desc& out_md,
+ const bool
use_adaptive_pooling) {
+ CHECK((param.kernel.ndim() >= 1 && param.kernel.ndim() <= 3) ||
use_adaptive_pooling)
+ << "Not Implemented"; // to be changed
Review comment:
to be changed when?
##########
File path: src/operator/nn/dnnl/dnnl_pooling.cc
##########
@@ -384,22 +400,40 @@ DNNLPoolingBwd& GetPoolingBwd(const PoolingParam& param,
return it->second;
}
-void DNNLPoolingGradCompute(const OpContext& ctx,
- const PoolingParam& param,
- const NDArray& out_grad,
- const NDArray& in_data,
- const NDArray* workspace,
- const OpReqType req,
- const NDArray& in_grad) {
- if (req == kNullOp) {
+void DNNLPoolingGradCompute(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+ if (req[0] == kNullOp) {
return;
}
+ const PoolingParam& param = nnvm::get<PoolingParam>(attrs.parsed);
+
+ const NDArray& out_grad = inputs[0];
+ const NDArray* workspace = nullptr;
+ const NDArray* in_data = nullptr;
+ if (DNNLRequireWorkspace(param)) {
+ // The first two elements are the gradient of the outputs in forward.
Review comment:
```suggestion
// The first two elements are the gradients of the outputs in forward.
```
##########
File path: src/operator/nn/dnnl/dnnl_pooling-inl.h
##########
@@ -28,10 +28,13 @@
#include <dnnl.hpp>
#include <utility>
+#include <vector>
#include "../pooling-inl.h"
#include "./dnnl_base-inl.h"
+#define DIV_ROUND_UP(a, b) (((a) + (b - 1)) / b)
Review comment:
Does a have to be in brackets?
##########
File path: src/operator/nn/dnnl/dnnl_pooling.cc
##########
@@ -357,21 +364,30 @@ DNNLPoolingBwd& GetPoolingBwd(const PoolingParam& param,
auto dst_md = dnnl::memory::desc(dst_dims, get_data_type(data_md), any);
// fwd hint
- auto fwd_pd = GetPoolingFwdPdesc(param, true, data_md, dst_md);
+ auto fwd_pd = GetPoolingFwdPdesc(param, true, data_md, dst_md,
use_adaptive_pooling);
// creat bwd desc
Review comment:
create?
##########
File path: src/operator/contrib/adaptive_avg_pooling.cc
##########
@@ -25,11 +25,13 @@
// #include "elemwise_op_common.h"
#include "../elemwise_op_common.h"
#if MXNET_USE_ONEDNN == 1
+#include "../nn/dnnl/dnnl_base-inl.h"
#include "../nn/dnnl/dnnl_pooling-inl.h"
#endif // MXNET_USE_ONEDNN
#define START_IND(a, b, c) static_cast<int>(std::floor(static_cast<float>(a *
c) / b))
#define END_IND(a, b, c) static_cast<int>(std::ceil(static_cast<float>((a +
1) * c) / b))
+#define DIV_ROUND_UP(a, b) (((a) + (b - 1)) / b)
Review comment:
Does a have to be in brackets?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]