This is an automated email from the ASF dual-hosted git repository.

wangwei pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/singa.git


The following commit(s) were added to refs/heads/dev by this push:
     new fc077bf  add function comments for autograd
     new 1820075  Merge pull request #649 from joddiy/fix_autograd_doc
fc077bf is described below

commit fc077bfa5c492f0bd35dba2bcdeb49216fae7f04
Author: joddiy <[email protected]>
AuthorDate: Thu Apr 2 19:25:14 2020 +0800

    add function comments for autograd
---
 python/singa/autograd.py | 2211 ++++++++++++++++++++++++++++++++++++----------
 1 file changed, 1755 insertions(+), 456 deletions(-)

diff --git a/python/singa/autograd.py b/python/singa/autograd.py
index 17ce07e..4e7593f 100644
--- a/python/singa/autograd.py
+++ b/python/singa/autograd.py
@@ -362,6 +362,9 @@ class Dummy(Operation):
 
 
 class Mean(Operation):
+    """
+    Element-wise mean of each of the input CTensors.
+    """
 
     def __init__(self):
         super(Mean, self).__init__()
@@ -369,10 +372,9 @@ class Mean(Operation):
     def forward(self, *l):
         """
         Args:
-            l: a list of CTensor
-            element-wise mean operator
+            l (a list of CTensor): a list of CTensor for element-wise mean.
         Returns:
-            a new CTensor
+            a new CTensor.
         """
         if training:
             self.l = len(l)
@@ -386,18 +388,29 @@ class Mean(Operation):
     def backward(self, dy):
         """
         Args:
-            dy(CTensor): dL / dy
+            dy (CTensor): dL / dy.
         Returns:
-            a list of dx(CTensor)
+            a list of dx (CTensor).
         """
         return [singa.MultFloat(dy, 1 / self.l)] * self.l
 
 
 def mean(*l):
+    """
+    Element-wise mean of each of the input tensors.
+    Args:
+        l (a list of Tensor): element-wise mean operator.
+    Returns:
+        a new Tensor.
+    """
     return Mean()(*l)[0]
 
 
 class ReLU(Operation):
+    """
+    Relu means rectified linear function, i.e, y = max(0, x) is applied to the 
+    CTensor elementwise.
+    """
 
     def __init__(self):
         super(ReLU, self).__init__()
@@ -405,9 +418,9 @@ class ReLU(Operation):
     def forward(self, x):
         """
         Args:
-            x(CTensor): input tensor
+            x (CTensor): input tensor.
         Returns:
-            a new CTensor whose element y = x if x >= 0; otherwise 0;
+            a new CTensor whose element y = x if x >= 0; otherwise 0.
         """
         if training:
             self.input = x
@@ -416,30 +429,37 @@ class ReLU(Operation):
     def backward(self, dy):
         """
         Args:
-            dy(CTensor): dL / dy
+            dy (CTensor): dL / dy.
         Returns:
-            dx(CTensor): dL / dx = dy if x >= 0; otherwise 0;
+            dx (CTensor): dL / dx = dy if x >= 0; otherwise 0.
         """
         return singa.ReLUBackward(dy, self.input)
 
 
 def relu(x):
+    """
+    Relu means rectified linear function, i.e, y = max(0, x) is applied to the 
+    CTensors elementwise.
+    Args:
+        x (Tensor): input tensor.
+    Returns:
+        a new Tensor whose element y = x if x >= 0; otherwise 0.
+    """
     return ReLU()(x)[0]
 
 
 class Less(Operation):
+    """
+    Returns the tensor resulted from performing the less logical operation 
+    elementwise on the input CTensors x and y.
+    """
 
     def __init__(self):
         super(Less, self).__init__()
 
     def forward(self, x, y):
-        """Do forward propgation.
-        Store the [x<y] if requires gradient.
-        Args:
-            x (CTensor): matrix
-            y (CTensor): matrix
-        Returns:
-            a CTensor for the result
+        """
+        Return a<b, where a and b are CTensor.
         """
         cur = singa.LTFloat(singa.__sub__(x, y), 0)
         if training:
@@ -449,18 +469,32 @@ class Less(Operation):
     def backward(self, dy):
         """
         Args:
-            dy (CTensor): data for the dL / dy, L is the loss
+            dy (CTensor): data for the dL / dy, L is the loss.
+        Raises:
+            AssertionError: no backward function for this operator.
         """
         assert False, ('no backward function for less')
 
 
 def less(x, y):
+    """
+    Return a<b, where a and b are CTensor.
+    """
     return Less()(x, y)[0]
 
 
 class Clip(Operation):
+    """
+    Clip operator limits the given input within an interval. The interval 
+    is specified by the inputs 'min' and 'max'.
+    """
 
     def __init__(self, min, max):
+        """
+        Args:
+            min (float): min value, under which element is replaced by min.
+            max (float): max value, above which element is replaced by max.
+        """
         super(Clip, self).__init__()
         self.max = max
         self.min = min
@@ -468,9 +502,9 @@ class Clip(Operation):
     def forward(self, x):
         """
         Args:
-            x(CTensor): input tensor
+            x (CTensor): input tensor
         Returns:
-            np.clip(x,min,max)
+            a new CTensor with np.clip(x,min,max)
         """
         self.mask = singa.Tensor(list(x.shape()), x.device())
         self.mask.SetFloatValue(1.0)
@@ -492,49 +526,77 @@ class Clip(Operation):
         return x
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): dL / dy
+        Returns:
+            dx (CTensor): dL / dx
+        """
         return singa.__mul__(dy, self.mask)
 
 
 def clip(x, min=None, max=None):
+    """
+    Clip operator limits the given input within an interval. The interval 
+    is specified by the inputs 'min' and 'max'.
+    Args:
+        x (Tensor): input tensor
+        min (float): Minimum value, under which element is replaced by min.
+        max (float): Maximum value, above which element is replaced by max.
+    Returns:
+        a new Tensor with np.clip(x,min,max).
+    """
     return Clip(min, max)(x)[0]
 
 
 class Identity(Operation):
+    """
+    Init a identity operator
+    """
 
     def __init__(self):
         super(Identity, self).__init__()
 
     def forward(self, x):
+        """
+        Args:
+            x (CTensor): input tensor.
+        Returns:
+            the same CTensor x.
+        """
         return x
 
     def backward(self, dy):
         """
         Args:
-            dy(CTensor): dL / dy
+            dy (CTensor): dL / dy.
         Returns:
-            dx(CTensor): dL / dx = dy;
+            dx (CTensor): dL / dx.
         """
         return dy
 
 
 def identity(x):
+    """
+    Init a identity operator.
+    Args:
+        x (Tensor): input tensor.
+    Returns:
+        the same Tensor with x.
+    """
     return Identity()(x)[0]
 
-
 class Matmul(Operation):
-    """For matrix multiplication"""
+    """
+    Init matrix multiplication operator.
+    """
 
     def __init__(self):
         super(Matmul, self).__init__()
 
     def forward(self, x, w):
-        """Do forward propgation.
-        Store the x(or w) if w(or x) requires gradient.
-        Args:
-            x (CTensor): matrix
-            w (CTensor): matrix
-        Returns:
-            a CTensor for the result
+        """
+        Return np.matmul(x,w), where x and w are CTensor.
         """
         if training:
             self.input = (x, w)
@@ -543,9 +605,9 @@ class Matmul(Operation):
     def backward(self, dy):
         """
         Args:
-            dy (CTensor): data for the dL / dy, L is the loss
+            dy (CTensor): data for the dL / dy, L is the loss.
         Returns:
-            a tuple for (dx, dw)
+            a tuple for (dx, dw).
         """
         return (
             singa.Mult(dy, singa.DefaultTranspose(self.input[1])),
@@ -554,22 +616,24 @@ class Matmul(Operation):
 
 
 def matmul(x, w):
+    """
+    Return np.matmul(x,w), where x and w are Tensor.
+    """
     return Matmul()(x, w)[0]
 
 
 class Greater(Operation):
+    """
+    Returns the tensor resulted from performing the greater logical 
+    operation elementwise on the input tensors A and B.
+    """
 
     def __init__(self):
         super(Greater, self).__init__()
 
     def forward(self, x, y):
-        """Do forward propgation.
-        Store the [x>y] if requires gradient.
-        Args:
-            x (CTensor): matrix
-            y (CTensor): matrix
-        Returns:
-            a CTensor for the result
+        """
+        Return a>b, where a and b are CTensor.
         """
         cur = singa.GTFloat(singa.__sub__(x, y), 0)
         if training:
@@ -579,12 +643,17 @@ class Greater(Operation):
     def backward(self, dy):
         """
         Args:
-            dy (CTensor): data for the dL / dy, L is the loss
+            dy (CTensor): data for the dL / dy, L is the loss.
+        Raises:
+            AssertionError: no backward function for this operator.
         """
         assert False, ('no backward function for greater')
 
 
 def greater(x, y):
+    """
+    Return a>b, where a and b are Tensor.
+    """
     return Greater()(x, y)[0]
 
 
@@ -597,7 +666,7 @@ class AddBias(Operation):
         """
         To indicate the calculation axis, 0 for row, 1 for column.
         Args:
-            axis: 0 or 1, default is 0.
+            axis (int): 0 or 1, default is 0.
         """
         super(AddBias, self).__init__()
         self.axis = axis
@@ -605,8 +674,8 @@ class AddBias(Operation):
     def forward(self, x, b):
         """
         Args:
-            x: matrix.
-            b: bias to be added.
+            x (CTensor): matrix.
+            b (CTensor): bias to be added.
         Return:
             the result Tensor
         """
@@ -631,16 +700,42 @@ class AddBias(Operation):
 
 
 def add_bias(x, b, axis=0):
+    """
+    Add Bias to each row / column of the Tensor, depending on the axis arg.
+    Args:
+        x (Tensor): matrix.
+        b (Tensor): bias to be added.
+        axis (int): 0 or 1, default is 0.
+    Return:
+        the result Tensor
+    """
     return AddBias(axis)(x, b)[0]
 
-
 class Reshape(Operation):
+    """
+    Reshape the input tensor similar to np.reshape. 
+    """
 
     def __init__(self, shape):
+        """
+        Args:
+            shape (list of int): Specified shape for output. At most one
+                dimension of the new shape can be -1. In this case, the 
+                value is inferred from the size of the tensor and the 
+                remaining dimensions. A dimension could also be 0, 
+                in which case the actual dimension value is unchanged 
+                (i.e. taken from the input tensor).
+        """
         super(Reshape, self).__init__()
         self.shape = list(shape)
 
     def forward(self, x):
+        """
+        Args:
+            x (CTensor): matrix.
+        Return:
+            the result CTensor
+        """
         self._shape = x.shape()
         shape = self.shape
         # handle the shape with 0
@@ -655,19 +750,48 @@ class Reshape(Operation):
         return singa.Reshape(x, self.cache)
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): dL / dy
+        Returns:
+            dx (CTensor): dL / dx
+        """
         return singa.Reshape(dy, self._shape)
 
 
-def reshape(a, shape):
-    return Reshape(shape)(a)[0]
+def reshape(x, shape):
+    """
+    Reshape the input tensor similar to mp.reshape. 
+    Args:
+        x (Tensor): matrix.
+        shape (list of int): Specified shape for output. At most one
+            dimension of the new shape can be -1. In this case, the 
+            value is inferred from the size of the tensor and the 
+            remaining dimensions. A dimension could also be 0, 
+            in which case the actual dimension value is unchanged 
+            (i.e. taken from the input tensor).
+    Return:
+        the result Tensor
+    """
+    return Reshape(shape)(x)[0]
 
 
 class PRelu(Operation):
+    """
+    PRelu applies the function f(x) = slope * x for x < 0, 
+    f(x) = x for x >= 0 to the data tensor elementwise.
+    """
 
     def __init__(self):
         super(PRelu, self).__init__()
 
     def forward(self, x, slope):
+        """
+        Args:
+            x (CTensor): matrix.
+        Return:
+            the result CTensor
+        """
         mask0 = singa.LTFloat(x, 0.0)
         res = singa.__mul__(x, mask0)
         res = singa.__mul__(res, slope)
@@ -682,6 +806,12 @@ class PRelu(Operation):
         return res
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): dL / dy
+        Returns:
+            dx (CTensor): dL / dx
+        """
         dx1mask = singa.GEFloat(self.input, 0.0)
         dx2 = singa.__mul__(self.mask0, self.slope)
         dx = singa.__add__(dx1mask, dx2)
@@ -697,15 +827,29 @@ class PRelu(Operation):
 
 
 def prelu(x, slope):
+    """
+    PRelu applies the function f(x) = slope * x for x < 0, 
+    f(x) = x for x >= 0 to the data tensor elementwise.
+    Args:
+        x (Tensor): matrix.
+    Return:
+        the result Tensor
+    """
     return PRelu()(x, slope)[0]
 
 
 class Add(Operation):
+    """
+    Performs element-wise binary addition.
+    """
 
     def __init__(self):
         super(Add, self).__init__()
 
     def forward(self, a, b):
+        """
+        Return a+b, where a and b are CTensor.
+        """
         res = singa.__add__(a, b)
         if training:
             self.shape0 = list(a.shape())
@@ -714,6 +858,13 @@ class Add(Operation):
         return res
 
     def backward(self, dy):
+        """
+        Args:
+            dy(CTensor): dL / dy
+        Return:
+            a tuple for (dx0, dx1), dx0 is data for dL / da, dx1 is data
+            for dL / db.
+        """
         dx0, dx1 = dy, dy
         if (type(dy) == float) or self.shape0 == self.shape1:
             assert self.shape0 == self.shape1, ('should have same shape')
@@ -725,18 +876,27 @@ class Add(Operation):
 
 
 def add(a, b):
+    """
+    Return a+b, where a and b are Tensor.
+    """
     return Add()(a, b)[0]
 
 
 class Elu(Operation):
-
-    def __init__(self, alpha=1):
+    """
+    f(x) = alpha * (exp(x) - 1.) for x < 0, f(x) = x for x >= 0., is applied 
to 
+    the tensor elementwise.
+    """
+    def __init__(self, alpha=1.):
+        """
+        Args:
+            alpha (float): Coefficient of ELU, default is 1.0
+        """
         super(Elu, self).__init__()
         self.alpha = alpha
 
     def forward(self, x):
-        """Do forward propgation.
-        Store the x if requires gradient.
+        """
         Args:
             x (CTensor): matrix
         Returns:
@@ -755,9 +915,9 @@ class Elu(Operation):
     def backward(self, dy):
         """
         Args:
-            dy (CTensor): data for the dL / dy, L is the loss
+            dy (CTensor): dL / dy
         Returns:
-            a tuple for dx
+            dx (CTensor): dL / dx
         """
         dx1mask = singa.LTFloat(self.input, 0.0)
         dx = singa.MultFloat(singa.Exp(self.input), self.alpha)
@@ -771,22 +931,30 @@ class Elu(Operation):
 
 
 def elu(x, alpha=1):
+    """
+    f(x) = alpha * (exp(x) - 1.) for x < 0, f(x) = x for x >= 0., is applied 
to 
+    the tensor elementwise.
+    Args:
+        x (Tensor): matrix
+        alpha (float): Coefficient of ELU, default is 1.0
+    Returns:
+        a Tensor for the result
+    """
     return Elu(alpha)(x)[0]
 
 
 class Equal(Operation):
-
+    """
+    Returns the tensor resulted from performing the equal logical operation 
+    elementwise on the input tensors x and y.
+    """
     def __init__(self):
         super(Equal, self).__init__()
 
     def forward(self, x, y):
-        """Do forward propgation.
-       Store the x if requires gradient.
-       Args:
-           x (CTensor): matrix
-       Returns:
-           a CTensor for the result
-       """
+        """
+        Return a=b, where a and b are CTensor.
+        """
         m = singa.__sub__(x, y)
         cur = singa.__mul__(singa.GEFloat(m, 0), singa.LEFloat(m, 0))
         return cur
