eric-haibin-lin commented on a change in pull request #10208: [MXNET-117] 
Sparse operator broadcast_mul/div(csr, dense) = csr
URL: https://github.com/apache/incubator-mxnet/pull/10208#discussion_r178997359
 
 

 ##########
 File path: python/mxnet/ndarray/sparse.py
 ##########
 @@ -1159,6 +1186,322 @@ def _ndarray_cls(handle, writable=True, 
stype=_STORAGE_TYPE_UNDEFINED):
 _set_ndarray_class(_ndarray_cls)
 
 
+def add(lhs, rhs):
+    """Returns element-wise sum of the input arrays with broadcasting.
+
+    Equivalent to ``lhs + rhs``, ``mx.nd.broadcast_add(lhs, rhs)`` and
+    ``mx.nd.broadcast_plus(lhs, rhs)`` when shapes of lhs and rhs do not
+    match. If lhs.shape == rhs.shape, this is equivalent to
+    ``mx.nd.elemwise_add(lhs, rhs)``
+
+    .. note::
+
+        If the corresponding dimensions of two arrays have the same size or 
one of them has size 1,
+        then the arrays are broadcastable to a common shape.abs
+
+    Parameters
+    ----------
+    lhs : scalar or array
+        First array to be added.
+    rhs : scalar or array
+         Second array to be added.
+        If ``lhs.shape != rhs.shape``, they must be
+        broadcastable to a common shape.
+
+    Returns
+    -------
+    NDArray
+        The element-wise sum of the input arrays.
+
+    Examples
+    --------
+    >>> x = mx.nd.ones((2,3)).tostype('csr')
+    >>> y = mx.nd.arange(2).reshape((2,1))
+    >>> z = mx.nd.arange(2).reshape((1,2))
+    >>> x.asnumpy()
+    array([[ 1.,  1.,  1.],
+           [ 1.,  1.,  1.]], dtype=float32)
+    >>> y.asnumpy()
+    array([[ 0.],
+           [ 1.]], dtype=float32)
+    >>> z.asnumpy()
+    array([[ 0.,  1.]], dtype=float32)
+    >>> (x+2).asnumpy()
+    array([[ 3.,  3.,  3.],
+           [ 3.,  3.,  3.]], dtype=float32)
+    >>> (x+y).asnumpy()
+    array([[ 1.,  1.,  1.],
+           [ 2.,  2.,  2.]], dtype=float32)
+    >>> mx.nd.add(x,y).asnumpy()
+    array([[ 1.,  1.,  1.],
+           [ 2.,  2.,  2.]], dtype=float32)
+    >>> (z + y).asnumpy()
+    array([[ 0.,  1.],
+           [ 1.,  2.]], dtype=float32)
+    """
+    # pylint: disable= no-member, protected-access
+    if isinstance(lhs, NDArray) and isinstance(rhs, NDArray) and lhs.shape == 
rhs.shape:
+        return _ufunc_helper(
+            lhs,
+            rhs,
+            op.elemwise_add,
+            operator.add,
+            _internal._plus_scalar,
+            None)
+
+    return _ufunc_helper(
+        lhs,
+        rhs,
+        op.broadcast_add,
+        operator.add,
+        _internal._plus_scalar,
+        None)
+    # pylint: enable= no-member, protected-access
+
+
+def subtract(lhs, rhs):
+    """Returns element-wise difference of the input arrays with broadcasting.
+
+    Equivalent to ``lhs - rhs``, ``mx.nd.broadcast_sub(lhs, rhs)`` and
+    ``mx.nd.broadcast_minus(lhs, rhs)`` when shapes of lhs and rhs do not
+    match. If lhs.shape == rhs.shape, this is equivalent to
+    ``mx.nd.elemwise_sub(lhs, rhs)``
+
+    .. note::
+
+        If the corresponding dimensions of two arrays have the same size or 
one of them has size 1,
+        then the arrays are broadcastable to a common shape.
+
+    Parameters
+    ----------
+    lhs : scalar or array
+        First array to be subtracted.
+    rhs : scalar or array
+         Second array to be subtracted.
+        If ``lhs.shape != rhs.shape``, they must be
+        broadcastable to a common shape.__spec__
+
+    Returns
+    -------
+    NDArray
+        The element-wise difference of the input arrays.
+
+    Examples
+    --------
+    >>> x = mx.nd.ones((2,3)).tostype('csr')
+    >>> y = mx.nd.arange(2).reshape((2,1))
 
 Review comment:
   csr - dense will fallback..

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to