mbrookhart commented on a change in pull request #7477:
URL: https://github.com/apache/tvm/pull/7477#discussion_r583167241



##########
File path: tests/python/relay/test_op_level3.py
##########
@@ -1311,6 +1311,232 @@ def verify_sparse_to_dense(sparse_indices, 
sparse_values, default_value, output_
     # verify_sparse_to_dense([[[[0, 1, 4], [0, 2, 4]]]], [[[[3.1, 3.1, 
3.1]]]], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1])
 
 
[email protected]_gpu
[email protected](
+    "sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np",
+    [
+        (
+            np.array([[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0], [1, 2, 3]], 
dtype=np.int64),
+            np.array([7, 5, 6, 3, 9], dtype=np.int64),
+            np.array([2, 3, 6], dtype=np.int64),
+            np.array([9, -1], dtype=np.int64),
+        ),
+        (
+            np.array(
+                [[0, 0, 0, 0], [0, 0, 1, 2], [0, 1, 0, 3], [1, 0, 0, 4], [1, 
2, 3, 6]],
+                dtype=np.int64,
+            ),
+            np.array([7, 5, 6, 3, 9], dtype=np.int64),
+            np.array([2, 3, 6, 7], dtype=np.int64),
+            np.array([9, -1, 7], dtype=np.int64),
+        ),
+        (
+            np.array(
+                [
+                    [0, 0, 0, 0, 0],
+                    [0, 0, 1, 2, 3],
+                    [0, 1, 0, 3, 5],
+                    [1, 0, 0, 4, 6],
+                    [1, 2, 3, 6, 8],
+                ],
+                dtype=np.int64,
+            ),
+            np.array([7, 5, 6, 3, 9], dtype=np.int64),
+            np.array([2, 3, 6, 7, 9], dtype=np.int64),
+            np.array([9, -1, 7], dtype=np.int64),
+        ),
+        (
+            np.array([[0, 0], [0, 1], [3, 4], [4, 3], [7, 3]], dtype=np.int64),
+            np.array([7, 5, 6, 3, 9], dtype=np.int64),
+            np.array([9, 4], dtype=np.int64),
+            np.array([2, -1, 6], dtype=np.int64),
+        ),
+        (
+            np.array([[0, 0], [0, 1], [3, 4], [4, 3], [7, 3]], dtype=np.int64),
+            np.array([7, 5, 6, 3, 9], dtype=np.int64),
+            np.array([9, 4], dtype=np.int64),
+            np.array([-1], dtype=np.int64),
+        ),
+        (
+            np.array([[0], [5], [10], [20], [24]], dtype=np.int64),
+            np.array([7, 5, 6, 3, 9], dtype=np.int64),
+            np.array([25], dtype=np.int64),
+            np.array([5, 5], dtype=np.int64),
+        ),
+        (
+            np.array([[0, 100], [200, 100], [300, 400], [50, 20], [400, 50]], 
dtype=np.int64),
+            np.array([7, 5, 6, 3, 9], dtype=np.int64),
+            np.array([500, 20], dtype=np.int64),
+            np.array([500, 20], dtype=np.int64),
+        ),
+        (
+            np.array([[0, 100], [200, 100], [300, 400], [50, 20], [400, 50]], 
dtype=np.int64),
+            np.array([7, 5, 6, 3, 9], dtype=np.int64),
+            np.array([500, 20], dtype=np.int64),
+            np.array([500, -1], dtype=np.int64),
+        ),
+        (
+            np.array([[0, 100], [200, 100], [300, 400], [50, 20], [400, 50]], 
dtype=np.int64),
+            np.array([7, 5, 6, 3, 9], dtype=np.int64),
+            np.array([500, 20], dtype=np.int64),
+            np.array([250, 40], dtype=np.int64),
+        ),
+        (
+            np.ones((0, 1), dtype=np.int64),
+            np.array([], dtype=np.int64),
+            np.array([4], dtype=np.int64),
+            np.array([2, -1], dtype=np.int64),
+        ),
+        (
+            np.ones((0, 1), dtype=np.int64),
+            np.array([], dtype=np.int64),
+            np.array([4], dtype=np.int64),
+            np.array([2, 2], dtype=np.int64),
+        ),
+        (
+            np.ones((0, 2), dtype=np.int64),
+            np.array([], dtype=np.int64),
+            np.array([3, 6], dtype=np.int64),
+            np.array([-1, 2], dtype=np.int64),
+        ),
+    ],
+)
[email protected]("dtype", [np.int32, np.int64])
[email protected]("use_dyn", [True, False])
+def test_sparse_reshape(
+    sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np, dtype, 
use_dyn
+):
+    def ref_sparse_reshape(

Review comment:
       Probably not worth it to duplicate the tests between TF and topi? Maybe 
run a subset?

##########
File path: python/tvm/topi/sparse_reshape.py
##########
@@ -0,0 +1,185 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks
+"""Sparse_Reshape operator"""
+from ..tir import decl_buffer, ir_builder, Cast
+from ..te import extern, div, floordiv, floormod
+
+
+def sparse_reshape(
+    sparse_indices,
+    prev_shape,
+    new_shape,
+    new_sparse_indices_shape,
+    new_shape_shape,
+):
+    """
+    Reshape a Sparse Tensor
+    Parameters
+    ----------
+    sparse_indices : relay.Expr
+        A 2-D tensor[N, n_dim] of integers containing location of sparse 
values, where N is the
+        number of sparse values and n_dim is the number of dimensions of the 
dense_shape
+    prev_shape : relay.Expr
+        A 1-D tensor containing the previous shape of the dense tensor
+    new_shape : relay.Expr
+        A 1-D tensor containing the new shape of the dense tensor
+    Returns
+    -------
+    result: relay.Expr
+        Output tensor.
+    Examples
+    --------
+    .. code-block:: python
+        sparse_indices = [[0, 0, 0],
+                            [0, 0, 1],
+                            [0, 1, 0],
+                            [1, 0, 0],
+                            [1, 2, 3]]
+        prev_shape = [2, 3, 4]
+        new_shape = [9, -1]
+        new_sparse_indices, new_shape = relay.sparse_reshape(sparse_indices,
+                            prev_shape,
+                            new_shape)
+        new_sparse_indices = [[0, 0],
+                              [0, 1],
+                              [1, 2],
+                              [4, 2],
+                              [8, 1]]
+        new_shape = [9, 4]
+    """
+
+    def gen_ir(

Review comment:
       I'm wondering if it would help to parallelize some of the for loops on 
CPU like we do on GPU? On CPU you can use the parallel tag in 
ir_builder.for_range. Not sure if this is enough of a performance improvement 
to justify it.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to