@@ -795,24 +963,37 @@ class Equal(Operation):
         """
         Args:
             dy (CTensor): data for the dL / dy, L is the loss
+        Raises:
+            AssertionError: no backward function for this operator
         """
         assert False, ('no backward function for equal')
 
 
 def equal(x, y):
+    """
+    Return a=b, where a and b are Tensor.
+    """
     return Equal()(x, y)[0]
 
 
 class SeLU(Operation):
+    """
+    y = gamma * (alpha * e^x - alpha) for x <= 0, y = gamma * x for x > 0 
+    is applied to the tensor elementwise.
+    """
 
     def __init__(self, alpha=1.67326, gamma=1.0507):
+        """
+        Args:
+            alpha (float): Coefficient of SELU default to 1.67326
+            gamma (float): Coefficient of SELU default to 1.0507
+        """
         super(SeLU, self).__init__()
         self.alpha = alpha
         self.gamma = gamma
 
     def forward(self, x):
-        """Do forward propgation.
-        Store the x if x requires gradient.
+        """
         Args:
             x (CTensor): matrix
         Returns:
@@ -833,9 +1014,9 @@ class SeLU(Operation):
     def backward(self, dy):
         """
         Args:
-            dy (CTensor): data for the dL / dy, L is the loss
+            dy (CTensor): dL / dy
         Returns:
-            dx
+            dx (CTensor): dL / dx
         """
         dx1mask = singa.LEFloat(self.input, 0.0)
         dx1 = singa.MultFloat(singa.Exp(self.input), self.gamma * self.alpha)
@@ -850,6 +1031,16 @@ class SeLU(Operation):
 
 
 def selu(x, alpha=1.67326, gamma=1.0507):
+    """
+    y = gamma * (alpha * e^x - alpha) for x <= 0, y = gamma * x for x > 0 
+    is applied to the tensor elementwise.
+    Args:
+        x (Tensor): matrix
+        alpha (float): Coefficient of SELU default to 1.67326
+        gamma (float): Coefficient of SELU default to 1.0507
+    Returns:
+        a Tensor for the result
+    """
     return SeLU(alpha, gamma)(x)[0]
 
 
@@ -860,15 +1051,19 @@ class SoftMax(Operation):
     """
 
     def __init__(self, axis=1):
+        """
+        Args:
+            axis (int): axis of softmax, default to 1
+        """
         super(SoftMax, self).__init__()
         self.axis = axis
 
     def forward(self, x):
         """
         Args:
-            x(data): the input 1d or 2d tensor
+            x (CTensor): the input 1d or 2d tensor
         Returns:
-            the result Tensor
+            the result CTensor
         """
         self.output = singa.SoftMax(x, self.axis)
         return self.output
@@ -876,24 +1071,41 @@ class SoftMax(Operation):
     def backward(self, dy):
         """
         Args:
-            dy (CTensor): data for the dL / dy, L is the loss
+            dy (CTensor): dL / dy
         Returns:
-            dx (Ctensor): data for the dL / dx, L is the loss,
-            x is the input of current Opertion
+            dx (CTensor): dL / dx
         """
         return singa.SoftMaxBackward(dy, self.axis, self.output)
 
 
 def softmax(x, axis=1):
+    """
+    Apply SoftMax for each row of the Tensor or each column of the Tensor
+    according to the parameter axis.
+    Args:
+        x (Tensor): the input 1d or 2d tensor
+        axis (int): axis of softmax, default to 1
+    Returns:
+        the result Tensor
+    """
     return SoftMax(axis)(x)[0]
 
 
 class Sum(Operation):
+    """
+    Element-wise sum of each of the input tensors
+    """
 
     def __init__(self):
         super(Sum, self).__init__()
 
     def forward(self, *l):
+        """
+        Args:
+            l (a list of CTensor): element-wise sum operator
+        Returns:
+            a CTensor for the result
+        """
         if training:
             self.l = len(l)
         assert (len(l) > 0)
@@ -904,10 +1116,23 @@ class Sum(Operation):
         return x
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): dL / dy
+        Returns:
+            dx (CTensor): dL / dx
+        """
         return [dy] * self.l
 
 
 def sum(*l):
+    """
+    Element-wise sum of each of the input tensors
+    Args:
+        l (a list of Tensor): element-wise sum operator
+    Returns:
+        a Tensor for the result
+    """
     return Sum()(*l)[0]
 
 
@@ -1021,19 +1246,41 @@ def ctensor2numpy(x):
 
 
 class Flatten(Operation):
+    """
+    Flattens the input tensor into a 2D matrix. If input tensor has shape 
+    (d_0, d_1, ... d_n) then the output will have shape (d_0 X d_1 ... 
+    d_(axis-1), d_axis X d_(axis+1) ... X dn).
+    """
 
-    def __init__(self, start_axis=1):
+    def __init__(self, axis=1):
+        """
+        Args:
+            axis (int): Indicate up to which input dimensions (exclusive) 
+                should be flattened to the outer dimension of the output. The 
+                value for axis must be in the range [-r, r], where r is the 
+                rank of the input tensor. Negative value means counting 
+                dimensions from the back. When axis = 0, the shape of the 
+                output tensor is (1, (d_0 X d_1 ... d_n), where the shape 
+                of the input tensor is (d_0, d_1, ... d_n).
+        Returns:
+            the result CTensor
+        """
         super(Flatten, self).__init__()
-        # flatten all axis after (inclusive) start_axis
-        self.start_axis = start_axis
+        self.axis = axis
 
     def forward(self, x):
+        """
+        Args:
+            x (CTensor): the input tensor
+        Returns:
+            the result CTensor
+        """
         self.shape = list(x.shape())
