This is an automated email from the ASF dual-hosted git repository.

wangwei pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/singa.git


The following commit(s) were added to refs/heads/dev by this push:
     new f0953f5  added astype api, support float2int, int2float, added tests
     new d0acfa3  Merge pull request #625 from dcslin/astype
f0953f5 is described below

commit f0953f54bb641d4c063a08b8ad5e1db4a1ad731b
Author: dcslin <[email protected]>
AuthorDate: Tue Mar 10 10:34:47 2020 +0000

    added astype api, support float2int, int2float, added tests
---
 python/singa/tensor.py             | 15 ++++++--
 src/core/tensor/math_kernel.cu     | 22 +++++++++++
 src/core/tensor/math_kernel.h      |  4 ++
 src/core/tensor/tensor.cc          | 66 +++++++++++++++++++++++++++++++-
 src/core/tensor/tensor_math.h      |  5 +++
 src/core/tensor/tensor_math_cpp.h  | 16 ++++++++
 src/core/tensor/tensor_math_cuda.h | 17 +++++++++
 test/python/test_api.py            | 27 +++++++++++++
 test/python/test_tensor.py         | 15 ++++++++
 test/singa/test_tensor.cc          | 78 ++++++++++++++++++++++++++++++++++++--
 10 files changed, 256 insertions(+), 9 deletions(-)

diff --git a/python/singa/tensor.py b/python/singa/tensor.py
index b64c8b5..64d88c2 100755
--- a/python/singa/tensor.py
+++ b/python/singa/tensor.py
@@ -233,14 +233,23 @@ class Tensor(object):
         self.device = t.device
         self.dtype = t.dtype
 
-    '''
     def as_type(self, dtype):
-        Change the data type.
+        '''Change the data type.
 
         Args:
             dtype:
