SINGA-186 Create Python Tensor class

- Add Swig interfaces: singa.i
  . core_tensor.i for singa::Tensor
  . core_device.i for singa::Device
- Add generate_singa_wrapper.sh to generate wrapper
- Add python script: tensor.py, device.py
- Add an example unittest unittest_python.py in test/python

Notes:
- call copy constructor of singa::Tensor in __init__
- rename transpose -> is_tranpose(), which calls Tensor.transpose()
  rename matrix_transpose -> transpose(), which calls singa::Tensor.T()
- Revised a way to generate Tensor copy in python
  . add copy() for shallow copy, which calls singa::Tensor(Tensor&)
  . add deepcopy() for deep copy, which calls singa::Tensor::Clone()
  Note: revised __init__()
        now clone() and deepcopy() are the same

TODO:
- need think about more efficient way to convert Tensor data to numpy array
- example test is in test/python/example_test_device.py, which will be included 
in unittest_python.py


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/2a582df1
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/2a582df1
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/2a582df1

Branch: refs/heads/dev
Commit: 2a582df121d8b3fa45dad1e9c0435bc01bf79bea
Parents: 8900543
Author: chonho <[email protected]>
Authored: Wed Jun 8 01:03:05 2016 +0800
Committer: Wei Wang <[email protected]>
Committed: Fri Jun 17 22:27:01 2016 +0800

----------------------------------------------------------------------
 src/core/device/cpp_cpu.cc           |   2 +-
 src/python/core_device.i             |  60 +++++
 src/python/core_tensor.i             | 264 +++++++++++++++++++++
 src/python/device.py                 |  79 +++++++
 src/python/generate_singa_wrapper.sh |  43 ++++
 src/python/singa.i                   |  27 +++
 src/python/tensor.py                 | 365 ++++++++++++++++++++++++++++++
 test/python/example_test_device.py   |  36 +++
 test/python/unittest_python.py       | 139 ++++++++++++
 9 files changed, 1014 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/2a582df1/src/core/device/cpp_cpu.cc