-        shape, axis = self.shape, self.start_axis
-        # the start_axis must be within this range (0, r-1)
+        shape, axis = self.shape, self.axis
+        # the axis must be within this range (0, r-1)
         assert axis <= len(
             shape
-        ) - 1 or axis >= 0, "the start_axis must be within (0, %d-1)" % len(
+        ) - 1 or axis >= 0, "the axis must be within (0, %d-1)" % len(
             shape)
         # calculate the new shape
         new_shape = (1, int(np.prod(shape))) if axis == 0 else (
@@ -1043,12 +1290,34 @@ class Flatten(Operation):
         return y
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): data for the dL / dy, L is the loss
+        Returns:
+            dx (CTensor): data for the dL / dx, L is the loss,
+        """
         dx = singa.Reshape(dy, self.shape)
         return dx
 
 
-def flatten(x):
-    return Flatten()(x)[0]
+def flatten(x, axis=1):
+    """
+    Flattens the input tensor into a 2D matrix. If input tensor has shape 
+    (d_0, d_1, ... d_n) then the output will have shape (d_0 X d_1 ... 
+    d_(axis-1), d_axis X d_(axis+1) ... X dn).
+    Args:
+        x (Tensor): the input tensor
+        axis (int): Indicate up to which input dimensions (exclusive) 
+            should be flattened to the outer dimension of the output. The 
+            value for axis must be in the range [-r, r], where r is the 
+            rank of the input tensor. Negative value means counting 
+            dimensions from the back. When axis = 0, the shape of the 
+            output tensor is (1, (d_0 X d_1 ... d_n), where the shape 
+            of the input tensor is (d_0, d_1, ... d_n).
+    Returns:
+        the result Tensor
+    """
+    return Flatten(axis)(x)[0]
 
 
 class Layer(object):
@@ -1112,8 +1381,18 @@ class Layer(object):
 
 
 class Linear(Layer):
+    """
+    Generate a Linear operator
+    """
 
     def __init__(self, in_features, out_features, bias=True):
+        """
+        Args:
+            in_channels: int, the channel of input
+            out_channels: int, the channel of output, also is the number of 
+                filters
+            bias: bool
+        """
         w_shape = (in_features, out_features)
         b_shape = (out_features,)
         self.bias = bias
@@ -1156,12 +1435,31 @@ class Linear(Layer):
 
 
 class Concat(Operation):
+    """
+    Concatenate a list of tensors into a single tensor. All input tensors must 
+    have the same shape, except for the dimension size of the axis to 
+    concatenate on.
+    """
 
     def __init__(self, axis=0):
+        """
+        Args:
+            axis (int): Which axis to concat on. A negative value means 
+                counting dimensions from the back. Accepted range is [-r, r-1] 
+                where r = rank(inputs).
+        Returns:
+            the result CTensor
+        """
         super(Concat, self).__init__()
         self.axis = axis
 
     def forward(self, *xs):
+        """
+        Args:
+            xs (a list of CTensor): List of tensors for concatenation
+        Returns:
+            a CTensor for the result
+        """
         if training:
             offset = 0
             self.slice_point = []
@@ -1172,6 +1470,12 @@ class Concat(Operation):
         return singa.ConcatOn(x, self.axis)
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): data for the dL / dy, L is the loss
+        Returns:
+            dxs (a tuple of CTensor): data for the dL / dxs, L is the loss,
+        """
         assert hasattr(
             self, "slice_point"), "Please set training as True before do BP. "
         assert self.slice_point[-1] == dy.shape()[self.axis], "Shape mismatch."
@@ -1184,20 +1488,34 @@ class Concat(Operation):
 
 
 def cat(xs, axis=0):
-    # xs is a tuple of multiple Tensors
+    """
+    Concatenate a list of tensors into a single tensor. All input tensors must 
+    have the same shape, except for the dimension size of the axis to 
+    concatenate on.
+    Args:
+        xs (a list of Tensor): List of tensors for concatenation
+        axis (int): Which axis to concat on. A negative value means 
+            counting dimensions from the back. Accepted range is [-r, r-1] 
+            where r = rank(inputs).
+    Returns:
+        a Tensor for the result
+    """
     return Concat(axis)(*xs)[0]
 
 
 class _Conv2d(Operation):
+    """
+    Init a conv 2d operator
+    """
 
     def __init__(self, handle, odd_padding=(0, 0, 0, 0)):
         """
-        Init a conv 2d operator
         Args:
-            handle: ConvHandle for cpu or CudnnConvHandle for gpu
-        Args:
-            odd_padding:tuple of four bins, the odd paddding is the value that 
cannot be handled by the tuple padding (w, h) mode
-            so we need to firstly handle the input, then use the nomal padding 
method.
+            handle (object): ConvHandle for cpu or CudnnConvHandle for gpu
+            odd_padding (tuple of four ints):, the odd paddding is the value 
+                that cannot be handled by the tuple padding (w, h) mode so 
+                we need to firstly handle the input, then use the nomal 
padding 
+                method.
         """
         super(_Conv2d, self).__init__()
         self.handle = handle
@@ -1207,13 +1525,10 @@ class _Conv2d(Operation):
 
     def forward(self, x, W, b=None):
         """
-        Do forward of conv
-        Args:
-            x: CTensor, input
         Args:
-            W: CTensor, weight
-        Args:
-            b: CTensor, bias
+            x (CTensor): input
+            W (CTensor): weight
+            b (CTensor): bias
         Returns:
             CTensor 
         """
@@ -1243,11 +1558,10 @@ class _Conv2d(Operation):
 
     def backward(self, dy):
         """
-        Do backward of conv
         Args:
-            dy: CTensor, gradient
+            dy (CTensor): dL / dy
         Returns:
-            CTensor 
+            dx (CTensor): dL / dx
         """
         assert training is True and hasattr(
             self, "inputs"), "Please set training as True before do BP. "
@@ -1282,16 +1596,14 @@ def conv2d(handle, x, W, b=None, odd_padding=(0, 0, 0, 
0)):
     """
     Conv 2d operator
     Args:
-        handle: ConvHandle for cpu or CudnnConvHandle for gpu
-    Args:
-        x: CTensor, input
-    Args:
-        W: CTensor, weight
-    Args:
-        b: CTensor, bias
-    Args:
-        odd_padding:tuple of four bins, the odd paddding is the value that 
cannot be handled by the tuple padding (w, h) mode
-        so we need to firstly handle the input, then use the nomal padding 
method.
+        handle (object): ConvHandle for cpu or CudnnConvHandle for gpu
+        x (Tensor): input
+        W (Tensor): weight
+        b (Tensor): bias
+        odd_padding (tuple of four ints):, the odd paddding is the value 
+            that cannot be handled by the tuple padding (w, h) mode so 
+            we need to firstly handle the input, then use the nomal padding 
+            method.
     """
     if b is None:
         return _Conv2d(handle, odd_padding)(x, W)[0]
@@ -1300,6 +1612,9 @@ def conv2d(handle, x, W, b=None, odd_padding=(0, 0, 0, 
0)):
 
 
 class Conv2d(Layer):
+    """
+    Generate a Conv 2d operator
+    """
 
     def __init__(self,
                  in_channels,
@@ -1313,29 +1628,26 @@ class Conv2d(Layer):
                  pad_mode="NOTSET",
                  **kwargs):
         """
-        Generate a Conv 2d operator
-        Args:
-            in_channels: int, the channel of input
-        Args:
-            out_channels: int, the channel of output, also is the number of 
filters
-        Args:
-            kernel_size: int or tuple, kernel size for two direction of each 
axis. For example, (2, 3), the first 2 means will add 2 at the beginning and 
also 2 at the end for its axis.
-            and if a int is accepted, the kernel size will be inited as (int, 
int)
         Args:
-            stride: int or tuple, stride, the logic is the same as kernel size.
-        Args:
-            padding: int, tuple, list or None, padding, the logic is the same 
as kernel size. However, if you set pad_mode as "SAME_UPPER" or "SAME_LOWER" 
mode, 
-            you can set padding as None, and the padding will be computed 
automatically.
-        Args:
-            dilation: int, only support 1
-        Args:
-            group: int
-        Args:
-            bias: bool
-        Args:
-            pad_mode: string, can be NOTSET, SAME_UPPER, or SAME_LOWER, where 
default value is NOTSET, which means explicit padding is used.
-            SAME_UPPER or SAME_LOWER mean pad the input so that the output 
spatial size match the input.
-            In case of odd number add the extra padding at the end for 
SAME_UPPER and at the beginning for SAME_LOWER.
+            in_channels (int): the channel of input
+            out_channels (int): the channel of output, also is the number of 
filters
+            kernel_size (int or tuple): kernel size for two direction of each 
+                axis. For example, (2, 3), the first 2 means will add 2 at the 
+                beginning and also 2 at the end for its axis.and if a int is 
+                accepted, the kernel size will be initiated as (int, int)
+            stride (int or tuple): stride, the logic is the same as kernel 
size.
+            padding (int): tuple, list or None, padding, the logic is the same 
+                as kernel size. However, if you set pad_mode as "SAME_UPPER" 
or 
+                "SAME_LOWER" mode, you can set padding as None, and the 
padding 
+                will be computed automatically.
+            dilation (int): only support 1
+            group (int): group
+            bias (bool): bias
+            pad_mode (string): can be NOTSET, SAME_UPPER, or SAME_LOWER, where 
+                default value is NOTSET, which means explicit padding is used. 
+                SAME_UPPER or SAME_LOWER mean pad the input so that the output 
+                spatial size match the input. In case of odd number add the 
extra 
+                padding at the end for SAME_UPPER and at the beginning for 
SAME_LOWER.
         """
         self.in_channels = in_channels
         self.out_channels = out_channels
@@ -1490,6 +1802,9 @@ class Conv2d(Layer):
 
 
 class SeparableConv2d(Layer):
+    """
+    Generate a Conv 2d operator
+    """
 
     def __init__(
         self,
@@ -1500,6 +1815,21 @@ class SeparableConv2d(Layer):
         padding=0,
         bias=False,
     ):
+        """
+        Args:
+            in_channels (int): the channel of input
+            out_channels (int): the channel of output, also is the number of 
filters
+            kernel_size (int or tuple): kernel size for two direction of each 
+                axis. For example, (2, 3), the first 2 means will add 2 at the 
+                beginning and also 2 at the end for its axis.and if a int is 
+                accepted, the kernel size will be initiated as (int, int)
+            stride (int or tuple): stride, the logic is the same as kernel 
size.
+            padding (int): tuple, list or None, padding, the logic is the same 
+                as kernel size. However, if you set pad_mode as "SAME_UPPER" 
or 
+                "SAME_LOWER" mode, you can set padding as None, and the 
padding 
+                will be computed automatically.
+            bias (bool): bias
+        """
         self.depthwise_conv = Conv2d(
             in_channels,
             in_channels,
@@ -1519,8 +1849,17 @@ class SeparableConv2d(Layer):
 
 
 class BatchNorm2d(Layer):
+    """
+    Generate a BatchNorm 2d operator
+    """
 
     def __init__(self, num_features, momentum=0.9):
+        """
+        Args:
+            num_features (int): int, the channel of input
+            momentum (float): Factor used in computing the running mean and 
+                variance.
+        """
         self.channels = num_features
         self.momentum = momentum
 
@@ -1589,14 +1928,34 @@ class BatchNorm2d(Layer):
 
 
 class _BatchNorm2d(Operation):
+    """
+    Carries out batch normalization as described in the paper 
+    https://arxiv.org/abs/1502.03167. 
+    """
 
     def __init__(self, handle, running_mean, running_var, name=None):
+        """
+        Args:
+            handle (object): BatchNormHandle for cpu and CudnnBatchNormHandle 
+                for gpu
+            running_mean (float): the running_mean
+            running_var (float): the running_var
+            name (string): the name assigned to this operator
+        """
         super(_BatchNorm2d, self).__init__(name)
         self.handle = handle
         self.running_mean = running_mean.data
         self.running_var = running_var.data
 
     def forward(self, x, scale, bias):
+        """
+        Args:
+            x (CTensor): the input tensor
+            scale (CTensor): the bias tensor
+            bias (CTensor): the bias tensor
+        Returns:
+            the result CTensor
+        """
         if training:
             if (type(self.handle) == singa.BatchNormHandle):
                 y, mean, var = singa.CpuBatchNormForwardTraining(
@@ -1634,6 +1993,14 @@ class _BatchNorm2d(Operation):
         return y
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): data for the dL / dy, L is the loss
+        Returns:
+            dx (CTensor): data for the dL / dx, L is the loss
+            ds (CTensor): data for the dL / ds, L is the loss
+            db (CTensor): data for the dL / db, L is the loss
+        """
         assert training is True and hasattr(
             self, "cache"), "Please set training as True before do BP. "
 
@@ -1650,19 +2017,37 @@ class _BatchNorm2d(Operation):
 
 
 def batchnorm_2d(handle, x, scale, bias, running_mean, running_var):
+    """
+    Carries out batch normalization as described in the paper 
+    https://arxiv.org/abs/1502.03167. 
+    Args:
+        handle (object): BatchNormHandle for cpu and CudnnBatchNormHandle 
+            for gpu
+        x (Tensor): the input tensor
+        scale (Tensor): the bias tensor
+        bias (Tensor): the bias tensor
+        running_mean (float): the running_mean
+        running_var (float): the running_var
+    Returns:
+        the result Tensor
+    """
     return _BatchNorm2d(handle, running_mean, running_var)(x, scale, bias)[0]
 
 
 class _Pooling2d(Operation):
+    """
+    Init a pool 2d operator
+    """
 
     def __init__(self, handle, odd_padding=(0, 0, 0, 0)):
         """
-        Init a pool 2d operator
-        Args:
-            handle: PoolingHandle for cpu or CudnnPoolingHandle for gpu
         Args:
-            odd_padding:tuple of four bins, the odd paddding is the value that 
cannot be handled by the tuple padding (w, h) mode
-            so we need to firstly handle the input, then use the nomal padding 
method.
+            handle (object): PoolingHandle for cpu or CudnnPoolingHandle for 
+                gpu
+            odd_padding (tuple of four int): the odd paddding is the value 
+                that cannot be handled by the tuple padding (w, h) mode so 
+                it needs to firstly handle the input, then use the normal 
+                padding method.
         """
         super(_Pooling2d, self).__init__()
         self.handle = handle
@@ -1671,6 +2056,12 @@ class _Pooling2d(Operation):
             self.re_new_handle = True
 
     def forward(self, x):
+        """
+        Args:
+            x (CTensor): the input tensor
+        Returns:
+            the result CTensor
+        """
         assert x.nDim() == 4, "The dimensions of input should be 4D."
         if self.odd_padding != (0, 0, 0, 0):
             x = utils.handle_odd_pad_fwd(x, self.odd_padding)
@@ -1688,6 +2079,12 @@ class _Pooling2d(Operation):
         return y
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): data for the dL / dy, L is the loss
+        Returns:
+            dx (CTensor): data for the dL / dx, L is the loss,
+        """
         if (type(self.handle) != singa.PoolingHandle):
             dx = singa.GpuPoolingBackward(self.handle, dy, self.cache[0],
                                           self.cache[1])
@@ -1704,17 +2101,23 @@ def pooling_2d(handle, x, odd_padding=(0, 0, 0, 0)):
     """
     Pooling 2d operator
     Args:
-        handle: ConvHandle for cpu or CudnnConvHandle for gpu
-    Args:
-        x: CTensor, input
-    Args:
-        odd_padding:tuple of four bins, the odd paddding is the value that 
cannot be handled by the tuple padding (w, h) mode
-        so we need to firstly handle the input, then use the nomal padding 
method.
+        handle (object): PoolingHandle for cpu or CudnnPoolingHandle for 
+            gpu
+        x (Tensor): input
+        odd_padding (tuple of four int): the odd paddding is the value 
+            that cannot be handled by the tuple padding (w, h) mode so 
+            it needs to firstly handle the input, then use the normal 
+            padding method.
+    Returns:
+        the result Tensor
     """
     return _Pooling2d(handle, odd_padding)(x)[0]
 
 
 class Pooling2d(Layer):
+    """
+    Generate a Pooling 2d operator
+    """
 
     def __init__(self,
                  kernel_size,
@@ -1723,21 +2126,22 @@ class Pooling2d(Layer):
                  is_max=True,
                  pad_mode="NOTSET"):
         """
-        Generate a Pooling 2d operator
         Args:
-            kernel_size: int or tuple, kernel size for two direction of each 
axis. For example, (2, 3), the first 2 means will add 2 at the beginning and 
also 2 at the end for its axis.
-            and if a int is accepted, the kernel size will be inited as (int, 
int)
-        Args:
-            stride: int or tuple, stride, the logic is the same as kernel size.
-        Args:
-            padding: int or tuple or None, padding, the logic is the same as 
kernel size. However, if you set pad_mode as "SAME_UPPER" or "SAME_LOWER" mode, 
-            you can set padding as None, and the padding will be computed 
automatically.
-        Args:
-            is_max: bool, is max pooling or avg pooling
-        Args:
-            pad_mode: string, can be NOTSET, SAME_UPPER, or SAME_LOWER, where 
default value is NOTSET, which means explicit padding is used.
-            SAME_UPPER or SAME_LOWER mean pad the input so that the output 
spatial size match the input.
-            In case of odd number add the extra padding at the end for 
SAME_UPPER and at the beginning for SAME_LOWER.
+            kernel_size (int or tuple): kernel size for two direction of each 
+                axis. For example, (2, 3), the first 2 means will add 2 at the 
+                beginning and also 2 at the end for its axis.and if a int is 
+                accepted, the kernel size will be initiated as (int, int)
+            stride (int or tuple): stride, the logic is the same as kernel 
size.
+            padding (int): tuple, list or None, padding, the logic is the same 
+                as kernel size. However, if you set pad_mode as "SAME_UPPER" 
or 
+                "SAME_LOWER" mode, you can set padding as None, and the 
padding 
+                will be computed automatically.
+            is_max (bool): is max pooling or avg pooling
+            pad_mode (string): can be NOTSET, SAME_UPPER, or SAME_LOWER, where 
+                default value is NOTSET, which means explicit padding is used. 
+                SAME_UPPER or SAME_LOWER mean pad the input so that the output 
+                spatial size match the input. In case of odd number add the 
extra 
+                padding at the end for SAME_UPPER and at the beginning for 
SAME_LOWER.
         """
         if isinstance(kernel_size, int):
             self.kernel_size = (kernel_size, kernel_size)
@@ -1837,6 +2241,9 @@ class Pooling2d(Layer):
 
 
 class MaxPool2d(Pooling2d):
+    """
+    Generate a Max Pooling 2d operator
+    """
 
     def __init__(self,
                  kernel_size,
@@ -1844,19 +2251,20 @@ class MaxPool2d(Pooling2d):
                  padding=0,
                  odd_padding=(0, 0, 0, 0)):
         """
-        Generate a Max Pooling 2d operator
         Args:
-            kernel_size: int or tuple, kernel size for two direction of each 
axis. For example, (2, 3), the first 2 means will add 2 at the beginning and 
also 2 at the end for its axis.
-            and if a int is accepted, the kernel size will be inited as (int, 
int)
-        Args:
-            stride: int or tuple, stride, the logic is the same as kernel size.
-        Args:
-            padding: int or tuple or None, padding, the logic is the same as 
kernel size. However, if you set pad_mode as "SAME_UPPER" or "SAME_LOWER" mode, 
-            you can set padding as None, and the padding will be computed 
automatically.
-        Args:
-            pad_mode: string, can be NOTSET, SAME_UPPER, or SAME_LOWER, where 
default value is NOTSET, which means explicit padding is used.
-            SAME_UPPER or SAME_LOWER mean pad the input so that the output 
spatial size match the input.
-            In case of odd number add the extra padding at the end for 
SAME_UPPER and at the beginning for SAME_LOWER.
+            kernel_size (int or tuple): kernel size for two direction of each 
+                axis. For example, (2, 3), the first 2 means will add 2 at the 
+                beginning and also 2 at the end for its axis.and if a int is 
+                accepted, the kernel size will be initiated as (int, int)
+            stride (int or tuple): stride, the logic is the same as kernel 
size.
+            padding (int): tuple, list or None, padding, the logic is the same 
+                as kernel size. However, if you set pad_mode as "SAME_UPPER" 
or 
+                "SAME_LOWER" mode, you can set padding as None, and the 
padding 
+                will be computed automatically.
+            odd_padding (tuple of four int): the odd paddding is the value 
+                that cannot be handled by the tuple padding (w, h) mode so 
+                it needs to firstly handle the input, then use the normal 
+                padding method.
         """
         super(MaxPool2d, self).__init__(kernel_size, stride, padding, True,
                                         odd_padding)
@@ -1870,24 +2278,29 @@ class AvgPool2d(Pooling2d):
                  padding=0,
                  odd_padding=(0, 0, 0, 0)):
         """
-        Generate a Avg Pooling 2d operator
-        Args:
-            kernel_size: int or tuple, kernel size for two direction of each 
axis. For example, (2, 3), the first 2 means will add 2 at the beginning and 
also 2 at the end for its axis.
-            and if a int is accepted, the kernel size will be inited as (int, 
int)
-        Args:
-            stride: int or tuple, stride, the logic is the same as kernel size.
         Args:
-            padding: int or tuple or None, padding, the logic is the same as 
kernel size. However, if you set pad_mode as "SAME_UPPER" or "SAME_LOWER" mode, 
-            you can set padding as None, and the padding will be computed 
automatically.
-        Args:
-            odd_padding:tuple of four bins, the odd paddding is the value that 
cannot be handled by the tuple padding (w, h) mode
-            so we need to firstly handle the input, then use the nomal padding 
method.
+            kernel_size (int or tuple): kernel size for two direction of each 
+                axis. For example, (2, 3), the first 2 means will add 2 at the 
+                beginning and also 2 at the end for its axis.and if a int is 
+                accepted, the kernel size will be initiated as (int, int)
+            stride (int or tuple): stride, the logic is the same as kernel 
size.
+            padding (int): tuple, list or None, padding, the logic is the same 
+                as kernel size. However, if you set pad_mode as "SAME_UPPER" 
or 
+                "SAME_LOWER" mode, you can set padding as None, and the 
padding 
+                will be computed automatically.
+            odd_padding (tuple of four int): the odd paddding is the value 
+                that cannot be handled by the tuple padding (w, h) mode so 
+                it needs to firstly handle the input, then use the normal 
+                padding method.
         """
         super(AvgPool2d, self).__init__(kernel_size, stride, padding, False,
                                         odd_padding)
 
 
 class MaxPool1d(Pooling2d):
+    """
+    Generate a Max Pooling 1d operator
+    """
 
     def __init__(self,
                  kernel_size,
@@ -1895,18 +2308,20 @@ class MaxPool1d(Pooling2d):
                  padding=0,
                  odd_padding=(0, 0, 0, 0)):
         """
-        Generate a Max Pooling 1d operator
-        Args:
-            kernel_size: int or tuple, kernel size for two direction of each 
axis. For example, (2, 3), the first 2 means will add 2 at the beginning and 
also 2 at the end for its axis.
-            and if a int is accepted, the kernel size will be inited as (int, 
int)
-        Args:
-            stride: int or tuple, stride, the logic is the same as kernel size.
-        Args:
-            padding: int or tuple or None, padding, the logic is the same as 
kernel size. However, if you set pad_mode as "SAME_UPPER" or "SAME_LOWER" mode, 
-            you can set padding as None, and the padding will be computed 
automatically.
         Args:
-            odd_padding:tuple of four bins, the odd paddding is the value that 
cannot be handled by the tuple padding (w, h) mode
-            so we need to firstly handle the input, then use the nomal padding 
method.
+            kernel_size (int or tuple): kernel size for two direction of each 
+                axis. For example, (2, 3), the first 2 means will add 2 at the 
+                beginning and also 2 at the end for its axis.and if a int is 
+                accepted, the kernel size will be initiated as (int, int)
+            stride (int or tuple): stride, the logic is the same as kernel 
size.
+            padding (int): tuple, list or None, padding, the logic is the same 
+                as kernel size. However, if you set pad_mode as "SAME_UPPER" 
or 
+                "SAME_LOWER" mode, you can set padding as None, and the 
padding 
+                will be computed automatically.
+            odd_padding (tuple of four int): the odd paddding is the value 
+                that cannot be handled by the tuple padding (w, h) mode so 
+                it needs to firstly handle the input, then use the normal 
+                padding method.
         """
         if stride is None:
             stride = kernel_size
@@ -1915,6 +2330,9 @@ class MaxPool1d(Pooling2d):
 
 
 class AvgPool1d(Pooling2d):
+    """
+    Generate a Avg Pooling 1d operator
+    """
 
     def __init__(self,
                  kernel_size,
@@ -1922,18 +2340,20 @@ class AvgPool1d(Pooling2d):
                  padding=0,
                  odd_padding=(0, 0, 0, 0)):
         """
-        Generate a Avg Pooling 1d operator
-        Args:
-            kernel_size: int or tuple, kernel size for two direction of each 
axis. For example, (2, 3), the first 2 means will add 2 at the beginning and 
also 2 at the end for its axis.
-            and if a int is accepted, the kernel size will be inited as (int, 
int)
         Args:
-            stride: int or tuple, stride, the logic is the same as kernel size.
-        Args:
-            padding: int or tuple or None, padding, the logic is the same as 
kernel size. However, if you set pad_mode as "SAME_UPPER" or "SAME_LOWER" mode, 
-            you can set padding as None, and the padding will be computed 
automatically.
-        Args:
-            odd_padding:tuple of four bins, the odd paddding is the value that 
cannot be handled by the tuple padding (w, h) mode
-            so we need to firstly handle the input, then use the nomal padding 
method.
+            kernel_size (int or tuple): kernel size for two direction of each 
+                axis. For example, (2, 3), the first 2 means will add 2 at the 
+                beginning and also 2 at the end for its axis.and if a int is 
+                accepted, the kernel size will be initiated as (int, int)
+            stride (int or tuple): stride, the logic is the same as kernel 
size.
+            padding (int): tuple, list or None, padding, the logic is the same 
+                as kernel size. However, if you set pad_mode as "SAME_UPPER" 
or 
+                "SAME_LOWER" mode, you can set padding as None, and the 
padding 
+                will be computed automatically.
+            odd_padding (tuple of four int): the odd paddding is the value 
+                that cannot be handled by the tuple padding (w, h) mode so 
+                it needs to firstly handle the input, then use the normal 
+                padding method.
         """
         if stride is None:
             stride = kernel_size
@@ -1942,17 +2362,32 @@ class AvgPool1d(Pooling2d):
 
 
 class Tanh(Operation):
+    """
+    Calculates the hyperbolic tangent of the given input tensor element-wise.
+    """
 
     def __init__(self):
         super(Tanh, self).__init__()
 
     def forward(self, x):
+        """
+        Args:
+            x (CTensor): Input tensor
+        Returns: 
+            CTensor, the output
+        """
         out = singa.Tanh(x)
         if training:
             self.cache = (out,)
         return out
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): the gradient tensor from upper operations
+        Returns: 
+            CTensor, the gradient over input
+        """     
         dx = singa.__mul__(self.cache[0], self.cache[0])
         dx = singa.MultFloat(dx, -1.0)
         dx = singa.AddFloat(dx, 1.0)
@@ -1961,20 +2396,42 @@ class Tanh(Operation):
 
 
 def tanh(x):
+    """
+    Calculates the hyperbolic tangent of the given input tensor element-wise.
+    Args:
+        x (Tensor): Input tensor
+    Returns: 
+        Tensor, the output
+    """
     return Tanh()(x)[0]
 
 
 class Cos(Operation):
+    """
+    Calculates the cosine of the given input tensor, element-wise.
+    """
 
     def __init__(self):
         super(Cos, self).__init__()
 
     def forward(self, x):
+        """
+        Args:
+            x (CTensor): Input tensor
+        Returns: 
+            CTensor, the output
+        """
         if training:
             self.input = x
         return singa.Cos(x)
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): the gradient tensor from upper operations
+        Returns: 
+            CTensor, the gradient over input
+        """     
         dx = singa.Sin(self.input)
         dx = singa.MultFloat(dx, -1.0)
         dx *= dy
