This is an automated email from the ASF dual-hosted git repository.

wangwei pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/singa.git


The following commit(s) were added to refs/heads/dev by this push:
     new 497a4fc  SINGA-505 SoftMax Backward to be bufferable
     new 8bf0c62  Merge pull request #588 from chrishkchris/SINGA-505
497a4fc is described below

commit 497a4fc86fd50ccaf6545b7ed9784b92ce55847e
Author: chrishkchris <chrishkch...@yahoo.com.hk>
AuthorDate: Tue Feb 11 14:39:32 2020 +0000

    SINGA-505 SoftMax Backward to be bufferable
---
 include/singa/core/tensor.h        |   1 +
 python/singa/autograd.py           | 104 +++++++-----------
 src/api/core_tensor.i              |   1 +
 src/core/tensor/tensor.cc          | 219 ++++++++++++++++++++-----------------
 src/core/tensor/tensor_math.h      |   6 +-
 src/core/tensor/tensor_math_cpp.h  |  63 +++++------
 src/core/tensor/tensor_math_cuda.h |  42 ++++++-
 test/python/test_api.py            |  21 ++--
 8 files changed, 243 insertions(+), 214 deletions(-)

diff --git a/include/singa/core/tensor.h b/include/singa/core/tensor.h
index 93cf44a..846c14c 100644
--- a/include/singa/core/tensor.h
+++ b/include/singa/core/tensor.h
@@ -514,6 +514,7 @@ void MultRow(const Tensor &v, Tensor *M);
 /// Do softmax for each row. 'in' could be a 1-d or 2-d Tensor.
 Tensor SoftMax(const Tensor &in);
 Tensor SoftMax(const Tensor &in, int axis);
+Tensor SoftMaxBackward(const Tensor &in, int axis, const Tensor &fdout);
 
 Tensor RowMax(const Tensor &in);
 /// Do softmax for each row. 'in' could be a 1-d or 2-d Tensor.
diff --git a/python/singa/autograd.py b/python/singa/autograd.py
index 0c5f456..01e4d82 100644
--- a/python/singa/autograd.py
+++ b/python/singa/autograd.py
@@ -647,10 +647,14 @@ class Reshape(Operation):
         self._shape = x.shape()
         shape = self.shape
         # handle the shape with 0
