This is an automated email from the ASF dual-hosted git repository.
wangwei pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/singa.git
The following commit(s) were added to refs/heads/dev by this push:
new acbf930 added erf backend
new a171023 Merge pull request #777 from dcslin/erf
acbf930 is described below
commit acbf93063c5674338813b7217bb50a31c99ef7c3
Author: root <[email protected]>
AuthorDate: Tue Aug 4 05:09:00 2020 +0000
added erf backend
---
include/singa/core/tensor.h | 2 ++
src/api/core_tensor.i | 1 +
src/core/tensor/math_kernel.cu | 11 +++++++++++
src/core/tensor/math_kernel.h | 1 +
src/core/tensor/tensor.cc | 1 +
src/core/tensor/tensor_math.h | 5 +++++
src/core/tensor/tensor_math_cpp.h | 5 +++++
src/core/tensor/tensor_math_cuda.h | 14 ++++++++++++++
test/python/test_api.py | 13 +++++++++++++
9 files changed, 53 insertions(+)
diff --git a/include/singa/core/tensor.h b/include/singa/core/tensor.h
index cf74ef7..47a73b9 100644
--- a/include/singa/core/tensor.h
+++ b/include/singa/core/tensor.h
@@ -348,6 +348,7 @@ void RepeatDataToFrom(bool broadcast_flag, const
vector<size_t> &repeats,
// =============Element-wise operations====================================
Tensor Abs(const Tensor &in);
+Tensor Erf(const Tensor &in);
Tensor Ceil(const Tensor &in);
Tensor Floor(const Tensor &in);
Tensor Round(const Tensor &in);
@@ -376,6 +377,7 @@ Tensor Atanh(const Tensor &in);
Tensor Transform(const Tensor &in);
void Abs(const Tensor &in, Tensor *out);
+void Erf(const Tensor &in, Tensor *out);
void Ceil(const Tensor &in, Tensor *out);
void Floor(const Tensor &in, Tensor *out);
void Round(const Tensor &in, Tensor *out);
diff --git a/src/api/core_tensor.i b/src/api/core_tensor.i
index fc64b2d..6146608 100755
--- a/src/api/core_tensor.i
+++ b/src/api/core_tensor.i
@@ -177,6 +177,7 @@ namespace singa{
Tensor Round(const Tensor &t);
Tensor RoundE(const Tensor &t);
Tensor Exp(const Tensor &t);
+ Tensor Erf(const Tensor &t);
Tensor Log(const Tensor &t);
Tensor ReLU(const Tensor &t);
Tensor Sigmoid(const Tensor &t);
diff --git a/src/core/tensor/math_kernel.cu b/src/core/tensor/math_kernel.cu
index ad0b63b..3777a06 100644
--- a/src/core/tensor/math_kernel.cu
+++ b/src/core/tensor/math_kernel.cu
@@ -117,6 +117,13 @@ __global__ void KernelExp(const size_t n, const float *in,
float *out) {
}
}
+__global__ void KernelErf(const size_t n, const float *in, float *out) {
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
+ i += blockDim.x * gridDim.x) {
+ out[i] = erff(in[i]);
+ }
+}
+
__global__ void KernelCeil2(const size_t n, const float *in, float *out) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
i += blockDim.x * gridDim.x) {
@@ -582,6 +589,10 @@ void exp(const size_t n, const float *in, float *out,
cudaStream_t s) {
KernelExp <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, out);
}
+void erf(const size_t n, const float *in, float *out, cudaStream_t s) {
+ KernelErf <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, out);
+}
+
void ceil2(const size_t n, const float *in, float *out, cudaStream_t s) {
KernelCeil2 <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, out);
}
diff --git a/src/core/tensor/math_kernel.h b/src/core/tensor/math_kernel.h
index 91bf526..206fa1a 100644
--- a/src/core/tensor/math_kernel.h
+++ b/src/core/tensor/math_kernel.h
@@ -44,6 +44,7 @@ void set(const size_t n, const float v, float *out,
cudaStream_t s);
void abs(const size_t n, const float *in, float *out, cudaStream_t s);
void sign(const size_t n, const float *in, float *out, cudaStream_t s);
void exp(const size_t n, const float *in, float *out, cudaStream_t s);
+void erf(const size_t n, const float *in, float *out, cudaStream_t s);
void ceil2(const size_t n, const float *in, float *out, cudaStream_t s);
void floor(const size_t n, const float *in, float *out, cudaStream_t s);
void round(const size_t n, const float *in, float *out, cudaStream_t s);
diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc
index 475aab5..faec534 100644
--- a/src/core/tensor/tensor.cc
+++ b/src/core/tensor/tensor.cc
@@ -807,6 +807,7 @@ template void Tensor::GetValue<int>(int *value, const
size_t num) const;
void fn(const Tensor &in, Tensor *out) { EltwiseUnaryTensorFn(fn, in, out); }
GenUnaryTensorFn(Abs);
+GenUnaryTensorFn(Erf);
GenUnaryTensorFn(Ceil);
GenUnaryTensorFn(Floor);
GenUnaryTensorFn(Round);
diff --git a/src/core/tensor/tensor_math.h b/src/core/tensor/tensor_math.h
index 5484da9..3236e7c 100644
--- a/src/core/tensor/tensor_math.h
+++ b/src/core/tensor/tensor_math.h
@@ -86,6 +86,11 @@ void Abs(const Tensor &in, Tensor *out, Context *ctx) {
LOG(FATAL) << "Abs Not Implemented";
}
+template <typename DType, typename Lang>
+void Erf(const Tensor &in, Tensor *out, Context *ctx) {
+ LOG(FATAL) << "Erf Not Implemented";
+}
+
template <typename DTypeSrc, typename DTypeDst, typename Lang>
void CastCopy(const Tensor *src, Tensor *dst, Context *ctx) {
LOG(FATAL) << "CastCopy Not Implemented";
diff --git a/src/core/tensor/tensor_math_cpp.h
b/src/core/tensor/tensor_math_cpp.h
index 452b98d..5be46c6 100644
--- a/src/core/tensor/tensor_math_cpp.h
+++ b/src/core/tensor/tensor_math_cpp.h
@@ -241,6 +241,11 @@ void Abs<float, lang::Cpp>(const Tensor &in, Tensor *out,
Context *ctx) {
}
template <>
+void Erf<float, lang::Cpp>(const Tensor &in, Tensor *out, Context *ctx) {
+ traverse_unary<float>(in, out, [](float x) { return erff(x); });
+}
+
+template <>
void CastCopy<float, int, lang::Cpp>(const Tensor *src, Tensor *dst,
Context *ctx) {
int *dst_array = static_cast<int *>(dst->block()->mutable_data());
diff --git a/src/core/tensor/tensor_math_cuda.h
b/src/core/tensor/tensor_math_cuda.h
index b16582d..f3a3173 100644
--- a/src/core/tensor/tensor_math_cuda.h
+++ b/src/core/tensor/tensor_math_cuda.h
@@ -363,6 +363,20 @@ void Exp<float, lang::Cuda>(const Tensor& in, Tensor* out,
Context* ctx) {
}
template <>
+void Erf<float, lang::Cuda>(const Tensor& in, Tensor* out, Context* ctx) {
+ const float* inPtr = static_cast<const float*>(in.block()->data());
+ float* outPtr = static_cast<float*>(out->block()->mutable_data());
+ const size_t num = in.Size();
+
+ if (in.stride() == out->stride()) {
+ cuda::erf(num, inPtr, outPtr, ctx->stream);
+ } else { // else we transform in to out to store first
+ Transform<float, lang::Cuda>(in, out, ctx);
+ cuda::erf(num, outPtr, outPtr, ctx->stream);
+ }
+}
+
+template <>
void Ceil<float, lang::Cuda>(const Tensor& in, Tensor* out, Context* ctx) {
const float* inPtr = static_cast<const float*>(in.block()->data());
float* outPtr = static_cast<float*>(out->block()->mutable_data());
diff --git a/test/python/test_api.py b/test/python/test_api.py
index 11f52e3..e307dc9 100644
--- a/test/python/test_api.py
+++ b/test/python/test_api.py
@@ -340,6 +340,19 @@ class TestAPI(unittest.TestCase):
def test_transpose_and_arithmetic_op_broadcast_cpu(self):
self._transpose_and_arithmetic_op_broadcast_helper(cpu_dev)
+ def _erf(self, dev=cpu_dev):
+ np1 = np.random.random((2, 3)).astype(np.float32)
+
+ x1 = tensor.from_numpy(np1)
+ x1.to_device(dev)
+ y1 = tensor.from_raw_tensor(singa_api.Erf(x1.data))
+
+ # from scipy.special import erf
+ # np.testing.assert_array_almost_equal(erf(np1), tensor.to_numpy(y1))
+
+ def test_erf_cpu(self):
+ self._erf(cpu_dev)
+
@unittest.skipIf(not singa_api.USE_CUDA, 'CUDA is not enabled')
def test_transpose_and_arithmetic_op_broadcast_gpu(self):
self._transpose_and_arithmetic_op_broadcast_helper(gpu_dev)