tkonolige commented on a change in pull request #7562:
URL: https://github.com/apache/tvm/pull/7562#discussion_r585724228



##########
File path: python/tvm/relay/op/transform.py
##########
@@ -1450,6 +1450,61 @@ def sparse_reshape(sparse_indices, prev_shape, 
new_shape):
     return TupleWrapper(_make.sparse_reshape(sparse_indices, prev_shape, 
new_shape), 2)
 
 
+def segment_sum(data, indices, num_segments=None):
+    """
+    Computes the sum along segments of a tensor. This op is much better 
understood with
+    visualization articulated in the following links and examples at the end 
of this docstring.
+
+    
https://www.tensorflow.org/api_docs/python/tf/raw_ops/UnsortedSegmentSum?hl=fr
+    
https://caffe2.ai/docs/sparse-operations.html#null__unsorted-segment-reduction-ops
+
+    Parameters
+    ----------
+    data : relay.Expr

Review comment:
       Is this tensor 1D? If not, is the reduction over all dimensions? Or just 
one?

##########
File path: python/tvm/relay/op/transform.py
##########
@@ -1450,6 +1450,61 @@ def sparse_reshape(sparse_indices, prev_shape, 
new_shape):
     return TupleWrapper(_make.sparse_reshape(sparse_indices, prev_shape, 
new_shape), 2)
 
 
+def segment_sum(data, indices, num_segments=None):
+    """
+    Computes the sum along segments of a tensor. This op is much better 
understood with
+    visualization articulated in the following links and examples at the end 
of this docstring.
+
+    
https://www.tensorflow.org/api_docs/python/tf/raw_ops/UnsortedSegmentSum?hl=fr

Review comment:
       Did you mean to link to the French docs?