-        shape = [self._shape[i] if i < len(self._shape) and shape[i] == 0 else 
shape[i] for i in range(len(shape))]
+        shape = [
+            self._shape[i]
+            if i < len(self._shape) and shape[i] == 0 else shape[i]
+            for i in range(len(shape))
+        ]
         # handle the shape with -1
         hidden_shape = int(np.prod(self._shape) // np.abs(np.prod(shape)))
-        self.cache=[s if s != -1 else hidden_shape for s in shape]
+        self.cache = [s if s != -1 else hidden_shape for s in shape]
 
         return singa.Reshape(x, self.cache)
 
@@ -881,32 +885,10 @@ class SoftMax(Operation):
             dx (Ctensor): data for the dL / dx, L is the loss,
             x is the input of current Opertion
         """
-        # calculations are made on numpy array
-        if self.axis == 1:
-            dy = singa.DefaultTranspose(dy)
-        grad = ctensor2numpy(dy)
-        output = ctensor2numpy(self.output)
-        out_1 = np.einsum("ki,ki->ki", grad, output)
-        medium_out = np.einsum("ki,kj->kij", output, output)
-        out_2 = np.einsum("kij,kj->ki", medium_out, grad)
-        out = out_1 - out_2
-        dx = CTensor(out_1.shape)
-        dx.CopyFloatDataFromHostPtr(out.flatten())
-        """grad = Tensor(data=dy)
-        output = Tensor(data=self.output)
-        out_1 = einsum('ki,ki->ki', grad, output)
-        medium_out = einsum('ki,kj->kij', output, output)
-        out_2 = einsum('kij,kj->ki', medium_out, grad)
-        out = out_1 - out_2
-        dx = CTensor(out_1.data.shape)
-        dx.CopyFloatDataFromHostPtr(out.data.flatten())"""
-        if self.axis == 0:
-            return dx
-        elif self.axis == 1:
-            return singa.DefaultTranspose(dx)
+        return singa.SoftMaxBackward(dy, self.axis, self.output)
 
 
-def softmax(x, axis=0):
+def softmax(x, axis=1):
     return SoftMax(axis)(x)[0]
 
 
@@ -1236,16 +1218,13 @@ class _Conv2d(Operation):
 
     def backward(self, dy):
         assert training is True and hasattr(
-            self, "inputs"
-        ), "Please set training as True before do BP. "
-        
+            self, "inputs"), "Please set training as True before do BP. "
+
         if (type(self.handle) != singa.ConvHandle):
-            dx = singa.GpuConvBackwardx(
-                dy, self.inputs[1], self.inputs[0], self.handle
-            )
-            dW = singa.GpuConvBackwardW(
-                dy, self.inputs[0], self.inputs[1], self.handle
-            )
+            dx = singa.GpuConvBackwardx(dy, self.inputs[1], self.inputs[0],
+                                        self.handle)
+            dW = singa.GpuConvBackwardW(dy, self.inputs[0], self.inputs[1],
+                                        self.handle)
             if self.handle.bias_term:
                 db = singa.GpuConvBackwardb(dy, self.inputs[2], self.handle)
                 return dx, dW, db
@@ -1420,13 +1399,13 @@ class Conv2d(Layer):
 class SeparableConv2d(Layer):
 
     def __init__(
-            self,
-            in_channels,
-            out_channels,
-            kernel_size,
-            stride=1,
-            padding=0,
-            bias=False,
+        self,
+        in_channels,
+        out_channels,
+        kernel_size,
+        stride=1,
+        padding=0,
+        bias=False,
     ):
         self.depthwise_conv = Conv2d(
             in_channels,
@@ -1600,9 +1579,8 @@ class _Pooling2d(Operation):
 
     def backward(self, dy):
         if (type(self.handle) != singa.PoolingHandle):
-            dx = singa.GpuPoolingBackward(
-                self.handle, dy, self.cache[0], self.cache[1]
-            )
+            dx = singa.GpuPoolingBackward(self.handle, dy, self.cache[0],
+                                          self.cache[1])
         else:
             dx = singa.CpuPoolingBackward(self.handle, dy, self.cache[0],
                                           self.cache[1])
@@ -2120,15 +2098,15 @@ class RNN_Base(Layer):
 class RNN(RNN_Base):
 
     def __init__(
-            self,
-            input_size,
-            hidden_size,
-            num_layers=1,
-            nonlinearity="tanh",
-            bias=True,
-            batch_first=False,
-            dropout=0,
-            bidirectional=False,
+        self,
+        input_size,
+        hidden_size,
+        num_layers=1,
+        nonlinearity="tanh",
+        bias=True,
+        batch_first=False,
+        dropout=0,
+        bidirectional=False,
     ):
         self.nonlinearity = nonlinearity
 
@@ -2181,15 +2159,15 @@ class RNN(RNN_Base):
 class LSTM(RNN_Base):
 
     def __init__(
-            self,
-            input_size,
-            hidden_size,
-            nonlinearity="tanh",
-            num_layers=1,
-            bias=True,
-            batch_first=False,
-            dropout=0,
-            bidirectional=False,
+        self,
+        input_size,
+        hidden_size,
+        nonlinearity="tanh",
+        num_layers=1,
+        bias=True,
+        batch_first=False,
+        dropout=0,
+        bidirectional=False,
     ):
         self.nonlinearity = nonlinearity
 
diff --git a/src/api/core_tensor.i b/src/api/core_tensor.i
index d54beed..4550e6a 100755
--- a/src/api/core_tensor.i
+++ b/src/api/core_tensor.i
@@ -201,6 +201,7 @@ namespace singa{
   Tensor Average(const Tensor &t, int axis);
   Tensor SoftMax(const Tensor &t);
   Tensor SoftMax(const Tensor &t, int axis);
+  Tensor SoftMaxBackward(const Tensor &t, int axis, const Tensor &fdout);
 
   Tensor Pow(const Tensor &base, const Tensor &exp);
 
diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc
index c61d4fa..8b90932 100644
--- a/src/core/tensor/tensor.cc
+++ b/src/core/tensor/tensor.cc
@@ -627,13 +627,11 @@ void RepeatDataToFrom(bool broadcast_flag, const 
vector<size_t> &repeats,
 float Tensor::l1() const {
   float nrm = 0.0f;
   TYPE_LANG_SWITCH(data_type_, DType, device_->lang(), Lang, {
-    device_->Exec(
-        [&nrm, this](Context *ctx) {
-          DType ret = DType(0);
-          Asum<DType, Lang>(*this, &ret, ctx);
-          nrm = TypeCast<DType, float>(ret);
-        },
-        {this->block()}, {});
+    device_->Exec([&nrm, this](Context *ctx) {
+      DType ret = DType(0);
+      Asum<DType, Lang>(*this, &ret, ctx);
+      nrm = TypeCast<DType, float>(ret);
+    }, {this->block()}, {});
   });
   return nrm / Size();
 }
@@ -645,13 +643,11 @@ float Tensor::L1() const { return l1(); }
 float Tensor::l2() const {
   float nrm = 0.0f;
   TYPE_LANG_SWITCH(data_type_, DType, device_->lang(), Lang, {
-    device_->Exec(
-        [&nrm, this](Context *ctx) {
-          DType ret = DType(0);
-          Nrm2<DType, Lang>(*this, &ret, ctx);
-          nrm = TypeCast<DType, float>(ret);
-        },
-        {this->block()}, {});
+    device_->Exec([&nrm, this](Context *ctx) {
+      DType ret = DType(0);
+      Nrm2<DType, Lang>(*this, &ret, ctx);
+      nrm = TypeCast<DType, float>(ret);
+    }, {this->block()}, {});
   });
   return nrm / Size();
 }
@@ -667,9 +663,9 @@ void Tensor::SetValue(const SType x) {
 
   TYPE_LANG_SWITCH(data_type_, DType, device_->lang(), Lang, {
     // TODO(wangwei) cast x to DType
-    device_->Exec(
-        [this, x, ptr](Context *ctx) { Set<DType, Lang>(x, this, ctx); }, {},
-        {ptr});
+    device_->Exec([this, x, ptr](Context *ctx) {
+      Set<DType, Lang>(x, this, ctx);
+    }, {}, {ptr});
   });
 }
 template void Tensor::SetValue<float>(const float x);
@@ -698,9 +694,9 @@ template void Tensor::GetValue<int>(int *value, const 
size_t num);
 #define EltwiseUnaryTensorFn(fn, t, ret)                               \
   do {                                                                 \
     TYPE_LANG_SWITCH(t.data_type(), DType, t.device()->lang(), Lang, { \
-      ret->device()->Exec(                                             \
-          [t, ret](Context *ctx) { fn<DType, Lang>(t, ret, ctx); },    \
-          {t.block()}, {ret->block()});                                \
+      ret->device()->Exec([t, ret](Context *ctx) {                     \
+        fn<DType, Lang>(t, ret, ctx);                                  \
+      }, {t.block()}, {ret->block()});                                 \
     });                                                                \
   } while (0)
 
@@ -778,16 +774,55 @@ Tensor SoftMax(const Tensor &in, int axis) {
   SoftMax(in, retptr, axis);
   return ret;
 }
+void SoftMaxBackward(const Tensor &in, Tensor *out, int axis,
+                     const Tensor &fdout) {
+  // {a_0, a_1, ..., a_k-1, a_k, ... a_n-1}
+  // reshape to
+  // { a_0 * a_1 * ... a_k-1, a_k * ... a_n-1 }
+
+  // assert axis \in {-r, r-1}
+  CHECK_LE(axis, (int)in.shape().size() - 1);
+  CHECK_GE(axis, -1 * (int)in.nDim());
+
+  Shape original_shape = in.shape();
+  if (axis < 0) axis = in.shape().size() + axis;
+
+  Shape coerced_shape = {1, 1};
+  for (std::size_t i = 0, max = in.shape().size(); i != max; ++i) {
+    if (i < axis)
+      coerced_shape[0] *= in.shape()[i];
+    else
+      coerced_shape[1] *= in.shape()[i];
+  }
+
+  Tensor in_reshaped = Reshape(in, coerced_shape);
+  out->Reshape(coerced_shape);
+
+  do {
+    TYPE_LANG_SWITCH(in.data_type(), DType, in.device()->lang(), Lang, {
+      out->device()->Exec([in, out, fdout](Context *ctx) {
+        SoftMaxBackward<DType, Lang>(in, out, fdout, ctx);
+      }, {in.block(), fdout.block()}, {out->block()});
+    });
+  } while (0);
+
+  out->Reshape(original_shape);
+}
+
+Tensor SoftMaxBackward(const Tensor &in, int axis, const Tensor &fdout) {
+  Tensor ret(in.shape(), in.device(), in.data_type());
+  auto *retptr = &ret;
+  SoftMaxBackward(in, retptr, axis, fdout);
+  return ret;
+}
 
 #define EltwiseBinaryTensorFn(fn, lhs, rhs, ret)                           \
   do {                                                                     \
     TYPE_LANG_SWITCH(lhs.data_type(), DType, lhs.device()->lang(), Lang, { \
       CHECK_EQ(sizeof(DType), SizeOf(rhs.data_type()));                    \
-      ret->device()->Exec(                                                 \
-          [lhs, rhs, ret](Context *ctx) {                                  \
-            fn<DType, Lang>(lhs, rhs, ret, ctx);                           \
-          },                                                               \
-          {lhs.block(), rhs.block()}, {ret->block()});                     \
+      ret->device()->Exec([lhs, rhs, ret](Context *ctx) {                  \
+        fn<DType, Lang>(lhs, rhs, ret, ctx);                               \
+      }, {lhs.block(), rhs.block()}, {ret->block()});                      \
     });                                                                    \
   } while (0)
 
@@ -832,15 +867,15 @@ GenBinaryTensorFn(operator>, GT);
 GenBinaryTensorFn(operator>=, GE);
 GenBinaryTensorFn(ReLUBackward, ReLUBackward);
 
-#define EltwiseTensorScalarFn(fn, t, x, ret)                              \
-  do {                                                                    \
-    TYPE_LANG_SWITCH(t.data_type(), DType, t.device()->lang(), Lang, {    \
-      static_assert(std::is_same<SType, DType>::value,                    \
-                    "The Scalar type must match the Tensor data type");   \
-      ret->device()->Exec(                                                \
-          [t, x, ret](Context *ctx) { fn<DType, Lang>(t, x, ret, ctx); }, \
-          {t.block()}, {ret->block()});                                   \
-    });                                                                   \
+#define EltwiseTensorScalarFn(fn, t, x, ret)                            \
+  do {                                                                  \
+    TYPE_LANG_SWITCH(t.data_type(), DType, t.device()->lang(), Lang, {  \
+      static_assert(std::is_same<SType, DType>::value,                  \
+                    "The Scalar type must match the Tensor data type"); \
+      ret->device()->Exec([t, x, ret](Context *ctx) {                   \
+        fn<DType, Lang>(t, x, ret, ctx);                                \
+      }, {t.block()}, {ret->block()});                                  \
+    });                                                                 \
   } while (0)
 
 #define GenTensorScalarFn(op, fn)                             \
@@ -880,11 +915,9 @@ void Div(const SType alpha, const Tensor &in, Tensor *out) 
{
   CHECK(in.shape() == out->shape());
   TYPE_LANG_SWITCH(in.data_type(), DType, in.device()->lang(), Lang, {
     // TODO(wangwei) type cast SType to DType;
-    in.device()->Exec(
-        [alpha, in, out](Context *ctx) {
-          Div<DType, Lang>(alpha, in, out, ctx);
-        },
-        {in.block()}, {out->block()});
+    in.device()->Exec([alpha, in, out](Context *ctx) {
+      Div<DType, Lang>(alpha, in, out, ctx);
+    }, {in.block()}, {out->block()});
   });
 }
 template void Div<float>(const float, const Tensor &, Tensor *);
@@ -919,13 +952,11 @@ float Sum<float>(const Tensor &in) {
   Tensor one(in.shape(), in.device(), in.data_type());
   one.SetValue(1.0f);
   TYPE_LANG_SWITCH(in.data_type(), DType, in.device()->lang(), Lang, {
-    one.device()->Exec(
-        [in, one, &s](Context *ctx) {
-          DType ret = DType(0);
-          Dot<DType, Lang>(in, one, &ret, ctx);
-          s = ret;
-        },
-        {in.block(), one.block()}, {});
+    one.device()->Exec([in, one, &s](Context *ctx) {
+      DType ret = DType(0);
+      Dot<DType, Lang>(in, one, &ret, ctx);
+      s = ret;
+    }, {in.block(), one.block()}, {});
   });
   return s;
 }
@@ -950,24 +981,22 @@ Tensor SumAll(const Tensor &in) {
   auto *outPtr = &out;
   one.SetValue(1.0f);
   TYPE_LANG_SWITCH(in.data_type(), DType, in.device()->lang(), Lang, {
-    one.device()->Exec([in, one, outPtr](Context * ctx) {
+    one.device()->Exec([in, one, outPtr](Context *ctx) {
       Dot<DType, Lang>(in, one, outPtr, ctx);
     }, {in.block(), one.block()}, {outPtr->block()});
   });
   return out;
 }
- 
+
 Tensor RowMax(const Tensor &in) {
   Tensor ret({in.shape(0)}, in.device(), in.data_type());
   TYPE_LANG_SWITCH(in.data_type(), DType, in.device()->lang(), Lang, {
-    in.device()->Exec(
-        [&in, &ret](Context *ctx) {
-          // size_t nrow = 1;
-          // if (in.nDim() > 1) nrow = in.shape(0);
-          // size_t ncol = in.Size() / nrow;
-          RowMax<DType, Lang>(in, &ret, ctx);
-        },
-        {in.block()}, {ret.block()});
+    in.device()->Exec([&in, &ret](Context *ctx) {
+      // size_t nrow = 1;
+      // if (in.nDim() > 1) nrow = in.shape(0);
+      // size_t ncol = in.Size() / nrow;
+      RowMax<DType, Lang>(in, &ret, ctx);
+    }, {in.block()}, {ret.block()});
   });
   return ret;
 }
@@ -1179,9 +1208,9 @@ void MultColumn(const Tensor &v, Tensor *M) {
   CHECK_EQ(v.Size(), M->shape(0));
   CheckDataTypeAndLang(*M, v);
   TYPE_LANG_SWITCH(v.data_type(), DType, v.device()->lang(), Lang, {
-    v.device()->Exec(
-        [M, v](Context *ctx) { DGMM<DType, Lang>(false, *M, v, M, ctx); },
-        {M->block(), v.block()}, {M->block()});
+    v.device()->Exec([M, v](Context *ctx) {
+      DGMM<DType, Lang>(false, *M, v, M, ctx);
+    }, {M->block(), v.block()}, {M->block()});
   });
 }
 
@@ -1193,9 +1222,9 @@ void MultRow(const Tensor &v, Tensor *M) {
   CHECK_EQ(v.Size(), M->shape(1));
   CheckDataTypeAndLang(*M, v);
   TYPE_LANG_SWITCH(v.data_type(), DType, v.device()->lang(), Lang, {
-    v.device()->Exec(
-        [M, v](Context *ctx) { DGMM<DType, Lang>(true, *M, v, M, ctx); },
-        {M->block(), v.block()}, {M->block()});
+    v.device()->Exec([M, v](Context *ctx) {
+      DGMM<DType, Lang>(true, *M, v, M, ctx);
+    }, {M->block(), v.block()}, {M->block()});
   });
 }
 
@@ -1239,9 +1268,9 @@ template <typename SType>
 void Bernoulli(const SType p, Tensor *out) {
   TYPE_LANG_SWITCH(out->data_type(), DType, out->device()->lang(), Lang, {
     auto prob = TypeCast<SType, DType>(p);
-    out->device()->Exec(
-        [prob, out](Context *ctx) { Bernoulli<DType, Lang>(prob, out, ctx); },
-        {}, {out->block()}, true);
+    out->device()->Exec([prob, out](Context *ctx) {
+      Bernoulli<DType, Lang>(prob, out, ctx);
+    }, {}, {out->block()}, true);
   });
 }
 
@@ -1252,9 +1281,9 @@ void Uniform(const SType low, const SType high, Tensor 
*out) {
   TYPE_LANG_SWITCH(out->data_type(), DType, out->device()->lang(), Lang, {
     auto l = TypeCast<SType, DType>(low);
     auto h = TypeCast<SType, DType>(high);
-    out->device()->Exec(
-        [l, h, out](Context *ctx) { Uniform<DType, Lang>(l, h, out, ctx); }, 
{},
-        {out->block()}, true);
+    out->device()->Exec([l, h, out](Context *ctx) {
+      Uniform<DType, Lang>(l, h, out, ctx);
+    }, {}, {out->block()}, true);
   });
 }
 
@@ -1265,9 +1294,9 @@ void Gaussian(const SType mean, const SType std, Tensor 
*out) {
   TYPE_LANG_SWITCH(out->data_type(), DType, out->device()->lang(), Lang, {
     auto m = TypeCast<SType, DType>(mean);
     auto s = TypeCast<SType, DType>(std);
-    out->device()->Exec(
-        [m, s, out](Context *ctx) { Gaussian<DType, Lang>(m, s, out, ctx); },
-        {}, {out->block()}, true);
+    out->device()->Exec([m, s, out](Context *ctx) {
+      Gaussian<DType, Lang>(m, s, out, ctx);
+    }, {}, {out->block()}, true);
   });
 }
 template void Gaussian<float>(const float mean, const float std, Tensor *out);
@@ -1278,9 +1307,9 @@ template <typename SType>
 void Axpy(const SType alpha, const Tensor &in, Tensor *out) {
   TYPE_LANG_SWITCH(in.data_type(), DType, in.device()->lang(), Lang, {
     auto a = TypeCast<SType, DType>(alpha);
-    out->device()->Exec(
-        [a, in, out](Context *ctx) { Axpy<DType, Lang>(a, in, out, ctx); },
-        {in.block(), out->block()}, {out->block()});
+    out->device()->Exec([a, in, out](Context *ctx) {
+      Axpy<DType, Lang>(a, in, out, ctx);
+    }, {in.block(), out->block()}, {out->block()});
   });
 }
 
@@ -1307,22 +1336,18 @@ void Mult(const SType alpha, const Tensor &A, const 
Tensor &B, const SType beta,
     TYPE_LANG_SWITCH(A.data_type(), DType, A.device()->lang(), Lang, {
       auto a = TypeCast<SType, DType>(alpha);
       auto b = TypeCast<SType, DType>(beta);
-      C->device()->Exec(
-          [a, A, b, B, C](Context *ctx) {
-            GEMV<DType, Lang>(a, A, B, b, C, ctx);
-          },
-          {A.block(), B.block()}, {C->block()});
+      C->device()->Exec([a, A, b, B, C](Context *ctx) {
+        GEMV<DType, Lang>(a, A, B, b, C, ctx);
+      }, {A.block(), B.block()}, {C->block()});
     });
   } else {
     CHECK(!C->transpose());
     TYPE_LANG_SWITCH(A.data_type(), DType, A.device()->lang(), Lang, {
       auto a = TypeCast<SType, DType>(alpha);
       auto b = TypeCast<SType, DType>(beta);
-      C->device()->Exec(
-          [a, A, b, B, C](Context *ctx) {
-            GEMM<DType, Lang>(a, A, B, b, C, ctx);
-          },
-          {A.block(), B.block()}, {C->block()});
+      C->device()->Exec([a, A, b, B, C](Context *ctx) {
+        GEMM<DType, Lang>(a, A, B, b, C, ctx);
+      }, {A.block(), B.block()}, {C->block()});
     });
   }
 }
@@ -1349,14 +1374,11 @@ void ComputeCrossEntropy(const Tensor &p, const Tensor 
&t, Tensor *loss) {
   if (p.nDim() == 2u) batchsize = p.shape(0);
   size_t dim = p.Size() / batchsize;
   TYPE_LANG_SWITCH(p.data_type(), DType, p.device()->lang(), Lang, {
-    p.device()->Exec(
-        [batchsize, dim, t, p, loss](Context *ctx) {
-          bool int_target = t.Size() == batchsize;
-          ComputeCrossEntropy<DType, Lang>(int_target, batchsize, dim,
-                                           p.block(), t.block(), loss->block(),
-                                           ctx);
-        },
-        {p.block(), t.block()}, {loss->block()});
+    p.device()->Exec([batchsize, dim, t, p, loss](Context *ctx) {
+      bool int_target = t.Size() == batchsize;
+      ComputeCrossEntropy<DType, Lang>(int_target, batchsize, dim, p.block(),
+                                       t.block(), loss->block(), ctx);
+    }, {p.block(), t.block()}, {loss->block()});
   });
 }
 
@@ -1367,14 +1389,11 @@ void SoftmaxCrossEntropyBwd(const Tensor &t, Tensor *p) 
{
   if (p->nDim() == 2u) batchsize = p->shape(0);
   size_t dim = p->Size() / batchsize;
   TYPE_LANG_SWITCH(p->data_type(), DType, p->device()->lang(), Lang, {
-    p->device()->Exec(
-        [batchsize, dim, t, p](Context *ctx) {
-          bool int_target = t.Size() == batchsize;
-          SoftmaxCrossEntropyBwd<DType, Lang>(int_target, batchsize, dim,
-                                              p->block(), t.block(), 
p->block(),
-                                              ctx);
-        },
-        {p->block(), t.block()}, {p->block()});
+    p->device()->Exec([batchsize, dim, t, p](Context *ctx) {
+      bool int_target = t.Size() == batchsize;
+      SoftmaxCrossEntropyBwd<DType, Lang>(
+          int_target, batchsize, dim, p->block(), t.block(), p->block(), ctx);
+    }, {p->block(), t.block()}, {p->block()});
   });
 }
 
diff --git a/src/core/tensor/tensor_math.h b/src/core/tensor/tensor_math.h
index a9b5c70..aef4a59 100644
--- a/src/core/tensor/tensor_math.h
+++ b/src/core/tensor/tensor_math.h
@@ -369,8 +369,7 @@ void Dot(const Tensor &in1, const Tensor &in2, DType *out, 
Context *ctx) {
   LOG(FATAL) << "Dot Not Implemented";
 }
 template <typename DType, typename Lang>
-void Dot(const Tensor &in1, const Tensor &in2, Tensor *out,
-         Context *ctx) {
+void Dot(const Tensor &in1, const Tensor &in2, Tensor *out, Context *ctx) {
   LOG(FATAL) << "Dot Not Implemented";
 }
 
@@ -404,7 +403,8 @@ void SoftMax(const Tensor &in, Tensor *out, Context *ctx) {
 }
 
 template <typename DType, typename Lang>
-void SoftMax(const Tensor &in, Tensor *out, Context *ctx, int axis) {
+void SoftMaxBackward(const Tensor &in, Tensor *out, const Tensor &fdout,
+                     Context *ctx) {
   LOG(FATAL) << "Not Implemented";
 }
 
diff --git a/src/core/tensor/tensor_math_cpp.h 
b/src/core/tensor/tensor_math_cpp.h
index b592ecc..fb42576 100644
--- a/src/core/tensor/tensor_math_cpp.h
+++ b/src/core/tensor/tensor_math_cpp.h
@@ -240,36 +240,11 @@ void Abs<float, lang::Cpp>(const Tensor &in, Tensor *out, 
Context *ctx) {
 
 #ifdef USE_DNNL
 template <>
-void SoftMax<float, lang::Cpp>(const Tensor &in, Tensor *out, Context *ctx,
-                               int axis) {
-  CHECK_EQ(in.device()->lang(), kCpp);
-
-  CHECK_LE(axis, (int)in.shape().size() - 1);
-  CHECK_GE(axis, -1 * (int)in.nDim());
-
-  Shape original_shape = in.shape();
-  if (axis < 0) axis = in.shape().size() + axis;
-
-  Shape coerced_shape = {1, 1};
-  for (int i = 0; i < in.shape().size(); i++) {
-    if (i < axis)
-      coerced_shape[0] *= in.shape()[i];
-    else
-      coerced_shape[1] *= in.shape()[i];
-  }
-  Tensor in_reshaped = Reshape(in, coerced_shape);
-  out->Reshape(coerced_shape);
-
-  // optimise by minus x - x.max()
-  auto in_max = RowMax(in_reshaped);
-  in_max.Reshape({coerced_shape[0], 1});
-  in_reshaped = in_reshaped - in_max;
-
-  auto md = dnnl::memory::desc({coerced_shape[0], coerced_shape[1]},
+void SoftMax<float, lang::Cpp>(const Tensor &in, Tensor *out, Context *ctx) {
+  auto md = dnnl::memory::desc({in.shape()[0], in.shape()[1]},
                                dnnl::memory::data_type::f32,
                                dnnl::memory::format_tag::ab);
-  auto in_mem =
-      dnnl::memory(md, ctx->dnnl_engine, in_reshaped.block()->mutable_data());
+  auto in_mem = dnnl::memory(md, ctx->dnnl_engine, in.block()->mutable_data());
   auto out_mem =
       dnnl::memory(md, ctx->dnnl_engine, out->block()->mutable_data());
 
@@ -281,9 +256,35 @@ void SoftMax<float, lang::Cpp>(const Tensor &in, Tensor 
*out, Context *ctx,
   softmax.execute(ctx->dnnl_stream,
                   {{DNNL_ARG_SRC, in_mem}, {DNNL_ARG_DST, out_mem}});
   ctx->dnnl_stream.wait();
+}
+
+template <>
+void SoftMaxBackward<float, lang::Cpp>(const Tensor &in, Tensor *out,
+                                       const Tensor &fdout, Context *ctx) {
+  auto md = dnnl::memory::desc({in.shape()[0], in.shape()[1]},
+                               dnnl::memory::data_type::f32,
+                               dnnl::memory::format_tag::ab);
+  auto in_mem = dnnl::memory(md, ctx->dnnl_engine, in.block()->mutable_data());
+  auto fdout_mem =
+      dnnl::memory(md, ctx->dnnl_engine, fdout.block()->mutable_data());
+  auto out_mem =
+      dnnl::memory(md, ctx->dnnl_engine, out->block()->mutable_data());
 
-  out->Reshape(original_shape);
+  auto softmax_desc =
+      dnnl::softmax_forward::desc(dnnl::prop_kind::forward_scoring, md, 1);
+  auto softmax_prim_desc =
+      dnnl::softmax_forward::primitive_desc(softmax_desc, ctx->dnnl_engine);
+
+  auto softmaxbwd_desc = dnnl::softmax_backward::desc(md, md, 1);
+  auto softmaxbwd_prim_desc = dnnl::softmax_backward::primitive_desc(
+      softmaxbwd_desc, ctx->dnnl_engine, softmax_prim_desc);
+  auto softmaxbwd = dnnl::softmax_backward(softmaxbwd_prim_desc);
+  softmaxbwd.execute(ctx->dnnl_stream, {{DNNL_ARG_DIFF_SRC, out_mem},
+                                        {DNNL_ARG_DIFF_DST, in_mem},
+                                        {DNNL_ARG_DST, fdout_mem}});
+  ctx->dnnl_stream.wait();
 }
+
 #endif  // USE_DNNL
 
 template <>
@@ -927,6 +928,8 @@ void RowMax<float, lang::Cpp>(const Tensor &in, Tensor 
*out, Context *ctx) {
   }
 }
 
+// =========Matrix operations ================================================
+/*
 template <>
 void SoftMax<float, lang::Cpp>(const Tensor &in, Tensor *out, Context *ctx) {
   CHECK_LE(in.nDim(), 2u)
@@ -947,8 +950,6 @@ void SoftMax<float, lang::Cpp>(const Tensor &in, Tensor 
*out, Context *ctx) {
   out->Reshape(in.shape());
 }
 
-// =========Matrix operations ================================================
-/*
 template <>
 void AddCol<float, lang::Cpp>(const size_t nrow, const size_t ncol,
                               const Tensor& A, const Tensor& v, Tensor* out,
diff --git a/src/core/tensor/tensor_math_cuda.h 
b/src/core/tensor/tensor_math_cuda.h
index 0a0b685..4b16af0 100644
--- a/src/core/tensor/tensor_math_cuda.h
+++ b/src/core/tensor/tensor_math_cuda.h
@@ -815,8 +815,8 @@ void Dot<float, lang::Cuda>(const Tensor& in1, const 
Tensor& in2, float* out,
   CUBLAS_CHECK(cublasSdot(handle, num, inPtr1, 1, inPtr2, 1, out));
 }
 template <>
-void Dot<float, lang::Cuda>(const Tensor& in1,
-                            const Tensor& in2, Tensor* out, Context* ctx) {
+void Dot<float, lang::Cuda>(const Tensor& in1, const Tensor& in2, Tensor* out,
+                            Context* ctx) {
   const float* inPtr1 = static_cast<const float*>(in1.block()->data());
   const float* inPtr2 = static_cast<const float*>(in2.block()->data());
   float* outPtr = static_cast<float*>(out->block()->mutable_data());
@@ -828,8 +828,7 @@ void Dot<float, lang::Cuda>(const Tensor& in1,
 }
 
 template <>
-void Nrm2<float, lang::Cuda>(const Tensor& in, float* out,
-                             Context* ctx) {
+void Nrm2<float, lang::Cuda>(const Tensor& in, float* out, Context* ctx) {
   auto handle = ctx->cublas_handle;  // TODO(wangwei) set cudastream
   const float* inPtr = static_cast<const float*>(in.block()->data());
   const size_t num = in.Size();
@@ -937,6 +936,41 @@ void SoftMax<float, lang::Cuda>(const Tensor& in, Tensor* 
out, Context* ctx) {
 }
 
 template <>
+void SoftMaxBackward<float, lang::Cuda>(const Tensor& in, Tensor* out,
+                                        const Tensor& fdout, Context* ctx) {
+  cudnnSoftmaxAlgorithm_t algorithm = CUDNN_SOFTMAX_FAST;
+  cudnnSoftmaxMode_t mode = CUDNN_SOFTMAX_MODE_INSTANCE;
+
+  /*
+   * tensor tmp is for generating cudnn descriptor
+   *   as for cudnn softmax, it required shape of {N, C, 1, 1}
+   *   while helper func `generate_shape_cuda` generate shape of {1, 1, N, C}
+   *   Thus this part serve similar purpose as `generate_shape_cuda` but in
+   * reverse manner
+  */
+  CHECK_LE(in.shape().size(), 5)
+      << "Dimensions (shape) beyond 5 are currently not supported";
+  auto tmp = in;
+  while (tmp.shape().size() < 4) {
+    auto s = tmp.shape();
+    s.push_back(1);
+    tmp.Reshape(s);
+  }
+
+  const float* inPtr = static_cast<const float*>(in.block()->data());
+  const float* fdoutPtr = static_cast<const float*>(fdout.block()->data());
+  float* outPtr = static_cast<float*>(out->block()->mutable_data());
+
+  float alpha = 1.0;
+  float beta = 0.0;
+
+  check_cudnn(cudnnSoftmaxBackward(
+      ctx->cudnn_handle, algorithm, mode, (void*)(&alpha),
+      generate_tensor_nd_desc(tmp), fdoutPtr, generate_tensor_nd_desc(tmp),
+      inPtr, (void*)(&beta), generate_tensor_nd_desc(tmp), outPtr));
+}
+
+template <>
 void ComputeCrossEntropy<float, lang::Cuda>(bool int_target,
                                             const size_t batchsize,
                                             const size_t dim, const Block* p,
diff --git a/test/python/test_api.py b/test/python/test_api.py
index 197f884..518c4f9 100644
--- a/test/python/test_api.py
+++ b/test/python/test_api.py
@@ -340,27 +340,22 @@ class TestAPI(unittest.TestCase):
             hndl = singa_api.BatchNormHandle(
                 m_0,
                 tensor.Tensor(device=dev, data=x_0).data)
-            (y_2_c, rm_2_c, rv_2_c, bm_2_c,
-             bv_2_c) = singa_api.CpuBatchNormForwardTraining(
-                 hndl,
-                 tensor.Tensor(device=dev, data=x_0).data,
-                 tensor.Tensor(device=dev, data=s_0).data,
-                 tensor.Tensor(device=dev, data=b_0).data,
-                 tensor.Tensor(device=dev, data=rm_0).data,
-                 tensor.Tensor(device=dev, data=rv_0).data)
+            (y_2_c, bm_2_c, bv_2_c) = singa_api.CpuBatchNormForwardTraining(
+                hndl,
+                tensor.Tensor(device=dev, data=x_0).data,
+                tensor.Tensor(device=dev, data=s_0).data,
+                tensor.Tensor(device=dev, data=b_0).data,
+                tensor.Tensor(device=dev, data=rm_0).data,
+                tensor.Tensor(device=dev, data=rv_0).data)
 
             np.testing.assert_array_almost_equal(
                 y_1, tensor.to_numpy(_cTensor_to_pyTensor(y_2_c)))
-            #np.testing.assert_array_almost_equal(
-            #    bm_1, tensor.to_numpy(_cTensor_to_pyTensor(bm_2_c)))
             np.testing.assert_array_almost_equal(
-                rm_1, tensor.to_numpy(_cTensor_to_pyTensor(rm_2_c)))
+                bm_1, tensor.to_numpy(_cTensor_to_pyTensor(bm_2_c)))
             #print(bv_1)
             #print(tensor.to_numpy(_cTensor_to_pyTensor(bv_2_c)))
             #np.testing.assert_array_almost_equal(
             #    bv_1, tensor.to_numpy(_cTensor_to_pyTensor(bv_2_c)), 
decimal=3)
-            np.testing.assert_array_almost_equal(
-                rv_1, tensor.to_numpy(_cTensor_to_pyTensor(rv_2_c)), decimal=4)
             return
 
         x_0 = np.array([1, 1, 1, 1, 2, 2, 2, 2, 10, 10, 10, 10, 20, 20, 20, 
20],

Reply via email to