This is an automated email from the ASF dual-hosted git repository.
wangwei pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/singa.git
The following commit(s) were added to refs/heads/dev by this push:
new 5b784f8 added ceil to tensor, and tests
new 001ba4c Merge pull request #619 from dcslin/ceil
5b784f8 is described below
commit 5b784f8d90d236420c2d705853fe80b8761f9537
Author: dcslin <[email protected]>
AuthorDate: Tue Mar 3 05:05:19 2020 +0000
added ceil to tensor, and tests
---
include/singa/core/tensor.h | 2 ++
python/singa/tensor.py | 11 +++++++++++
src/api/core_tensor.i | 1 +
src/core/tensor/math_kernel.cu | 11 +++++++++++
src/core/tensor/math_kernel.h | 1 +
src/core/tensor/tensor.cc | 1 +
src/core/tensor/tensor_math.h | 5 +++++
src/core/tensor/tensor_math_cpp.h | 5 +++++
src/core/tensor/tensor_math_cuda.h | 14 ++++++++++++++
test/python/test_api.py | 15 +++++++++++++++
test/python/test_tensor.py | 14 ++++++++++++++
11 files changed, 80 insertions(+)
diff --git a/include/singa/core/tensor.h b/include/singa/core/tensor.h
index 02817b4..6d4b86b 100644
--- a/include/singa/core/tensor.h
+++ b/include/singa/core/tensor.h
@@ -342,6 +342,7 @@ void RepeatDataToFrom(bool broadcast_flag, const
vector<size_t> &repeats,
// =============Element-wise operations====================================
Tensor Abs(const Tensor &in);
+Tensor Ceil(const Tensor &in);
Tensor Exp(const Tensor &in);
Tensor Log(const Tensor &in);
Tensor ReLU(const Tensor &in);
@@ -366,6 +367,7 @@ Tensor Atanh(const Tensor &in);
Tensor Transform(const Tensor &in);
void Abs(const Tensor &in, Tensor *out);
+void Ceil(const Tensor &in, Tensor *out);
void Exp(const Tensor &in, Tensor *out);
void Log(const Tensor &in, Tensor *out);
void ReLU(const Tensor &in, Tensor *out);
diff --git a/python/singa/tensor.py b/python/singa/tensor.py
index f835b44..b64c8b5 100755
--- a/python/singa/tensor.py
+++ b/python/singa/tensor.py
@@ -805,6 +805,17 @@ def exp(t):
return _call_singa_func(singa.Exp, t.data)
+def ceil(t):
+ '''
+ Args:
+ t (Tensor): input Tensor
+
+ Returns:
+ a new Tensor whose element y = ceil(x), x is an element of t
+ '''
+ return _call_singa_func(singa.Ceil, t.data)
+
+
def log(t):
'''
Args:
diff --git a/src/api/core_tensor.i b/src/api/core_tensor.i
index 4550e6a..28e8ac5 100755
--- a/src/api/core_tensor.i
+++ b/src/api/core_tensor.i
@@ -171,6 +171,7 @@ namespace singa{
Tensor Transpose(const Tensor &in);
Tensor Abs(const Tensor &t);
+ Tensor Ceil(const Tensor &t);
Tensor Exp(const Tensor &t);
Tensor Log(const Tensor &t);
Tensor ReLU(const Tensor &t);
diff --git a/src/core/tensor/math_kernel.cu b/src/core/tensor/math_kernel.cu
index 5ac23cc..7d9df0c 100644
--- a/src/core/tensor/math_kernel.cu
+++ b/src/core/tensor/math_kernel.cu
@@ -117,6 +117,13 @@ __global__ void KernelExp(const size_t n, const float *in,
float *out) {
}
}
+__global__ void KernelCeil2(const size_t n, const float *in, float *out) {
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
+ i += blockDim.x * gridDim.x) {
+ out[i] = std::ceil(in[i]);
+ }
+}
+
__global__ void KernelLog(const size_t n, const float *in, float *out) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
i += blockDim.x * gridDim.x) {
@@ -510,6 +517,10 @@ void exp(const size_t n, const float *in, float *out,
cudaStream_t s) {
KernelExp <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, out);
}
+void ceil2(const size_t n, const float *in, float *out, cudaStream_t s) {
+ KernelCeil2 <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, out);
+}
+
void log(const size_t n, const float *in, float *out, cudaStream_t s) {
KernelLog <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, out);
}
diff --git a/src/core/tensor/math_kernel.h b/src/core/tensor/math_kernel.h
index af5f938..0b9f2fa 100644
--- a/src/core/tensor/math_kernel.h
+++ b/src/core/tensor/math_kernel.h
@@ -44,6 +44,7 @@ void set(const size_t n, const float v, float *out,
cudaStream_t s);
void abs(const size_t n, const float *in, float *out, cudaStream_t s);
void sign(const size_t n, const float *in, float *out, cudaStream_t s);
void exp(const size_t n, const float *in, float *out, cudaStream_t s);
+void ceil2(const size_t n, const float *in, float *out, cudaStream_t s);
void log(const size_t n, const float *in, float *out, cudaStream_t s);
void sqrt(const size_t n, const float *in, float *out, cudaStream_t s);
void square(const size_t n, const float *in, float *out, cudaStream_t s);
diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc
index 715be80..7665009 100644
--- a/src/core/tensor/tensor.cc
+++ b/src/core/tensor/tensor.cc
@@ -719,6 +719,7 @@ template void Tensor::GetValue<int>(int *value, const
size_t num);
void fn(const Tensor &in, Tensor *out) { EltwiseUnaryTensorFn(fn, in, out); }
GenUnaryTensorFn(Abs);
+GenUnaryTensorFn(Ceil);
GenUnaryTensorFn(Exp);
GenUnaryTensorFn(Log);
GenUnaryTensorFn(ReLU);
diff --git a/src/core/tensor/tensor_math.h b/src/core/tensor/tensor_math.h
index d9440ad..35ef665 100644
--- a/src/core/tensor/tensor_math.h
+++ b/src/core/tensor/tensor_math.h
@@ -86,6 +86,11 @@ void Abs(const Tensor &in, Tensor *out, Context *ctx) {
LOG(FATAL) << "Abs Not Implemented";
}
+template <typename DType, typename Lang>
+void Ceil(const Tensor &in, Tensor *out, Context *ctx) {
+ LOG(FATAL) << "Ceil Not Implemented";
+}
+
/// out[i] = in[i] + x
template <typename DType, typename Lang>
void Add(const Tensor &in, const DType x, Tensor *out, Context *ctx) {
diff --git a/src/core/tensor/tensor_math_cpp.h
b/src/core/tensor/tensor_math_cpp.h
index 9550d24..7ce7d14 100644
--- a/src/core/tensor/tensor_math_cpp.h
+++ b/src/core/tensor/tensor_math_cpp.h
@@ -240,6 +240,11 @@ void Abs<float, lang::Cpp>(const Tensor &in, Tensor *out,
Context *ctx) {
traverse_unary<float>(in, out, [](float x) { return fabs(x); });
}
+template <>
+void Ceil<float, lang::Cpp>(const Tensor &in, Tensor *out, Context *ctx) {
+ traverse_unary<float>(in, out, [](float x) { return std::ceil(x); });
+}
+
#ifdef USE_DNNL
template <>
void SoftMax<float, lang::Cpp>(const Tensor &in, Tensor *out, Context *ctx) {
diff --git a/src/core/tensor/tensor_math_cuda.h
b/src/core/tensor/tensor_math_cuda.h
index ea7e9a5..3c043ab 100644
--- a/src/core/tensor/tensor_math_cuda.h
+++ b/src/core/tensor/tensor_math_cuda.h
@@ -347,6 +347,20 @@ void Exp<float, lang::Cuda>(const Tensor& in, Tensor* out,
Context* ctx) {
}
template <>
+void Ceil<float, lang::Cuda>(const Tensor& in, Tensor* out, Context* ctx) {
+ const float* inPtr = static_cast<const float*>(in.block()->data());
+ float* outPtr = static_cast<float*>(out->block()->mutable_data());
+ const size_t num = in.Size();
+
+ if (in.stride() == out->stride()) {
+ cuda::ceil2(num, inPtr, outPtr, ctx->stream);
+ } else { // else we transform in to out to store first
+ Transform<float, lang::Cuda>(in, out, ctx);
+ cuda::ceil2(num, outPtr, outPtr, ctx->stream);
+ }
+}
+
+template <>
void GE<float, lang::Cuda>(const Tensor& in, const float x, Tensor* out,
Context* ctx) {
float* outPtr = static_cast<float*>(out->block()->mutable_data());
diff --git a/test/python/test_api.py b/test/python/test_api.py
index 3b847aa..d62b58f 100644
--- a/test/python/test_api.py
+++ b/test/python/test_api.py
@@ -590,6 +590,21 @@ class TestAPI(unittest.TestCase):
np.testing.assert_array_almost_equal(
tensor.to_numpy(_cTensor_to_pyTensor(t3_ct)), np3)
+ def test_ceil(self):
+
+ for dev in [cpu_dev, gpu_dev]:
+
+ np1 = np.random.random([5, 6, 7, 8]).astype(np.float32)
+ np1 = np1 * 10
+ np2 = np.ceil(np1)
+
+ t1 = tensor.Tensor(device=dev, data=np1)
+
+ t2_ct = singa_api.Ceil(t1.data)
+
+ np.testing.assert_array_almost_equal(
+ tensor.to_numpy(_cTensor_to_pyTensor(t2_ct)), np2)
+
if __name__ == '__main__':
unittest.main()
diff --git a/test/python/test_tensor.py b/test/python/test_tensor.py
index 872d650..4bf592c 100644
--- a/test/python/test_tensor.py
+++ b/test/python/test_tensor.py
@@ -324,6 +324,20 @@ class TestTensorMethods(unittest.TestCase):
np.testing.assert_array_almost_equal((tensor.to_numpy(sg_tensor_ret)),
np1[1:3, :, 1:, :-1])
+ def test_ceil(self):
+
+ for dev in [cpu_dev, gpu_dev]:
+
+ np1 = np.random.random([5, 6, 7, 8]).astype(np.float32)
+ np1 = np1 * 10
+ np2 = np.ceil(np1)
+
+ t1 = tensor.Tensor(device=dev, data=np1)
+
+ t2 = tensor.ceil(t1)
+
+ np.testing.assert_array_almost_equal(tensor.to_numpy(t2), np2)
+
if __name__ == '__main__':
unittest.main()