@@ -1982,40 +2439,86 @@ class Cos(Operation):
 
 
 def cos(x):
+    """
+    Calculates the cosine of the given input tensor, element-wise.
+    Args:
+        x (Tensor): Input tensor
+    Returns: 
+        Tensor, the output
+    """    
+
     return Cos()(x)[0]
 
 
 class Cosh(Operation):
+    """
+    Calculates the hyperbolic cosine of the given input tensor element-wise.
+    """
 
     def __init__(self):
         super(Cosh, self).__init__()
 
     def forward(self, x):
+        """
+        Args:
+            x (CTensor): Input tensor
+        Returns: 
+            CTensor, the output
+        """
         if training:
             self.input = x
         return singa.Cosh(x)
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): the gradient tensor from upper operations
+        Returns: 
+            CTensor, the gradient over input
+        """     
         dx = singa.Sinh(self.input)
         dx *= dy
         return dx
 
 
 def cosh(x):
+    """
+    Calculates the hyperbolic cosine of the given input tensor element-wise.
+    Args:
+        x (Tensor): Input tensor
+    Returns: 
+        Tensor, the output
+    """
     return Cosh()(x)[0]
 
 
 class Acos(Operation):
+    """
+    Calculates the arccosine (inverse of cosine) of the given input tensor, 
+    element-wise.
+    """
 
     def __init__(self):
         super(Acos, self).__init__()
 
     def forward(self, x):
+        """
+        Args:
+            x (CTensor): Input tensor
+        Returns: 
+            CTensor, the output
+        """
         if training:
             self.input = x
         return singa.Acos(x)
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): the gradient tensor from upper operations
+        Returns: 
+            CTensor, the gradient over input
+        """     
         dx = singa.Square(self.input)
         dx = singa.MultFloat(dx, -1.0)
         dx = singa.AddFloat(dx, 1.0)
@@ -2026,20 +2529,43 @@ class Acos(Operation):
 
 
 def acos(x):
+    """
+    Calculates the arccosine (inverse of cosine) of the given input tensor, 
+    element-wise.
+    Args:
+        x (Tensor): Input tensor
+    Returns: 
+        Tensor, the output
+    """
     return Acos()(x)[0]
 
 
 class Acosh(Operation):
+    """
+    Calculates the hyperbolic arccosine of the given input tensor element-wise.
+    """
 
     def __init__(self):
         super(Acosh, self).__init__()
 
     def forward(self, x):
+        """
+        Args:
+            x (CTensor): Input tensor
+        Returns: 
+            CTensor, the output
+        """
         if training:
             self.input = x
         return singa.Acosh(x)
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): the gradient tensor from upper operations
+        Returns: 
+            CTensor, the gradient over input
+        """     
         dx = singa.SubFloat(self.input, 1.0)
         dx = singa.Sqrt(dx)
         temp = singa.AddFloat(self.input, 1.0)
@@ -2051,60 +2577,126 @@ class Acosh(Operation):
 
 
 def acosh(x):
+    """
+    Calculates the hyperbolic arccosine of the given input tensor element-wise.
+    Args:
+        x (Tensor): Input tensor
+    Returns: 
+        Tensor, the output
+    """
     return Acosh()(x)[0]
 
 
 class Sin(Operation):
+    """
+    Calculates the sine of the given input tensor, element-wise.
+    """
 
     def __init__(self):
         super(Sin, self).__init__()
 
     def forward(self, x):
+        """
+        Args:
+            x (CTensor): Input tensor
+        Returns: 
+            CTensor, the output
+        """
         if training:
             self.input = x
         return singa.Sin(x)
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): the gradient tensor from upper operations
+        Returns: 
+            CTensor, the gradient over input
+        """     
         dx = singa.Cos(self.input)
         dx *= dy
         return dx
 
 
 def sin(x):
+    """
+    Calculates the sine of the given input tensor, element-wise.
+    Args:
+        x (Tensor): Input tensor
+    Returns: 
+        Tensor, the output
+    """
     return Sin()(x)[0]
 
 
 class Sinh(Operation):
+    """
+    Calculates the hyperbolic sine of the given input tensor element-wise.
+    """
 
     def __init__(self):
         super(Sinh, self).__init__()
 
     def forward(self, x):
+        """
+        Args:
+            x (CTensor): Input tensor
+        Returns: 
+            CTensor, the output
+        """
         if training:
             self.input = x
         return singa.Sinh(x)
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): the gradient tensor from upper operations
+        Returns: 
+            CTensor, the gradient over input
+        """     
         dx = singa.Cosh(self.input)
         dx *= dy
         return dx
 
 
 def sinh(x):
+    """
+    Calculates the hyperbolic sine of the given input tensor element-wise.
+    Args:
+        x (Tensor): Input tensor
+    Returns: 
+        Tensor, the output
+    """
     return Sinh()(x)[0]
 
 
 class Asin(Operation):
+    """
+    Calculates the arcsine (inverse of sine) of the given input tensor, 
element-wise.
+    """
 
     def __init__(self):
         super(Asin, self).__init__()
 
     def forward(self, x):
+        """
+        Args:
+            x (CTensor): Input tensor
+        Returns: 
+            CTensor, the output
+        """
         if training:
             self.input = x
         return singa.Asin(x)
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): the gradient tensor from upper operations
+        Returns: 
+            CTensor, the gradient over input
+        """     
         dx = singa.Square(self.input)
         dx = singa.MultFloat(dx, -1.0)
         dx = singa.AddFloat(dx, 1.0)
@@ -2114,20 +2706,43 @@ class Asin(Operation):
 
 
 def asin(x):
+    """
+    Calculates the arcsine (inverse of sine) of the given input tensor, 
element-wise.
+    Args:
+        x (Tensor): Input tensor
+    Returns: 
+        Tensor, the output
+    """    
+
     return Asin()(x)[0]
 
 
 class Asinh(Operation):
+    """
+    Calculates the hyperbolic arcsine of the given input tensor element-wise.
+    """
 
     def __init__(self):
         super(Asinh, self).__init__()
 
     def forward(self, x):
+        """
+        Args:
+            x (CTensor): Input tensor
+        Returns: 
+            CTensor, the output
+        """
         if training:
             self.input = x
         return singa.Asinh(x)
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): the gradient tensor from upper operations
+        Returns: 
+            CTensor, the gradient over input
+        """             
         dx = singa.Square(self.input)
         dx = singa.AddFloat(dx, 1.0)
         dx = singa.PowFloat(dx, -0.5)
@@ -2136,20 +2751,42 @@ class Asinh(Operation):
 
 
 def asinh(x):
+    """
+    Calculates the hyperbolic arcsine of the given input tensor element-wise.
+    Args:
+        x (Tensor): Input tensor
+    Returns: 
+        Tensor, the output
+    """
     return Asinh()(x)[0]
 
 
 class Tan(Operation):
+    """
+    Insert single-dimensional entries to the shape of an input tensor (data). 
+    """
 
     def __init__(self):
         super(Tan, self).__init__()
 
     def forward(self, x):
+        """
+        Args:
+            x (CTensor): Input tensor
+        Returns: 
+            CTensor, the output
+        """
         if training:
             self.input = x
         return singa.Tan(x)
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): the gradient tensor from upper operations
+        Returns: 
+            CTensor, the gradient over input
+        """     
         dx = singa.Cos(self.input)
         dx = singa.Square(dx)
         dx = singa.PowFloat(dx, -1.0)
@@ -2158,20 +2795,42 @@ class Tan(Operation):
 
 
 def tan(x):
+    """
+    Calculates the tangent of the given input tensor, element-wise.
+    Args:
+        x (Tensor): Input tensor
+    Returns: 
+        Tensor, the output
+    """
     return Tan()(x)[0]
 
 
 class Atan(Operation):
+    """
+    Calculates the arctangent (inverse of tangent) of the given input tensor, 
element-wise.
+    """
 
     def __init__(self):
         super(Atan, self).__init__()
 
     def forward(self, x):
+        """
+        Args:
+            x (CTensor): Input tensor
+        Returns: 
+            CTensor, the output
+        """
         if training:
             self.input = x
         return singa.Atan(x)
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): the gradient tensor from upper operations
+        Returns: 
+            CTensor, the gradient over input
+        """     
         dx = singa.Square(self.input)
         dx = singa.AddFloat(dx, 1.0)
         dx = singa.PowFloat(dx, -1.0)
@@ -2180,20 +2839,42 @@ class Atan(Operation):
 
 
 def atan(x):
+    """
+    Calculates the arctangent (inverse of tangent) of the given input tensor, 
element-wise.
+    Args:
+        x (Tensor): Input tensor
+    Returns: 
+        Tensor, the output
+    """    
     return Atan()(x)[0]
 
 
 class Atanh(Operation):
+    """
+    Calculates the hyperbolic arctangent of the given input tensor 
element-wise.
+    """
 
     def __init__(self):
         super(Atanh, self).__init__()
 
     def forward(self, x):
+        """
+        Args:
+            x (CTensor): Input tensor
+        Returns: 
+            CTensor, the output
+        """
         if training:
             self.input = x
         return singa.Atanh(x)
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): the gradient tensor from upper operations
+        Returns: 
+            CTensor, the gradient over input
+        """     
         dx = singa.Square(self.input)
         dx = singa.MultFloat(dx, -1.0)
         dx = singa.AddFloat(dx, 1.0)
@@ -2203,21 +2884,43 @@ class Atanh(Operation):
 
 
 def atanh(x):
+    """
+    Calculates the hyperbolic arctangent of the given input tensor 
element-wise.
+    Args:
+        x (Tensor): Input tensor
+    Returns: 
+        Tensor, the output
+    """   
     return Atanh()(x)[0]
 
 
 class Sigmoid(Operation):
+    """
+    y = 1 / (1 + exp(-x)), is applied to the tensor elementwise.
+    """
 
     def __init__(self):
         super(Sigmoid, self).__init__()
 
     def forward(self, x):
+        """
+        Args:
+            x (CTensor): Input tensor
+        Returns: 
+            CTensor, the output
+        """
         out = singa.Sigmoid(x)
         if training:
             self.cache = (out,)
         return out
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): the gradient tensor from upper operations
+        Returns: 
+            CTensor, the gradient over input
+        """     
         dx = singa.MultFloat(self.cache[0], -1.0)
         dx = singa.AddFloat(dx, 1.0)
         dx = singa.__mul__(self.cache[0], dx)
@@ -2226,15 +2929,29 @@ class Sigmoid(Operation):
 
 
 def sigmoid(x):
+    """
+    y = 1 / (1 + exp(-x)), is applied to the tensor elementwise.
+    Args:
+        x (Tensor): Input tensor
+    Returns: 
+        Tensor, the output
+    """   
     return Sigmoid()(x)[0]
 
 
 class Mul(Operation):
+    """
+    Performs element-wise binary multiplication (with Numpy-style broadcasting 
+    support).        
+    """   
 
     def __init__(self):
         super(Mul, self).__init__()
 
     def forward(self, a, b):
+        """
+        Return np.multiply(a,b), where a and b are CTensor.
+        """
         # todo we cannot support mul op for int tensors
         _a, _b = a, b
         dtype0 = _a.data_type()
@@ -2254,6 +2971,13 @@ class Mul(Operation):
         return res
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): the gradient tensor from upper operations
+        Returns: 
+            a tuple for (da, db), da is data for dL / da, db is data
+                for dL / db.
+        """
         dx0 = singa.__mul__(dy, self.input[1])
         dx1 = singa.__mul__(dy, self.input[0])
         if (type(dy) == float) or self.shape0 == self.shape1:
@@ -2265,9 +2989,23 @@ class Mul(Operation):
         return dx0, dx1
 
 
+def mul(x, y):
+    """
+    Return np.multiply(x,y), where a and b are Tensor.
+    """
+    return Mul()(x, y)[0]
+
+
 class Unsqueeze(Operation):
+    """
+    Insert single-dimensional entries to the shape of an input tensor (data). 
+    """
 
     def __init__(self, axis):
+        """
+        Args:
+            axis (list of int): the dimensions to be inserted.
+        """
         super(Unsqueeze, self).__init__()
         if (type(axis) is int):
             self.axis = list(axis)
@@ -2275,6 +3013,12 @@ class Unsqueeze(Operation):
             self.axis = axis
 
     def forward(self, x):
+        """
+        Args:
+            x (CTensor): Input tensor
+        Returns: 
+            CTensor, the output
+        """
         self.cache = x.shape()
         cur = list(self.cache)
         # todo, need optimize after we have scalar tensor
@@ -2285,28 +3029,57 @@ class Unsqueeze(Operation):
         return singa.Reshape(x, cur)
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): the gradient tensor from upper operations
+        Returns: 
+            CTensor, the gradient over input
+        """     
         return singa.Reshape(dy, self.cache)
 
 
 def unsqueeze(x, axis=-1):
+    """
+    Insert single-dimensional entries to the shape of an input tensor (data). 
+    Args:
+        x (Tensor): Input tensor
+        axis (list of int): the dimensions to be inserted.
+    Returns: 
+        Tensor, the output
+    """  
     return Unsqueeze(axis)(x)[0]
 
 
-def mul(x, y):
-    # do pointwise multiplication
-    return Mul()(x, y)[0]
-
-
 class Transpose(Operation):
+    """
+    Transpose the input tensor similar to numpy.transpose. 
+    """
 
     def __init__(self, perm):
+        """
+        Args:
+            perm (list of ints): A list of integers. By default, reverse the 
+                dimensions, otherwise permute the axes according to the values 
given.
+        """
         super(Transpose, self).__init__()
         self.perm = list(perm)
 
     def forward(self, x):
+        """
+        Args:
+            x (CTensor): Input tensor
+        Returns: 
+            CTensor, the output
+        """
         return singa.Transpose(x, self.perm)
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): the gradient tensor from upper operations
+        Returns: 
+            CTensor, the gradient over input
+        """
         cur = []
         for i in range(len(self.perm)):
             cur += [self.perm.index(i)]
@@ -2314,6 +3087,15 @@ class Transpose(Operation):
 
 
 def transpose(x, shape):
