This is an automated email from the ASF dual-hosted git repository. anirudh2290 pushed a commit to branch v1.2.0 in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
commit 9e97a96fab3bc6c6174c9541a69bf1accd95f731 Author: Da Zheng <zhengda1...@gmail.com> AuthorDate: Fri Apr 27 10:35:12 2018 -0700 invalidate MKLDNN memory for reused NDArrays. (#10706) * Revert "Revert "invalidate outputs for imperative."" This reverts commit b428937968adf177e0260361c972e502e839edb5. * invalidate mkldnn memory. * enable test. --- src/executor/attach_op_execs_pass.cc | 9 +++++++++ src/imperative/imperative_utils.h | 13 +++++++++++++ 2 files changed, 22 insertions(+) diff --git a/src/executor/attach_op_execs_pass.cc b/src/executor/attach_op_execs_pass.cc index e4d4955..3c8fb83 100644 --- a/src/executor/attach_op_execs_pass.cc +++ b/src/executor/attach_op_execs_pass.cc @@ -113,6 +113,9 @@ class StatefulComputeExecutor : public StorageFallbackOpExecutor { public: void Run(RunContext rctx, bool is_gpu) override { op_ctx.run_ctx = rctx; +#if MXNET_USE_MKLDNN == 1 + InvalidateOutputs(out_array, req); +#endif PreFCompute(is_gpu); fcompute_(state_, op_ctx, in_data_, req, out_data_); PostFCompute(is_gpu); @@ -146,6 +149,9 @@ class StatefulComputeExExecutor : public OpExecutor { public: void Run(RunContext rctx, bool is_gpu) override { op_ctx.run_ctx = rctx; +#if MXNET_USE_MKLDNN == 1 + InvalidateOutputs(out_array, req); +#endif fcompute_(state_, op_ctx, in_array, req, out_array); } @@ -178,6 +184,9 @@ class FComputeExecutor : public StorageFallbackOpExecutor { void Run(RunContext rctx, bool is_gpu) override { using namespace common; op_ctx.run_ctx = rctx; +#if MXNET_USE_MKLDNN == 1 + InvalidateOutputs(out_array, req); +#endif PreFCompute(is_gpu); fcompute_(attrs_, op_ctx, in_data_, req, out_data_); PostFCompute(is_gpu); diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index 0d6525d..86683f9 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -29,6 +29,7 @@ #include "../c_api/c_api_common.h" #include "../common/utils.h" #include "../common/exec_utils.h" +#include "../operator/nn/mkldnn/mkldnn_base-inl.h" #ifndef MXNET_IMPERATIVE_IMPERATIVE_UTILS_H_ #define MXNET_IMPERATIVE_IMPERATIVE_UTILS_H_ @@ -365,6 +366,9 @@ inline void PushFCompute(const FCompute& fn, std::vector<NDArray> pre_temp_src, pre_temp_dst, post_temp_dst, post_temp_src; // mapping from index in input_blobs to index in pre_temp_dst std::unordered_map<uint32_t, uint32_t> in_temp_idx_map; +#if MXNET_USE_MKLDNN == 1 + InvalidateOutputs(outputs, req); +#endif // setup blobs SetupDefaultBlobsInOut(inputs, outputs, req, nullptr, nullptr, &input_blobs, &output_blobs, &pre_temp_src, &pre_temp_dst, @@ -402,6 +406,9 @@ inline void PushFComputeEx(const FComputeEx& fn, DerefInputOutput(p_inputs, p_outputs, &inputs, &outputs); const auto& run = [=](RunContext rctx) { OpContext opctx{is_train, rctx, engine::CallbackOnComplete(), requested}; +#if MXNET_USE_MKLDNN == 1 + InvalidateOutputs(outputs, req); +#endif fn(attrs, opctx, inputs, req, outputs); if (ctx.dev_mask() == gpu::kDevMask && exec_type == ExecType::kSync) { rctx.get_stream<gpu>()->Wait(); @@ -445,6 +452,9 @@ inline void PushOperator(const OpStatePtr& state, const auto& run = [=](RunContext rctx, engine::CallbackOnComplete on_complete) { OpContext opctx{is_train, rctx, on_complete, requested}; +#if MXNET_USE_MKLDNN == 1 + InvalidateOutputs(outputs, req); +#endif fcompute_ex(state, opctx, inputs, req, outputs); if (ctx.dev_mask() == gpu::kDevMask && exec_type == ExecType::kSync) { rctx.get_stream<gpu>()->Wait(); @@ -475,6 +485,9 @@ inline void PushOperator(const OpStatePtr& state, std::vector<NDArray> pre_temp_src, pre_temp_dst, post_temp_dst, post_temp_src; // mapping from index in input_blobs to index in pre_temp_dst std::unordered_map<uint32_t, uint32_t> in_temp_idx_map; +#if MXNET_USE_MKLDNN == 1 + InvalidateOutputs(outputs, req); +#endif // populate input blobs and output blobs SetupDefaultBlobsInOut(inputs, outputs, req, nullptr, nullptr, &input_blobs, &output_blobs, &pre_temp_src, &pre_temp_dst, -- To stop receiving notification emails like this one, please contact anirudh2...@apache.org.