zhanghang1989 closed pull request #11785: [MXNET-683] [WIP] Adding Inplace 
Activated Batch Normalization
URL: https://github.com/apache/incubator-mxnet/pull/11785
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/gluon/contrib/nn/basic_layers.py 
b/python/mxnet/gluon/contrib/nn/basic_layers.py
index c656cd2d4e1..b3685ac507b 100644
--- a/python/mxnet/gluon/contrib/nn/basic_layers.py
+++ b/python/mxnet/gluon/contrib/nn/basic_layers.py
@@ -19,7 +19,7 @@
 # pylint: disable= arguments-differ
 """Custom neural network layers in model_zoo."""
 __all__ = ['Concurrent', 'HybridConcurrent', 'Identity', 'SparseEmbedding',
-           'SyncBatchNorm']
+           'SyncBatchNorm', 'InplaceABN']
 
 import warnings
 from .... import nd, test_utils
@@ -235,3 +235,81 @@ def _get_num_devices(self):
     def hybrid_forward(self, F, x, gamma, beta, running_mean, running_var):
         return F.contrib.SyncBatchNorm(x, gamma, beta, running_mean, 
running_var,
                                        name='fwd', **self._kwargs)
+
+
+class InplaceABN(BatchNorm):
+    """Inplace Activated Batch normalization (InplaceABN)
+
+    Inplace ABN acts the same as standard BatchNorm with LeakyReLU activation.
+    It saves the memory by recalculating featuremaps.
+    Parameters
+    ----------
+    in_channels : int, default 0
+        Number of channels (feature maps) in input data. If not specified,
+        initialization will be deferred to the first time `forward` is called
+        and `in_channels` will be inferred from the shape of input data.
+    sync : bool, default False
+        Synchronizing across GPUs, see 
:class:`mxnet.gluon.contrib.nn.SyncBatchNorm`
+        for detail.
+    num_devices : int, default number of visible GPUs
+    slope : float, default 0.01
+        slope for LeakyReLU.
+    momentum: float, default 0.9
+        Momentum for the moving average.
+    epsilon: float, default 1e-5
+        Small float added to variance to avoid dividing by zero.
+    center: bool, default True
+        If True, add offset of `beta` to normalized tensor.
+        If False, `beta` is ignored.
+    use_global_stats: bool, default False
+        If True, use global moving statistics instead of local batch-norm. 
This will force
+        change batch-norm into a scale shift operator.
+        If False, use local batch-norm.
+    beta_initializer: str or `Initializer`, default 'zeros'
+        Initializer for the beta weight.
+    gamma_initializer: str or `Initializer`, default 'ones'
+        Initializer for the gamma weight.
+    moving_mean_initializer: str or `Initializer`, default 'zeros'
+        Initializer for the moving mean.
+    moving_variance_initializer: str or `Initializer`, default 'ones'
+        Initializer for the moving variance.
+
+
+    Inputs:
+        - **data**: input tensor with arbitrary shape.
+    Outputs:
+        - **out**: output tensor with the same shape as `data`.
+
+    Reference:
+        .. [1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: 
Accelerating \
+          deep network training by reducing internal covariate shift." *ICML 
2015*
+        .. [2] Bulò, Samuel Rota, Lorenzo Porzi, and Peter Kontschieder. \
+          In-Place Activated BatchNorm for Memory-Optimized Training of DNNs 
CVPR 2018
+
+    """
+    def __init__(self, in_channels=0, sync=False, num_devices=None, slope=0.01,
+                 momentum=0.9, epsilon=1e-5, center=True, 
use_global_stats=False,
+                 beta_initializer='zeros', gamma_initializer='ones',
+                 running_mean_initializer='zeros',
+                 running_variance_initializer='ones', **kwargs):
+        super(InplaceABN, self).__init__(1, momentum, epsilon, center, True, 
use_global_stats,
+                                         beta_initializer, gamma_initializer,
+                                         running_mean_initializer, 
running_variance_initializer,
+                                         in_channels, **kwargs)
+        num_devices = 1 if not sync else self._get_num_devices() \
+            if num_devices is None else num_devices
+        self._kwargs = {'eps': epsilon, 'momentum': momentum,
+                        'use_global_stats': use_global_stats, 'sync' : sync,
+                        'ndev': num_devices, 'slope' : slope, 'key': 
self.prefix}
+
+    def _get_num_devices(self):
+        warnings.warn("Caution using InplaceABN: "
+                      "if not using all the GPUs, please mannually set 
num_devices",
+                      UserWarning)
+        num_devices = len(test_utils.list_gpus())
+        num_devices = num_devices if num_devices > 0 else 1
+        return num_devices
+
+    def hybrid_forward(self, F, x, gamma, beta, running_mean, running_var):
+        return F.contrib.InplaceABN(x, gamma, beta, running_mean, running_var,
+                                    name='fwd', **self._kwargs)
diff --git a/src/operator/contrib/inplace_abn-inl.h 
b/src/operator/contrib/inplace_abn-inl.h
new file mode 100644
index 00000000000..5a86b92b348
--- /dev/null
+++ b/src/operator/contrib/inplace_abn-inl.h
@@ -0,0 +1,503 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file inplace_abn-inl.h
+ * \brief Inplace Activated BatchNorm
+ * modified from sync_batch_norm-inl.h
+ * \author Hang Zhang
+*/
+#ifndef MXNET_OPERATOR_CONTRIB_INPLACE_ABN_INL_H_
+#define MXNET_OPERATOR_CONTRIB_INPLACE_ABN_INL_H_
+
+#include <dmlc/logging.h>
+#include <dmlc/parameter.h>
+#include <mxnet/operator.h>
+#include <condition_variable>
+#include <map>
+#include <vector>
+#include <string>
+#include <utility>
+#include "../operator_common.h"
+#include "../mshadow_op.h"
+#include "sync_batch_norm-inl.h"
+
+namespace mxnet {
+namespace op {
+
+namespace inplaceabn {
+enum InplaceABNInputs {kData, kGamma, kBeta};
+enum InplaceABNOutputs {kOut, kMean, kVar};
+enum InplaceABNAuxiliary {kMovingMean, kMovingVar};
+enum BatchNormBackResource {kTempSpace};
+}  // namespace inplaceabn
+
+struct InplaceABNParam : public dmlc::Parameter<InplaceABNParam> {
+  float eps;
+  float momentum;
+  float slope;
+  bool use_global_stats;
+  bool output_mean_var;
+  bool sync;
+  int ndev;
+  std::string key;
+  DMLC_DECLARE_PARAMETER(InplaceABNParam) {
+    DMLC_DECLARE_FIELD(eps).set_default(1e-3f)
+    .describe("Epsilon to prevent div 0");
+    DMLC_DECLARE_FIELD(momentum).set_default(0.9f)
+    .describe("Momentum for moving average");
+    DMLC_DECLARE_FIELD(use_global_stats).set_default(false)
+    .describe("Whether use global moving statistics instead of local 
batch-norm. "
+              "This will force change batch-norm into a scale shift 
operator.");
+    DMLC_DECLARE_FIELD(output_mean_var).set_default(false)
+    .describe("Output All,normal mean and var");
+    DMLC_DECLARE_FIELD(slope).set_default(0.01f)
+    .describe("Init slope for the activation. (For leaky and elu only)");
+    DMLC_DECLARE_FIELD(sync).set_default(false)
+    .describe("Syncrhonize the BatchNorm using global mean and var.");
+    DMLC_DECLARE_FIELD(ndev).set_default(1)
+    .describe("The count of GPU devices");
+    DMLC_DECLARE_FIELD(key)
+      .set_default("")
+      .describe("Hash key for synchronization");
+  }
+};
+
+// Global variables for Synchronizations
+static GlobalSharedRank<int> inpabn_global_shared_rank_forward;
+static GlobalSharedRank<int> inpabn_global_shared_rank_backward;
+static GlobalShared<Barrier> inpabn_global_shared_barrier_forward;
+static GlobalShared<Barrier> inpabn_global_shared_barrier_backward;
+static GlobalShared<SharedND<mshadow::Tensor<cpu, 1, real_t>>> 
inp_abn_global_shared_mean;
+static GlobalShared<SharedND<mshadow::Tensor<cpu, 1, real_t>>> 
inpabn_global_shared_var;
+static GlobalShared<SharedND<mshadow::Tensor<cpu, 1, real_t>>> 
inpabn_global_shared_grad;
+static GlobalShared<SharedND<mshadow::Tensor<cpu, 1, real_t>>> 
inpabn_global_shared_prod;
+
+template<typename xpu>
+class InplaceABN : public Operator {
+ public:
+  explicit InplaceABN(InplaceABNParam param) {
+    this->param_ = param;
+  }
+
+  virtual void Forward(const OpContext &ctx,
+                       const std::vector<TBlob> &in_data,
+                       const std::vector<OpReqType> &req,
+                       const std::vector<TBlob> &out_data,
+                       const std::vector<TBlob> &aux_states) {
+    using namespace mshadow;
+    using namespace mshadow::expr;
+    CHECK_EQ(in_data.size(), 3U);
+    CHECK_EQ(aux_states.size(), 2U);
+    if (ctx.is_train) {
+      CHECK_EQ(out_data.size(), 3U);
+      CHECK_EQ(req.size(), 3U);
+      // CHECK_EQ(req[inplaceabn::kOut], kWriteInplace);
+    } else {
+      CHECK_GE(out_data.size(), 1U);
+      CHECK_GE(req.size(), 1U);
+      // CHECK_EQ(req[inplaceabn::kOut], kWriteInplace);
+    }
+
+    Stream<xpu> *s = ctx.get_stream<xpu>();
+    const real_t scale = 
static_cast<real_t>(in_data[inplaceabn::kData].shape_[1]) /
+      static_cast<real_t>(in_data[inplaceabn::kData].shape_.Size());
+    Tensor<xpu, 4> data;
+    Tensor<xpu, 4> out;
+    if (in_data[inplaceabn::kData].ndim() == 2) {
+      Shape<4> dshape = Shape4(in_data[inplaceabn::kData].shape_[0],
+                               in_data[inplaceabn::kData].shape_[1], 1, 1);
+      data = in_data[inplaceabn::kData].get_with_shape<xpu, 4, real_t>(dshape, 
s);
+      out = out_data[inplaceabn::kOut].get_with_shape<xpu, 4, real_t>(dshape, 
s);
+    } else {
+      data = in_data[inplaceabn::kData].get<xpu, 4, real_t>(s);
+      out = out_data[inplaceabn::kOut].get<xpu, 4, real_t>(s);
+    }
+    Tensor<xpu, 1> gamma = in_data[inplaceabn::kGamma].get<xpu, 1, real_t>(s);
+    Tensor<xpu, 1> beta = in_data[inplaceabn::kBeta].get<xpu, 1, real_t>(s);
+    Tensor<xpu, 1> moving_mean = aux_states[inplaceabn::kMovingMean].get<xpu, 
1, real_t>(s);
+    Tensor<xpu, 1> moving_var = aux_states[inplaceabn::kMovingVar].get<xpu, 1, 
real_t>(s);
+
+    // whether use global statistics
+    if (ctx.is_train && !param_.use_global_stats) {
+      // get the mean and var
+      Tensor<xpu, 1> mean = out_data[syncbatchnorm::kMean].get<xpu, 1, 
real_t>(s);
+      Tensor<xpu, 1> var = out_data[syncbatchnorm::kVar].get<xpu, 1, 
real_t>(s);
+      CHECK(req[syncbatchnorm::kMean] == kNullOp || req[syncbatchnorm::kMean] 
== kWriteTo);
+      CHECK(req[syncbatchnorm::kVar] == kNullOp || req[syncbatchnorm::kVar] == 
kWriteTo);
+      // E(x) and E(x^2)
+      mean = scale * sumall_except_dim<1>(data);
+      var = scale * sumall_except_dim<1>(F<mshadow_op::square>(data));
+      // whether use synchronized batch normalization
+      if (param_.sync) {
+        // get my rank
+        Barrier *global_barrier = 
inpabn_global_shared_barrier_forward.Register(param_.key, param_.ndev);
+        int myRank = inpabn_global_shared_rank_forward.Register(param_.key, 
param_.ndev);
+        SharedND<mshadow::Tensor<cpu, 1, real_t>> *sharedMean =
+          inp_abn_global_shared_mean.Register(param_.key, param_.ndev);
+        SharedND<mshadow::Tensor<cpu, 1, real_t>> *sharedVar =
+          inpabn_global_shared_var.Register(param_.key, param_.ndev);
+        // copy to cpu, push and pull
+        Tensor<cpu, 1, real_t>* mean_cpu_ptr = 
sharedMean->Retrieve(mean.shape_, myRank);
+        Tensor<cpu, 1, real_t>* var_cpu_ptr = sharedVar->Retrieve(mean.shape_, 
myRank);
+        mshadow::Copy(*mean_cpu_ptr, mean, s);
+        mshadow::Copy(*var_cpu_ptr, var, s);
+        sharedMean->SetReady(myRank);
+        sharedVar->SetReady(myRank);
+        global_barrier->Wait();
+        Tensor<cpu, 1, real_t> mean_cpu = sharedMean->Pop(myRank);
+        Tensor<cpu, 1, real_t> var_cpu = sharedVar->Pop(myRank);
+        // copy back to gpu
+        mshadow::Copy(mean, mean_cpu, s);
+        mshadow::Copy(var, var_cpu, s);
+      }
+      var = var - F<mshadow_op::square>(mean);
+      // batch normalization
+      /*
+      Tensor<xpu, 4> tmp = 
ctx.requested[syncbatchnorm::kTempSpace].get_space<xpu>(
+          out.shape_, s);
+      tmp = broadcast<1>(gamma, out.shape_) *
+         (data - broadcast<1>(mean, data.shape_)) /
+         F<mshadow_op::square_root>(broadcast<1>(var + param_.eps, 
data.shape_)) +
+         broadcast<1>(beta, out.shape_);
+      */
+      // update running mean and var
+      moving_mean = moving_mean * param_.momentum + mean * (1.0f - 
param_.momentum);
+      moving_var = moving_var * param_.momentum + var / scale / (1.0f / scale 
- 1.0f)
+          * (1.0f - param_.momentum);
+      Assign(out, req[inplaceabn::kOut], broadcast<1>(gamma, out.shape_) *
+         (data - broadcast<1>(mean, data.shape_)) /
+         F<mshadow_op::square_root>(broadcast<1>(var + param_.eps, 
data.shape_)) +
+         broadcast<1>(beta, out.shape_));
+      // leaky relu forward
+      MXNET_ASSIGN_REQ_SWITCH(req[inplaceabn::kOut], Req, {
+        mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::xelu, Req>, 
xpu>::Launch(
+          s, out.size(0) * out.size(1) * out.size(2) * out.size(3), out.dptr_,
+          out.dptr_, real_t(param_.slope));
+      });
+    } else {
+      /*
+      Tensor<xpu, 4> tmp = 
ctx.requested[syncbatchnorm::kTempSpace].get_space<xpu>(
+          out.shape_, s);
+      */
+      Assign(out, req[inplaceabn::kOut],
+             broadcast<1>(gamma / F<mshadow_op::square_root>(moving_var + 
param_.eps),
+                          data.shape_) * data +
+             broadcast<1>(beta - (gamma * moving_mean) /
+                          F<mshadow_op::square_root>(moving_var + param_.eps), 
data.shape_));
+      // leaky relu forward
+      MXNET_ASSIGN_REQ_SWITCH(req[inplaceabn::kOut], Req, {
+        mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::xelu, Req>, 
xpu>::Launch(
+          s, out.size(0) * out.size(1) * out.size(2) * out.size(3), out.dptr_,
+          out.dptr_, real_t(param_.slope));
+      });
+    }
+  }
+
+  virtual void Backward(const OpContext &ctx,
+                        const std::vector<TBlob> &out_grad,
+                        const std::vector<TBlob> &in_data,
+                        const std::vector<TBlob> &out_data,
+                        const std::vector<OpReqType> &req,
+                        const std::vector<TBlob> &in_grad,
+                        const std::vector<TBlob> &aux_states) {
+    using namespace mshadow;
+    using namespace mshadow::expr;
+    CHECK_EQ(out_grad.size(), param_.output_mean_var ? 3U : 1U);
+    CHECK_EQ(in_data.size(), 3U);
+    CHECK_EQ(out_data.size(), 3U);
+    CHECK_EQ(in_grad.size(), 3U);
+
+    Stream<xpu> *s = ctx.get_stream<xpu>();
+    Tensor<xpu, 4> out, grad_out, grad_in;
+    const real_t scale = 
static_cast<real_t>(out_grad[inplaceabn::kOut].shape_[1]) /
+      static_cast<real_t>(out_grad[inplaceabn::kOut].shape_.Size());
+    if (out_data[inplaceabn::kOut].ndim() == 2) {
+      Shape<4> dshape = Shape4(out_grad[inplaceabn::kOut].shape_[0],
+                               out_grad[inplaceabn::kOut].shape_[1], 1, 1);
+      // data is the output
+      out = out_data[inplaceabn::kOut].get_with_shape<xpu, 4, real_t>(dshape, 
s);
+      grad_out = out_grad[inplaceabn::kOut].get_with_shape<xpu, 4, 
real_t>(dshape, s);
+      grad_in = in_grad[inplaceabn::kData].get_with_shape<xpu, 4, 
real_t>(dshape, s);
+    } else {
+      // data is the output
+      out = out_data[inplaceabn::kOut].get<xpu, 4, real_t>(s);
+      grad_out = out_grad[inplaceabn::kOut].get<xpu, 4, real_t>(s);
+      grad_in = in_grad[inplaceabn::kData].get<xpu, 4, real_t>(s);
+    }
+    Tensor<xpu, 1> mean = out_data[inplaceabn::kMean].get<xpu, 1, real_t>(s);
+    Tensor<xpu, 1> var = out_data[inplaceabn::kVar].get<xpu, 1, real_t>(s);
+    Tensor<xpu, 1> gamma = in_data[inplaceabn::kGamma].get<xpu, 1, real_t>(s);
+    Tensor<xpu, 1> beta = in_data[inplaceabn::kBeta].get<xpu, 1, real_t>(s);
+    Tensor<xpu, 1> ggamma = in_grad[inplaceabn::kGamma].get<xpu, 1, real_t>(s);
+    Tensor<xpu, 1> gbeta = in_grad[inplaceabn::kBeta].get<xpu, 1, real_t>(s);
+    // Tensor<xpu, 1> moving_mean = 
aux_states[inplaceabn::kMovingMean].get<xpu, 1, real_t>(s);
+    Tensor<xpu, 1> moving_var = aux_states[inplaceabn::kMovingVar].get<xpu, 1, 
real_t>(s);
+    // get the work space
+    size_t data_size = out.shape_[0] * out.shape_[1] * out.shape_[2] * 
out.shape_[3];
+    size_t mean_size = mean.shape_[0];
+    Tensor<xpu, 1> workspace = 
ctx.requested[syncbatchnorm::kTempSpace].get_space<xpu>(
+        mshadow::Shape1(2*(data_size + mean_size)), s);
+    real_t *data_ptr = workspace.dptr_;
+    real_t *grad_ptr = workspace.dptr_ + data_size;
+    real_t *sum_grad_ptr = workspace.dptr_ + 2 * data_size;
+    real_t *sum_prod_ptr = workspace.dptr_ + 2 * data_size + mean_size;
+    Tensor<xpu, 4> data_y(data_ptr, out.shape_, s);
+    Tensor<xpu, 4> grad_y(grad_ptr, out.shape_, s);
+    Tensor<xpu, 1> sumGrad(sum_grad_ptr, mean.shape_, s);
+    Tensor<xpu, 1> sumProd(sum_prod_ptr, mean.shape_, s);
+    /*
+    // hacky, inplace memory
+    Tensor<xpu, 4> data_y(out.dptr_, out.shape_, s);
+    Tensor<xpu, 4> grad_y(grad_out.dptr_, out.shape_, s);
+    Tensor<xpu, 2> workspace = 
ctx.requested[syncbatchnorm::kTempSpace].get_space<xpu>(
+          mshadow::Shape2(2, mean.shape_[0]), s);
+    Tensor<xpu, 1> sumGrad = workspace[0];
+    Tensor<xpu, 1> sumProd = workspace[1];
+    */
+    // recover y and dl/dy
+    mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::xelu, kWriteTo>, 
xpu>::Launch(
+      s, data_size, data_y.dptr_, out.dptr_, real_t(1.0f / param_.slope));
+    mxnet_op::Kernel<mxnet_op::op_with_req<
+      mxnet_op::backward_grad_tuned<mshadow_op::xelu_grad>, kWriteTo>, 
xpu>::Launch(
+        s, data_size, grad_y.dptr_, grad_out.dptr_,
+        out.dptr_, real_t(param_.slope));
+
+    if (ctx.is_train && !param_.use_global_stats) {
+      // cal
+      sumGrad = sumall_except_dim<1>(grad_y);
+      sumProd = sumall_except_dim<1>(grad_y * data_y);
+      Assign(ggamma, req[inplaceabn::kGamma], (sumProd - beta * sumGrad) / 
gamma);
+      Assign(gbeta, req[inplaceabn::kBeta], 1.0f * sumGrad);
+      // whether use synchronized batch normalization
+      if (param_.sync) {
+        // get my rank
+        Barrier *global_barrier = 
inpabn_global_shared_barrier_backward.Register(
+           param_.key, param_.ndev);
+        int myRank = inpabn_global_shared_rank_backward.Register(param_.key, 
param_.ndev);
+        SharedND<mshadow::Tensor<cpu, 1, real_t>> *sharedGrad =
+          inpabn_global_shared_grad.Register(param_.key, param_.ndev);
+        SharedND<mshadow::Tensor<cpu, 1, real_t>> *sharedProd =
+          inpabn_global_shared_prod.Register(param_.key, param_.ndev);
+        // copy to cpu, push and pull
+        Tensor<cpu, 1, real_t>* grad_cpu_ptr = 
sharedGrad->Retrieve(sumGrad.shape_, myRank);
+        Tensor<cpu, 1, real_t>* prod_cpu_ptr = 
sharedProd->Retrieve(sumGrad.shape_, myRank);
+        mshadow::Copy(*grad_cpu_ptr, sumGrad, s);
+        mshadow::Copy(*prod_cpu_ptr, sumProd, s);
+        sharedGrad->SetReady(myRank);
+        sharedProd->SetReady(myRank);
+        global_barrier->Wait();
+        Tensor<cpu, 1, real_t> grad_cpu = sharedGrad->Pop(myRank);
+        Tensor<cpu, 1, real_t> prod_cpu = sharedProd->Pop(myRank);
+        // copy back to gpu
+        mshadow::Copy(sumGrad, grad_cpu, s);
+        mshadow::Copy(sumProd, prod_cpu, s);
+      }
+      // assign
+      Assign(grad_in, req[inplaceabn::kData],
+             grad_y * broadcast<1>(1.0f * gamma / 
F<mshadow_op::square_root>(var + param_.eps),
+                            out.shape_) -
+             scale * broadcast<1>(1.0f / gamma/ F<mshadow_op::square_root>(var 
+ param_.eps),
+                                  out.shape_) *
+               (broadcast<1>(sumProd, out.shape_) - broadcast<1>(beta, 
out.shape_) *
+                broadcast<1>(sumGrad, out.shape_)) * data_y -
+             scale * broadcast<1>(1.0f * gamma * sumGrad /
+                                  F<mshadow_op::square_root>(var + param_.eps),
+                                  out.shape_) -
+             scale * broadcast<1>(1.0f * beta / gamma /
+                                  F<mshadow_op::square_root>(var + param_.eps),
+                                  out.shape_) *
+               (broadcast<1>(sumProd, out.shape_) -
+                broadcast<1>(beta * sumGrad, out.shape_)));
+    } else {
+      // use global statistics with freeze moving mean and var.
+      Assign(ggamma, req[inplaceabn::kGamma],
+             sumall_except_dim<1>((grad_y * (data_y - broadcast<1>(beta, 
out.shape_))) *
+               broadcast<1>(1.0f / gamma, out.shape_)));
+      Assign(gbeta, req[inplaceabn::kBeta], sumall_except_dim<1>(grad_y));
+      Assign(grad_in, req[inplaceabn::kData], (grad_y * broadcast<1>(gamma, 
out.shape_)) *
+             broadcast<1>(
+               1.0f / F<mshadow_op::square_root>(moving_var + param_.eps), 
out.shape_));
+    }
+  }
+
+ private:
+  InplaceABNParam param_;
+};  // class InplaceABN
+
+template<typename xpu>
+Operator *CreateOp(InplaceABNParam param, int dtype);
+
+#if DMLC_USE_CXX11
+class InplaceABNProp : public OperatorProperty {
+ public:
+  void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) 
override {
+    param_.Init(kwargs);
+  }
+
+  std::map<std::string, std::string> GetParams() const override {
+    return param_.__DICT__();
+  }
+
+  bool InferShape(std::vector<TShape> *in_shape,
+                  std::vector<TShape> *out_shape,
+                  std::vector<TShape> *aux_shape) const override {
+    using namespace mshadow;
+    CHECK_EQ(in_shape->size(), 3U) << "Input:[data, gamma, beta]";
+    const TShape &dshape = in_shape->at(0);
+    if (dshape.ndim() == 0) return false;
+    in_shape->at(1) = TShape(Shape1(dshape[1]));
+    in_shape->at(2) = TShape(Shape1(dshape[1]));
+    out_shape->clear();
+    out_shape->push_back(dshape);
+    out_shape->push_back(Shape1(dshape[1]));
+    out_shape->push_back(Shape1(dshape[1]));
+
+    aux_shape->clear();
+    aux_shape->push_back(Shape1(dshape[1]));
+    aux_shape->push_back(Shape1(dshape[1]));
+    return true;
+  }
+
+  bool InferType(std::vector<int> *in_type,
+                 std::vector<int> *out_type,
+                 std::vector<int> *aux_type) const override {
+    using namespace mshadow;
+    CHECK_GE(in_type->size(), 1U);
+    int dtype = (*in_type)[0];
+    CHECK_NE(dtype, -1) << "First input must have specified type";
+    // For float16 input type beta, gamma, mean, and average are stored in 
float32.
+    // For other input types, these parameters have the same type as input
+    // NOTE: This requirement is from cuDNN (v. 4 and 5)
+    int dtype_param = (dtype == kFloat16) ? kFloat32 : dtype;
+    for (index_t i = 1; i < in_type->size(); ++i) {
+      if ((*in_type)[i] == -1) {
+        (*in_type)[i] = dtype_param;
+      } else {
+        UNIFORM_TYPE_CHECK((*in_type)[i], dtype_param, ListArguments()[i]);
+      }
+    }
+    for (index_t i = 0; i < aux_type->size(); ++i) {
+      if ((*aux_type)[i] != -1) {
+        UNIFORM_TYPE_CHECK((*aux_type)[i], dtype_param, ListArguments()[i]);
+      }
+    }
+    int n_aux = this->ListAuxiliaryStates().size();
+    aux_type->clear();
+    for (int i = 0; i < n_aux; ++i ) aux_type->push_back(dtype_param);
+    int n_out = this->ListOutputs().size();
+    out_type->clear();
+    out_type->push_back(dtype);
+    for (int i = 1; i < n_out; ++i ) out_type->push_back(dtype_param);
+    return true;
+  }
+
+  OperatorProperty* Copy() const override {
+    auto ptr = new InplaceABNProp();
+    ptr->param_ = param_;
+    return ptr;
+  }
+
+  std::string TypeString() const override {
+    return "_contrib_InplaceABN";
+  }
+
+  std::vector<int> DeclareBackwardDependency(
+    const std::vector<int> &out_grad,
+    const std::vector<int> &in_data,
+    const std::vector<int> &out_data) const override {
+    return {out_grad[inplaceabn::kOut],
+            out_data[inplaceabn::kOut],
+            out_data[inplaceabn::kMean],
+            out_data[inplaceabn::kVar],
+            // in_data[inplaceabn::kData],
+            in_data[inplaceabn::kGamma],
+            in_data[inplaceabn::kBeta]
+           };
+  }
+
+  std::vector<ResourceRequest> ForwardResource(
+      const std::vector<TShape> &in_shape) const override {
+    return {ResourceRequest::kTempSpace};
+  }
+
+  std::vector<ResourceRequest> BackwardResource(
+      const std::vector<TShape> &in_shape) const override {
+    return {ResourceRequest::kTempSpace};
+  }
+
+  int NumVisibleOutputs() const override {
+    if (param_.output_mean_var) {
+      return 3;
+    }
+    return 1;
+  }
+
+  int NumOutputs() const override {
+    return 3;
+  }
+
+  std::vector<std::string> ListArguments() const override {
+    return {"data", "gamma", "beta"};
+  }
+
+  std::vector<std::string> ListOutputs() const override {
+    return {"output", "mean", "var"};
+  }
+
+  std::vector<std::string> ListAuxiliaryStates() const override {
+    return {"moving_mean", "moving_var"};
+  }
+
+  Operator* CreateOperator(Context ctx) const override {
+      LOG(FATAL) << "Not Implemented.";
+      return NULL;
+  }
+
+  Operator* CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
+      std::vector<int> *in_type) const override;
+
+  inline const InplaceABNParam& getParam() const {
+    return param_;
+  }
+
+  /*
+  std::vector<std::pair<int, void*> > ForwardInplaceOption(
+      const std::vector<int> &in_data,
+      const std::vector<void*> &out_data) const {
+    return {{in_data[inplaceabn::kData], out_data[inplaceabn::kOut]}};
+  }
+
+  std::vector<std::pair<int, void*> > BackwardInplaceOption(
+      const std::vector<int> &out_grad,
+      const std::vector<int> &in_data,
+      const std::vector<int> &out_data,
+      const std::vector<void*> &in_grad) const override {
+    return {{out_grad[inplaceabn::kOut], in_grad[inplaceabn::kData]}};
+  }
+  */
+
+ private:
+  InplaceABNParam param_;
+};  // class InplaceABNProp
+
+#endif  // DMLC_USE_CXX11
+}  // namespace op
+}  // namespace mxnet
+#endif  // MXNET_OPERATOR_CONTRIB_INPLACE_ABN_INL_H_
diff --git a/src/operator/contrib/inplace_abn.cc 
b/src/operator/contrib/inplace_abn.cc
new file mode 100644
index 00000000000..2faa61fff78
--- /dev/null
+++ b/src/operator/contrib/inplace_abn.cc
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file sync_batch_norm.cc
+ * \brief Synchronized BatchNorm modified from BatchNormV1
+ * \author Hang Zhang
+*/
+
+#include "inplace_abn-inl.h"
+#include <nnvm/op_attr_types.h>
+
+namespace mxnet {
+namespace op {
+template<>
+Operator *CreateOp<cpu>(InplaceABNParam param, int dtype) {
+  return new InplaceABN<cpu>(param);
+}
+
+// DO_BIND_DISPATCH comes from operator_common.h
+Operator *InplaceABNProp::CreateOperatorEx(Context ctx, std::vector<TShape> 
*in_shape,
+    std::vector<int> *in_type) const {
+    std::vector<TShape> out_shape, aux_shape;
+    std::vector<int> out_type, aux_type;
+    CHECK(InferType(in_type, &out_type, &aux_type));
+    CHECK(InferShape(in_shape, &out_shape, &aux_shape));
+    DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]);
+}
+
+DMLC_REGISTER_PARAMETER(InplaceABNParam);
+
+MXNET_REGISTER_OP_PROPERTY(_contrib_InplaceABN, InplaceABNProp)
+.describe(R"code(Inplace Activated Batch normalization [1]_.
+
+Inplace ABN acts the same as standard BatchNorm with LeakyReLU activation.
+It saves the memory by recalculating featuremaps.
+Normalizes a data batch by mean and variance, and applies a scale ``gamma`` as
+well as offset ``beta``.
+Standard BN [2]_ implementation only normalize the data within each device.
+For synchronizing the Batch Normalization using global batch,
+We follow the sync-onece implmentation described in the paper [3]_ .
+
+Assume the input has more than one dimension and we normalize along axis 1.
+We first compute the mean and variance along this axis:
+
+.. math::
+
+  data\_mean[i] = mean(data[:,i,:,...]) \\
+  data\_var[i] = var(data[:,i,:,...])
+
+Then compute the normalized output, which has the same shape as input, as 
following:
+
+.. math::
+
+  out[:,i,:,...] = \frac{data[:,i,:,...] - 
data\_mean[i]}{\sqrt{data\_var[i]+\epsilon}} * gamma[i] + beta[i]
+
+Both *mean* and *var* returns a scalar by treating the input as a vector.
+
+Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta``
+have shape *(k,)*. If ``output_mean_var`` is set to be true, then outputs both 
``data_mean`` and
+``data_var`` as well, which are needed for the backward pass.
+
+Besides the inputs and the outputs, this operator accepts two auxiliary
+states, ``moving_mean`` and ``moving_var``, which are *k*-length
+vectors. They are global statistics for the whole dataset, which are updated
+by::
+
+  moving_mean = moving_mean * momentum + data_mean * (1 - momentum)
+  moving_var = moving_var * momentum + data_var * (1 - momentum)
+
+If ``use_global_stats`` is set to be true, then ``moving_mean`` and
+``moving_var`` are used instead of ``data_mean`` and ``data_var`` to compute
+the output. It is often used during inference.
+
+Both ``gamma`` and ``beta`` are learnable parameters. 
+
+Reference:
+  .. [1] Bulò, Samuel Rota, Lorenzo Porzi, and Peter Kontschieder. "In-place 
activated batchnorm for memory-optimized training of DNNs." CVPR (2018).
+  .. [2] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: 
Accelerating
+  deep network training by reducing internal covariate shift." *ICML 2015*
+  .. [3] Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang,
+  Ambrish Tyagi, and Amit Agrawal. "Context Encoding for Semantic 
Segmentation." *CVPR 2018*
+)code" ADD_FILELINE)
+.add_argument("data", "NDArray-or-Symbol", "Input data to batch normalization")
+.add_argument("gamma", "NDArray-or-Symbol", "gamma array")
+.add_argument("beta", "NDArray-or-Symbol", "beta array")
+.add_argument("moving_mean", "NDArray-or-Symbol", "running mean of input")
+.add_argument("moving_var", "NDArray-or-Symbol", "running variance of input")
+.add_arguments(InplaceABNParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_contrib_InplaceABN)
+.set_attr<nnvm::FSetInputVarAttrOnCompose>("FSetInputVarAttrOnCompose",
+    [](const nnvm::NodeAttrs& attrs, nnvm::NodePtr var, const int index) {
+      if (var->attrs.dict.find("__init__") != var->attrs.dict.end()) return;
+      if (index == 3) {
+        var->attrs.dict["__init__"] = "[\"zero\", {}]";
+      } else if (index == 4) {
+        var->attrs.dict["__init__"] = "[\"one\", {}]";
+      }
+    });
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/contrib/inplace_abn.cu 
b/src/operator/contrib/inplace_abn.cu
new file mode 100644
index 00000000000..15198ac31b3
--- /dev/null
+++ b/src/operator/contrib/inplace_abn.cu
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file sync_batch_norm.cc
+ * \brief Synchronized BatchNorm modified from BatchNormV1
+ * \author Hang Zhang
+*/
+
+#include "inplace_abn-inl.h"
+
+namespace mxnet {
+namespace op {
+template<>
+Operator *CreateOp<gpu>(InplaceABNParam param, int dtype) {
+  return new InplaceABN<gpu>(param);
+}
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/tests/python/gpu/test_operator_gpu.py 
b/tests/python/gpu/test_operator_gpu.py
index 5612b0a647e..60b649a672b 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -1921,6 +1921,213 @@ def test_context_num_gpus():
     # Test that num_gpus reports at least one GPU, as the test is run on a GPU 
host.
     assert mx.context.num_gpus() > 0
 
+def _check_batchnorm_result(input, num_devices=1, cuda=False):
+    from mxnet.gluon.utils import split_and_load
+    def _find_bn(module):
+        if isinstance(module, (mx.gluon.nn.BatchNorm, 
mx.gluon.contrib.nn.SyncBatchNorm)):
+            return module
+        elif isinstance(module.module, (mx.gluon.nn.BatchNorm, 
mx.gluon.contrib.nn.SyncBatchNorm)):
+            return module.module
+
+        raise RuntimeError('BN not found')
+
+    def _syncParameters(bn1, bn2, ctx):
+        ctx = input.context
+        bn2.gamma.set_data(bn1.gamma.data(ctx))
+        bn2.beta.set_data(bn1.beta.data(ctx))
+        bn2.running_mean.set_data(bn1.running_mean.data(ctx))
+        bn2.running_var.set_data(bn1.running_var.data(ctx))
+
+    input1 = input.copy()
+    input2 = input.copy()
+
+    if cuda:
+        input1 = input.as_in_context(mx.gpu(0))
+        ctx_list = [mx.gpu(i) for i in range(num_devices)]
+    else:
+        ctx_list = [mx.cpu(0) for _ in range(num_devices)]
+
+    nch = input.shape[1]
+    bn1 = mx.gluon.nn.BatchNorm(in_channels=nch)
+    bn2 = mx.gluon.contrib.nn.SyncBatchNorm(in_channels=nch, 
num_devices=num_devices)
+
+    bn1.initialize(ctx=ctx_list[0])
+    bn2.initialize(ctx=ctx_list)
+
+    # using the same values for gamma and beta
+    #_syncParameters(_find_bn(bn1), _find_bn(bn2), ctx_list[0])
+
+    input1.attach_grad()
+    inputs2 = split_and_load(input2, ctx_list, batch_axis=0)
+    for xi in inputs2:
+        xi.attach_grad()
+
+    with mx.autograd.record():
+        output1 = bn1(input1)
+        output2  = [bn2(xi) for xi in inputs2]
+        loss1 = (output1 ** 2).sum()
+        loss2 = [(output ** 2).sum() for output in output2]
+        mx.autograd.backward(loss1)
+        mx.autograd.backward(loss2)
+
+    output2 = mx.nd.concat(*[output.as_in_context(input.context) for output in 
output2], dim=0)
+    # assert forwarding
+    assert_almost_equal(input1.asnumpy(), input2.asnumpy(), atol=1e-3, 
rtol=1e-3)
+    assert_almost_equal(output1.asnumpy(), output2.asnumpy(), atol=1e-3, 
rtol=1e-3)
+    assert_almost_equal(_find_bn(bn1).running_mean.data(ctx_list[0]).asnumpy(),
+                        _find_bn(bn2).running_mean.data(ctx_list[0]).asnumpy(),
+                        atol=1e-3, rtol=1e-3)
+    assert_almost_equal(_find_bn(bn1).running_var.data(ctx_list[0]).asnumpy(),
+                        _find_bn(bn2).running_var.data(ctx_list[0]).asnumpy(),
+                        atol=1e-3, rtol=1e-3)
+    input2grad = mx.nd.concat(*[output.grad.as_in_context(input.context) for 
output in inputs2], dim=0)
+    assert_almost_equal(input1.grad.asnumpy(), input2grad.asnumpy(), 
atol=1e-3, rtol=1e-3)
+
+def test_sync_batchnorm():
+    def get_num_devices():
+        for i in range(100):
+            try:
+                mx.nd.zeros((1,), ctx=mx.gpu(i))
+            except:
+                return i
+    # no need to use SyncBN with 1 gpu
+    if get_num_devices() < 2:
+        return
+    ndev = 2
+    # check with unsync version
+    for i in range(10):
+        _check_batchnorm_result(mx.nd.random.uniform(shape=(4, 1, 4, 4)),
+                                num_devices=ndev, cuda=True)
+
+class NormAct(mx.gluon.nn.BatchNorm):
+    def __init__(self, in_channels=0, slope=0.01,
+                 momentum=0.9, epsilon=1e-5, center=True, 
use_global_stats=False,
+                 beta_initializer='zeros', gamma_initializer='ones',
+                 running_mean_initializer='zeros',
+                 running_variance_initializer='ones', **kwargs):
+        super(NormAct, self).__init__(1, momentum, epsilon, center, True, 
use_global_stats,
+                                            beta_initializer, 
gamma_initializer,
+                                            running_mean_initializer, 
running_variance_initializer,
+                                            in_channels, **kwargs)
+        self.slope = slope
+
+    def hybrid_forward(self, F, x, gamma, beta, running_mean, running_var):
+        x = F.BatchNorm(x, gamma, beta, running_mean, running_var,
+                        name='fwd', **self._kwargs)
+        x = F.LeakyReLU(x, act_type='leaky', slope=self.slope, name='fwd')
+        return x
+
+def _check_inplace_abn(input, training=True, ndev=1):
+    ch = input.shape[1]
+    sync = ndev > 1
+    ctx_list = mx.gpu(0) if ndev <=1 else [mx.gpu(i) for i in range(ndev)]
+    layer1 = NormAct(in_channels=ch, slope=0.1)
+    layer2 = mx.gluon.contrib.nn.InplaceABN(in_channels=ch, slope=0.1)
+
+    layer1.initialize(ctx=ctx_list)
+    layer2.initialize(ctx=ctx_list)
+
+    input1 = input.copy()
+    input2 = input.copy()
+    input1.attach_grad()
+    input2.attach_grad()
+    if not training:
+        output1 = layer1(input1)
+        output2 = layer2(input2)
+        assert_almost_equal(output1.asnumpy(), output2.asnumpy(), atol=1e-3, 
rtol=1e-3)
+        return
+
+    with mx.autograd.record():
+        output1 = layer1(input1)
+        output2 = layer2(input2)
+        loss1 = (output1 ** 2).sum()
+        loss2 = (output2 ** 2).sum()
+        mx.autograd.backward(loss1)
+        mx.autograd.backward(loss2)
+
+    assert_almost_equal(input1.asnumpy(), input2.asnumpy(), atol=1e-5, 
rtol=1e-3)
+    assert_almost_equal(output1.asnumpy(), output2.asnumpy(), atol=1e-5, 
rtol=1e-3)
+    assert_almost_equal(loss1.asnumpy(), loss2.asnumpy(), atol=1e-5, rtol=1e-3)
+    assert_almost_equal(input1.grad.asnumpy(), input2.grad.asnumpy(), 
atol=1e-5, rtol=1e-3)
+    assert_almost_equal(layer1.running_mean.data(mx.gpu(0)).asnumpy(),
+                        layer2.running_mean.data(mx.gpu(0)).asnumpy(),
+                        atol=1e-5, rtol=1e-3)
+    assert_almost_equal(layer1.running_var.data(mx.gpu(0)).asnumpy(),
+                        layer2.running_var.data(mx.gpu(0)).asnumpy(),
+                        atol=1e-5, rtol=1e-3)
+    assert_almost_equal(layer1.gamma.data(mx.gpu(0)).grad.asnumpy(),
+                        layer2.gamma.data(mx.gpu(0)).grad.asnumpy(),
+                        atol=1e-5, rtol=1e-3)
+    assert_almost_equal(layer1.beta.data(mx.gpu(0)).grad.asnumpy(),
+                        layer2.beta.data(mx.gpu(0)).grad.asnumpy(),
+                        atol=1e-5, rtol=1e-3)
+
+def _check_inplace_abn2(input, training=True, ndev=1):
+    ch = input.shape[1]
+    sync = ndev > 1
+    ctx_list = mx.gpu(0) if ndev <=1 else [mx.gpu(i) for i in range(ndev)]
+    layer1 = mx.gluon.nn.Sequential() 
+    with layer1.name_scope():
+        layer1.add(mx.gluon.nn.Conv2D(in_channels=ch, channels=ch, 
kernel_size=1))
+        layer1.add(NormAct(in_channels=ch, slope=0.01))
+        layer1.add(mx.gluon.nn.Conv2D(in_channels=ch, channels=ch, 
kernel_size=1))
+    layer2 = mx.gluon.nn.Sequential()
+    with layer2.name_scope():
+        layer2.add(mx.gluon.nn.Conv2D(in_channels=ch, channels=ch, 
kernel_size=1))
+        layer2.add(mx.gluon.contrib.nn.InplaceABN(in_channels=ch, slope=0.01))
+        layer2.add(mx.gluon.nn.Conv2D(in_channels=ch, channels=ch, 
kernel_size=1))
+
+    layer1.initialize(ctx=ctx_list)
+    layer2.initialize(ctx=ctx_list)
+
+    def _syncParameters(conv1, conv2, ctx):
+        ctx = input.context
+        conv2.weight.set_data(conv1.weight.data(ctx))
+        conv2.bias.set_data(conv1.bias.data(ctx))
+    _syncParameters(layer1[0], layer2[0], mx.gpu(0))
+    _syncParameters(layer1[2], layer2[2], mx.gpu(0))
+
+    input1 = input.copy()
+    input2 = input.copy()
+    input1.attach_grad()
+    input2.attach_grad()
+    if not training:
+        output1 = layer1(input1)
+        output2 = layer2(input2)
+        assert_almost_equal(output1.asnumpy(), output2.asnumpy(), atol=1e-5, 
rtol=1e-3)
+        return
+
+    with mx.autograd.record():
+        output1 = layer1(input1)
+        output2 = layer2(input2)
+        loss1 = (output1 ** 2).sum()
+        loss2 = (output2 ** 2).sum()
+        mx.autograd.backward(loss1)
+        mx.autograd.backward(loss2)
+
+    assert_almost_equal(input1.asnumpy(), input2.asnumpy(), atol=1e-5, 
rtol=1e-3)
+    assert_almost_equal(output1.asnumpy(), output2.asnumpy(), atol=1e-5, 
rtol=1e-3)
+    assert_almost_equal(loss1.asnumpy(), loss2.asnumpy(), atol=1e-5, rtol=1e-3)
+    assert_almost_equal(input1.grad.asnumpy(), input2.grad.asnumpy(), 
atol=1e-5, rtol=1e-3)
+
+def test_inpabn():
+    def get_num_devices():
+        for i in range(100):
+            try:
+                mx.nd.zeros((1,), ctx=mx.gpu(i))
+            except:
+                return i
+    for i in range(10):
+        target_shape = np.random.randint(2,32, size=(4,))
+        print(i, target_shape)
+        _check_inplace_abn(mx.nd.random.uniform(-2, 2, 
shape=tuple(target_shape)), True, 1)
+        _check_inplace_abn(mx.nd.random.uniform(-2, 2, 
shape=tuple(target_shape)), False, 1)
+        _check_inplace_abn2(mx.nd.random.uniform(-2, 2, 
shape=tuple(target_shape)), True, 1)
+    # no need to use SyncBN with 1 gpu
+    if get_num_devices() < 2:
+        return
+    ndev = 2
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to