zhreshold commented on a change in pull request #11502: [MXNET-614] Adding 
Synchronized Batch Normalization
URL: https://github.com/apache/incubator-mxnet/pull/11502#discussion_r201927974
 
 

 ##########
 File path: src/operator/contrib/sync_batch_norm-inl.h
 ##########
 @@ -0,0 +1,592 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file sync_batch_norm-inl.h
+ * \brief Synchronized BatchNorm modified from BatchNormV1
+ * \author Hang Zhang
+*/
+#ifndef MXNET_OPERATOR_CONTRIB_SYNC_BATCH_NORM_INL_H_
+#define MXNET_OPERATOR_CONTRIB_SYNC_BATCH_NORM_INL_H_
+
+#include <dmlc/logging.h>
+#include <dmlc/parameter.h>
+#include <mxnet/operator.h>
+#include <condition_variable>
+#include <map>
+#include <vector>
+#include <string>
+#include <utility>
+#include "../operator_common.h"
+#include "../mshadow_op.h"
+
+namespace mxnet {
+namespace op {
+
+namespace syncbatchnorm {
+enum BatchNormOpInputs {kData, kGamma, kBeta};
+enum BatchNormOpOutputs {kOut, kMean, kVar};
+enum BatchNormOpAuxiliary {kMovingMean, kMovingVar};
+enum BatchNormBackResource {kTempSpace};
+}  // namespace syncbatchnorm
+
+struct SyncBatchNormParam : public dmlc::Parameter<SyncBatchNormParam> {
+  float eps;
+  float momentum;
+  bool fix_gamma;
+  bool use_global_stats;
+  bool output_mean_var;
+  int ndev;
+  std::string key;
+  DMLC_DECLARE_PARAMETER(SyncBatchNormParam) {
+    DMLC_DECLARE_FIELD(eps).set_default(1e-3f)
+    .describe("Epsilon to prevent div 0");
+    DMLC_DECLARE_FIELD(momentum).set_default(0.9f)
+    .describe("Momentum for moving average");
+    DMLC_DECLARE_FIELD(fix_gamma).set_default(true)
+    .describe("Fix gamma while training");
+    DMLC_DECLARE_FIELD(use_global_stats).set_default(false)
+    .describe("Whether use global moving statistics instead of local 
batch-norm. "
+              "This will force change batch-norm into a scale shift 
operator.");
+    DMLC_DECLARE_FIELD(output_mean_var).set_default(false)
+    .describe("Output All,normal mean and var");
+    DMLC_DECLARE_FIELD(ndev).set_default(1)
+      .describe("The count of GPU devices");
+    DMLC_DECLARE_FIELD(key)
+      .set_default("")
+      .describe("Hash key for synchronization, please set the same hash key 
for same layer, "
+                "Block.prefix is typically used as in 
:class:`gluon.nn.contrib.SyncBatchNorm`.");
+  }
+};
+
+// Modified from https://github.com/brucechin/SharedTensor
+template<class T>
+class SharedND {
+ private:
+  int num_devices_;
+  T mean_;
+  T *data_;
+  bool *flag_;
+  bool mean_ready_ = false;
+  bool data_inited_ = false;
+  std::mutex mutex_;
+
+ public:
+  explicit SharedND(int ndev) :num_devices_(ndev) {
+    flag_ = new bool[ndev];
+    data_ = new T[ndev];
+    memset(flag_, false, ndev * sizeof(bool));
+  }
+
+  ~SharedND() {
+    mshadow::FreeSpace(&mean_);
+    delete [] flag_;
+    delete [] data_;
+  }
+
+  void Init(mshadow::Shape<1> shape) {
+    std::lock_guard<std::mutex> lock(mutex_);
+    if (!data_inited_) {
+      for (int i = 0; i < num_devices_; i++) {
+        data_[i] = mshadow::NewTensor<cpu, real_t>(shape, 0.0f);
+      }
+      mean_ = mshadow::NewTensor<cpu, real_t>(shape, 0.0f);
+      data_inited_ = true;
+    }
+  }
+
+  T* Retrieve(mshadow::Shape<1> shape, int index) {
+    if (!data_inited_) {
+      Init(shape);
+    }
+    if (flag_[index] == false) {
+      return &data_[index];
+    } else {
+      return nullptr;
+    }
+  }
+
+  bool SetReady(int index) {
+    if (flag_[index] == false) {
+      flag_[index] = true;
+      return true;
+    } else {
+      return false;
+    }
+  }
+
+  T Pop(int index) {
+    std::lock_guard<std::mutex> lock(mutex_);
+    while (!MeanReady()) {}
+    flag_[index] = false;
+    T tmp = mean_;
+    ResetMean();
+    return tmp;
+  }
+
+  bool MeanReady() {
+    if (mean_ready_) {
+      return true;
+    }
+    for (int i = 0; i < num_devices_; i++) {
+      if (!flag_[i]) {
+        return false;
+      }
+    }
+    for (int i = 1; i < num_devices_; i++) {
+      data_[0] += data_[i];
+    }
+    mean_ = data_[0] * 1.0f /  num_devices_;
+    mean_ready_ = true;
+    return true;
+  }
+
+  void ResetMean() {
+    for (int i = 0; i < num_devices_; i++) {
+      if (flag_[i]) return;
+    }
+    mean_ready_ = false;
+  }
+};
+
+template<class T>
+class GlobalShared {
+ public:
+  T* Register(const std::string &key, int ndev) {
+    std::lock_guard<std::mutex> lock(mutex_);
+    auto it = registry_.find(key);
+    if (it != registry_.end()) return it->second;
+    T *newT = new T(ndev);
+    registry_[key] = newT;
+    return newT;
+  }
+  ~GlobalShared() {
+    for (auto it = registry_.begin(); it != registry_.end(); it++) {
+      T *ptr = it->second;
+      delete ptr;
+    }
+  }
+ private:
+  std::mutex mutex_;
+  std::map<std::string, T*> registry_;
+};
+
+template<class T>
+class GlobalSharedRank {
+ public:
+  T Register(const std::string &key, int ndev) {
+    std::lock_guard<std::mutex> lock(mutex_);
+    auto it = registry_.find(key);
+    if (it != registry_.end()) {
+      T* tmpT = it->second;
+      *tmpT = (*tmpT == ndev - 1) ? 0 : *tmpT + 1;
+      return *tmpT;
+    }
+    T *newT = new T(0);
+    registry_[key] = newT;
+    return *newT;
+  }
+  ~GlobalSharedRank() {
+    for (auto it = registry_.begin(); it != registry_.end(); it++) {
+      T *ptr = it->second;
+      delete ptr;
+    }
+  }
+ private:
+  std::mutex mutex_;
+  std::map<std::string, T*> registry_;
+};
+
+class Barrier {
+ private:
+  std::mutex mutex_;
+  std::condition_variable cv_;
+  std::size_t count_;
+  std::size_t total_count_;
+ public:
+  explicit Barrier(std::size_t count) : count_{count}, total_count_{count} { }
+  void Wait() {
+    std::unique_lock<std::mutex> lock{mutex_};
+    if (--count_ == 0) {
+      count_ = total_count_;
+      cv_.notify_all();
+    } else {
+      cv_.wait(lock, [this] { return count_ == total_count_; });
+    }
+  }
+};
+
+// Global variables for Synchronizations
+static GlobalSharedRank<int> global_shared_rank_forward;
+static GlobalSharedRank<int> global_shared_rank_backward;
+static GlobalShared<Barrier> global_shared_barrier_forward;
+static GlobalShared<Barrier> global_shared_barrier_backward;
+static GlobalShared<SharedND<mshadow::Tensor<cpu, 1, real_t>>> 
global_shared_mean;
+static GlobalShared<SharedND<mshadow::Tensor<cpu, 1, real_t>>> 
global_shared_var;
+static GlobalShared<SharedND<mshadow::Tensor<cpu, 1, real_t>>> 
global_shared_grad;
+static GlobalShared<SharedND<mshadow::Tensor<cpu, 1, real_t>>> 
global_shared_prod;
+
+template<typename xpu>
+class SyncBatchNorm : public Operator {
+ public:
+  explicit SyncBatchNorm(SyncBatchNormParam param) {
+    this->param_ = param;
+  }
+
+  virtual void Forward(const OpContext &ctx,
+                       const std::vector<TBlob> &in_data,
+                       const std::vector<OpReqType> &req,
+                       const std::vector<TBlob> &out_data,
+                       const std::vector<TBlob> &aux_states) {
+    using namespace mshadow;
+    using namespace mshadow::expr;
+    CHECK_EQ(in_data.size(), 3U);
+    CHECK_EQ(aux_states.size(), 2U);
+    if (ctx.is_train) {
+      CHECK_EQ(out_data.size(), 3U);
+      CHECK_EQ(req.size(), 3U);
+    } else {
+      CHECK_GE(out_data.size(), 1U);
+      CHECK_GE(req.size(), 1U);
+      CHECK_EQ(req[syncbatchnorm::kOut], kWriteTo);
+    }
+
+    Stream<xpu> *s = ctx.get_stream<xpu>();
+    const real_t scale = 
static_cast<real_t>(in_data[syncbatchnorm::kData].shape_[1]) /
+      static_cast<real_t>(in_data[syncbatchnorm::kData].shape_.Size());
+    Tensor<xpu, 4> data;
+    Tensor<xpu, 4> out;
+    if (in_data[syncbatchnorm::kData].ndim() == 2) {
+      Shape<4> dshape = Shape4(in_data[syncbatchnorm::kData].shape_[0],
+                               in_data[syncbatchnorm::kData].shape_[1], 1, 1);
+      data = in_data[syncbatchnorm::kData].get_with_shape<xpu, 4, 
real_t>(dshape, s);
+      out = out_data[syncbatchnorm::kOut].get_with_shape<xpu, 4, 
real_t>(dshape, s);
+    } else {
+      data = in_data[syncbatchnorm::kData].get<xpu, 4, real_t>(s);
+      out = out_data[syncbatchnorm::kOut].get<xpu, 4, real_t>(s);
+    }
+    Tensor<xpu, 1> slope = in_data[syncbatchnorm::kGamma].get<xpu, 1, 
real_t>(s);
+    Tensor<xpu, 1> bias = in_data[syncbatchnorm::kBeta].get<xpu, 1, real_t>(s);
+    Tensor<xpu, 1> moving_mean = 
aux_states[syncbatchnorm::kMovingMean].get<xpu, 1, real_t>(s);
+    Tensor<xpu, 1> moving_var = aux_states[syncbatchnorm::kMovingVar].get<xpu, 
1, real_t>(s);
+
+    if (param_.fix_gamma) slope = 1.f;
+
+    // whether use global statistics
+    if (ctx.is_train && !param_.use_global_stats) {
+      // get my rank
+      Barrier *global_barrier = 
global_shared_barrier_forward.Register(param_.key, param_.ndev);
+      int myRank = global_shared_rank_forward.Register(param_.key, 
param_.ndev);
+      // get the mean and var
+      Tensor<xpu, 1> mean = out_data[syncbatchnorm::kMean].get<xpu, 1, 
real_t>(s);
+      Tensor<xpu, 1> var = out_data[syncbatchnorm::kVar].get<xpu, 1, 
real_t>(s);
+      CHECK(req[syncbatchnorm::kMean] == kNullOp || req[syncbatchnorm::kMean] 
== kWriteTo);
+      CHECK(req[syncbatchnorm::kVar] == kNullOp || req[syncbatchnorm::kVar] == 
kWriteTo);
+      // E(x) and E(x^2)
+      mean = scale * sumall_except_dim<1>(data);
+      var = scale * sumall_except_dim<1>(F<mshadow_op::square>(data));
+      SharedND<mshadow::Tensor<cpu, 1, real_t>> *sharedMean =
+        global_shared_mean.Register(param_.key, param_.ndev);
+      SharedND<mshadow::Tensor<cpu, 1, real_t>> *sharedVar =
+        global_shared_var.Register(param_.key, param_.ndev);
+      // copy to cpu, push and pull
+      Tensor<cpu, 1, real_t>* mean_cpu_ptr = sharedMean->Retrieve(mean.shape_, 
myRank);
+      Tensor<cpu, 1, real_t>* var_cpu_ptr = sharedVar->Retrieve(mean.shape_, 
myRank);
+      mshadow::Copy(*mean_cpu_ptr, mean, s);
+      mshadow::Copy(*var_cpu_ptr, var, s);
+      sharedMean->SetReady(myRank);
+      sharedVar->SetReady(myRank);
+      global_barrier->Wait();
+      Tensor<cpu, 1, real_t> mean_cpu = sharedMean->Pop(myRank);
+      Tensor<cpu, 1, real_t> var_cpu = sharedVar->Pop(myRank);
+      // copy back to gpu
+      mshadow::Copy(mean, mean_cpu, s);
+      mshadow::Copy(var, var_cpu, s);
+
+      var = var-F<mshadow_op::square>(mean);
+      Assign(out, req[syncbatchnorm::kOut], broadcast<1>(slope, out.shape_) *
+             (data - broadcast<1>(mean, data.shape_)) /
+             F<mshadow_op::square_root>(broadcast<1>(var + param_.eps, 
data.shape_)) +
+             broadcast<1>(bias, out.shape_));
+    } else {
+      Assign(out, req[syncbatchnorm::kOut], broadcast<1>(slope /
+                                          
F<mshadow_op::square_root>(moving_var + param_.eps),
+                                          data.shape_) * data +
+             broadcast<1>(bias - (slope * moving_mean) /
+                          F<mshadow_op::square_root>(moving_var + param_.eps), 
data.shape_));
+    }
+  }
+
+  virtual void Backward(const OpContext &ctx,
+                        const std::vector<TBlob> &out_grad,
+                        const std::vector<TBlob> &in_data,
+                        const std::vector<TBlob> &out_data,
+                        const std::vector<OpReqType> &req,
+                        const std::vector<TBlob> &in_grad,
+                        const std::vector<TBlob> &aux_states) {
+    using namespace mshadow;
+    using namespace mshadow::expr;
+    CHECK_EQ(out_grad.size(), param_.output_mean_var ? 3U : 1U);
+    CHECK_EQ(in_data.size(), 3U);
+    CHECK_EQ(out_data.size(), 3U);
+    CHECK_EQ(in_grad.size(), 3U);
+
+    Stream<xpu> *s = ctx.get_stream<xpu>();
+    Tensor<xpu, 4> data, grad, grad_in;
+    const real_t scale = 
static_cast<real_t>(out_grad[syncbatchnorm::kOut].shape_[1]) /
+      static_cast<real_t>(out_grad[syncbatchnorm::kOut].shape_.Size());
+    if (in_data[syncbatchnorm::kData].ndim() == 2) {
+      Shape<4> dshape = Shape4(out_grad[syncbatchnorm::kOut].shape_[0],
+                               out_grad[syncbatchnorm::kOut].shape_[1], 1, 1);
+      data = in_data[syncbatchnorm::kData].get_with_shape<xpu, 4, 
real_t>(dshape, s);
+      grad = out_grad[syncbatchnorm::kOut].get_with_shape<xpu, 4, 
real_t>(dshape, s);
+      grad_in = in_grad[syncbatchnorm::kData].get_with_shape<xpu, 4, 
real_t>(dshape, s);
+    } else {
+      data = in_data[syncbatchnorm::kData].get<xpu, 4, real_t>(s);
+      grad = out_grad[syncbatchnorm::kOut].get<xpu, 4, real_t>(s);
+      grad_in = in_grad[syncbatchnorm::kData].get<xpu, 4, real_t>(s);
+    }
+
+    Tensor<xpu, 1> mean = out_data[syncbatchnorm::kMean].get<xpu, 1, 
real_t>(s);
+    Tensor<xpu, 1> var = out_data[syncbatchnorm::kVar].get<xpu, 1, real_t>(s);
+    Tensor<xpu, 1> slope = in_data[syncbatchnorm::kGamma].get<xpu, 1, 
real_t>(s);
+    // Tensor<xpu, 1> bias = in_data[kBeta].get<xpu, 1, real_t>(s);
+    Tensor<xpu, 1> gslope = in_grad[syncbatchnorm::kGamma].get<xpu, 1, 
real_t>(s);
+    Tensor<xpu, 1> gbias = in_grad[syncbatchnorm::kBeta].get<xpu, 1, 
real_t>(s);
+    // update moving avg
+    Tensor<xpu, 1> moving_mean = 
aux_states[syncbatchnorm::kMovingMean].get<xpu, 1, real_t>(s);
+    Tensor<xpu, 1> moving_var = aux_states[syncbatchnorm::kMovingVar].get<xpu, 
1, real_t>(s);
+
+    if (param_.fix_gamma) slope = 1.f;
+
+    if (ctx.is_train && !param_.use_global_stats) {
+      // get my rank
+      Barrier *global_barrier = 
global_shared_barrier_backward.Register(param_.key, param_.ndev);
+      int myRank = global_shared_rank_backward.Register(param_.key, 
param_.ndev);
+      // get requested temp space
+      Tensor<xpu, 2> workspace = 
ctx.requested[syncbatchnorm::kTempSpace].get_space<xpu>(
+          mshadow::Shape2(5, mean.shape_[0]), s);
+      Tensor<xpu, 1> gmean = workspace[0];
+      Tensor<xpu, 1> gvar = workspace[1];
+      // Tensor<xpu, 1> tmp = workspace[2];
 
 Review comment:
   remove unused

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to