##########
File path: tests/python/frontend/tensorflow/test_forward.py
##########
@@ -2080,6 +2080,140 @@ def test_forward_sparse_reshape(
     _test_sparse_reshape(sparse_indices_np, sparse_values_np, prev_shape_np, 
new_shape_np, use_dyn)
 
 
+#######################################################################
+# Sparse SegmentSum
+# ------------
+
+
+def _test_sparse_segment_sum(data_np, indices_np, segment_ids_np, 
num_segments, use_dyn=False):
+    with tf.Graph().as_default():
+        if use_dyn:
+            data = tf.placeholder(
+                shape=[None for _ in data_np.shape], dtype=data_np.dtype, 
name="data"
+            )
+            indices = tf.placeholder(shape=[None], dtype=indices_np.dtype, 
name="indices")
+            segment_ids = tf.placeholder(
+                shape=(None), dtype=segment_ids_np.dtype, name="segment_ids"
+            )
+        else:
+            data = tf.placeholder(shape=data_np.shape, dtype=data_np.dtype, 
name="data")
+            indices = tf.placeholder(shape=indices_np.shape, 
dtype=indices_np.dtype, name="indices")
+            segment_ids = tf.placeholder(
+                shape=segment_ids_np.shape, dtype=segment_ids_np.dtype, 
name="segment_ids"
+            )
+
+        _ = tf.sparse.segment_sum(
+            data, indices, segment_ids, num_segments=num_segments, 
name="sparse_segment_sum"
+        )
+        compare_tf_with_tvm(
+            [data_np, indices_np, segment_ids_np],
+            [data.name, indices.name, segment_ids.name],
+            ["sparse_segment_sum:0"],
+            mode="vm",
+        )
+
+
[email protected](
+    "data_np, indices_np, segment_ids_np, num_segments",

Review comment:
       Could you have one example with `int64` inputs for the indices array?

##########
File path: tests/python/relay/test_op_level3.py
##########
@@ -1515,6 +1516,95 @@ def verify_sparse_reshape(
     )
 
 
[email protected]_gpu
[email protected](
+    "data_np, indices_np, num_segments",
+    [
+        (
+            np.array([5, 1, 7, 2, 3, 4], dtype=np.float32),
+            np.array([0, 0, 1, 1, 0, 1], dtype=np.int32),
+            None,
+        ),
+        (
+            np.array([[1, 2, 3, 4], [-1, -2, -3, -4], [5, 6, 7, 8]], 
dtype=np.float64),
+            np.array([0, 0, 1], dtype=np.int32),
+            None,
+        ),
+        (
+            np.random.random((6, 4, 5)),
+            np.array([2, 0, 1, 0, 3, 2], dtype=np.int32),
+            None,
+        ),
+        (
+            np.array([[[1, 7]], [[3, 8]], [[2, 9]]], dtype=np.float32),
+            np.array([0, 0, 1], dtype=np.int32),
+            None,
+        ),
+        (
+            np.random.random((9, 4, 5, 7)),
+            np.array([5, 0, 1, 0, 3, 6, 8, 7, 7], dtype=np.int32),
+            9,
+        ),
+    ],
+)
[email protected]("use_dyn", [True, False])
+def test_segment_sum(data_np, indices_np, num_segments, use_dyn):
+    def ref_segment_sum(
+        data: np.ndarray,
+        indices: np.ndarray,
+        num_segments: Optional[int] = None,
+    ):
+        """
+        This function calculates the expected output of sparseshape operator 
given the inputs.

Review comment:
       ```suggestion
           This function calculates the expected output of segment_sum operator 
given the inputs.
   ```

##########
File path: python/tvm/relay/op/transform.py
##########
@@ -1450,6 +1450,61 @@ def sparse_reshape(sparse_indices, prev_shape, 
new_shape):
     return TupleWrapper(_make.sparse_reshape(sparse_indices, prev_shape, 
new_shape), 2)
 
 
+def segment_sum(data, indices, num_segments=None):
+    """
+    Computes the sum along segments of a tensor. This op is much better 
understood with
+    visualization articulated in the following links and examples at the end 
of this docstring.
+
+    
https://www.tensorflow.org/api_docs/python/tf/raw_ops/UnsortedSegmentSum?hl=fr
+    
https://caffe2.ai/docs/sparse-operations.html#null__unsorted-segment-reduction-ops
+
+    Parameters
+    ----------
+    data : relay.Expr
+        Input floating point data
+    indices : relay.Expr
+        A 1-D tensor containing the indices of the rows to calculate the 
output sum upon.
+        This tensor doesn't need to be sorted
+    num_segments : Optional[int]
+        An integer describing the shape of the zeroth dimension. If 
unspecified, its calculated
+        equivalent to the number of unique indices
+    Returns
+    -------
+    result: relay.Expr
+        Output tensor.
+    Examples
+    --------
+    .. code-block:: python
+        data = [[1, 2, 3, 4],
+                [4, -3, 2, -1],
+                [5, 6, 7, 8]]
+        indices = [0, 0, 1]
+        result = segment_sum(data, indices)
+        result = [[5, -1, 5, 3],[5, 6, 7, 8]]
+
+        data = [[1, 2, 3, 4],
+                [4, -3, 2, -1],
+                [5, 6, 7, 8]]
+        indices = [2, 0, 0]
+        num_segments = 3
+        result = segment_sum(data, indices, num_segments)
+        result = [[5, 6, 7, 8],[0, 0, 0, 0], [5, -1, 5, 3]]
+    """
+
+    if num_segments:
+        num_unique = const([num_segments])
+    else:
+        _, _, num_unique = unique(reshape(indices, -1))
+    data_offrow_shape = strided_slice(_make.shape_of(data, "int64"), [1], 
[-1], slice_mode="size")

Review comment:
       Should the `int64` here match the input datatype?

##########
File path: python/tvm/relay/op/transform.py
##########
@@ -1450,6 +1450,61 @@ def sparse_reshape(sparse_indices, prev_shape, 
new_shape):
     return TupleWrapper(_make.sparse_reshape(sparse_indices, prev_shape, 
new_shape), 2)
 
 
+def segment_sum(data, indices, num_segments=None):
+    """
+    Computes the sum along segments of a tensor. This op is much better 
understood with
+    visualization articulated in the following links and examples at the end 
of this docstring.

Review comment:
       I think it would be best if we had a mathematical description of what's 
happening.

##########
File path: tests/python/relay/test_op_level3.py
##########
@@ -1515,6 +1516,95 @@ def verify_sparse_reshape(
     )
 
 
[email protected]_gpu
[email protected](
+    "data_np, indices_np, num_segments",
+    [
+        (
+            np.array([5, 1, 7, 2, 3, 4], dtype=np.float32),
+            np.array([0, 0, 1, 1, 0, 1], dtype=np.int32),
+            None,
+        ),
+        (
+            np.array([[1, 2, 3, 4], [-1, -2, -3, -4], [5, 6, 7, 8]], 
dtype=np.float64),
+            np.array([0, 0, 1], dtype=np.int32),
+            None,
+        ),
+        (
+            np.random.random((6, 4, 5)),
+            np.array([2, 0, 1, 0, 3, 2], dtype=np.int32),
+            None,
+        ),
+        (
+            np.array([[[1, 7]], [[3, 8]], [[2, 9]]], dtype=np.float32),
+            np.array([0, 0, 1], dtype=np.int32),
+            None,
+        ),
+        (
+            np.random.random((9, 4, 5, 7)),
+            np.array([5, 0, 1, 0, 3, 6, 8, 7, 7], dtype=np.int32),
+            9,
+        ),
+    ],
+)
[email protected]("use_dyn", [True, False])
+def test_segment_sum(data_np, indices_np, num_segments, use_dyn):
+    def ref_segment_sum(
+        data: np.ndarray,
+        indices: np.ndarray,
+        num_segments: Optional[int] = None,
+    ):
+        """
+        This function calculates the expected output of sparseshape operator 
given the inputs.
+        """
+        if not num_segments:
+            num_segments = np.unique(indices).shape[0]
+
+        result = np.zeros((num_segments,) + data.shape[1:], data.dtype)
+        for i, index in enumerate(indices):
+            result[index] += data[i]
+        return result
+
+    def verify_segment_sum(
+        data_np: np.ndarray, indices_np: np.ndarray, num_segments: 
Optional[int]
+    ):
+        """
+        This function verifies the relay output of sparse_reshape with its 
expected output.

Review comment:
       ```suggestion
           This function verifies the relay output of segment_sum with its 
expected output.
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to