kevinthesun commented on a change in pull request #6868:
URL: https://github.com/apache/incubator-tvm/pull/6868#discussion_r526417838



##########
File path: python/tvm/topi/cuda/argwhere.py
##########
@@ -0,0 +1,621 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=too-many-arguments, invalid-name
+"""Argwhere operator"""
+
+import logging
+
+import tvm
+from tvm import te
+from tvm._ffi import get_global_func
+from .injective import schedule_injective_from_existing
+from .nms import atomic_add
+from .sort import topk, topk_thrust, argsort, argsort_thrust
+from .. import tag
+from ..transform import strided_slice, adv_index, squeeze
+
+logger = logging.getLogger("topi")
+
+
+def _get_sort_func(mode=0):
+    """Get sort function for argwhere. mode 0 for topk and others for 
argsort."""
+    if get_global_func("tvm.contrib.thrust.sort", allow_missing=True):
+        ret = topk_thrust if mode == 0 else argsort_thrust
+    else:
+        logger.warn(
+            "It's highly recommended to enable thrust library with 
set(USE_THRUST ON)"
+            " when compiling argwhere for cuda target. Otherwise, it can 
result in"
+            " significant performance degradation or incorrect result"
+        )
+        ret = topk if mode == 0 else argsort
+
+    return ret
+
+
+def argwhere_1d_ir(condition, out):
+    """Low level IR for argwhere 1D
+
+    Parameters
+    ----------
+    condition : Buffer
+        The condition buffer.
+
+    out : Buffer
+        The output buffer.
+
+    Returns
+    -------
+    stmt : Stmt
+        The result IR statement.
+    """
+    ib = tvm.tir.ir_builder.create()
+    a0 = condition.shape[0]
+
+    condition = ib.buffer_ptr(condition)
+    out = ib.buffer_ptr(out)
+
+    valid_index = ib.allocate("int32", (1,), name="valid_index", 
scope="global")
+    tmp = ib.allocate("int32", (1,), name="tmp", scope="local")
+    one_count = tvm.tir.const(1, dtype="int32")
+
+    max_threads = 
int(tvm.target.Target.current(allow_none=False).max_num_threads)
+    nthread_tx = max_threads
+    # Limit threads to a single block to make sure atomic_add works normally.
+    tx = te.thread_axis("threadIdx.x")
+    ib.scope_attr(tx, "thread_extent", nthread_tx)
+    len_inner_for = a0 // nthread_tx + 1
+    valid_index[0] = 0
+
+    with ib.for_range(0, len_inner_for, name="i") as i:
+        idx = tx * len_inner_for + i
+        with ib.if_scope(idx < a0):
+            with ib.if_scope(condition[idx] != 0):
+                tmp[0] = atomic_add(
+                    tvm.tir.call_intrin("handle", "tir.address_of", 
valid_index[0]),
+                    one_count,
+                )
+                out[tmp[0]] = idx
+
+    return ib.get()
+
+
+def argwhere_1d(output_shape, condition):
+    """Compute for argwhere 1D
+
+    Parameters
+    ----------
+    condition : list of int or tvm.tir.Any
+        The output shape
+
+    out : tvm.te.Tensor
+        Tensor with boolean values.
+
+    Returns
+    -------
+    stmt : Stmt
+        The result IR statement.
+    """
+    condition_buf = tvm.tir.decl_buffer(
+        condition.shape, condition.dtype, "data_buf", data_alignment=8
+    )
+    out_buf = tvm.tir.decl_buffer(output_shape, "int32", "out_buf", 
data_alignment=8)
+
+    out = te.extern(
+        [output_shape],
+        [condition],
+        lambda ins, outs: argwhere_1d_ir(ins[0], outs[0]),
+        dtype=["int32"],
+        in_buffers=[condition_buf],
+        out_buffers=[out_buf],
+        name="argwhere_1d",
+        tag="argwhere1d_gpu",
+    )
+
+    if out.shape[0] <= 1:
+        return out
+
+    sorted_out = _get_sort_func()(
+        out, k=0, axis=0, ret_type="values", is_ascend="True", dtype="int32"
+    )
+
+    return sorted_out
+
+
+def argwhere_2d_ir(condition, out):
+    """Low level IR for argwhere 2D
+
+    Parameters
+    ----------
+    condition : Buffer
+        The condition buffer.
+
+    out : Buffer
+        The output buffer.
+
+    Returns
+    -------
+    stmt : Stmt
+        The result IR statement.
+    """
+    ib = tvm.tir.ir_builder.create()
+    a0 = condition.shape[0]
+    a1 = condition.shape[1]
+
+    condition = ib.buffer_ptr(condition)
+    out = ib.buffer_ptr(out)
+
+    valid_index = ib.allocate("int32", (1,), name="valid_index", scope="local")
+    tmp = ib.allocate("int32", (1,), name="tmp", scope="local")
+    one_count = tvm.tir.const(1, dtype="int32")
+
+    max_threads = 
int(tvm.target.Target.current(allow_none=False).max_num_threads)
+    nthread_tx = max_threads
+
+    # Limit threads to a single block to make sure atomic_add works normally.

Review comment:
       The observation is that if input data size is large( > 300 * 300 for 
example), previous we don't limit the number of blocks and the output of IR 
routine would be incorrect. I didn't dig deeper into it at this time.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to