+    """
+    Transpose the input tensor similar to numpy.transpose. 
+    Args:
+        x (Tensor): Input tensor
+        perm (list of ints): A list of integers. By default, reverse the 
+            dimensions, otherwise permute the axes according to the values 
given.
+    Returns: 
+        Tensor, the output
+    """
     return Transpose(shape)(x)[0]
 
 
@@ -2346,6 +3128,9 @@ class RNN_Base(Layer):
 
 
 class RNN(RNN_Base):
+    """
+    Generate a RNN operator
+    """
 
     def __init__(
         self,
@@ -2358,6 +3143,22 @@ class RNN(RNN_Base):
         dropout=0,
         bidirectional=False,
     ):
+        """
+        Args:
+            input_size (int):  The number of expected features in the input x
+            hidden_size (int): The number of features in the hidden state h
+            num_layers (int):  Number of recurrent layers. Default: 1
+            nonlinearity (string): The non-linearity to use. Default: 'tanh'
+            bias (bool):  If False, then the layer does not use bias weights. 
+                Default: True
+            batch_first (bool):  If True, then the input and output tensors 
+                are provided as (batch, seq, feature). Default: False
+            dropout (float): If non-zero, introduces a Dropout layer on the 
+                outputs of each RNN layer except the last layer, with dropout 
+                probability equal to dropout. Default: 0
+            bidirectional (bool): If True, becomes a bidirectional RNN. 
+                Default: False
+        """
         self.nonlinearity = nonlinearity
 
         Wx_shape = (input_size, hidden_size)
@@ -2407,6 +3208,9 @@ class RNN(RNN_Base):
 
 
 class LSTM(RNN_Base):
+    """
+    Generate a LSTM operator
+    """
 
     def __init__(
         self,
@@ -2419,6 +3223,22 @@ class LSTM(RNN_Base):
         dropout=0,
         bidirectional=False,
     ):
+        """
+        Args:
+            input_size (int):  The number of expected features in the input x
+            hidden_size (int): The number of features in the hidden state h
+            num_layers (int):  Number of recurrent layers. Default: 1
+            nonlinearity (string): The non-linearity to use. Default: 'tanh'
+            bias (bool):  If False, then the layer does not use bias weights. 
+                Default: True
+            batch_first (bool):  If True, then the input and output tensors 
+                are provided as (batch, seq, feature). Default: False
+            dropout (float): If non-zero, introduces a Dropout layer on the 
+                outputs of each RNN layer except the last layer, with dropout 
+                probability equal to dropout. Default: 0
+            bidirectional (bool): If True, becomes a bidirectional RNN. 
+                Default: False
+        """
         self.nonlinearity = nonlinearity
 
         Wx_shape = (input_size, hidden_size)
@@ -2511,46 +3331,89 @@ class LSTM(RNN_Base):
 
 
 class Abs(Operation):
+    """
+    y = abs(x), is applied to the tensor elementwise.
+    """
 
     def forward(self, a):
+        """
+        Return abs(a), where a is CTensor.
+        """
         if training:
             self.input = a
         return singa.Abs(a)
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): the gradient tensor from upper operations
+        Returns: 
+            CTensor, the gradient over input
+        """     
         dx = singa.Sign(self.input)
         dx *= dy
         return dx
 
 
 def abs(a):
+    """
+    Return abs(a), where a is Tensor.
+    """
     return Abs()(a)[0]
 
 
 class Exp(Operation):
+    """
+    y = exp(x), is applied to the tensor elementwise.
+    """
 
     def forward(self, a):
+        """
+        Return exp(a), where a is Tensor.
+        """
         if training:
             self.input = a
         return singa.Exp(a)
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): the gradient tensor from upper operations
+        Returns: 
+            CTensor, the gradient over input
+        """
         dx = singa.Exp(self.input)
         dx *= dy
         return dx
 
 
 def exp(a):
+    """
+    Return exp(a), where a is Tensor.
+    """
     return Exp()(a)[0]
 
 
 class LeakyRelu(Operation):
+    """
+    f(x) = alpha * x for x < 0, f(x) = x for x >= 0, is applied to the tensor 
elementwise.
+    """
 
     def __init__(self, a):
+        """
+        Args:
+            a (float): Coefficient of leakage.
+        """
         super(LeakyRelu, self).__init__()
         self.a = a
 
     def forward(self, x):
+        """
+        Args:
+            x (CTensor): Input tensor
+        Returns: 
+            CTensor, the output
+        """
         if training:
             self.input = x
         x1 = singa.LTFloat(x, 0.0)
@@ -2561,6 +3424,12 @@ class LeakyRelu(Operation):
         return x1
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): the gradient tensor from upper operations
+        Returns: 
+            CTensor, the gradient over input
+        """
         # TODO(wangwei) check the correctness
         dx1 = singa.GTFloat(self.input, 0.0)
         dx2 = singa.LTFloat(self.input, 0.0)
@@ -2571,34 +3440,73 @@ class LeakyRelu(Operation):
 
 
 def leakyrelu(x, a=0.01):
+    """
+    f(x) = alpha * x for x < 0, f(x) = x for x >= 0 is applied to the tensor 
+    elementwise.
+    Args:
+        x (Tensor): Input tensor
+        a (float): Coefficient of leakage, default to 0.01.
+    Returns: 
+        Tensor, the output
+    """
     return LeakyRelu(a)(x)[0]
 
 
 class Sign(Operation):
+    """
+    Calculate the sign of the given input tensor element-wise. If input > 0, 
+    output 1. if input < 0, output -1. if input == 0, output 0.
+    """
 
     def __init__(self):
         super(Sign, self).__init__()
 
     def forward(self, a):
+        """
+        Args:
+            a (CTensor): Input tensor
+        Returns: 
+            CTensor, the output
+        """
         if training:
             self.input = a
         return singa.Sign(a)
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): the gradient tensor from upper operations
+        Returns: 
+            CTensor, the gradient over input
+        """
         dx = singa.MultFloat(dy, 0.0)
         return dx
 
 
 def sign(a):
+    """
+    Calculate the sign of the given input tensor element-wise. If input > 0, 
+    output 1. if input < 0, output -1. if input == 0, output 0.
+    Args:
+        a (Tensor): Input tensor
+    Returns: 
+        Tensor, the output
+    """
     return Sign()(a)[0]
 
 
 class Pow(Operation):
+    """
+    f(x) = a^b, is applied to the tensor elementwise.
+    """
 
     def __init__(self):
         super(Pow, self).__init__()
 
     def forward(self, a, b):
+        """
+        Return a^b, where a and b are CTensor.
+        """
         res = singa.Pow(a, b)
         if training:
             self.input = (a, b)
@@ -2608,6 +3516,13 @@ class Pow(Operation):
         return res
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): the gradient tensor from upper operations
+        Returns: 
+            a tuple for (da, db), da is data for dL / da, db is data
+                for dL / db.
+        """
         da1 = singa.__mul__(
             self.input[1],
             singa.Pow(self.input[0], singa.SubFloat(self.input[1], 1.0)))
@@ -2625,15 +3540,24 @@ class Pow(Operation):
 
 
 def pow(a, b):
+    """
+    Return a^b, where a and b are Tensor.
+    """
     return Pow()(a, b)[0]
 
 
 class SoftSign(Operation):
+    """
+    Calculates the softsign (x/(1+|x|)) of the given input tensor element-wise.
+    """
 
     def __init__(self):
         super(SoftSign, self).__init__()
 
     def forward(self, x):
+        """
+        Return (x/(1+|x|)), where x is CTensor.
+        """
         # y = x / (1 + np.abs(x))
         if training:
             self.input = x
@@ -2643,6 +3567,12 @@ class SoftSign(Operation):
         return y
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): the gradient tensor from upper operations
+        Returns: 
+            CTensor, the gradient over input
+        """
         dx = singa.AddFloat(singa.Abs(self.input), 1.0)
         dx = singa.PowFloat(singa.Square(dx), -1.0)
         dx = singa.__mul__(dy, dx)
@@ -2650,20 +3580,35 @@ class SoftSign(Operation):
 
 
 def softsign(x):
+    """
+    Return (x/(1+|x|)), where x is Tensor.
+    """
     return SoftSign()(x)[0]
 
 
 class Sqrt(Operation):
+    """
+    y = x^0.5, is applied to the tensor elementwise.
+    """
 
     def __init__(self):
         super(Sqrt, self).__init__()
 
     def forward(self, x):
+        """
+        Return x^0.5, where x is CTensor.
+        """
         if training:
             self.input = x
         return singa.Sqrt(x)
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): the gradient tensor from upper operations
+        Returns: 
+            CTensor, the gradient over input
+        """
         dx = singa.PowFloat(self.input, -0.5)
         dx = singa.MultFloat(dx, 0.5)
         dx = singa.__mul__(dy, dx)
@@ -2671,15 +3616,24 @@ class Sqrt(Operation):
 
 
 def sqrt(x):
+    """
+    Return x^0.5, where x is Tensor.
+    """
     return Sqrt()(x)[0]
 
 
 class SoftPlus(Operation):
+    """
+    y = ln(exp(x) + 1) is applied to the tensor elementwise.
+    """
 
     def __init__(self):
         super(SoftPlus, self).__init__()
 
     def forward(self, x):
+        """
+        Return ln(exp(x) + 1), where x is CTensor.
+        """
         #f(x) = ln(exp(x) + 1)
         if training:
             self.input = x
@@ -2688,6 +3642,12 @@ class SoftPlus(Operation):
         return y
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): the gradient tensor from upper operations
+        Returns: 
+            CTensor, the gradient over input
+        """
         dx = singa.Exp(singa.MultFloat(self.input, -1.0))
         dx = singa.PowFloat(singa.AddFloat(dx, 1.0), -1.0)
         dx = singa.__mul__(dy, dx)
@@ -2695,15 +3655,25 @@ class SoftPlus(Operation):
 
 
 def softplus(x):
+    """
+    Return ln(exp(x) + 1), where x is Tensor.
+    """
     return SoftPlus()(x)[0]
 
 
 class Sub(Operation):
+    """
+    Performs element-wise binary subtraction (with Numpy-style broadcasting 
+    support).
+    """
 
     def __init__(self):
         super(Sub, self).__init__()
 
     def forward(self, a, b):
+        """
+        Return a-b, where x is CTensor.
+        """
         res = singa.__sub__(a, b)
         if training:
             self.shape0 = list(a.shape())
@@ -2712,6 +3682,13 @@ class Sub(Operation):
         return res
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): the gradient tensor from upper operations
+        Returns: 
+            a tuple for (da, db), da is data for dL / da, db is data
+                for dL / db.
+        """
         dx0 = dy
         dx1 = singa.MultFloat(dy, -1.0)
         if (type(dy) == float) or self.shape0 == self.shape1:
@@ -2724,17 +3701,32 @@ class Sub(Operation):
 
 
 def sub(a, b):
+    """
+    Return a-b, where a and b are Tensor.
+    """
     return Sub()(a, b)[0]
 
 
 # optimize min to support multi inputs
 class Min(Operation):
+    """
+    Element-wise min of each of the input tensors (with Numpy-style 
+    broadcasting support).
+    """
 
     def __init__(self):
         super(Min, self).__init__()
         self.masks = []
 
     def _min(self, a, b):
+        """
+        Args:
+            a (CTensor): First operand
+            b (CTensor): Second operand
+        Returns: 
+            CTensor, the output
+            tuple of CTensor, mask tensor
+        """
         m = singa.__sub__(a, b)
         mask0 = singa.LEFloat(m, 0)
         mask1 = singa.GTFloat(m, 0)
@@ -2742,6 +3734,12 @@ class Min(Operation):
         return res, (mask0, mask1)
 
     def forward(self, *x):
+        """
+        Args:
+            *x (a list of CTensor): List of tensors for max.
+        Returns: 
+            CTensor, the output
+        """    
         assert (len(x) > 0)
         self.l = len(x)
         if len(x) == 1:
@@ -2756,6 +3754,12 @@ class Min(Operation):
         return res
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): the gradient tensor from upper operations
+        Returns: 
+            a tuple for (*dx), dx is data for dL / dx.
+        """
         if self.l == 1:
             return self.masks[0][0]
         else:
@@ -2773,39 +3777,69 @@ class Min(Operation):
 
 
 def min(*l):
+    """
+    Element-wise min of each of the input tensors (with Numpy-style 
+    broadcasting support).
+    Args:
+        *x (a list of Tensor): List of tensors for max.
+    Returns: 
+        Tensor, the output
+    """
     return Min()(*l)[0]
 
 
 class Log(Operation):
+    """
+    y = log(x), is applied to the tensor elementwise.
+    """
 
     def __init__(self):
         super(Log, self).__init__()
 
     def forward(self, x):
+        """
+        Return log(x), where x is CTensor.
+        """
         if training:
             self.input = x
         return singa.Log(x)
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): the gradient tensor from upper operations
+        Returns: 
+            CTensor, the gradient over input
+        """
         dx = singa.PowFloat(self.input, -1)
         dx = singa.__mul__(dy, dx)
         return dx
 
 
 def log(x):
+    """
+    Return log(x), where x is Tensor.
+    """
     return Log()(x)[0]
 
 
 class HardSigmoid(Operation):
+    """
+    y = max(0, min(1, alpha * x + beta)), is applied to the tensor elementwise.
+    """
 
     def __init__(self, alpha=0.2, gamma=0.5):
+        """
+        Args:
+            alpha (float): Value of alpha.
+            gamma (float): Value of beta.
+        """
         super(HardSigmoid, self).__init__()
         self.alpha = alpha
         self.gamma = gamma
 
     def forward(self, x):
-        """Do forward propgation.
-        #y = max(0, min(1, alpha * x + gamma))
+        """
         Args:
             x (CTensor): matrix
         Returns:
@@ -2823,6 +3857,12 @@ class HardSigmoid(Operation):
         return singa.ReLU(ans)
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): the gradient tensor from upper operations
+        Returns: 
+            CTensor, the gradient over input
+        """
         mask0 = singa.GTFloat(self.cache, 0.0)
         mask1 = singa.LTFloat(self.cache, 1.0)
         mask = singa.__mul__(mask0, mask1)
@@ -2830,16 +3870,43 @@ class HardSigmoid(Operation):
 
 
 def hardsigmoid(x, alpha=0.2, gamma=0.5):
+    """
+    y = max(0, min(1, alpha * x + beta)), is applied to the tensor elementwise.
+    Args:
+        x (Tensor): matrix
+        alpha (float): Value of alpha.
+        gamma (float): Value of beta.        
+    Returns:
+        a Tensor for the result
+    """
     return HardSigmoid(alpha, gamma)(x)[0]
 
 
 class Squeeze(Operation):
+    """
+    Remove single-dimensional entries from the shape of a tensor. Takes a 
+    parameter axes with a list of axes to squeeze. If axes is not provided, 
+    all the single dimensions will be removed from the shape. If an axis is 
+    selected with shape entry not equal to one, an error is raised.
+    """
 
     def __init__(self, axis=[]):
+        """
+        Args:
+            axis (list of ints): List of integers indicating the dimensions 
+                to squeeze. Negative value means counting dimensions from 
+                the back. Accepted range is [-r, r-1] where r = rank(data).
+        """
         super(Squeeze, self).__init__()
         self.axis = axis
 
     def forward(self, x):
+        """
+        Args:
+            x (CTensor): Input tensor
+        Returns: 
+            CTensor, the output
+        """
         self.cache = x.shape()
         newshape = []
         if (self.axis == []):