----------------------------------------------------------------------
diff --git a/src/core/device/cpp_cpu.cc b/src/core/device/cpp_cpu.cc
index 44f614a..66f5d54 100644
--- a/src/core/device/cpp_cpu.cc
+++ b/src/core/device/cpp_cpu.cc
@@ -21,7 +21,7 @@ CppCPU defaultDevice(-1, 1);
 CppCPU::CppCPU(int id, int num_executors, string scheduler,
          string vm) : Device(id, num_executors, scheduler, vm) {
   lang_ = kCpp;
-  host_ = nullptr;
+  //host_ = nullptr;
 }
 
 void CppCPU::SetRandSeed(unsigned seed) {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/2a582df1/src/python/core_device.i
----------------------------------------------------------------------
diff --git a/src/python/core_device.i b/src/python/core_device.i
new file mode 100644
index 0000000..ab9abd8
--- /dev/null
+++ b/src/python/core_device.i
@@ -0,0 +1,60 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+/*interface file for swig */
+
+%module core_device
+%include "std_vector.i"
+%include "std_string.i"
+
+%{
+#include "singa/core/device.h"
+%}
+
+namespace singa{
+
+  %nodefault Device;
+  class Device {
+   public:
+    virtual void SetRandSeed(unsigned seed) = 0;
+    Device* host();
+    int id() const;
+  };
+
+  class CppCPU : public Device {
+   public:
+    CppCPU(int id = -1, int num_executors = 1,
+           std::string scheduler = "sync", std::string vm = "gc-only");
+    void SetRandSeed(unsigned seed) override;
+    /* (TODO) add necessary functions of CppCPU class
+    */
+  };
+
+  class CudaGPU : public Device {
+   public:
+    CudaGPU(int id = 0, int num_executors = 1,
+            std::string scheduler = "sync", std::string vm = "gc-only");
+    void SetRandSeed(unsigned seed) override;
+    /* (TODO) add necessary functions of CudaGPU class
+    */
+  };
+}
+

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/2a582df1/src/python/core_tensor.i
----------------------------------------------------------------------
diff --git a/src/python/core_tensor.i b/src/python/core_tensor.i
new file mode 100644
index 0000000..a700602
--- /dev/null
+++ b/src/python/core_tensor.i
@@ -0,0 +1,264 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+/*interface file for swig */
+
+%module core_tensor
+%include "std_vector.i"
+%include "std_string.i"
+
+%include "carrays.i"
+%array_class(float, floatArray);
+%array_class(int, intArray);
+%array_class(char, charArray);
+%array_class(double, doubleArray);
+
+%{
+#include "core/tensor/tensor_math.h"
+#include "singa/core/tensor.h"
+#include "singa/core/device.h"
+#include "singa/proto/core.pb.h"
+#include "singa/proto/model.pb.h"
+using singa::DataType;
+%}
+
+%template(Shape) std::vector<size_t>;
+
+namespace singa{
+
+  enum DataType {
+    kFloat32, kFloat16, kInt, kChar, kDouble
+  };
+
+  inline size_t Product(const std::vector<size_t> &shape,
+                        int start = 0, size_t len = 0);
+  inline size_t SizeOf(DataType t);
+
+  class Tensor {
+
+   public:
+    Tensor();
+    explicit Tensor(const std::vector<size_t> &shape,
+                    DataType dtype = kFloat32);
+    Tensor(const std::vector<size_t> &shape,
+           singa::Device *dev, DataType dtype = kFloat32);
+    Tensor(const Tensor &from);
+
+    //Blob *blob() const;
+    singa::Device *device() const;
+
+    template <typename DType> DType data() const;
+    %template(floatData) data<const float*>;
+    %template(intData) data<const int*>;
+    %template(charData) data<const char*>;
+    %template(doubleData) data<const double*>;
+
+    const DataType data_type() const;
+    const std::vector<size_t> &shape() const;
+    const size_t shape(size_t idx) const;
+    size_t nDim() const;
+    bool transpose() const;
+    size_t Size() const;
+    size_t MemSize() const;
+    void Reshape(const std::vector<size_t> &shape);
+    void ResetLike(const Tensor &t);
+    void AsType(DataType type);
+    void ToDevice(singa::Device *dev);
+    void ToHost();
+
+    template <typename SType> void SetValue(const SType x);
+    %template(floatSetValue) SetValue<float>;
+    // ...
+
+    /* no need to expose this function
+    template <typename DType> void CopyDataFromHostPtr(const DType *src,
+                                                       size_t num);
+    */
+
+    void CopyData(const Tensor &other);
+    Tensor Clone() const;
+    Tensor T() const;
+
+    /* python has no assignment operator as c++
+    Tensor &operator=(const Tensor &t); */
+    Tensor &operator+=(const Tensor &t);
+    Tensor &operator-=(const Tensor &t);
+    Tensor &operator*=(const Tensor &t);
+    Tensor &operator/=(const Tensor &t);
+
+
+    template <typename DType> Tensor &operator+=(const DType x);
+    %template(iAdd_f) operator+=<float>;
+    /* TODO(chonho-01) for other types */
+    // ...
+
+    template <typename DType> Tensor &operator-=(DType x);
+    %template(iSub_f) operator-=<float>;
+    /* TODO(chonho-01) for other types */
+    // ...
+
+    template <typename DType> Tensor &operator*=(DType x);
+    %template(iMul_f) operator*=<float>;
+    /* TODO(chonho-01) for other types */
+    // ...
+
+    template <typename DType> Tensor &operator/=(DType x);
+    %template(iDiv_f) operator/=<float>;
+    /* TODO(chonho-01) for other types */
+    // ...
+
+  };
+
+  /* TODO
+  inline void CheckDataTypeAndLang(const Tensor &in1, const Tensor &in2);
+  */
+  void CopyDataToFrom(Tensor *dst, const Tensor &src, size_t num,
+                      size_t src_offset = 0, size_t dst_offset = 0);
+
+  Tensor Reshape(const Tensor &in, const std::vector<size_t> &s);
+
+  Tensor Abs(const Tensor &t);
+  Tensor Exp(const Tensor &t);
+  Tensor Log(const Tensor &t);
+  Tensor ReLU(const Tensor &t);
+  Tensor Sigmoid(const Tensor &t);
+  Tensor Sign(const Tensor &t);
+  Tensor Sqrt(const Tensor &t);
+  Tensor Square(const Tensor &t);
+  Tensor Tanh(const Tensor &t);
+
+  Tensor Sum(const Tensor &t, int axis);
+  template <typename SType> SType Sum(const Tensor &t);
+  %template(floatSum) Sum<float>;
+  /* TODO(chonho-03) not implemented
+  %template(intSum) Sum<int>;
+  %template(charSum) Sum<char>;
+  %template(doubleSum) Sum<double>;
+  */
+
+  /* TODO(chonho-04) not implemented
+     need average of all elements ??? */
+  Tensor Average(const Tensor &t, int axis);
+  Tensor SoftMax(const Tensor &t, int axis = 0);
+
+  /* TODO(chonho-05) not implemented ???
+  Tensor Pow(const Tensor &base, Tensor exp);
+  template <typename DType>
+  Tensor Pow(const Tensor &t, DType x);
+  */
+
+
+  /* rename comparison operators */
+  %rename(LT_Tf) operator<(const Tensor &t, const float x);
+  %rename(LE_Tf) operator<=(const Tensor &t, const float x);
+  %rename(GT_Tf) operator>(const Tensor &t, const float x);
+  %rename(GE_Tf) operator>=(const Tensor &t, const float x);
+
+  template <typename DType>
+  Tensor operator<(const Tensor &t, const DType x);
+  %template(op) operator< <float>;
+  // --- other types
+
+  template <typename DType>
+  Tensor operator<=(const Tensor &t, const DType x);
+  %template(op) operator<= <float>;
+  // --- other types
+
+  template <typename DType>
+  Tensor operator>(const Tensor &t, const DType x);
+  %template(op) operator> <float>;
+  // --- other types
+
+  template <typename DType>
+  Tensor operator>=(const Tensor &t, const DType x);
+  %template(op) operator>= <float>;
+  // --- other types
+
+  /* TODO(chonho-06)
+  no need to include theses
+  in python, these can be replaced with comparison operators
+
+  template <typename DType>
+  void LT(const Tensor &t, DType x, Tensor *ret);
+  template <typename DType>
+  void LE(const Tensor &t, DType x, Tensor *ret);
+  template <typename DType>
+  void GT(const Tensor &t, DType x, Tensor *ret);
+  template <typename DType>
+  void GE(const Tensor &t, DType x, Tensor *ret);
+  */
+
+
+  /* rename operators */
+  %rename(Add_TT) operator+(const Tensor &lhs, const Tensor &rhs);
+  %rename(Sub_TT) operator-(const Tensor &lhs, const Tensor &rhs);
+  %rename(Mul_TT) operator*(const Tensor &lhs, const Tensor &rhs);
+  %rename(Div_TT) operator/(const Tensor &lhs, const Tensor &rhs);
+  Tensor operator+(const Tensor &lhs, const Tensor &rhs);
+  Tensor operator-(const Tensor &lhs, const Tensor &rhs);
+  Tensor operator*(const Tensor &lhs, const Tensor &rhs);
+  Tensor operator/(const Tensor &lhs, const Tensor &rhs);
+
+  %rename(Add_Tf) operator+(const Tensor &t, float x);
+  template <typename DType>
+  Tensor operator+(const Tensor &t, DType x);
+  %template(op) operator+<float>;
+  // --- other types
+
+  %rename(Sub_Tf) operator-(const Tensor &t, float x);
+  template <typename DType>
+  Tensor operator-(const Tensor &t, DType x);
+  %template(op) operator-<float>;
+  // --- other types
+
+  %rename(Mul_Tf) operator*(const Tensor &t, float x);
+  template <typename DType>
+  Tensor operator*(const Tensor &t, DType x);
+  %template(op) operator*<float>;
+  // --- other types
+
+  %rename(Div_Tf) operator/(const Tensor &t, float x);
+  template <typename DType>
+  Tensor operator/(const Tensor &t, DType x);
+  %template(op) operator/<float>;
+  // --- other types
+
+  /* TODO(chonho-07)
+  no need to include theses
+  in python, these can be replaced with operators
+
+  void Add(const Tensor &lhs, const Tensor &rhs, Tensor *ret);
+  void Sub(const Tensor &lhs, const Tensor &rhs, Tensor *ret);
+  void EltwiseMult(const Tensor &lhs, const Tensor &rhs, Tensor *ret);
+  void Div(const Tensor &lhs, const Tensor &rhs, Tensor *ret);
+
+  template <typename DType>
+  void Add(const Tensor &t, DType x, Tensor *ret);
+  template <typename DType>
+  void Sub(const Tensor &t, DType x, Tensor *ret);
+  template <typename DType>
+  void EltwiseMult(const Tensor &t, DType x, Tensor *ret);
+  template <typename DType>
+  void Div(const Tensor &t, DType x, Tensor *ret);
+  */
+
+}
+

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/2a582df1/src/python/device.py
----------------------------------------------------------------------
diff --git a/src/python/device.py b/src/python/device.py
new file mode 100644
index 0000000..9a9787c
--- /dev/null
+++ b/src/python/device.py
@@ -0,0 +1,79 @@
+#!/usr/bin/env python
+
+# /************************************************************
+# *
+# * Licensed to the Apache Software Foundation (ASF) under one
+# * or more contributor license agreements.  See the NOTICE file
+# * distributed with this work for additional information
+# * regarding copyright ownership.  The ASF licenses this file
+# * to you under the Apache License, Version 2.0 (the
+# * "License"); you may not use this file except in compliance
+# * with the License.  You may obtain a copy of the License at
+# *
+# *   http://www.apache.org/licenses/LICENSE-2.0
+# *
+# * Unless required by applicable law or agreed to in writing,
+# * software distributed under the License is distributed on an
+# * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# * KIND, either express or implied.  See the License for the
+# * specific language governing permissions and limitations
+# * under the License.
+# *
+# *************************************************************/
+
+'''
+This script includes Device class and its subclasses for python users
+to call singa::Device and its methods
+'''
+import sys
+import os
+import numpy as np
+import singa
+
+sys.path.append(os.path.join(os.path.dirname(__file__), '../'))
+
+
+class Device(object):
+    ''' Class and member functions for singa::Device
+    '''
+
+    def __init__(self, id=-1, num_executors=1, scheduler='sync', vm='gc-only',
+                 device='cpu'):
+        ''' id = (int)            // device ID
+            num_executors = (int) // # of executors (e.g., cuda streams)
+            scheduler = (string)  // identifier of scheduler type (default
+                                  // scheduler run operations synchronously)
+            vm = (string)         // virtual memory type (default vm only
+                                  // provides garbage collection)
+            (TODO) max mem size to use (in MB)
+        '''
+        if device == 'gpu':
+            self.singa_device = singa.CudaGPU(id, num_executors, scheduler, vm)
+        else:
+            self.singa_device = singa.CppCPU(id, num_executors, scheduler, vm)
+
+        self.id = id
+        self.num_executors = num_executors
+        self.scheduler = scheduler
+        self.vm = vm
+
+    def set_rand_seed(self, seed):
+        self.singa_device.SetRandSeed(seed)
+
+    def get_host(self):
+        return self.singa_device.host()
+
+    def get_id(self):
+        return self.singa_device.id()
+
+
+class CppCPU(Device):
+
+    def __init__(self, id=-1, num_executors=1, scheduler='sync', vm='gc-only'):
+        super(CppCPU, self).__init__(id, num_executors, scheduler, vm)
+
+
+class CudaGPU(Device):
+
+    def __init__(self, id=0, num_executors=1, scheduler='sync', vm='gc-only'):
+        super(CudaGPU, self).__init__(id, num_executors, scheduler, vm, 'gpu')

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/2a582df1/src/python/generate_singa_wrapper.sh
----------------------------------------------------------------------
diff --git a/src/python/generate_singa_wrapper.sh 
b/src/python/generate_singa_wrapper.sh
new file mode 100755
index 0000000..037db91
--- /dev/null
+++ b/src/python/generate_singa_wrapper.sh
@@ -0,0 +1,43 @@
+#!/usr/bin/env bash
+#/**
+# * Licensed to the Apache Software Foundation (ASF) under one
+# * or more contributor license agreements.  See the NOTICE file
+# * distributed with this work for additional information
+# * regarding copyright ownership.  The ASF licenses this file
+# * to you under the Apache License, Version 2.0 (the
+# * "License"); you may not use this file except in compliance
+# * with the License.  You may obtain a copy of the License at
+# *
+# *     http://www.apache.org/licenses/LICENSE-2.0
+# *
+# * Unless required by applicable law or agreed to in writing, software
+# * distributed under the License is distributed on an "AS IS" BASIS,
+# * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# * See the License for the specific language governing permissions and
+# * limitations under the License.
+# */
+
+SINGA_ROOT=/home/chonho/incubator-singa
+SINGA_SRC=${SINGA_ROOT}/src
+SRC_CC=(${SINGA_SRC}/core/tensor/tensor.cc \
+        ${SINGA_SRC}/core/device/device.cc
+       )
+USR_LOCAL=/home/chonho/local
+
+#The following commands are only for developers adding new py apis.
+#swig -c++ -python -w509 -I../../include singa.i
+swig -c++ -python -I../../include singa.i
+
+g++ -fPIC ${SRC_CC[@]} singa_wrap.cxx -shared -o _singa.so \
+    -L${USR_LOCAL}/lib -lprotobuf -Wl,-rpath=${USR_LOCAL}/lib \
+    -L../../lib -lsinga_core -lsinga_model -lsinga_utils -Wl,-rpath=../../lib \
+    -std=c++11 \
+    -I../.. \
+    -I../../include \
+    -I${SINGA_SRC} \
+    -I${USR_LOCAL}/include \
+    -I${USR_LOCAL}/cudnn/include \
+    -I/usr/include/python2.7 \
+    -I/usr/local/cuda-7.0/include
+
+#python example.py

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/2a582df1/src/python/singa.i
----------------------------------------------------------------------
diff --git a/src/python/singa.i b/src/python/singa.i
new file mode 100644
index 0000000..8883404
--- /dev/null
+++ b/src/python/singa.i
@@ -0,0 +1,27 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+/*interface file for swig */
+
+%module singa
+%include "core_tensor.i"
+%include "core_device.i"
+//%include "model_layer.i"

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/2a582df1/src/python/tensor.py
----------------------------------------------------------------------
diff --git a/src/python/tensor.py b/src/python/tensor.py
new file mode 100644
index 0000000..12e6cb4
--- /dev/null
+++ b/src/python/tensor.py
@@ -0,0 +1,365 @@
+#!/usr/bin/env python
+
+# /************************************************************
+# *
+# * Licensed to the Apache Software Foundation (ASF) under one
+# * or more contributor license agreements.  See the NOTICE file
+# * distributed with this work for additional information
+# * regarding copyright ownership.  The ASF licenses this file
+# * to you under the Apache License, Version 2.0 (the
+# * "License"); you may not use this file except in compliance
+# * with the License.  You may obtain a copy of the License at
+# *
+# *   http://www.apache.org/licenses/LICENSE-2.0
+# *
+# * Unless required by applicable law or agreed to in writing,
+# * software distributed under the License is distributed on an
+# * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# * KIND, either express or implied.  See the License for the
+# * specific language governing permissions and limitations
+# * under the License.
+# *
+# *************************************************************/
+
+'''
+This script includes Tensor class and its methods for python users
+to call singa::Tensor and its methods
+'''
+import sys
+import os
+import numpy as np
+import singa
+
+sys.path.append(os.path.join(os.path.dirname(__file__), '../'))
+
+from core_pb2 import *
+
+
+class Tensor(object):
+    ''' Class and member functions for singa::Tensor
+    '''
+
+    def __init__(self, shape=None, device=None, dtype=kFloat32):
+        ''' shape = (tuple)
+        '''
+        if shape is None:
+            # call constructor of singa::Tensor
+            self.singa_tensor = singa.Tensor()
+            return
+        else:
+            assert type(shape) == tuple, 'shape should be tuple'
+            vs = _tuple_to_vector(shape)
+            if device is None:
+                self.singa_tensor = singa.Tensor(vs, dtype)
+            else:
+                self.singa_tensor = singa.Tensor(vs, device, dtype)
+            self.tuple_shape = shape
+            self.device = device
+            self.dtype = dtype
+
+    def toarray(self):
+        # TODO(chonho) - need to think more efficient way to convert???
+        idx = self.singa_tensor.data_type()
+        if idx == kFloat32:
+            data_array = singa.floatArray_frompointer(
+                             self.singa_tensor.floatData())
+            dt = np.float32
+        elif idx == kFloat16:
+            print 'not implemented yet'
+            return
+            # data_array = singa.floatArray_frompointer(
+            #                  self.singa_tensor.floatData())
+            # dt = np.float16
+        elif idx == kInt:
+            data_array = singa.intArray_frompointer(
+                             self.singa_tensor.intData())
+            dt = np.int32
+        elif idx == kChar:
+            data_array = singa.charArray_frompointer(
+                             self.singa_tensor.charData())
+            dt = np.int8
+        elif idx == kDouble:
+            data_array = singa.doubleArray_frompointer(
+                             self.singa_tensor.doubleData())
+            dt = np.float64
+
+        data = [data_array[i] for i in range(self.singa_tensor.Size())]
+        data = np.array(data, dtype=dt).reshape(self.tuple_shape)
+        return data
+
+    def data_type(self):
+        return self.singa_tensor.data_type()
+
+    def shape(self, axis=None):
+        if axis is None:
+            return self.singa_tensor.shape()
+        else:
+            return self.singa_tensor.shape(axis)
+
+    def ndim(self):
+        return self.singa_tensor.nDim()
+
+    def is_transpose(self):
+        return self.singa_tensor.transpose()
+
+    def size(self):
+        return self.singa_tensor.Size()
+
+    def memsize(self):
+        return self.singa_tensor.MemSize()
+
+    def reshape(self, shape):
+        assert product(self.tuple_shape) == product(shape), \
+               'product of shape should be equal'
+        self.tuple_shape = shape
+        self.singa_tensor.Reshape(_tuple_to_vector(shape))
+
+    def reset_like(self, t):
+        self.singa_tensor.ResetLike(t.singa_tensor)
+
+    def as_type(self, dtype):
+        self.singa_tensor.AsType(dtype)
+
+    def to_device(self, device):
+        self.singa_tensor.ToDevice(device)
+
+    def to_host(self):
+        self.singa_tensor.ToHost()
+
+    def set_value(self, x):
+        self.singa_tensor.SetValue(x)
+
+    def copy_data(self, t):
+        self.singa_tensor.CopyData(t.singa_tensor)
+
+    def clone(self):
+        ''' it does deep copy
+            call singa::Tensor::Clone()
+        '''
+        return _call_singa_func(self.singa_tensor.Clone)
+
+    def transpose(self):
+        ''' shallow copy, negate the transpose field
+            call singa::Tensor::T()
+        '''
+        return _call_singa_func(self.singa_tensor.T)
+
+    def copy(self):
+        ''' shallow copy
+            call copy constructor of singa::Tensor
+        '''
+        return _call_singa_func(singa.Tensor, self.singa_tensor)
+
+    def deepcopy(self):
+        ''' deep copy
+            call singa::Tensor::Clone()
+        '''
+        return self.clone()
+
+    '''
+    python operators (+=, -=, *=, /=) for singa::Tensor unary operators
+    '''
+    def __iadd__(self, x):
+        if type(x) == Tensor:
+            self.singa_tensor += x.singa_tensor
+        else:
+            self.singa_tensor += x
+        return self
+
+    def __isub__(self, x):
+        if type(x) == Tensor:
+            self.singa_tensor -= x.singa_tensor
+        else:
+            self.singa_tensor -= x
+        return self
+
+    def __imul__(self, x):
+        if type(x) == Tensor:
+            self.singa_tensor *= x.singa_tensor
+        else:
+            self.singa_tensor *= x
+        return self
+
+    def __idiv__(self, x):
+        if type(x) == Tensor:
+            self.singa_tensor /= x.singa_tensor
+        else:
+            self.singa_tensor /= x
+        return self
+
+    '''
+    python operators (+, -, *, /, <, <=, >, >=) for singa binary operators
+    '''
+    def __add__(self, rhs):
+        if isinstance(rhs, Tensor):
+            return _call_singa_func(singa.Add_TT,
+                                    self.singa_tensor, rhs.singa_tensor)
+        else:
+            return _call_singa_func(singa.Add_Tf,
+                                    self.singa_tensor, rhs)
+
+    def __sub__(self, rhs):
+        if isinstance(rhs, Tensor):
+            return _call_singa_func(singa.Sub_TT,
+                                    self.singa_tensor, rhs.singa_tensor)
+        else:
+            return _call_singa_func(singa.Sub_Tf,
+                                    self.singa_tensor, rhs)
+
+    def __mul__(self, rhs):
+        if isinstance(rhs, Tensor):
+            return _call_singa_func(singa.Mul_TT,
+                                    self.singa_tensor, rhs.singa_tensor)
+        else:
+            return _call_singa_func(singa.Mul_Tf,
+                                    self.singa_tensor, rhs)
+
+    def __div__(self, rhs):
+        if isinstance(rhs, Tensor):
+            return _call_singa_func(singa.Div_TT,
+                                    self.singa_tensor, rhs.singa_tensor)
+        else:
+            return _call_singa_func(singa.Div_Tf,
+                                    self.singa_tensor, rhs)
+
+    def __lt__(self, rhs):
+        return _call_singa_func(singa.LT_Tf, self.singa_tensor, rhs)
+
+    def __le__(self, rhs):
+        return _call_singa_func(singa.LE_Tf, self.singa_tensor, rhs)
+
+    def __gt__(self, rhs):
+        return _call_singa_func(singa.GT_Tf, self.singa_tensor, rhs)
+
+    def __ge__(self, rhs):
+        return _call_singa_func(singa.GE_Tf, self.singa_tensor, rhs)
+
+
+''' python functions for global functions in Tensor.h
+'''
+
+
+def product(shape):
+    return reduce(lambda x, y: x * y, shape)
+
+
+def sizeof(dtype):
+    return singa.SizeOf(dtype)
+
+
+def reshape(t, s):
+    return _call_singa_func(singa.Reshape, t.singa_tensor, s)
+
+
+def copy_data_to_from(dst, src, size, src_offset=0, dst_offset=0):
+    singa.CopyDataToFrom(dst.singa_tensor, src.singa_tensor, size,
+                         src_offset, dst_offset)
+
+
+def abs(t):
+    return _call_singa_func(singa.Abs, t.singa_tensor)
+
+
+def exp(t):
+    return _call_singa_func(singa.Exp, t.singa_tensor)
+
+
+def log(t):
+    return _call_singa_func(singa.Log, t.singa_tensor)
+
+
+def relu(t):
+    return _call_singa_func(singa.ReLU, t.singa_tensor)
+
+
+def sigmoid(t):
+    return _call_singa_func(singa.Sigmoid, t.singa_tensor)
+
+
+def square(t):
+    return _call_singa_func(singa.Square, t.singa_tensor)
+
+
+def tanh(t):
+    return _call_singa_func(singa.Tanh, t.singa_tensor)
+
+
+def sum(t, axis=None):
+    if axis is None:
+        return singa.floatSum(t.singa_tensor)
+    else:
+        return _call_singa_func(singa.Sum, t.singa_tensor, axis)
+
+
+def pow(t, x):
+    print 'not implemented yet'
+
+
+def average(t, axis=0):
+    return _call_singa_func(singa.Average, t.singa_tensor, axis)
+
+
+def softmax(t, axis=0):
+    return _call_singa_func(singa.SoftMax, t.singa_tensor, axis)
+
+
+def lt(t, x):
+    return t < x
+
+
+def le(t, x):
+    return t <= x
+
+
+def gt(t, x):
+    return t > x
+
+
+def ge(t, x):
+    return t >= x
+
+
+def add(lhs, rhs):
+    # call Tensor.__add__()
+    return lhs + rhs
+
+
+def sub(lhs, rhs):
+    # call Tensor.__sub__()
+    return lhs - rhs
+
+
+def eltwise_mult(lhs, rhs):
+    # call Tensor.__mul__()
+    return lhs * rhs
+
+
+def div(lhs, rhs):
+    # call Tensor.__div__()
+    return lhs / rhs
+
+
+''' private functions, internally used
+'''
+
+
+def _tuple_to_vector(tshape):
+    ''' this function converts tuple to std::vector<int>
+    '''
+    vs = singa.Shape(len(tshape))
+    for i in range(len(tshape)):
+        vs[i] = tshape[i]
+    return vs
+
+
+def _call_singa_func(_singa_func, *args):
+    ''' this function calls singa global functions that returns Tensor
+        and create new python Tensor instance
+        e.g., Tensor [singa_func](args...)
+    '''
+    new_t = Tensor()
+    new_t.singa_tensor = _singa_func(*args)
+    new_t.tuple_shape = new_t.singa_tensor.shape()
+    new_t.device = new_t.singa_tensor.device()
+    new_t.dtype = new_t.singa_tensor.data_type()
+    return new_t

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/2a582df1/test/python/example_test_device.py
----------------------------------------------------------------------
diff --git a/test/python/example_test_device.py 
b/test/python/example_test_device.py
new file mode 100644
index 0000000..b7fb28e
--- /dev/null
+++ b/test/python/example_test_device.py
@@ -0,0 +1,36 @@
+import sys, os
+
+sys.path.append(os.path.join(os.path.dirname(__file__),
+                             '../../src/python'))
+from tensor import *
+from device import *
+
+sys.path.append(os.path.join(os.path.dirname(__file__),
+                             '../../src'))
+from core_pb2 import *
+
+#---------------------------------------------------------
+# example usage
+#---------------------------------------------------------
+
+d1 = CudaGPU(123)
+print d1.singa_device
+print d1.get_host()
+print d1.get_id()
+print
+
+d2 = CppCPU(345)
+print d2.singa_device
+print d2.get_host()
+print d2.get_id()
+print
+
+s = (2, 3)
+t = Tensor(s, d2.get_host())
+print t.singa_tensor
+print t.device
+print
+
+d = Device(0)
+print d.singa_device
+print

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/2a582df1/test/python/unittest_python.py
----------------------------------------------------------------------
diff --git a/test/python/unittest_python.py b/test/python/unittest_python.py
new file mode 100644
index 0000000..2b35d34
--- /dev/null
+++ b/test/python/unittest_python.py
@@ -0,0 +1,139 @@
+#!/usr/bin/env python
+
+#/************************************************************
+#*
+#* Licensed to the Apache Software Foundation (ASF) under one
+#* or more contributor license agreements.  See the NOTICE file
+#* distributed with this work for additional information
+#* regarding copyright ownership.  The ASF licenses this file
+#* to you under the Apache License, Version 2.0 (the
+#* "License"); you may not use this file except in compliance
+#* with the License.  You may obtain a copy of the License at
+#*
+#*   http://www.apache.org/licenses/LICENSE-2.0
+#*
+#* Unless required by applicable law or agreed to in writing,
+#* software distributed under the License is distributed on an
+#* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+#* KIND, either express or implied.  See the License for the
+#* specific language governing permissions and limitations
+#* under the License.
+#*
+#*************************************************************/
+
+import sys
+import os
+import math
+import unittest
+import numpy as np
+
+sys.path.append(os.path.join(os.path.dirname(__file__),
+                             '../../src/python'))
+from tensor import *
+from device import *
+
+sys.path.append(os.path.join(os.path.dirname(__file__),
+                             '../../src'))
+from core_pb2 import *
+
+class TestTensorMethods(unittest.TestCase):
+
+    def setUp(self):
+        self.shape = (2, 3)
+        self.t = Tensor(self.shape)
+        self.s = Tensor(self.shape)
+
+    def test_tensor_fields(self):
+        t = self.t
+        shape = self.shape
+        self.assertTupleEqual(t.shape(), shape)
+        self.assertEqual(t.shape(0), shape[0])
+        self.assertEqual(t.shape(1), shape[1])
+        self.assertEqual(product(shape), 2*3)
+        self.assertEqual(t.ndim(), 2)
+        self.assertEqual(t.size(), 2*3)
+        self.assertEqual(t.memsize(), 2*3*sizeof(kFloat32))
+        self.assertFalse(t.is_transpose())
+
+    def test_unary_operators(self):
+        t = self.t
+        self.assertAlmostEqual(t.toarray()[0,0], 0.0)
+        t += 1.23
+        self.assertAlmostEqual(t.toarray()[0,0], 1.23)
+        t -= 0.23
+        self.assertAlmostEqual(t.toarray()[0,0], 1.23-0.23)
+        t *= 2.5
+        self.assertAlmostEqual(t.toarray()[0,0], (1.23-0.23)*2.5)
+        t /= 2
+        self.assertAlmostEqual(t.toarray()[0,0], (1.23-0.23)*2.5/2)
+
+    def test_binary_operators(self):
+        t = self.t
+        t += 3.2
+        s = self.s
+        s += 2.1
+        a = t + s
+        self.assertAlmostEqual(a.toarray()[0,0], 3.2+2.1, 5)
+        a = t - s
+        self.assertAlmostEqual(a.toarray()[0,0], 3.2-2.1, 5)
+        a = t * s
+        self.assertAlmostEqual(a.toarray()[0,0], 3.2*2.1, 5)
+        ''' not implemented yet
+        a = t / s
+        self.assertAlmostEqual(a.toarray()[0,0], 3.2/2.1, 5)
+        '''
+
+    def test_comparison_operators(self):
+        t = self.t
+        t += 3.45
+        a = t < 3.45
+        self.assertEqual(a.toarray()[0,0], 0)
+        a = t <= 3.45
+        self.assertEqual(a.toarray()[0,0], 1)
+        a = t > 3.45
+        self.assertEqual(a.toarray()[0,0], 0)
+        a = t >= 3.45
+        self.assertEqual(a.toarray()[0,0], 1)
+        a = lt(t, 3.45)
+        self.assertEqual(a.toarray()[0,0], 0)
+        a = le(t, 3.45)
+        self.assertEqual(a.toarray()[0,0], 1)
+        a = gt(t, 3.45)
+        self.assertEqual(a.toarray()[0,0], 0)
+        a = ge(t, 3.45)
+        self.assertEqual(a.toarray()[0,0], 1)
+
+
+    def test_tensor_copy(self):
+        t = Tensor((2,3))
+        t += 1.23
+        self.assertAlmostEqual(t.toarray()[0,0], 1.23)
+        tc = t.copy()
+        tdc = t.deepcopy()
+        self.assertAlmostEqual(tc.toarray()[0,0], 1.23)
+        self.assertAlmostEqual(tdc.toarray()[0,0], 1.23)
+        t += 1.23
+        self.assertAlmostEqual(t.toarray()[0,0], 2.46)
+        self.assertAlmostEqual(tc.toarray()[0,0], 2.46)
+        self.assertAlmostEqual(tdc.toarray()[0,0], 1.23)
+
+    def test_copy_data(self):
+        t = self.t
+        t += 1.23
+        s = self.s
+        s += 5.43
+        self.assertAlmostEqual(t.toarray()[0,0], 1.23)
+        copy_data_to_from(t, s, 2)
+        self.assertAlmostEqual(t.toarray()[0,0], 5.43, 5)
+        self.assertAlmostEqual(t.toarray()[0,1], 5.43, 5)
+        self.assertAlmostEqual(t.toarray()[0,2], 1.23)
+
+
+    def test_global_method(self):
+        t = self.t
+        t += 12.34
+        a = log(t)
+        self.assertAlmostEqual(a.toarray()[0,0], math.log(12.34))
+
+if __name__ == '__main__':
+    unittest.main()


Reply via email to