masahi commented on a change in pull request #7441: URL: https://github.com/apache/tvm/pull/7441#discussion_r581459643
########## File path: python/tvm/topi/cuda/unique.py ########## @@ -0,0 +1,384 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, no-else-return +"""Unique operator""" +from tvm import te, tir +import tvm + +from ...te import hybrid +from .scan import cumsum +from .sort import sort, argsort +from ..utils import ceil_div + + +def _calc_adjacent_diff_ir(data, adjacent_diff): + ib = tvm.tir.ir_builder.create() + data_ptr = ib.buffer_ptr(data) + adjacent_diff_ptr = ib.buffer_ptr(adjacent_diff) + batch_size = data.shape[0] + max_threads = tir.min(batch_size, tvm.target.Target.current(allow_none=False).max_num_threads) + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(batch_size, max_threads) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + with ib.if_scope(tid < batch_size): + with ib.if_scope(tid == 0): + adjacent_diff_ptr[tid] = 0 + with ib.else_scope(): + with ib.if_scope(data_ptr[tid] != data_ptr[tid - 1]): + adjacent_diff_ptr[tid] = 1 + with ib.else_scope(): + adjacent_diff_ptr[tid] = 0 + return ib.get() + + [email protected] +def _calc_num_unique(data): + output = output_tensor((1,), "int32") + for i in bind("threadIdx.x", 1): + output[i] = data[data.shape[0] - 1] + int32(1) + return output + + +def _calc_unique_sorted_ir(data, argsorted_indices, inc_scan, unique_elements, indices): + ib = tvm.tir.ir_builder.create() + data_ptr = ib.buffer_ptr(data) + argsorted_indices_ptr = ib.buffer_ptr(argsorted_indices) + inc_scan_ptr = ib.buffer_ptr(inc_scan) + unique_elements_ptr = ib.buffer_ptr(unique_elements) + indices_ptr = ib.buffer_ptr(indices) + + batch_size = data.shape[0] + max_threads = tir.min(batch_size, tvm.target.Target.current(allow_none=False).max_num_threads) + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(batch_size, max_threads) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + with ib.if_scope(tid < batch_size): + indices_ptr[argsorted_indices_ptr[tid]] = inc_scan_ptr[tid] + with ib.if_scope(tid == 0): + unique_elements_ptr[inc_scan_ptr[tid]] = data_ptr[argsorted_indices_ptr[tid]] + with ib.else_scope(): + with ib.if_scope(inc_scan_ptr[tid] != inc_scan_ptr[tid - 1]): + unique_elements_ptr[inc_scan_ptr[tid]] = data_ptr[argsorted_indices_ptr[tid]] + return ib.get() + + +def _calc_counts_sorted_ir(inc_scan, counts): + ib = tvm.tir.ir_builder.create() + inc_scan_ptr = ib.buffer_ptr(inc_scan) + counts_ptr = ib.buffer_ptr(counts) + + batch_size = inc_scan.shape[0] + max_threads = tir.min(batch_size, tvm.target.Target.current(allow_none=False).max_num_threads) + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(batch_size, max_threads) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + with ib.if_scope(tid < batch_size): + counts_ptr[tid] = 0 + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(batch_size, max_threads) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + atomic_add_return = ib.allocate(counts.dtype, (1,), name="atomic_add_return", scope="local") + with ib.if_scope(tid < batch_size): + index = inc_scan_ptr[tid] + atomic_add_return[0] = tvm.tir.call_intrin( + counts.dtype, + "tir.atomic_add", + tvm.tir.call_intrin("handle", "tir.address_of", counts_ptr[index]), + 1, + ) + return ib.get() + + +def _calc_first_occurence_ir(argsorted_indices, inc_scan, first_occurence): + ib = tvm.tir.ir_builder.create() + argsorted_indices_ptr = ib.buffer_ptr(argsorted_indices) + inc_scan_ptr = ib.buffer_ptr(inc_scan) + first_occurence_ptr = ib.buffer_ptr(first_occurence) + batch_size = argsorted_indices.shape[0] + max_threads = tir.min(batch_size, tvm.target.Target.current(allow_none=False).max_num_threads) + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(batch_size, max_threads) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + with ib.if_scope(tid < batch_size): + first_occurence_ptr[tid] = batch_size + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(batch_size, max_threads) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + with ib.if_scope(tid < batch_size): + with ib.if_scope(tid == 0): + first_occurence_ptr[inc_scan_ptr[tid]] = argsorted_indices_ptr[tid] + with ib.else_scope(): + with ib.if_scope(inc_scan_ptr[tid] != inc_scan_ptr[tid - 1]): + first_occurence_ptr[inc_scan_ptr[tid]] = argsorted_indices_ptr[tid] + return ib.get() + + +def _calc_unique_unsorted_ir( + data, argsorted_indices, inc_scan, index_converter, unique_elements, indices +): + ib = tvm.tir.ir_builder.create() + data_ptr = ib.buffer_ptr(data) + argsorted_indices_ptr = ib.buffer_ptr(argsorted_indices) + inc_scan_ptr = ib.buffer_ptr(inc_scan) + index_converter_ptr = ib.buffer_ptr(index_converter) + unique_elements_ptr = ib.buffer_ptr(unique_elements) + indices_ptr = ib.buffer_ptr(indices) + + batch_size = data.shape[0] + max_threads = tir.min(batch_size, tvm.target.Target.current(allow_none=False).max_num_threads) + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(batch_size, max_threads) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + with ib.if_scope(tid < batch_size): + indices_ptr[argsorted_indices_ptr[tid]] = index_converter_ptr[inc_scan_ptr[tid]] + with ib.if_scope(tid == 0): + unique_elements_ptr[index_converter_ptr[inc_scan_ptr[tid]]] = data_ptr[ + argsorted_indices_ptr[tid] + ] + with ib.else_scope(): + with ib.if_scope(inc_scan_ptr[tid] != inc_scan_ptr[tid - 1]): + unique_elements_ptr[index_converter_ptr[inc_scan_ptr[tid]]] = data_ptr[ + argsorted_indices_ptr[tid] + ] + return ib.get() + + +def _calc_counts_unsorted_ir(inc_scan, index_converter, counts): Review comment: This looks similar to `_calc_counts_sorted_ir`, maybe you can do some trick around `index_converter` to share the implementation? ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected]