@@ -2860,19 +3927,44 @@ class Squeeze(Operation):
         return singa.Reshape(x, newshape)
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): the gradient tensor from upper operations
+        Returns: 
+            CTensor, the gradient over input
+        """
         return singa.Reshape(dy, self.cache)
 
 
 def squeeze(x, axis=[]):
+    """
+    Remove single-dimensional entries from the shape of a tensor. Takes a 
+    parameter axes with a list of axes to squeeze. If axes is not provided, 
+    all the single dimensions will be removed from the shape. If an axis is 
+    selected with shape entry not equal to one, an error is raised.
+    Args:
+        x (Tensor): Input tensor
+        axis (list of ints): List of integers indicating the dimensions 
+            to squeeze. Negative value means counting dimensions from 
+            the back. Accepted range is [-r, r-1] where r = rank(data).
+    Returns: 
+        Tensor, the output
+    """
     return Squeeze(axis)(x)[0]
 
 
 class Div(Operation):
+    """
+    Performs element-wise binary division (with Numpy-style broadcasting 
support).
+    """
 
     def __init__(self):
         super(Div, self).__init__()
 
     def forward(self, a, b):
+        """
+        Return np.div(a,b), where a and b are CTensor.
+        """
         res = singa.__mul__(a, singa.PowFloat(b, -1.0))
         # res = singa.__div__(a, b)
         if training:
@@ -2884,6 +3976,13 @@ class Div(Operation):
         return res
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): the gradient tensor from upper operations
+        Returns: 
+            a CTensor tuple for (da, db), da is data for dL / da, db is data
+                for dL / db.
+        """
         #dy/dx_0 = b^(-1)
         #dy/dx_1 = (-a)*b^(-2)
         dx0 = singa.__mul__(dy, self.input[1])
@@ -2899,36 +3998,73 @@ class Div(Operation):
 
 
 def div(a, b):
+    """
+    Return np.div(a,b), where a and b are Tensor.
+    """
     return Div()(a, b)[0]
 
 
 class Shape(Operation):
+    """
+    Takes a tensor as input and outputs a tensor containing the shape of the 
+    input tensor.
+    """
 
     def __init__(self):
         super(Shape, self).__init__()
 
     def forward(self, x):
+        """
+        Args:
+            x (CTensor): Input tensor
+        Returns: 
+            CTensor, the output
+        """
         cur = list(x.shape())
         cur = tensor.from_numpy(np.array(cur))
         cur.to_device(x.device())
         return cur.data
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): the gradient tensor from upper operations
+        """
         return list(dy.shape())
 
 
 def shape(x):
+    """
+    Takes a tensor as input and outputs a tensor containing the shape of the 
+    input tensor.
+    Args:
+        x (Tensor): Input tensor
+    Returns: 
+        Tensor, the output
+    """
     return Shape()(x)[0]
 
 
 # optimize max to support multi inputs
 class Max(Operation):
+    """
+    Element-wise max of each of the input tensors (with Numpy-style 
+    broadcasting support). 
+    """
 
     def __init__(self):
         super(Max, self).__init__()
         self.masks = []
 
     def _max(self, a, b):
+        """
+        Args:
+            a (CTensor): First operand
+            b (CTensor): Second operand
+        Returns: 
+            CTensor, the output
+            tuple of CTensor, mask tensor
+        """    
         m = singa.__sub__(a, b)
         mask0 = singa.GEFloat(m, 0)
         mask1 = singa.LTFloat(m, 0)
@@ -2936,6 +4072,12 @@ class Max(Operation):
         return res, (mask0, mask1)
 
     def forward(self, *x):
+        """
+        Args:
+            *x (a list of CTensor): List of tensors for max.
+        Returns: 
+            CTensor, the output
+        """    
         assert (len(x) > 0)
         self.l = len(x)
         if len(x) == 1:
@@ -2950,6 +4092,12 @@ class Max(Operation):
         return res
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): the gradient tensor from upper operations
+        Returns: 
+            a tuple for (*dx), dx is data for dL / dx.
+        """
         if self.l == 1:
             return self.masks[0][0]
         else:
@@ -2967,34 +4115,62 @@ class Max(Operation):
 
 
 def max(*l):
+    """
+    Element-wise max of each of the input tensors (with Numpy-style 
broadcasting support). 
+    Args:
+        *x (a list of Tensor): List of tensors for max.
+    Returns: 
+        CTensor, the output
+    """
     return Max()(*l)[0]
 
 
 class And(Operation):
+    """
+    Returns the tensor resulted from performing the and logical operation 
elementwise on the input tensors A and B (with Numpy-style broadcasting 
support).
+    """
 
     def __init__(self):
         super(And, self).__init__()
 
     def forward(self, a, b):
+        """
+        Return np.logical_and(a,b), where a and b are CTensor.
+        """
         m = singa.__mul__(a, b)
         cur = singa.PowFloat(singa.Sign(m), 2)
 
         return cur
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): the gradient tensor from upper operations
+        Raises:
+            AssertionError: no backward function for this operator
+        """
         assert False, ('no gradient for backward function')
 
 
 def _and(a, b):
+    """
+    Return np.logical_and(a,b), where a and b are Tensor.
+    """
     return And()(a, b)[0]
 
 
 class Or(Operation):
+    """
+    Returns the tensor resulted from performing the or logical operation 
elementwise on the input tensors A and B (with Numpy-style broadcasting 
support).
+    """
 
     def __init__(self):
         super(Or, self).__init__()
 
     def forward(self, a, b):
+        """
+        Return np.logical_or(a,b), where a and b are CTensor.
+        """
         m = singa.__add__(singa.PowFloat(singa.Sign(a), 2.0),
                           singa.PowFloat(singa.Sign(b), 2.0))
         cur = singa.Sign(m)
@@ -3002,19 +4178,34 @@ class Or(Operation):
         return cur
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): the gradient tensor from upper operations
+        Raises:
+            AssertionError: no backward function for this operator
+        """
         assert False, ('no gradient for backward function')
 
 
 def _or(a, b):
+    """
+    Return np.logical_or(a,b), where a and b are Tensor.
+    """
     return Or()(a, b)[0]
 
 
 class Not(Operation):
+    """
+    Returns the negation of the input tensor element-wise.
+    """
 
     def __init__(self):
         super(Not, self).__init__()
 
     def forward(self, x):
+        """
+        Return np.logical_not(x), where x is CTensor.
+        """
         mask0 = singa.GEFloat(x, 0)
         mask1 = singa.LEFloat(x, 0)
         cur = singa.__mul__(mask0, mask1)
@@ -3022,19 +4213,34 @@ class Not(Operation):
         return cur
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): the gradient tensor from upper operations
+        Raises:
+            AssertionError: no backward function for this operator
+        """
         assert False, ('no gradient for backward function')
 
 
 def _not(x):
+    """
+    Return np.logical_not(x), where x is Tensor.
+    """
     return Not()(x)[0]
 
 
 class Xor(Operation):
+    """
+    Performing the xor logical operation elementwise on the input tensors A 
and B (with Numpy-style broadcasting support).
+    """
 
     def __init__(self):
         super(Xor, self).__init__()
 
     def forward(self, a, b):
+        """
+        Return np.logical_xor(a,b), where a and b are CTensor.
+        """
         m = singa.__sub__(singa.PowFloat(singa.Sign(a), 2.0),
                           singa.PowFloat(singa.Sign(b), 2.0))
         cur = singa.PowFloat(singa.Sign(m), 2.0)
@@ -3042,36 +4248,66 @@ class Xor(Operation):
         return cur
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): the gradient tensor from upper operations
+        Raises:
+            AssertionError: no backward function for this operator
+        """
         assert False, ('no gradient for backward function')
 
 
 def _xor(a, b):
+    """
+    Return np.logical_xor(a,b), where a and b are Tensor.
+    """
     return Xor()(a, b)[0]
 
 
 class Negative(Operation):
+    """
+    y = -x, is applied to the tensor elementwise.
+    """
 
     def __init__(self):
         super(Negative, self).__init__()
 
     def forward(self, x):
+        """
+        Return -x, where x is CTensor.
+        """
         #y=-x
         return singa.MultFloat(x, -1)
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): the gradient tensor from upper operations
+        Returns: 
+            CTensor, the gradient over input
+        """
         return singa.MultFloat(dy, -1)
 
 
 def negative(x):
+    """
+    Return -x, where x is Tensor.
+    """
     return Negative()(x)[0]
 
 
 class Reciprocal(Operation):
+    """
+    y = 1/x, is applied to the tensor elementwise.
+    """
 
     def __init__(self):
         super(Reciprocal, self).__init__()
 
     def forward(self, x):
+        """
+        Return 1/x, where x is CTensor.
+        """
         #y=1/x elementwise
         if training:
             self.input = x
@@ -3079,33 +4315,44 @@ class Reciprocal(Operation):
         return singa.PowFloat(x, -1)
 
     def backward(self, dy):
+        """
+        Args:
+            dy (CTensor): the gradient tensor from upper operations
+        Returns: 
+            CTensor, the gradient over input
+        """
         #dy/dx = -1/x**2
         dx = singa.MultFloat(singa.PowFloat(self.input, -2), -1)
         return singa.__mul__(dy, dx)
 
 
 def reciprocal(x):
+    """
+    Return 1/x, where x is Tensor.
+    """
     return Reciprocal()(x)[0]
 
 
 class Gemm(Operation):
+    """
+    Init a General Matrix multiplication(Gemm) operator. Compute Y = alpha * 
+    A' * B' + beta * C, where input tensor A has shape (M, K) or (K, M), input 
+    tensor B has shape (K, N) or (N, K), input tensor C is broadcastable to 
+    shape (M, N), and output tensor Y has shape (M, N).
+    A' = transpose(A) if transA else A
+    B' = transpose(B) if transB else B
+    """
 
     def __init__(self, alpha=1.0, beta=1.0, transA=0, transB=0):
         """
-        init a General Matrix multiplication(Gemm) operator
-        Compute Y = alpha * A' * B' + beta * C, where input tensor A has shape 
(M, K) or (K, M), input tensor B has shape (K, N) or (N, K), input tensor C is 
broadcastable to shape (M, N), and output tensor Y has shape (M, N).
-        A' = transpose(A) if transA else A
-        B' = transpose(B) if transB else B
-        Args:alpha: 
-            float, Scalar multiplier for the product of input tensors A * B.
-        Args:beta: 
-            float, Scalar multiplier for input tensor C.
-        Args:transA: 
-            int, Whether A should be transposed
-        Args:transB: 
-            int, Whether B should be transposed
+        Args:
+            alpha (float): Scalar multiplier for the product of input tensors 
+                A * B.
+            beta (float): Scalar multiplier for input tensor C.
+            ransA (int): Whether A should be transposed
+            transB (int): Whether B should be transposed
         Returns: 
-            tensor, the output
+            CTensor, the output
         """
         super(Gemm, self).__init__()
         self.alpha = alpha
@@ -3116,12 +4363,14 @@ class Gemm(Operation):
     def forward(self, A, B, C=None):
         """
         forward propogation of Gemm
-        Args:A: 
-            tensor, The shape of A should be (M, K) if transA is 0, or (K, M) 
if transA is non-zero.
-        Args:B: 
-            tensor, The shape of B should be (K, N) if transB is 0, or (N, K) 
if transB is non-zero.
-        Args:C: 
-            tensor(optional), Optional input tensor C. If not specified, the 
computation is done as if C is a scalar 0. The shape of C should be 
unidirectional broadcastable to (M, N).
+        Args:
+            A (CTensor): The shape of A should be (M, K) if transA is 0, or 
+                (K, M) if transA is non-zero.
+            B (CTensor): The shape of B should be (K, N) if transB is 0, or 
+                (N, K) if transB is non-zero.
+            C (CTensor): (optional), Optional input tensor C. If not 
specified, 
+                the computation is done as if C is a scalar 0. The shape of C 
+                should be unidirectional broadcastable to (M, N).
         Returns: 
             tensor, the output
         """
@@ -3137,12 +4386,12 @@ class Gemm(Operation):
     def backward(self, dy):
         """
         backward propogation of Gemm
-        Args:dy: 
-            tensor, The shape of A should be (M, K) if transA is 0, or (K, M) 
if transA is non-zero.
+        Args:
+            dy (CTensor): The shape of A should be (M, K) if transA is 0, or 
(K, M) if transA is non-zero.
         Returns: 
-            tensor, the gradient over A
-            tensor, the gradient over B
-            tensor(optional), the gradient over C
+            CTensor, the gradient over A
+            CTensor, the gradient over B
+            CTensor(optional), the gradient over C
         """
         _A, _B, C = self.inputs
         # y = alpha * A  * B  => da = alpha * dy * BT
@@ -3172,39 +4421,42 @@ class Gemm(Operation):
 
 def gemm(A, B, C=None, alpha=1.0, beta=1.0, transA=0, transB=0):
     """
-    init a General Matrix multiplication(Gemm) operator
-    Compute Y = alpha * A' * B' + beta * C, where input tensor A has shape (M, 
K) or (K, M), input tensor B has shape (K, N) or (N, K), input tensor C is 
broadcastable to shape (M, N), and output tensor Y has shape (M, N).
+    Init a General Matrix multiplication(Gemm) operator. Compute Y = alpha * 
+    A' * B' + beta * C, where input tensor A has shape (M, K) or (K, M), input 
+    tensor B has shape (K, N) or (N, K), input tensor C is broadcastable to 
+    shape (M, N), and output tensor Y has shape (M, N).
     A' = transpose(A) if transA else A
     B' = transpose(B) if transB else B
-    Args:A: 
-        tensor, The shape of A should be (M, K) if transA is 0, or (K, M) if 
transA is non-zero.
-    Args:B: 
-        tensor, The shape of B should be (K, N) if transB is 0, or (N, K) if 
transB is non-zero.
-    Args:C: 
-        tensor(optional), Optional input tensor C. If not specified, the 
computation is done as if C is a scalar 0. The shape of C should be 
unidirectional broadcastable to (M, N).
-    Args:alpha: 
-        float, Scalar multiplier for the product of input tensors A * B.
-    Args:beta: 
-        float, Scalar multiplier for input tensor C.
-    Args:transA: 
-        int, Whether A should be transposed
-    Args:transB: 
-        int, Whether B should be transposed
+    Args:
+        A (Tensor): The shape of A should be (M, K) if transA is 0, or 
+            (K, M) if transA is non-zero.
+        B (Tensor): The shape of B should be (K, N) if transB is 0, or 
+            (N, K) if transB is non-zero.
+        C (Tensor): (optional), Optional input tensor C. If not specified, 
+            the computation is done as if C is a scalar 0. The shape of C 
+            should be unidirectional broadcastable to (M, N).
+        alpha (float): Scalar multiplier for the product of input tensors A * 
B.
+        beta (float): Scalar multiplier for input tensor C.
+        ransA (int): Whether A should be transposed
+        transB (int): Whether B should be transposed
     Returns: 
-        tensor, the output
+        Tensor, the output
     """
     return Gemm(alpha, beta, transA, transB)(A, B, C)[0]
 
 
 class GlobalAveragePool(Operation):
+    """
+    Init a GlobalAveragePool operator
+    """
 
     def __init__(self, data_format='channels_first'):
         """
-        init a GlobalAveragePool operator
-        Args:data_format: 
-            A string, we support two formats: channels_last and 
channels_first, default is channels_first.
-            channels_first means the format of input is (N x C x H x W)
-            channels_last means the format of input is (N x H x W x C)
+        Args:
+            data_format (string): A string, we support two formats: 
+                channels_last and channels_first, default is channels_first.
+                channels_first means the format of input is (N x C x H x W)
+                channels_last means the format of input is (N x H x W x C)
         """
         super(GlobalAveragePool, self).__init__()
         self.data_format = data_format
@@ -3212,10 +4464,10 @@ class GlobalAveragePool(Operation):
     def forward(self, x):
         """
         forward propogation of GlobalAveragePool
-        Args:x: 
-            the input tensor
+        Args:
+            x (CTensor): the input tensor
         Returns: 
-            tensor, the output
+            CTensor, the output
         """
         if training:
             self.mask = singa.Tensor(x.shape(), x.device())
@@ -3244,10 +4496,10 @@ class GlobalAveragePool(Operation):
     def backward(self, dy):
         """
         backward propogation of GlobalAveragePool
-        Args:dy: 
-            the gradient tensor from upper operations
+        Args:
+            dy (CTensor): the gradient tensor from upper operations
         Returns: 