+        '''
+        if dtype == singa.kInt:
+            pass
+        elif dtype == singa.kFloat32:
+            pass
+        elif dtype == 'int':
+            dtype = singa.kInt
+        elif dtype == 'float':
+            dtype = singa.kFloat32
+        else:
+            raise TypeError("invalid data type %s" % dtype)
         self.data.AsType(dtype)
-    '''
 
     def to_device(self, device):
         '''Move the tensor data onto a given device.
diff --git a/src/core/tensor/math_kernel.cu b/src/core/tensor/math_kernel.cu
index 7d9df0c..2ce87ea 100644
--- a/src/core/tensor/math_kernel.cu
+++ b/src/core/tensor/math_kernel.cu
@@ -184,6 +184,20 @@ __global__ void KernelAbs(const size_t n, const float *in, 
float *out) {
   }
 }
 
+__global__ void KernelCastFloat2Int(const size_t n, const float *in, int *out) 
{
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
+       i += blockDim.x * gridDim.x) {
+    out[i] = int(in[i]);
+  }
+}
+
+__global__ void KernelCastInt2Float(const size_t n, const int *in, float *out) 
{
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
+       i += blockDim.x * gridDim.x) {
+    out[i] = float(in[i]);
+  }
+}
+
 __global__ void KernelSoftplus(const size_t n, const float *in, float *out) {
   for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
        i += blockDim.x * gridDim.x) {
@@ -509,6 +523,14 @@ void abs(const size_t n, const float *in, float *out, 
cudaStream_t s) {
   KernelAbs <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, out);
 }
 
+void cast_float_2_int(const size_t n, const float *src, int *dst, cudaStream_t 
s) {
+  KernelCastFloat2Int <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, src, 
dst);
+}
+
+void cast_int_2_float(const size_t n, const int *src, float *dst, cudaStream_t 
s) {
+  KernelCastInt2Float <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, src, 
dst);
+}
+
 void sign(const size_t n, const float *in, float *out, cudaStream_t s) {
   KernelSign <<<ceil(n / CU1DBLOCKF), CU1DBLOCKF, 0, s>>> (n, in, out);
 }
diff --git a/src/core/tensor/math_kernel.h b/src/core/tensor/math_kernel.h
index 0b9f2fa..12398fb 100644
--- a/src/core/tensor/math_kernel.h
+++ b/src/core/tensor/math_kernel.h
@@ -45,6 +45,10 @@ void abs(const size_t n, const float *in, float *out, 
cudaStream_t s);
 void sign(const size_t n, const float *in, float *out, cudaStream_t s);
 void exp(const size_t n, const float *in, float *out, cudaStream_t s);
 void ceil2(const size_t n, const float *in, float *out, cudaStream_t s);
+void cast_float_2_int(const size_t n, const float *src, int *dst,
+                      cudaStream_t s);
+void cast_int_2_float(const size_t n, const int *src, float *dst,
+                      cudaStream_t s);
 void log(const size_t n, const float *in, float *out, cudaStream_t s);
 void sqrt(const size_t n, const float *in, float *out, cudaStream_t s);
 void square(const size_t n, const float *in, float *out, cudaStream_t s);
diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc
index 7665009..25e354e 100644
--- a/src/core/tensor/tensor.cc
+++ b/src/core/tensor/tensor.cc
@@ -105,11 +105,73 @@ Tensor Resize(const Tensor &in, const Shape &shape) {
   return out;
 }
 
+#define TYPE_TYPE_LANG_SWITCH(ldtype, LDType, rdtype, RDType, ltype, Lang,     
\
+                              ...)                                             
\
+  do {                                                                         
\
+    const int _SwitchShift = 3;                                                
\
+    int _SwitchHash =                                                          
\
+        ((ldtype) << _SwitchShift * 2) + ((rdtype) << _SwitchShift) + (ltype); 
\
+    switch (_SwitchHash) {                                                     
\
+      case (((kFloat32) << _SwitchShift * 2) + (kInt << _SwitchShift) +        
\
+            kCuda): {                                                          
\
+        typedef float LDType;                                                  
\
+        typedef int RDType;                                                    
\
+        typedef lang::Cuda Lang;                                               
\
+        { __VA_ARGS__ }                                                        
\
+        break;                                                                 
\
+      }                                                                        
\
+      case (((kInt) << _SwitchShift * 2) + (kFloat32 << _SwitchShift) +        
\
+            kCuda): {                                                          
\
+        typedef int LDType;                                                    
\
+        typedef float RDType;                                                  
\
+        typedef lang::Cuda Lang;                                               
\
+        { __VA_ARGS__ }                                                        
\
+        break;                                                                 
\
+      }                                                                        
\
+      case (((kFloat32) << _SwitchShift * 2) + (kInt << _SwitchShift) +        
\
+            kCpp): {                                                           
\
+        typedef float LDType;                                                  
\
+        typedef int RDType;                                                    
\
+        typedef lang::Cpp Lang;                                                
\
+        { __VA_ARGS__ }                                                        
\
+        break;                                                                 
\
+      }                                                                        
\
+      case (((kInt) << _SwitchShift * 2) + (kFloat32 << _SwitchShift) +        
\
+            kCpp): {                                                           
\
+        typedef int LDType;                                                    
\
+        typedef float RDType;                                                  
\
+        typedef lang::Cpp Lang;                                                
\
+        { __VA_ARGS__ }                                                        
\
+        break;                                                                 
\
+      }                                                                        
\
+      default:                                                                 
\
+        LOG(FATAL) << "Unknown combination of left data type "                 
\
+                   << DataType_Name(ldtype) << " and right data type "         
\
+                   << DataType_Name(rdtype) << " and language "                
\
+                   << LangType_Name(ltype);                                    
\
+    }                                                                          
\
+  } while (0)
+
 Tensor &Tensor::AsType(const DataType type) {
   if (data_type_ != type) {
-    if (block_ != nullptr && block_->DecRefCount() == 0)
+    if (block_ != nullptr && block_->DecRefCount() == 0) {
+      auto offset = Product(shape_);
+      auto new_block_ =
+          device_->NewBlock((int)(Product(shape_) * SizeOf(type)));
+      TYPE_TYPE_LANG_SWITCH(
+          data_type_, LDType, type, RDType, device_->lang(), Lang, {
+            device_->Exec(
+                [this, new_block_, offset, type](Context *ctx) {
+                  CastAsType<LDType, RDType, Lang>(this, new_block_, offset,
+                                                   ctx);
+                },
+                {}, {});
+          });
       device_->FreeBlock(block_);
-    block_ = device_->NewBlock((int)(Product(shape_) * SizeOf(type)));
+      block_ = new_block_;
+    } else {
+      block_ = device_->NewBlock((int)(Product(shape_) * SizeOf(type)));
+    }
     data_type_ = type;
   }
   return *this;
diff --git a/src/core/tensor/tensor_math.h b/src/core/tensor/tensor_math.h
index 35ef665..bae7235 100644
--- a/src/core/tensor/tensor_math.h
+++ b/src/core/tensor/tensor_math.h
@@ -86,6 +86,11 @@ void Abs(const Tensor &in, Tensor *out, Context *ctx) {
   LOG(FATAL) << "Abs Not Implemented";
 }
 
+template <typename DTypeSrc, typename DTypeDst, typename Lang>
+void CastAsType(const Tensor *src, Block *dst, int offset, Context *ctx) {
+  LOG(FATAL) << "CastAsType Not Implemented";
+}
+
 template <typename DType, typename Lang>
 void Ceil(const Tensor &in, Tensor *out, Context *ctx) {
   LOG(FATAL) << "Ceil Not Implemented";
diff --git a/src/core/tensor/tensor_math_cpp.h 
b/src/core/tensor/tensor_math_cpp.h
index 7ce7d14..d54034b 100644
--- a/src/core/tensor/tensor_math_cpp.h
+++ b/src/core/tensor/tensor_math_cpp.h
@@ -241,6 +241,22 @@ void Abs<float, lang::Cpp>(const Tensor &in, Tensor *out, 
Context *ctx) {
 }
 
 template <>
+void CastAsType<float, int, lang::Cpp>(const Tensor *src, Block *dst,
+                                       int offset, Context *ctx) {
+  int *dst_array = static_cast<int *>(dst->mutable_data());
+  const float *src_array = static_cast<const float *>(src->block()->data());
+  for (int i = 0; i < offset; ++i) dst_array[i] = (int)src_array[i];
+}
+
+template <>
+void CastAsType<int, float, lang::Cpp>(const Tensor *src, Block *dst,
+                                       int offset, Context *ctx) {
+  float *dst_array = static_cast<float *>(dst->mutable_data());
+  const int *src_array = static_cast<const int *>(src->block()->data());
+  for (int i = 0; i < offset; ++i) dst_array[i] = (float)src_array[i];
+}
+
+template <>
 void Ceil<float, lang::Cpp>(const Tensor &in, Tensor *out, Context *ctx) {
   traverse_unary<float>(in, out, [](float x) { return std::ceil(x); });
 }
diff --git a/src/core/tensor/tensor_math_cuda.h 
b/src/core/tensor/tensor_math_cuda.h
index 3c043ab..4fb481b 100644
--- a/src/core/tensor/tensor_math_cuda.h
+++ b/src/core/tensor/tensor_math_cuda.h
@@ -170,6 +170,23 @@ void Abs<float, lang::Cuda>(const Tensor& in, Tensor* out, 
Context* ctx) {
 }
 
 template <>
+void CastAsType<float, int, lang::Cuda>(const Tensor* src, Block* dst,
+                                        int offset, Context* ctx) {
+  const float* srcPtr = static_cast<const float*>(src->block()->data());
+  int* dstPtr = static_cast<int*>(dst->mutable_data());
+  const size_t num = src->Size();
+  cuda::cast_float_2_int(num, srcPtr, dstPtr, ctx->stream);
+}
+
+template <>
+void CastAsType<int, float, lang::Cuda>(const Tensor* src, Block* dst,
+                                        int offset, Context* ctx) {
+  const int* srcPtr = static_cast<const int*>(src->block()->data());
+  float* dstPtr = static_cast<float*>(dst->mutable_data());
+  cuda::cast_int_2_float(offset, srcPtr, dstPtr, ctx->stream);
+}
+
+template <>
 void Set<float, lang::Cuda>(const float x, Tensor* out, Context* ctx) {
   float* outPtr = static_cast<float*>(out->block()->mutable_data());
 
diff --git a/test/python/test_api.py b/test/python/test_api.py
index d62b58f..00c9c50 100644
--- a/test/python/test_api.py
+++ b/test/python/test_api.py
@@ -605,6 +605,33 @@ class TestAPI(unittest.TestCase):
             np.testing.assert_array_almost_equal(
                 tensor.to_numpy(_cTensor_to_pyTensor(t2_ct)), np2)
 
+    def test_as_type(self):
+        np1 = np.random.random([3]).astype(np.float32)
+        np1 = np1 * 10 - 5
+        np2 = np1.astype(np.int32)
+        np3 = np2.astype(np.float32)
+
+        for dev in [cpu_dev, gpu_dev]:
+            t1 = tensor.Tensor(device=dev, data=np1)
+
+            t1_ct = t1.data
+
+            self.assertEqual(t1_ct.data_type(), singa_api.kFloat32)
+
+            t1_ct.AsType(singa_api.kInt)
+
+            self.assertEqual(t1_ct.data_type(), singa_api.kInt)
+
+            np.testing.assert_array_almost_equal(
+                tensor.to_numpy(_cTensor_to_pyTensor(t1_ct)), np2)
+
+            t1_ct.AsType(singa_api.kFloat32)
+
+            self.assertEqual(t1_ct.data_type(), singa_api.kFloat32)
+
+            np.testing.assert_array_almost_equal(
+                tensor.to_numpy(_cTensor_to_pyTensor(t1_ct)), np3)
+
 
 if __name__ == '__main__':
     unittest.main()
diff --git a/test/python/test_tensor.py b/test/python/test_tensor.py
index 4bf592c..1588e19 100644
--- a/test/python/test_tensor.py
+++ b/test/python/test_tensor.py
@@ -338,6 +338,21 @@ class TestTensorMethods(unittest.TestCase):
 
             np.testing.assert_array_almost_equal(tensor.to_numpy(t2), np2)
 
+    def test_astype(self):
+        for dev in [cpu_dev, gpu_dev]:
+
+            np1 = np.random.random([5, 6, 7, 8]).astype(np.float32)
+            np1 = np1 * 10 - 5
+
+            np2 = np1.astype(np.int32)
+            np3 = np2.astype(np.float32)
+
+            t1 = tensor.Tensor(device=dev, data=np1)
+            t1.as_type('int')
+            np.testing.assert_array_almost_equal(tensor.to_numpy(t1), np2)
+            t1.as_type('float')
+            np.testing.assert_array_almost_equal(tensor.to_numpy(t1), np3)
+
 
 if __name__ == '__main__':
     unittest.main()
diff --git a/test/singa/test_tensor.cc b/test/singa/test_tensor.cc
index b0d9fd0..70ec868 100644
--- a/test/singa/test_tensor.cc
+++ b/test/singa/test_tensor.cc
@@ -68,11 +68,81 @@ TEST(TensorClass, Reshape) {
   EXPECT_TRUE(o.shape() != t.shape());
 }
 
-TEST(TensorClass, AsType) {
-  Tensor t;
+#ifdef USE_CUDA
+
+TEST(TensorClass, FloatAsTypeIntCuda) {
+  auto cuda = std::make_shared<singa::CudaGPU>();
+
+  Tensor t(Shape{3}, cuda);
+  float data[] = {1.0f, 2.0f, 3.0f};
+  t.CopyDataFromHostPtr(data, 3);
+  EXPECT_EQ(singa::kFloat32, t.data_type());
+
+  t.AsType(singa::kInt);
+
+  EXPECT_EQ(singa::kInt, t.data_type());
+
+  t.ToHost();
+  const int* dptr2 = static_cast<const int*>(t.block()->data());
+  EXPECT_EQ(1, dptr2[0]);
+  EXPECT_EQ(2, dptr2[1]);
+  EXPECT_EQ(3, dptr2[2]);
+}
+
+TEST(TensorClass, IntAsTypeFloatCuda) {
+  auto cuda = std::make_shared<singa::CudaGPU>();
+
+  Tensor t(Shape{3}, cuda, singa::kInt);
+  int data[] = {1, 2, 3};
+  t.CopyDataFromHostPtr(data, 3);
+  EXPECT_EQ(singa::kInt, t.data_type());
+
+  t.AsType(singa::kFloat32);
+
+  EXPECT_EQ(singa::kFloat32, t.data_type());
+
+  t.ToHost();
+  const float* dptr2 = static_cast<const float*>(t.block()->data());
+  EXPECT_EQ(1.0f, dptr2[0]);
+  EXPECT_EQ(2.0f, dptr2[1]);
+  EXPECT_EQ(3.0f, dptr2[2]);
+}
+
+#endif  // USE_CUDA
+
+TEST(TensorClass, FloatAsTypeIntCPU) {
+  Tensor t(Shape{3});
+  float data[] = {1.0f, 2.0f, 3.0f};
+  t.CopyDataFromHostPtr(data, 3);
+  EXPECT_EQ(singa::kFloat32, t.data_type());
+  const float* dptr = static_cast<const float*>(t.block()->data());
+  EXPECT_FLOAT_EQ(1.0f, dptr[0]);
+  EXPECT_FLOAT_EQ(2.0f, dptr[1]);
+  EXPECT_FLOAT_EQ(3.0f, dptr[2]);
+
+  t.AsType(singa::kInt);
+
+  EXPECT_EQ(singa::kInt, t.data_type());
+  const int* dptr2 = static_cast<const int*>(t.block()->data());
+  EXPECT_EQ(1, dptr2[0]);
+  EXPECT_EQ(2, dptr2[1]);
+  EXPECT_EQ(3, dptr2[2]);
+}
+
+TEST(TensorClass, IntAsTypeFloatCPU) {
+  Tensor t(Shape{3}, singa::kInt);
+  int data[] = {1, 2, 3};
+  t.CopyDataFromHostPtr(data, 3);
+  EXPECT_EQ(singa::kInt, t.data_type());
+
+  t.AsType(singa::kFloat32);
+
   EXPECT_EQ(singa::kFloat32, t.data_type());
-  t.AsType(singa::kFloat16);
-  EXPECT_EQ(singa::kFloat16, t.data_type());
+
+  const float* dptr2 = static_cast<const float*>(t.block()->data());
+  EXPECT_EQ(1.0f, dptr2[0]);
+  EXPECT_EQ(2.0f, dptr2[1]);
+  EXPECT_EQ(3.0f, dptr2[2]);
 }
 
 TEST(TensorClass, ToDevice) {

Reply via email to