ymwangg commented on a change in pull request #7441: URL: https://github.com/apache/tvm/pull/7441#discussion_r583124937
########## File path: python/tvm/topi/cuda/unique.py ########## @@ -0,0 +1,394 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Unique operator""" +import tvm +from tvm import te, tir +from ...te import hybrid +from .scan import cumsum +from .sort import sort, argsort +from ..utils import ceil_div + + +def _calc_adjacent_diff_ir(data, output, binop=tir.Sub): + """Low level IR to calculate adjacent difference in an 1-D array. + + Parameters + ---------- + data : Buffer + Input 1-D Buffer. + + output: Buffer + A buffer to store adjacent difference, of the same shape as data. The adjacent difference + is defined as: output[0] = 0, output[i] = binop(data[i], data[i-1]) + where i > 0 and i < len(data). + + binop: function, optional + A binary associative op to use for calculating adjacent difference. The function takes two + TIR expressions and produce a new TIR expression. By default it uses tvm.tir.Sub to + compute the adjacent difference. + """ + ib = tir.ir_builder.create() + data_ptr = ib.buffer_ptr(data) + output_ptr = ib.buffer_ptr(output) + batch_size = data.shape[0] + max_threads = tir.min(batch_size, tvm.target.Target.current(allow_none=False).max_num_threads) + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(batch_size, max_threads) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + with ib.if_scope(tid < batch_size): + with ib.if_scope(tid == 0): + output_ptr[tid] = 0 + with ib.else_scope(): + output_ptr[tid] = tir.Cast(output.dtype, binop(data_ptr[tid], data_ptr[tid - 1])) + return ib.get() + + +def _calc_adjacent_diff(data, out_dtype="int32", binop=tir.Sub): + """Function calculate adjacent difference in an 1-D array. + + Parameters + ---------- + data : tvm.te.Tensor + Input 1-D tensor. + + output_dtype : str + The output tensor data type. + + binop: function, optional + A binary associative op to use for calculating difference. The function takes two + TIR expressions and produce a new TIR expression. By default it uses tvm.tir.Sub to + compute the adjacent difference. + + Returns + ------- + output : tvm.te.Tensor + 1-D tensor storing the adjacent difference of the input tensor. The adjacent difference + is defined as: output[0] = 0, output[i] = binop(data[i], data[i-1]) + where i > 0 and i < len(data). + """ + data_buf = tir.decl_buffer(data.shape, data.dtype, "sorted_data_buf", data_alignment=8) + output_buf = tir.decl_buffer(data.shape, out_dtype, "output_buf", data_alignment=8) + return te.extern( + [data.shape], + [data], + lambda ins, outs: _calc_adjacent_diff_ir(ins[0], outs[0], binop=binop), + dtype=[out_dtype], + in_buffers=[data_buf], + out_buffers=[output_buf], + name="_calc_adjacent_diff", + tag="_calc_adjacent_diff_gpu", + ) + + [email protected] +def _calc_num_unique(inc_scan): + """Helper function to get the number of unique elements fron inc_scan tensor""" + output = output_tensor((1,), "int32") + for i in bind("threadIdx.x", 1): + output[i] = inc_scan[inc_scan.shape[0] - 1] + int32(1) + return output + + +def _calc_unique_ir( + data, argsorted_indices, inc_scan, index_converter, unique_elements, indices, counts +): + """Low level IR to calculate unique elements, inverse indices, and counts (optional) of + unique elements of 1-D array. + + Parameters + ---------- + data : Buffer + Input 1-D Buffer. + + argsorted_indices : Buffer + A buffer that stores the argsorted indices of the input data. + + inc_scan : Buffer + A buffer that stores the inclusive scan of the binary tir.NE adjacent difference + of the sorted data. + + index_converter (optional) : Buffer + An optional index converter that transforms the unique element index + such that new_idx = index_converter[old_idx]. + + unique_elements : Buffer + A buffer that stores the unique elements. + + indices : Buffer + A buffer that stores the the index of each input data element in the unique element array. + + counts (optional) : Buffer + A buffer that stores the count of each unique element. + """ + ib = tir.ir_builder.create() + data_ptr = ib.buffer_ptr(data) + argsorted_indices_ptr = ib.buffer_ptr(argsorted_indices) + inc_scan_ptr = ib.buffer_ptr(inc_scan) + unique_elements_ptr = ib.buffer_ptr(unique_elements) + indices_ptr = ib.buffer_ptr(indices) + + index_converter_ptr = None + if isinstance(index_converter, tir.Buffer): + index_converter_ptr = ib.buffer_ptr(index_converter) + + if isinstance(counts, tir.Buffer): + counts_ptr = ib.buffer_ptr(counts) + arange_ptr = ib.allocate(counts_ptr.dtype, counts.shape, name="arange_buf", scope="global") + + batch_size = data.shape[0] + max_threads = tir.min(batch_size, tvm.target.Target.current(allow_none=False).max_num_threads) + + # calculate unique elements and inverse indices + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(batch_size, max_threads) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + with ib.if_scope(tid < batch_size): + data_idx = argsorted_indices_ptr[tid] + unique_idx = ( + inc_scan_ptr[tid] + if not index_converter_ptr + else index_converter_ptr[inc_scan_ptr[tid]] + ) + indices_ptr[data_idx] = unique_idx + with ib.if_scope(tid == 0): + unique_elements_ptr[unique_idx] = data_ptr[data_idx] + with ib.else_scope(): + with ib.if_scope(inc_scan_ptr[tid] != inc_scan_ptr[tid - 1]): + unique_elements_ptr[unique_idx] = data_ptr[data_idx] + + # if need to return counts + if isinstance(counts, tir.Buffer): + num_unique = inc_scan_ptr[inc_scan.shape[0] - 1] + 1 + num_elements = data.shape[0] + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(batch_size, max_threads) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + with ib.if_scope(tid < batch_size): + with ib.if_scope(tid == 0): + arange_ptr[num_unique - 1] = num_elements + with ib.else_scope(): + with ib.if_scope(inc_scan_ptr[tid] != inc_scan_ptr[tid - 1]): + arange_ptr[inc_scan_ptr[tid] - 1] = tid + with ib.new_scope(): + nthread_tx = max_threads + nthread_bx = ceil_div(batch_size, max_threads) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + with ib.if_scope(tid < num_unique): + unique_idx = tid if not index_converter_ptr else index_converter_ptr[tid] + with ib.if_scope(tid == 0): + counts_ptr[unique_idx] = arange_ptr[tid] + with ib.else_scope(): Review comment: I checked the cuda kernels and if they are executed in the order they appear in IR (I think they do), then I don't see any potential issues. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected]