-            tensor, the gradient over input
+            CTensor, the gradient over input
         """
         self.mask.SetFloatValue(self.shape_divisor)
         return singa.__mul__(self.mask, dy)
@@ -3256,26 +4508,29 @@ class GlobalAveragePool(Operation):
 def globalaveragepool(x, data_format='channels_first'):
     """
     GlobalAveragePool operator
-    Args:x
-        the input tensor
-    Args:data_format: 
-        A string, we support two formats: channels_last and channels_first, 
default is channels_first.
-        channels_first means the format of input is (N x C x H x W)
-        channels_last means the format of input is (N x H x W x C)
+    Args:
+        x (Tensor): the input tensor
+        data_format (string): A string, we support two formats: 
+            channels_last and channels_first, default is channels_first.
+            channels_first means the format of input is (N x C x H x W)
+            channels_last means the format of input is (N x H x W x C)
     Returns: 
-        tensor, the output
+        Tensor, the output
     """
     return GlobalAveragePool(data_format)(x)[0]
 
 
 class ConstantOfShape(Operation):
+    """
+    Init a ConstantOfShape, generate a tensor with given value and shape.
+    """
 
-    def __init__(self, value=0):
+    def __init__(self, value=0.):
         """
-        Init a ConstantOfShape, generate a tensor with given value and shape.
         Args:
-            value: (Optional) The value of the output elements. Should be a 
one-element value. If not specified, 
-            it defaults to 0 and datatype float32
+            value (float): (Optional) The value of the output elements. Should 
+                be a one-element value. If not specified, it defaults to 0 and 
+                datatype float32
         """
         super(ConstantOfShape, self).__init__()
         self.value = value
@@ -3284,10 +4539,13 @@ class ConstantOfShape(Operation):
         """
         forward of ConstantOfShape
         Args:
-            x: CTensor, 1D tensor. The shape of the expected output tensor. 
All values must be >= 0.
+            x: CTensor, 1D tensor. The shape of the expected output tensor. 
+                All values must be >= 0.
         Returns:
-            the output CTensor. If attribute 'value' is specified, the value 
and datatype of the output tensor is taken from 'value'. 
-            If attribute 'value' is not specified, the value in the output 
defaults to 0, and the datatype defaults to float32.
+            the output CTensor. If attribute 'value' is specified, the value 
+                and datatype of the output tensor is taken from 'value'. If 
+                attribute 'value' is not specified, the value in the output 
+                defaults to 0, and the datatype defaults to float32.
         """
         x_shape = tensor.to_numpy(tensor.from_raw_tensor(x)).astype(
             np.int64).tolist()
@@ -3300,7 +4558,9 @@ class ConstantOfShape(Operation):
         """
         backward of ConstantOfShape
         Args:
-            dy: CTensor, gradient tensor.
+            dy (CTensor): gradient tensor.
+        Raises:
+            AssertionError: no backward function for this operator
         """
         assert False, ('no gradient for backward function')
 
@@ -3309,25 +4569,30 @@ def constant_of_shape(x, value=0):
     """
     Init a ConstantOfShape, generate a tensor with given value and shape.
     Args:
-        x: CTensor, 1D tensor. The shape of the expected output tensor. All 
values must be >= 0.
-    Args:
-        value: (Optional) The value of the output elements. Should be a 
one-element tensor. If not specified, 
-        it defaults to a tensor of value 0 and datatype float32
+        x: Tensor, 1D tensor. The shape of the expected output tensor. 
+            All values must be >= 0.
+        value (float): (Optional) The value of the output elements. Should 
+            be a one-element value. If not specified, it defaults to 0 and 
+            datatype float32
     Returns:
-            the output CTensor. If attribute 'value' is specified, the value 
and datatype of the output tensor is taken from 'value'. 
-            If attribute 'value' is not specified, the value in the output 
defaults to 0, and the datatype defaults to float32.
+        the output Tensor. If attribute 'value' is specified, the value 
+            and datatype of the output tensor is taken from 'value'. If 
+            attribute 'value' is not specified, the value in the output 
+            defaults to 0, and the datatype defaults to float32.
     """
     return ConstantOfShape(value)(x)[0]
 
 
 class Dropout(Operation):
+    """
+    Init a Dropout, which scales the masked input data by the following 
equation:
+    output = scale * data * mask, scale = 1. / (1. - ratio).
+    """
 
     def __init__(self, ratio=0.5):
         """
-        Init a Dropout, which scales the masked input data by the following 
equation:
-        output = scale * data * mask, scale = 1. / (1. - ratio).
         Args:
-            ratio: float, he ratio of random dropout, with value in [0, 1).
+            ratio (float): the ratio of random dropout, with value in [0, 1).
         """
         super(Dropout, self).__init__()
         self.ratio = ratio
@@ -3336,7 +4601,7 @@ class Dropout(Operation):
         """
         forward of Dropout
         Args:
-            x: CTensor, input tensor.
+            x (CTensor): input tensor.
         Returns:
             the output CTensor.
         """
@@ -3351,7 +4616,7 @@ class Dropout(Operation):
         """
         backward of Dropout
         Args:
-            dy: CTensor, gradient tensor.
+            dy (CTensor): gradient tensor.
         Returns:
             the gradient tensor over input tensor.
         """
@@ -3362,28 +4627,31 @@ class Dropout(Operation):
 
 def dropout(x, ratio=0.5):
     """
-    Init a Dropout, which scales the masked input data by the following 
equation:
-    output = scale * data * mask, scale = 1. / (1. - ratio).
-    Args:
-        x: CTensor, input tensor.
+    Init a Dropout, which scales the masked input data by the following 
+    equation: output = scale * data * mask, scale = 1. / (1. - ratio).
     Args:
-        ratio: float, he ratio of random dropout, with value in [0, 1).
+        x (Tensor): input tensor.
+        ratio (float): the ratio of random dropout, with value in [0, 1).
     Returns:
-        the output CTensor.
+        the output Tensor.
     """
     return Dropout(ratio)(x)[0]
 
 
 class ReduceSum(Operation):
+    """
+    Init a ReduceSum, computes the sum of the input tensor's element along 
+    the provided axes.
+    """
 
     def __init__(self, axes=None, keepdims=1):
         """
-        Init a ReduceSum, computes the sum of the input tensor's element along 
the provided axes.
         Args:
-            axes: list of ints, A list of integers, along which to reduce. 
Accepted range is [-r, r-1] where r = rank(data).
-            The default is None, which reduces over all the dimensions of the 
input tensor.
-        Args:
-            keepdims: int, Keep the reduced dimension or not, default 1 mean 
keep reduced dimension.
+            axes (list of int): A list of integers, along which to reduce. 
+                Accepted range is [-r, r-1] where r = rank(data). The default 
+                is None, which reduces over all the dimensions of the input 
tensor.
+            keepdims (int): Keep the reduced dimension or not, default 1 mean 
+                keep reduced dimension.
         """
         super(ReduceSum, self).__init__()
         self.axes = axes
@@ -3393,7 +4661,7 @@ class ReduceSum(Operation):
         """
         forward of ReduceSum
         Args:
-            x: CTensor, input tensor.
+            x (CTensor): input tensor.
         Returns:
             the output CTensor.
         """
@@ -3418,7 +4686,7 @@ class ReduceSum(Operation):
         """
         backward of ReduceSum
         Args:
-            dy: CTensor, gradient tensor.
+            dy (CTensor): gradient tensor.
         Returns:
             the gradient tensor over input tensor.
         """
@@ -3433,30 +4701,35 @@ class ReduceSum(Operation):
 
 def reduce_sum(x, axes=None, keepdims=1):
     """
-    Init a ReduceSum, computes the sum of the input tensor's element along the 
provided axes.
-    Args:
-        x: CTensor, input tensor.
+    Init a ReduceSum, computes the sum of the input tensor's element along 
+    the provided axes.
     Args:
-        axes: list of ints, A list of integers, along which to reduce. 
Accepted range is [-r, r-1] where r = rank(data).
-        The default is None, which reduces over all the dimensions of the 
input tensor.
-    Args:
-        keepdims: int, Keep the reduced dimension or not, default 1 mean keep 
reduced dimension.
+        x (Tensor): input tensor.
+        axes (list of int): A list of integers, along which to reduce. 
+            Accepted range is [-r, r-1] where r = rank(data). The default 
+            is None, which reduces over all the dimensions of the input tensor.
+        keepdims (int): Keep the reduced dimension or not, default 1 mean 
+            keep reduced dimension.
     Returns:
-        the output CTensor.
+        the output Tensor.
     """
     return ReduceSum(axes, keepdims)(x)[0]
 
 
 class ReduceMean(Operation):
+    """
+    Init a ReduceMean, computes the mean of the input tensor's element along 
+    the provided axes.
+    """
 
     def __init__(self, axes=None, keepdims=1):
         """
-        Init a ReduceMean, computes the mean of the input tensor's element 
along the provided axes.
-        Args:
-            axes: list of ints, A list of integers, along which to reduce. 
Accepted range is [-r, r-1] where r = rank(data).
-            The default is None, which reduces over all the dimensions of the 
input tensor.
         Args:
-            keepdims: int, Keep the reduced dimension or not, default 1 mean 
keep reduced dimension.
+            axes (list of int): A list of integers, along which to reduce. 
+                Accepted range is [-r, r-1] where r = rank(data). The default 
+                is None, which reduces over all the dimensions of the input 
tensor.
+            keepdims (int): Keep the reduced dimension or not, default 1 mean 
+                keep reduced dimension.
         """
         super(ReduceMean, self).__init__()
         self.axes = axes
@@ -3466,7 +4739,7 @@ class ReduceMean(Operation):
         """
         forward of ReduceMean
         Args:
-            x: CTensor, input tensor.
+            x (CTensor): input tensor.
         Returns:
             the output CTensor.
         """
@@ -3493,7 +4766,7 @@ class ReduceMean(Operation):
         """
         backward of ReduceMean
         Args:
-            dy: CTensor, gradient tensor.
+            dy (CTensor): gradient tensor.
         Returns:
             the gradient tensor over input tensor.
         """
@@ -3507,36 +4780,38 @@ class ReduceMean(Operation):
 
 def reduce_mean(x, axes=None, keepdims=1):
     """
-    Init a ReduceMean, computes the mean of the input tensor's element along 
the provided axes.
-    Args:
-        x: CTensor, input tensor.
-    Args:
-        axes: list of ints, A list of integers, along which to reduce. 
Accepted range is [-r, r-1] where r = rank(data).
-        The default is None, which reduces over all the dimensions of the 
input tensor.
+    Init a ReduceMean, computes the mean of the input tensor's element along 
+    the provided axes.
     Args:
-        keepdims: int, Keep the reduced dimension or not, default 1 mean keep 
reduced dimension.
+        x (Tensor): input tensor.
+        axes (list of int): A list of integers, along which to reduce. 
+            Accepted range is [-r, r-1] where r = rank(data). The default 
+            is None, which reduces over all the dimensions of the input tensor.
+        keepdims (int): Keep the reduced dimension or not, default 1 mean 
+            keep reduced dimension.
     Returns:
-        the output CTensor.
+        the output Tensor.
     """
     return ReduceMean(axes, keepdims)(x)[0]
 
 
 class Slice(Operation):
+    """
+    Init a Slice, Produces a slice of the input tensor along multiple axes. 
+    Similar to numpy: 
https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html
+    """
 
     def __init__(self, starts, ends, axes=None, steps=None):
         """
-        Init a Slice, Produces a slice of the input tensor along multiple 
axes. Similar to numpy: 
-        https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html
-        Args:
-            starts: list of ints, starting indices of corresponding axis
-        Args:
-            ends: list of ints, ending indices of corresponding axis
         Args:
-            axes: list of ints, axes that `starts` and `ends` apply to. 
-            Negative value means counting dimensions from the back. Accepted 
range is [-r, r-1] where r = rank(data).
-        Args:
-            steps: list of ints, slice step of corresponding axis in `axes`. 
-            Negative value means slicing backward. 'steps' cannot be 0. 
Defaults to 1.
+            starts (list of int): starting indices of corresponding axis
+            ends (list of int): ending indices of corresponding axis
+            axes (list of int): axes that `starts` and `ends` apply to. 
+                Negative value means counting dimensions from the back. 
+                Accepted range is [-r, r-1] where r = rank(data).
+            steps (list of int): slice step of corresponding axis in `axes`. 
+                Negative value means slicing backward. 'steps' cannot be 0. 
+                Defaults to 1.
         """
         super(Slice, self).__init__()
         self.starts = starts
@@ -3548,7 +4823,7 @@ class Slice(Operation):
         """
         forward of Slice
         Args:
-            x: CTensor, input tensor.
+            x (CTensor): input tensor.
         Returns:
             the output CTensor.
         """
@@ -3580,7 +4855,7 @@ class Slice(Operation):
         """
         backward of Slice
         Args:
-            dy: CTensor, gradient tensor.
+            dy (CTensor): gradient tensor.
         Returns:
             the gradient tensor over input tensor.
         """
@@ -3605,39 +4880,38 @@ class Slice(Operation):
 
 def slice(x, starts, ends, axes=None, steps=None):
     """
-    Init a Slice, Produces a slice of the input tensor along multiple axes. 
Similar to numpy: 
-    https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html
-    Args:
-        x: CTensor, input tensor.
-    Args:
-        starts: list of ints, starting indices of corresponding axis
-    Args:
-        ends: list of ints, ending indices of corresponding axis
+    Init a Slice, Produces a slice of the input tensor along multiple axes. 
+    Similar to numpy: 
https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html
     Args:
-        axes: list of ints, axes that `starts` and `ends` apply to. 
-        Negative value means counting dimensions from the back. Accepted range 
is [-r, r-1] where r = rank(data).
-    Args:
-        steps: list of ints, slice step of corresponding axis in `axes`. 
-        Negative value means slicing backward. 'steps' cannot be 0. Defaults 
to 1.
+        x (Tensor): input tensor.
+        starts (list of int): starting indices of corresponding axis
+        ends (list of int): ending indices of corresponding axis
+        axes (list of int): axes that `starts` and `ends` apply to. 
+            Negative value means counting dimensions from the back. 
+            Accepted range is [-r, r-1] where r = rank(data).
+        steps (list of int): slice step of corresponding axis in `axes`. 
+            Negative value means slicing backward. 'steps' cannot be 0. 
+            Defaults to 1.
     Returns:
-        the output CTensor.
+        the output Tensor.
     """
     return Slice(starts, ends, axes, steps)(x)[0]
 
 
 class Ceil(Operation):
+    """
+    Ceil takes one input data (Tensor) and produces one output data (Tensor) 
+    where the ceil is, y = ceil(x), is applied to the tensor elementwise.
+    """
 
     def __init__(self):
-        """
-        Ceil takes one input data (Tensor) and produces one output data 
(Tensor) where the ceil is, y = ceil(x), is applied to the tensor elementwise.
-        """
         super(Ceil, self).__init__()
 
     def forward(self, x):
         """
         forward of Ceil
         Args:
-            x: CTensor, input tensor.
+            x (CTensor): input tensor.
         Returns:
             the output CTensor.
         """
@@ -3647,7 +4921,7 @@ class Ceil(Operation):
         """
         backward of Ceil
         Args:
-            dy: CTensor, gradient tensor.
+            dy (CTensor): gradient tensor.
         Returns:
             the gradient tensor over input tensor.
         """
@@ -3658,28 +4932,33 @@ class Ceil(Operation):
 
 def ceil(x):
     """
-    Ceil takes one input data (Tensor) and produces one output data (Tensor) 
where the ceil is, y = ceil(x), is applied to the tensor elementwise.
+    Ceil takes one input data (Tensor) and produces one output data (Tensor) 
+    where the ceil is, y = ceil(x), is applied to the tensor elementwise.
     Args:
-        x: CTensor, input tensor.
+        x (Tensor): input tensor.
     Returns:
-        the output CTensor.
+        the output Tensor.
     """
     return Ceil()(x)[0]
 
 
 class Split(Operation):
+    """
+    Init a Split, Split a tensor into a list of tensors, along the specified 
+    'axis'. 
+    """
 
     def __init__(self, axis, parts, num_output=None):
         """
-        Init a Split, Split a tensor into a list of tensors, along the 
specified 'axis'. 
-        Args:
-            axis: int, Which axis to split on. A negative value means counting 
dimensions from the back. 
-            Accepted range is [-rank, rank-1] where r = rank(input).
         Args:
-            parts: list of ints, length of each output, which can be specified 
using argument 'parts'. 
-            Otherwise, the tensor is parts to equal sized parts.
-        Args:
-            num_output: once parts is none, the tensor is split to equal sized 
parts for each output.
+            axis (int): which axis to split on. A negative value means 
+                counting dimensions from the back. Accepted range is 
+                [-rank, rank-1] where r = rank(input).
+            parts (list of int): length of each output, which can be specified 
+                using argument 'parts'. Otherwise, the tensor is parts to 
equal 
+                sized parts.
+            num_output (bool): once parts is none, the tensor is split to 
equal 
+                sized parts for each output.
         """
         super(Split, self).__init__()
         self.axis = axis
@@ -3692,7 +4971,7 @@ class Split(Operation):
         """
         forward of Split
         Args:
-            x: CTensor, input tensor.
+            x (CTensor): input tensor.
         Returns:
             the output CTensor.
         """
@@ -3722,35 +5001,38 @@ class Split(Operation):
 
 def split(x, axis, parts, num_output=None):
     """
-    Init a Split, Split a tensor into a list of tensors, along the specified 
'axis'. 
-    Args:
-        x: CTensor, input tensor.
-    Args:
-        axis: int, Which axis to split on. A negative value means counting 
dimensions from the back. 
-        Accepted range is [-rank, rank-1] where r = rank(input).
-    Args:
-        parts: list of ints, length of each output, which can be specified 
using argument 'parts'. 
-        Otherwise, the tensor is split to equal sized parts.
+    Init a Split, Split a tensor into a list of tensors, along the specified 
+    'axis'. 
     Args:
-        num_output: once parts is none, the tensor is split to equal sized 
parts for each output.
+        x (Tensor): input tensor.
+        axis (int): which axis to split on. A negative value means 
+            counting dimensions from the back. Accepted range is 
+            [-rank, rank-1] where r = rank(input).
+        parts (list of int): length of each output, which can be specified 
+            using argument 'parts'. Otherwise, the tensor is parts to equal 
+            sized parts.
+        num_output (bool): once parts is none, the tensor is split to equal 
+            sized parts for each output.
     Returns:
-        the output CTensor.
+        the output Tensor.
     """
     return Split(axis, parts, num_output)(x)
 
 
 class Gather(Operation):
+    """
+    Init a Gather, Given data tensor of rank r >= 1, and indices tensor of 
+    rank q, gather entries of the axis dimension of data (by default 
outer-most 
+    one as axis=0) indexed by indices, and concatenates them in an output 
tensor of rank q + (r - 1).
+    """
 
     def __init__(self, axis, indices):
         """
-        Init a Gather, Given data tensor of rank r >= 1, and indices tensor of 
rank q, gather entries of 
-        the axis dimension of data (by default outer-most one as axis=0) 
indexed by indices,
-        and concatenates them in an output tensor of rank q + (r - 1).
         Args:
-            axis: int, Which axis to slice on. A negative value means counting 
dimensions from the back. 
-            Accepted range is [-rank, rank-1] where r = rank(input).
-        Args:
-            indices: list of ints, entries of the axis dimension of data.
+            axis (int): which axis to slice on. A negative value means 
counting 
+                dimensions from the back. Accepted range is [-rank, rank-1] 
+                where r = rank(input).
+            indices (list of int): entries of the axis dimension of data.
         """
         super(Gather, self).__init__()
         self.axis = axis
@@ -3760,7 +5042,7 @@ class Gather(Operation):
         """
         forward of Gather
         Args:
-            x: CTensor, input tensor.
+            x (CTensor): input tensor.
         Returns:
             the output CTensor.
         """
@@ -3792,7 +5074,7 @@ class Gather(Operation):
         """
         backward of Gather
         Args:
-            dy: CTensor, gradient tensor.
+            dy (CTensor): gradient tensor.
         Returns:
             the gradient tensor over input tensor.
         """
@@ -3839,31 +5121,33 @@ class Gather(Operation):
 
 def gather(x, axis, indices):
     """
-    Init a Gather, Given data tensor of rank r >= 1, and indices tensor of 
rank q, gather entries of 
-    the axis dimension of data (by default outer-most one as axis=0) indexed 
by indices,
-    and concatenates them in an output tensor of rank q + (r - 1).
-    Args:
-        x: CTensor, input tensor.
+    Init a Gather, Given data tensor of rank r >= 1, and indices tensor of 
+    rank q, gather entries of the axis dimension of data (by default 
outer-most 
+    one as axis=0) indexed by indices, and concatenates them in an output 
tensor of rank q + (r - 1).
     Args:
-        axis: int, Which axis to slice on. A negative value means counting 
dimensions from the back. 
-        Accepted range is [-rank, rank-1] where r = rank(input).
-    Args:
-        indices: list of ints, entries of the axis dimension of data.
+        x (Tensor): input tensor.
+        axis (int): which axis to slice on. A negative value means counting 
+            dimensions from the back. Accepted range is [-rank, rank-1] 
+            where r = rank(input).
+        indices (list of int): entries of the axis dimension of data.
     Returns:
-        the output CTensor.
+        the output Tensor.
     """
     return Gather(axis, indices)(x)[0]
 
 
 class Tile(Operation):
+    """
+    Init a Tile, Constructs a tensor by tiling a given tensor. This is the 
same 
+    as function tile in Numpy: 
https://docs.scipy.org/doc/numpy/reference/generated/numpy.tile.html
+    """
 
     def __init__(self, repeats):
         """
-        Init a Tile, Constructs a tensor by tiling a given tensor. This is the 
same as function tile in Numpy:
-        https://docs.scipy.org/doc/numpy/reference/generated/numpy.tile.html
         Args:
-            repeats: 1D int64 matrix of the same length as input's dimension 
number,
-            includes numbers of repeated copies along input's dimensions.
+            repeats (list of int): 1D int matrix of the same length as input's 
+                dimension number, includes numbers of repeated copies along 
+                input's dimensions.
         """
         super(Tile, self).__init__()
         self.repeats = [repeats] if isinstance(repeats, int) else repeats
@@ -3872,7 +5156,7 @@ class Tile(Operation):
         """
         forward of Tile
         Args:
-            x: CTensor, input tensor.
+            x (CTensor): input tensor.
         Returns:
             the output CTensor.
         """
@@ -3896,7 +5180,7 @@ class Tile(Operation):
         """
         backward of Tile
         Args:
-            dy: CTensor, gradient tensor.
+            dy (CTensor): gradient tensor.
         Returns:
             the gradient tensor over input tensor.
         """
@@ -3922,33 +5206,33 @@ class Tile(Operation):
 
 def tile(x, repeats):
     """
-    Init a Tile, Constructs a tensor by tiling a given tensor. This is the 
same as function tile in Numpy:
-    https://docs.scipy.org/doc/numpy/reference/generated/numpy.tile.html
-    Args:
-        x: CTensor, input tensor.
+    Init a Tile, Constructs a tensor by tiling a given tensor. This is the 
same 
+    as function tile in Numpy: 
https://docs.scipy.org/doc/numpy/reference/generated/numpy.tile.html
     Args:
-        repeats: 1D int64 matrix of the same length as input's dimension 
number,
-        includes numbers of repeated copies along input's dimensions.
+        x (Tensor): input tensor.
+        repeats (list of int): 1D int matrix of the same length as input's 
+            dimension number, includes numbers of repeated copies along 
+            input's dimensions.
     Returns:
-        the output CTensor.
+        the output Tensor.
     """
     return Tile(repeats)(x)[0]
 
 
 class NonZero(Operation):
+    """
+    Init a NonZero, Constructs a tensor by tiling a given tensor. This is the 
same 
+    as function tile in Numpy: 
https://docs.scipy.org/doc/numpy/reference/generated/numpy.tile.html
+    """
 
     def __init__(self):
-        """
-        Init a NonZero, Constructs a tensor by tiling a given tensor. This is 
the same as function tile in Numpy:
-        https://docs.scipy.org/doc/numpy/reference/generated/numpy.tile.html
-        """
         super(NonZero, self).__init__()
 
     def forward(self, x):
         """
         forward of NonZero
         Args:
-            x: CTensor, input tensor.
+            x (CTensor): input tensor.
         Returns:
             the output CTensor.
         """
@@ -3962,33 +5246,36 @@ class NonZero(Operation):
         """
         backward of NonZero
         Args:
-            dy: CTensor, gradient tensor.
-        Returns:
-            the gradient tensor over input tensor.
+            dy (CTensor): gradient tensor.
+        Raises:
+            AssertionError: no backward function for this operator
         """
         assert False, ('no gradient for backward function')
 
 
 def nonzero(x):
     """
-    Returns the indices of the elements that are non-zero (in row-major order 
- by dimension). 
-    NonZero behaves similar to numpy.nonzero: 
https://docs.scipy.org/doc/numpy/reference/generated/numpy.nonzero.html
+    Init a NonZero, Constructs a tensor by tiling a given tensor. This is the 
same 
+    as function tile in Numpy: 
https://docs.scipy.org/doc/numpy/reference/generated/numpy.tile.html
     Args:
-        x: CTensor, input tensor.
+        x (Tensor): input tensor.
     Returns:
-        the output CTensor.
+        the output Tensor.
     """
     return NonZero()(x)[0]
 
 
 class Cast(Operation):
+    """
+    The operator casts the elements of a given input tensor to a data type 
+    specified by the 'to' argument and returns an output tensor of the same 
+    size in the converted type.
+    """
 
     def __init__(self, to):
         """
-        The operator casts the elements of a given input tensor to a data type 
specified by the 'to' argument 
-        and returns an output tensor of the same size in the converted type.
         Args:
-            to: data type, 
+            to (int): data type, float32 = 0; int = 2.
         """
         super(Cast, self).__init__()
         self.to = to
@@ -3997,7 +5284,7 @@ class Cast(Operation):
         """
         forward of Cast
         Args:
-            x: CTensor, input tensor.
+            x (CTensor): input tensor.
         Returns:
             the output CTensor.
         """
@@ -4010,39 +5297,48 @@ class Cast(Operation):
         """
         backward of Cast
         Args:f
-            dy: CTensor, gradient tensor.
-        Returns:
-            the gradient tensor over input tensor.
+            dy (CTensor), gradient tensor.
+        Raises:
+            AssertionError: no backward function for this operator
         """
         assert False, ('no gradient for backward function')
 
 
 def cast(x, to):
     """
-    The operator casts the elements of a given input tensor to a data type 
specified by the 'to' argument 
-    and returns an output tensor of the same size in the converted type.
-    Args:x: 
-        CTensor, input tensor.
+    The operator casts the elements of a given input tensor to a data type 
+    specified by the 'to' argument and returns an output tensor of the same 
+    size in the converted type.
     Args:
-        to: data type
+        x (Tensor): input tensor.
+        to (int): data type, float32 = 0; int = 2.
     Returns:
-        the output CTensor.
+        the output Tensor.
     """
     return Cast(to)(x)[0]
 
 
 class OneHot(Operation):
+    """
+    Produces a one-hot tensor based on inputs. 
+    """
 
     def __init__(self, axis, depth, values):
         """
-        Produces a one-hot tensor based on inputs. 
-        Args:
-            axis: Axis along which one-hot representation in added. Default: 
axis=-1. 
-            axis=-1 means that the additional dimension will be inserted as 
the innermost/last dimension in the output tensor.
         Args:
-            values: Rank 1 tensor containing exactly two elements, in the 
format [off_value, on_value], 
-            where 'on_value' is the value used for filling locations specified 
in 'indices' input tensor, 
-            and 'off_value' is the value used for filling locations other than 
those specified in 'indices' input tensor.
+            axis (int): Axis along which one-hot representation in added. 
+                Default: axis=-1. axis=-1 means that the additional dimension 
+                will be inserted as the innermost/last dimension in the output 
+                tensor.
+            depth (int): Scalar specifying the number of classes in one-hot 
+                tensor. This is also the size of the one-hot dimension 
+                (specified by 'axis' attribute) added on in the output tensor. 
+                The values in the 'indices' input tensor are expected to be in 
+                the range [-depth, depth-1].
+            values (float): Rank 1 tensor containing exactly two elements, in 
+                the format [off_value, on_value], where 'on_value' is the 
+                value used for filling locations specified in 'indices' input 
+                tensor, 
         """
         super(OneHot, self).__init__()
         self.axis = axis
@@ -4051,13 +5347,11 @@ class OneHot(Operation):
 
     def forward(self, indices):
         """
-        forward of OneHot
-        ! borrow from onnx
+        forward of OneHot, we borrow this function from onnx
         Args:
-            indices: Scalar specifying the number of classes in one-hot 
tensor. 
-            This is also the size of the one-hot dimension (specified by 
'axis' attribute) added on in the output tensor. 
-            The values in the 'indices' input tensor are expected to be in the 
range [-depth, depth-1]. 
-            In case 'depth' is of non-integer type, it will be casted to int64 
before use.
+            indices (CTensor): Scalar specifying the number of classes in 
+                one-hot tensor. The values in the 'indices' input tensor are 
+                expected to be in the range [-depth, depth-1]. 
         Returns:
             the output CTensor.
         """
@@ -4080,10 +5374,10 @@ class OneHot(Operation):
     def backward(self, dy):
         """
         backward of OneHot
-        Args:f
-            dy: CTensor, gradient tensor.
-        Returns:
-            the gradient tensor over input tensor.
+        Args:
+            dy (CTensor):gradient tensor.
+        Raises:
+            AssertionError: no backward function for this operator
         """
         assert False, ('no gradient for backward function')
 
@@ -4092,18 +5386,23 @@ def onehot(axis, indices, depth, values):
     """
     Produces a one-hot tensor based on inputs. 
     Args:
-        axis: Axis along which one-hot representation in added. Default: 
axis=-1. 
-        axis=-1 means that the additional dimension will be inserted as the 
innermost/last dimension in the output tensor.
-    Args:
-        indices: Scalar specifying the number of classes in one-hot tensor. 
-        This is also the size of the one-hot dimension (specified by 'axis' 
attribute) added on in the output tensor. 
-        The values in the 'indices' input tensor are expected to be in the 
range [-depth, depth-1]. 
-        In case 'depth' is of non-integer type, it will be casted to int64 
before use.
-    Args:
-        values: Rank 1 tensor containing exactly two elements, in the format 
[off_value, on_value], 
-        where 'on_value' is the value used for filling locations specified in 
'indices' input tensor, 
-        and 'off_value' is the value used for filling locations other than 
those specified in 'indices' input tensor.
+        axis (int): Axis along which one-hot representation in added. 
+            Default: axis=-1. axis=-1 means that the additional dimension 
+            will be inserted as the innermost/last dimension in the output 
+            tensor.
+        indices (Tensor): Scalar specifying the number of classes in 
+            one-hot tensor. The values in the 'indices' input tensor are 
+            expected to be in the range [-depth, depth-1]. 
+        depth (int): Scalar specifying the number of classes in one-hot 
+            tensor. This is also the size of the one-hot dimension 
+            (specified by 'axis' attribute) added on in the output tensor. 
+            The values in the 'indices' input tensor are expected to be in 
+            the range [-depth, depth-1].
+        values (float): Rank 1 tensor containing exactly two elements, in 
+            the format [off_value, on_value], where 'on_value' is the 
+            value used for filling locations specified in 'indices' input 
+            tensor, 
     Returns:
-        the output CTensor.
+        the output Tensor.
     """
     return OneHot(axis, depth, values)(indices)[0]

Reply via email to