[GitHub] [incubator-tvm] kevinthesun commented on a change in pull request #6868: [WIP][TOPI][OP] cuda for argwhere
kevinthesun commented on a change in pull request #6868: URL: https://github.com/apache/incubator-tvm/pull/6868#discussion_r526417838 ## File path: python/tvm/topi/cuda/argwhere.py ## @@ -0,0 +1,621 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=too-many-arguments, invalid-name +"""Argwhere operator""" + +import logging + +import tvm +from tvm import te +from tvm._ffi import get_global_func +from .injective import schedule_injective_from_existing +from .nms import atomic_add +from .sort import topk, topk_thrust, argsort, argsort_thrust +from .. import tag +from ..transform import strided_slice, adv_index, squeeze + +logger = logging.getLogger("topi") + + +def _get_sort_func(mode=0): +"""Get sort function for argwhere. mode 0 for topk and others for argsort.""" +if get_global_func("tvm.contrib.thrust.sort", allow_missing=True): +ret = topk_thrust if mode == 0 else argsort_thrust +else: +logger.warn( +"It's highly recommended to enable thrust library with set(USE_THRUST ON)" +" when compiling argwhere for cuda target. Otherwise, it can result in" +" significant performance degradation or incorrect result" +) +ret = topk if mode == 0 else argsort + +return ret + + +def argwhere_1d_ir(condition, out): +"""Low level IR for argwhere 1D + +Parameters +-- +condition : Buffer +The condition buffer. + +out : Buffer +The output buffer. + +Returns +--- +stmt : Stmt +The result IR statement. +""" +ib = tvm.tir.ir_builder.create() +a0 = condition.shape[0] + +condition = ib.buffer_ptr(condition) +out = ib.buffer_ptr(out) + +valid_index = ib.allocate("int32", (1,), name="valid_index", scope="global") +tmp = ib.allocate("int32", (1,), name="tmp", scope="local") +one_count = tvm.tir.const(1, dtype="int32") + +max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) +nthread_tx = max_threads +# Limit threads to a single block to make sure atomic_add works normally. +tx = te.thread_axis("threadIdx.x") +ib.scope_attr(tx, "thread_extent", nthread_tx) +len_inner_for = a0 // nthread_tx + 1 +valid_index[0] = 0 + +with ib.for_range(0, len_inner_for, name="i") as i: +idx = tx * len_inner_for + i +with ib.if_scope(idx < a0): +with ib.if_scope(condition[idx] != 0): +tmp[0] = atomic_add( +tvm.tir.call_intrin("handle", "tir.address_of", valid_index[0]), +one_count, +) +out[tmp[0]] = idx + +return ib.get() + + +def argwhere_1d(output_shape, condition): +"""Compute for argwhere 1D + +Parameters +-- +condition : list of int or tvm.tir.Any +The output shape + +out : tvm.te.Tensor +Tensor with boolean values. + +Returns +--- +stmt : Stmt +The result IR statement. +""" +condition_buf = tvm.tir.decl_buffer( +condition.shape, condition.dtype, "data_buf", data_alignment=8 +) +out_buf = tvm.tir.decl_buffer(output_shape, "int32", "out_buf", data_alignment=8) + +out = te.extern( +[output_shape], +[condition], +lambda ins, outs: argwhere_1d_ir(ins[0], outs[0]), +dtype=["int32"], +in_buffers=[condition_buf], +out_buffers=[out_buf], +name="argwhere_1d", +tag="argwhere1d_gpu", +) + +if out.shape[0] <= 1: +return out + +sorted_out = _get_sort_func()( +out, k=0, axis=0, ret_type="values", is_ascend="True", dtype="int32" +) + +return sorted_out + + +def argwhere_2d_ir(condition, out): +"""Low level IR for argwhere 2D + +Parameters +-- +condition : Buffer +The condition buffer. + +out : Buffer +The output buffer. + +Returns +--- +stmt : Stmt +The result IR statement. +""" +ib = tvm.tir.ir_builder.create() +a0 = condition.shape[0] +a1 = condition.shape[1] + +condition = ib.buffer_ptr(condition) +out = ib.buffer_ptr(out) + +
[GitHub] [incubator-tvm] kevinthesun commented on a change in pull request #6868: [WIP][TOPI][OP] cuda for argwhere
kevinthesun commented on a change in pull request #6868: URL: https://github.com/apache/incubator-tvm/pull/6868#discussion_r526417838 ## File path: python/tvm/topi/cuda/argwhere.py ## @@ -0,0 +1,621 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=too-many-arguments, invalid-name +"""Argwhere operator""" + +import logging + +import tvm +from tvm import te +from tvm._ffi import get_global_func +from .injective import schedule_injective_from_existing +from .nms import atomic_add +from .sort import topk, topk_thrust, argsort, argsort_thrust +from .. import tag +from ..transform import strided_slice, adv_index, squeeze + +logger = logging.getLogger("topi") + + +def _get_sort_func(mode=0): +"""Get sort function for argwhere. mode 0 for topk and others for argsort.""" +if get_global_func("tvm.contrib.thrust.sort", allow_missing=True): +ret = topk_thrust if mode == 0 else argsort_thrust +else: +logger.warn( +"It's highly recommended to enable thrust library with set(USE_THRUST ON)" +" when compiling argwhere for cuda target. Otherwise, it can result in" +" significant performance degradation or incorrect result" +) +ret = topk if mode == 0 else argsort + +return ret + + +def argwhere_1d_ir(condition, out): +"""Low level IR for argwhere 1D + +Parameters +-- +condition : Buffer +The condition buffer. + +out : Buffer +The output buffer. + +Returns +--- +stmt : Stmt +The result IR statement. +""" +ib = tvm.tir.ir_builder.create() +a0 = condition.shape[0] + +condition = ib.buffer_ptr(condition) +out = ib.buffer_ptr(out) + +valid_index = ib.allocate("int32", (1,), name="valid_index", scope="global") +tmp = ib.allocate("int32", (1,), name="tmp", scope="local") +one_count = tvm.tir.const(1, dtype="int32") + +max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) +nthread_tx = max_threads +# Limit threads to a single block to make sure atomic_add works normally. +tx = te.thread_axis("threadIdx.x") +ib.scope_attr(tx, "thread_extent", nthread_tx) +len_inner_for = a0 // nthread_tx + 1 +valid_index[0] = 0 + +with ib.for_range(0, len_inner_for, name="i") as i: +idx = tx * len_inner_for + i +with ib.if_scope(idx < a0): +with ib.if_scope(condition[idx] != 0): +tmp[0] = atomic_add( +tvm.tir.call_intrin("handle", "tir.address_of", valid_index[0]), +one_count, +) +out[tmp[0]] = idx + +return ib.get() + + +def argwhere_1d(output_shape, condition): +"""Compute for argwhere 1D + +Parameters +-- +condition : list of int or tvm.tir.Any +The output shape + +out : tvm.te.Tensor +Tensor with boolean values. + +Returns +--- +stmt : Stmt +The result IR statement. +""" +condition_buf = tvm.tir.decl_buffer( +condition.shape, condition.dtype, "data_buf", data_alignment=8 +) +out_buf = tvm.tir.decl_buffer(output_shape, "int32", "out_buf", data_alignment=8) + +out = te.extern( +[output_shape], +[condition], +lambda ins, outs: argwhere_1d_ir(ins[0], outs[0]), +dtype=["int32"], +in_buffers=[condition_buf], +out_buffers=[out_buf], +name="argwhere_1d", +tag="argwhere1d_gpu", +) + +if out.shape[0] <= 1: +return out + +sorted_out = _get_sort_func()( +out, k=0, axis=0, ret_type="values", is_ascend="True", dtype="int32" +) + +return sorted_out + + +def argwhere_2d_ir(condition, out): +"""Low level IR for argwhere 2D + +Parameters +-- +condition : Buffer +The condition buffer. + +out : Buffer +The output buffer. + +Returns +--- +stmt : Stmt +The result IR statement. +""" +ib = tvm.tir.ir_builder.create() +a0 = condition.shape[0] +a1 = condition.shape[1] + +condition = ib.buffer_ptr(condition) +out = ib.buffer_ptr(out) + +
[GitHub] [incubator-tvm] kevinthesun commented on a change in pull request #6868: [WIP][TOPI][OP] cuda for argwhere
kevinthesun commented on a change in pull request #6868: URL: https://github.com/apache/incubator-tvm/pull/6868#discussion_r526417838 ## File path: python/tvm/topi/cuda/argwhere.py ## @@ -0,0 +1,621 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=too-many-arguments, invalid-name +"""Argwhere operator""" + +import logging + +import tvm +from tvm import te +from tvm._ffi import get_global_func +from .injective import schedule_injective_from_existing +from .nms import atomic_add +from .sort import topk, topk_thrust, argsort, argsort_thrust +from .. import tag +from ..transform import strided_slice, adv_index, squeeze + +logger = logging.getLogger("topi") + + +def _get_sort_func(mode=0): +"""Get sort function for argwhere. mode 0 for topk and others for argsort.""" +if get_global_func("tvm.contrib.thrust.sort", allow_missing=True): +ret = topk_thrust if mode == 0 else argsort_thrust +else: +logger.warn( +"It's highly recommended to enable thrust library with set(USE_THRUST ON)" +" when compiling argwhere for cuda target. Otherwise, it can result in" +" significant performance degradation or incorrect result" +) +ret = topk if mode == 0 else argsort + +return ret + + +def argwhere_1d_ir(condition, out): +"""Low level IR for argwhere 1D + +Parameters +-- +condition : Buffer +The condition buffer. + +out : Buffer +The output buffer. + +Returns +--- +stmt : Stmt +The result IR statement. +""" +ib = tvm.tir.ir_builder.create() +a0 = condition.shape[0] + +condition = ib.buffer_ptr(condition) +out = ib.buffer_ptr(out) + +valid_index = ib.allocate("int32", (1,), name="valid_index", scope="global") +tmp = ib.allocate("int32", (1,), name="tmp", scope="local") +one_count = tvm.tir.const(1, dtype="int32") + +max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) +nthread_tx = max_threads +# Limit threads to a single block to make sure atomic_add works normally. +tx = te.thread_axis("threadIdx.x") +ib.scope_attr(tx, "thread_extent", nthread_tx) +len_inner_for = a0 // nthread_tx + 1 +valid_index[0] = 0 + +with ib.for_range(0, len_inner_for, name="i") as i: +idx = tx * len_inner_for + i +with ib.if_scope(idx < a0): +with ib.if_scope(condition[idx] != 0): +tmp[0] = atomic_add( +tvm.tir.call_intrin("handle", "tir.address_of", valid_index[0]), +one_count, +) +out[tmp[0]] = idx + +return ib.get() + + +def argwhere_1d(output_shape, condition): +"""Compute for argwhere 1D + +Parameters +-- +condition : list of int or tvm.tir.Any +The output shape + +out : tvm.te.Tensor +Tensor with boolean values. + +Returns +--- +stmt : Stmt +The result IR statement. +""" +condition_buf = tvm.tir.decl_buffer( +condition.shape, condition.dtype, "data_buf", data_alignment=8 +) +out_buf = tvm.tir.decl_buffer(output_shape, "int32", "out_buf", data_alignment=8) + +out = te.extern( +[output_shape], +[condition], +lambda ins, outs: argwhere_1d_ir(ins[0], outs[0]), +dtype=["int32"], +in_buffers=[condition_buf], +out_buffers=[out_buf], +name="argwhere_1d", +tag="argwhere1d_gpu", +) + +if out.shape[0] <= 1: +return out + +sorted_out = _get_sort_func()( +out, k=0, axis=0, ret_type="values", is_ascend="True", dtype="int32" +) + +return sorted_out + + +def argwhere_2d_ir(condition, out): +"""Low level IR for argwhere 2D + +Parameters +-- +condition : Buffer +The condition buffer. + +out : Buffer +The output buffer. + +Returns +--- +stmt : Stmt +The result IR statement. +""" +ib = tvm.tir.ir_builder.create() +a0 = condition.shape[0] +a1 = condition.shape[1] + +condition = ib.buffer_ptr(condition) +out = ib.buffer_ptr(out) + +
[GitHub] [incubator-tvm] kevinthesun commented on a change in pull request #6868: [WIP][TOPI][OP] cuda for argwhere
kevinthesun commented on a change in pull request #6868: URL: https://github.com/apache/incubator-tvm/pull/6868#discussion_r526417838 ## File path: python/tvm/topi/cuda/argwhere.py ## @@ -0,0 +1,621 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=too-many-arguments, invalid-name +"""Argwhere operator""" + +import logging + +import tvm +from tvm import te +from tvm._ffi import get_global_func +from .injective import schedule_injective_from_existing +from .nms import atomic_add +from .sort import topk, topk_thrust, argsort, argsort_thrust +from .. import tag +from ..transform import strided_slice, adv_index, squeeze + +logger = logging.getLogger("topi") + + +def _get_sort_func(mode=0): +"""Get sort function for argwhere. mode 0 for topk and others for argsort.""" +if get_global_func("tvm.contrib.thrust.sort", allow_missing=True): +ret = topk_thrust if mode == 0 else argsort_thrust +else: +logger.warn( +"It's highly recommended to enable thrust library with set(USE_THRUST ON)" +" when compiling argwhere for cuda target. Otherwise, it can result in" +" significant performance degradation or incorrect result" +) +ret = topk if mode == 0 else argsort + +return ret + + +def argwhere_1d_ir(condition, out): +"""Low level IR for argwhere 1D + +Parameters +-- +condition : Buffer +The condition buffer. + +out : Buffer +The output buffer. + +Returns +--- +stmt : Stmt +The result IR statement. +""" +ib = tvm.tir.ir_builder.create() +a0 = condition.shape[0] + +condition = ib.buffer_ptr(condition) +out = ib.buffer_ptr(out) + +valid_index = ib.allocate("int32", (1,), name="valid_index", scope="global") +tmp = ib.allocate("int32", (1,), name="tmp", scope="local") +one_count = tvm.tir.const(1, dtype="int32") + +max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) +nthread_tx = max_threads +# Limit threads to a single block to make sure atomic_add works normally. +tx = te.thread_axis("threadIdx.x") +ib.scope_attr(tx, "thread_extent", nthread_tx) +len_inner_for = a0 // nthread_tx + 1 +valid_index[0] = 0 + +with ib.for_range(0, len_inner_for, name="i") as i: +idx = tx * len_inner_for + i +with ib.if_scope(idx < a0): +with ib.if_scope(condition[idx] != 0): +tmp[0] = atomic_add( +tvm.tir.call_intrin("handle", "tir.address_of", valid_index[0]), +one_count, +) +out[tmp[0]] = idx + +return ib.get() + + +def argwhere_1d(output_shape, condition): +"""Compute for argwhere 1D + +Parameters +-- +condition : list of int or tvm.tir.Any +The output shape + +out : tvm.te.Tensor +Tensor with boolean values. + +Returns +--- +stmt : Stmt +The result IR statement. +""" +condition_buf = tvm.tir.decl_buffer( +condition.shape, condition.dtype, "data_buf", data_alignment=8 +) +out_buf = tvm.tir.decl_buffer(output_shape, "int32", "out_buf", data_alignment=8) + +out = te.extern( +[output_shape], +[condition], +lambda ins, outs: argwhere_1d_ir(ins[0], outs[0]), +dtype=["int32"], +in_buffers=[condition_buf], +out_buffers=[out_buf], +name="argwhere_1d", +tag="argwhere1d_gpu", +) + +if out.shape[0] <= 1: +return out + +sorted_out = _get_sort_func()( +out, k=0, axis=0, ret_type="values", is_ascend="True", dtype="int32" +) + +return sorted_out + + +def argwhere_2d_ir(condition, out): +"""Low level IR for argwhere 2D + +Parameters +-- +condition : Buffer +The condition buffer. + +out : Buffer +The output buffer. + +Returns +--- +stmt : Stmt +The result IR statement. +""" +ib = tvm.tir.ir_builder.create() +a0 = condition.shape[0] +a1 = condition.shape[1] + +condition = ib.buffer_ptr(condition) +out = ib.buffer_ptr(out) + +
[GitHub] [incubator-tvm] kevinthesun commented on a change in pull request #6868: [WIP][TOPI][OP] cuda for argwhere
kevinthesun commented on a change in pull request #6868: URL: https://github.com/apache/incubator-tvm/pull/6868#discussion_r526417838 ## File path: python/tvm/topi/cuda/argwhere.py ## @@ -0,0 +1,621 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=too-many-arguments, invalid-name +"""Argwhere operator""" + +import logging + +import tvm +from tvm import te +from tvm._ffi import get_global_func +from .injective import schedule_injective_from_existing +from .nms import atomic_add +from .sort import topk, topk_thrust, argsort, argsort_thrust +from .. import tag +from ..transform import strided_slice, adv_index, squeeze + +logger = logging.getLogger("topi") + + +def _get_sort_func(mode=0): +"""Get sort function for argwhere. mode 0 for topk and others for argsort.""" +if get_global_func("tvm.contrib.thrust.sort", allow_missing=True): +ret = topk_thrust if mode == 0 else argsort_thrust +else: +logger.warn( +"It's highly recommended to enable thrust library with set(USE_THRUST ON)" +" when compiling argwhere for cuda target. Otherwise, it can result in" +" significant performance degradation or incorrect result" +) +ret = topk if mode == 0 else argsort + +return ret + + +def argwhere_1d_ir(condition, out): +"""Low level IR for argwhere 1D + +Parameters +-- +condition : Buffer +The condition buffer. + +out : Buffer +The output buffer. + +Returns +--- +stmt : Stmt +The result IR statement. +""" +ib = tvm.tir.ir_builder.create() +a0 = condition.shape[0] + +condition = ib.buffer_ptr(condition) +out = ib.buffer_ptr(out) + +valid_index = ib.allocate("int32", (1,), name="valid_index", scope="global") +tmp = ib.allocate("int32", (1,), name="tmp", scope="local") +one_count = tvm.tir.const(1, dtype="int32") + +max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) +nthread_tx = max_threads +# Limit threads to a single block to make sure atomic_add works normally. +tx = te.thread_axis("threadIdx.x") +ib.scope_attr(tx, "thread_extent", nthread_tx) +len_inner_for = a0 // nthread_tx + 1 +valid_index[0] = 0 + +with ib.for_range(0, len_inner_for, name="i") as i: +idx = tx * len_inner_for + i +with ib.if_scope(idx < a0): +with ib.if_scope(condition[idx] != 0): +tmp[0] = atomic_add( +tvm.tir.call_intrin("handle", "tir.address_of", valid_index[0]), +one_count, +) +out[tmp[0]] = idx + +return ib.get() + + +def argwhere_1d(output_shape, condition): +"""Compute for argwhere 1D + +Parameters +-- +condition : list of int or tvm.tir.Any +The output shape + +out : tvm.te.Tensor +Tensor with boolean values. + +Returns +--- +stmt : Stmt +The result IR statement. +""" +condition_buf = tvm.tir.decl_buffer( +condition.shape, condition.dtype, "data_buf", data_alignment=8 +) +out_buf = tvm.tir.decl_buffer(output_shape, "int32", "out_buf", data_alignment=8) + +out = te.extern( +[output_shape], +[condition], +lambda ins, outs: argwhere_1d_ir(ins[0], outs[0]), +dtype=["int32"], +in_buffers=[condition_buf], +out_buffers=[out_buf], +name="argwhere_1d", +tag="argwhere1d_gpu", +) + +if out.shape[0] <= 1: +return out + +sorted_out = _get_sort_func()( +out, k=0, axis=0, ret_type="values", is_ascend="True", dtype="int32" +) + +return sorted_out + + +def argwhere_2d_ir(condition, out): +"""Low level IR for argwhere 2D + +Parameters +-- +condition : Buffer +The condition buffer. + +out : Buffer +The output buffer. + +Returns +--- +stmt : Stmt +The result IR statement. +""" +ib = tvm.tir.ir_builder.create() +a0 = condition.shape[0] +a1 = condition.shape[1] + +condition = ib.buffer_ptr(condition) +out = ib.buffer_ptr(out) + +
[GitHub] [incubator-tvm] kevinthesun commented on a change in pull request #6868: [WIP][TOPI][OP] cuda for argwhere
kevinthesun commented on a change in pull request #6868: URL: https://github.com/apache/incubator-tvm/pull/6868#discussion_r519055922 ## File path: python/tvm/topi/cuda/argwhere.py ## @@ -0,0 +1,621 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=too-many-arguments, invalid-name +"""Argwhere operator""" + +import logging + +import tvm +from tvm import te +from tvm._ffi import get_global_func +from .injective import schedule_injective_from_existing +from .nms import atomic_add +from .sort import topk, topk_thrust, argsort, argsort_thrust +from .. import tag +from ..transform import strided_slice, adv_index, squeeze + +logger = logging.getLogger("topi") + + +def _get_sort_func(mode=0): +"""Get sort function for argwhere. mode 0 for topk and others for argsort.""" +if get_global_func("tvm.contrib.thrust.sort", allow_missing=True): +ret = topk_thrust if mode == 0 else argsort_thrust +else: +logger.warn( +"It's highly recommended to enable thrust library with set(USE_THRUST ON)" +" when compiling argwhere for cuda target. Otherwise, it can result in" +" significant performance degradation or incorrect result" +) +ret = topk if mode == 0 else argsort + +return ret + + +def argwhere_1d_ir(condition, out): +"""Low level IR for argwhere 1D + +Parameters +-- +condition : Buffer +The condition buffer. + +out : Buffer +The output buffer. + +Returns +--- +stmt : Stmt +The result IR statement. +""" +ib = tvm.tir.ir_builder.create() +a0 = condition.shape[0] + +condition = ib.buffer_ptr(condition) +out = ib.buffer_ptr(out) + +valid_index = ib.allocate("int32", (1,), name="valid_index", scope="global") +tmp = ib.allocate("int32", (1,), name="tmp", scope="local") +one_count = tvm.tir.const(1, dtype="int32") + +max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) +nthread_tx = max_threads +# Limit threads to a single block to make sure atomic_add works normally. +tx = te.thread_axis("threadIdx.x") +ib.scope_attr(tx, "thread_extent", nthread_tx) +len_inner_for = a0 // nthread_tx + 1 +valid_index[0] = 0 + +with ib.for_range(0, len_inner_for, name="i") as i: +idx = tx * len_inner_for + i +with ib.if_scope(idx < a0): +with ib.if_scope(condition[idx] != 0): +tmp[0] = atomic_add( +tvm.tir.call_intrin("handle", "tir.address_of", valid_index[0]), +one_count, +) +out[tmp[0]] = idx + +return ib.get() + + +def argwhere_1d(output_shape, condition): +"""Compute for argwhere 1D + +Parameters +-- +condition : list of int or tvm.tir.Any +The output shape + +out : tvm.te.Tensor +Tensor with boolean values. + +Returns +--- +stmt : Stmt +The result IR statement. +""" +condition_buf = tvm.tir.decl_buffer( +condition.shape, condition.dtype, "data_buf", data_alignment=8 +) +out_buf = tvm.tir.decl_buffer(output_shape, "int32", "out_buf", data_alignment=8) + +out = te.extern( +[output_shape], +[condition], +lambda ins, outs: argwhere_1d_ir(ins[0], outs[0]), +dtype=["int32"], +in_buffers=[condition_buf], +out_buffers=[out_buf], +name="argwhere_1d", +tag="argwhere1d_gpu", +) + +if out.shape[0] <= 1: +return out + +sorted_out = _get_sort_func()( +out, k=0, axis=0, ret_type="values", is_ascend="True", dtype="int32" +) + +return sorted_out + + +def argwhere_2d_ir(condition, out): +"""Low level IR for argwhere 2D + +Parameters +-- +condition : Buffer +The condition buffer. + +out : Buffer +The output buffer. + +Returns +--- +stmt : Stmt +The result IR statement. +""" +ib = tvm.tir.ir_builder.create() +a0 = condition.shape[0] +a1 = condition.shape[1] + +condition = ib.buffer_ptr(condition) +out = ib.buffer_ptr(out) + +
[GitHub] [incubator-tvm] kevinthesun commented on a change in pull request #6868: [WIP][TOPI][OP] cuda for argwhere
kevinthesun commented on a change in pull request #6868: URL: https://github.com/apache/incubator-tvm/pull/6868#discussion_r519055922 ## File path: python/tvm/topi/cuda/argwhere.py ## @@ -0,0 +1,621 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=too-many-arguments, invalid-name +"""Argwhere operator""" + +import logging + +import tvm +from tvm import te +from tvm._ffi import get_global_func +from .injective import schedule_injective_from_existing +from .nms import atomic_add +from .sort import topk, topk_thrust, argsort, argsort_thrust +from .. import tag +from ..transform import strided_slice, adv_index, squeeze + +logger = logging.getLogger("topi") + + +def _get_sort_func(mode=0): +"""Get sort function for argwhere. mode 0 for topk and others for argsort.""" +if get_global_func("tvm.contrib.thrust.sort", allow_missing=True): +ret = topk_thrust if mode == 0 else argsort_thrust +else: +logger.warn( +"It's highly recommended to enable thrust library with set(USE_THRUST ON)" +" when compiling argwhere for cuda target. Otherwise, it can result in" +" significant performance degradation or incorrect result" +) +ret = topk if mode == 0 else argsort + +return ret + + +def argwhere_1d_ir(condition, out): +"""Low level IR for argwhere 1D + +Parameters +-- +condition : Buffer +The condition buffer. + +out : Buffer +The output buffer. + +Returns +--- +stmt : Stmt +The result IR statement. +""" +ib = tvm.tir.ir_builder.create() +a0 = condition.shape[0] + +condition = ib.buffer_ptr(condition) +out = ib.buffer_ptr(out) + +valid_index = ib.allocate("int32", (1,), name="valid_index", scope="global") +tmp = ib.allocate("int32", (1,), name="tmp", scope="local") +one_count = tvm.tir.const(1, dtype="int32") + +max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) +nthread_tx = max_threads +# Limit threads to a single block to make sure atomic_add works normally. +tx = te.thread_axis("threadIdx.x") +ib.scope_attr(tx, "thread_extent", nthread_tx) +len_inner_for = a0 // nthread_tx + 1 +valid_index[0] = 0 + +with ib.for_range(0, len_inner_for, name="i") as i: +idx = tx * len_inner_for + i +with ib.if_scope(idx < a0): +with ib.if_scope(condition[idx] != 0): +tmp[0] = atomic_add( +tvm.tir.call_intrin("handle", "tir.address_of", valid_index[0]), +one_count, +) +out[tmp[0]] = idx + +return ib.get() + + +def argwhere_1d(output_shape, condition): +"""Compute for argwhere 1D + +Parameters +-- +condition : list of int or tvm.tir.Any +The output shape + +out : tvm.te.Tensor +Tensor with boolean values. + +Returns +--- +stmt : Stmt +The result IR statement. +""" +condition_buf = tvm.tir.decl_buffer( +condition.shape, condition.dtype, "data_buf", data_alignment=8 +) +out_buf = tvm.tir.decl_buffer(output_shape, "int32", "out_buf", data_alignment=8) + +out = te.extern( +[output_shape], +[condition], +lambda ins, outs: argwhere_1d_ir(ins[0], outs[0]), +dtype=["int32"], +in_buffers=[condition_buf], +out_buffers=[out_buf], +name="argwhere_1d", +tag="argwhere1d_gpu", +) + +if out.shape[0] <= 1: +return out + +sorted_out = _get_sort_func()( +out, k=0, axis=0, ret_type="values", is_ascend="True", dtype="int32" +) + +return sorted_out + + +def argwhere_2d_ir(condition, out): +"""Low level IR for argwhere 2D + +Parameters +-- +condition : Buffer +The condition buffer. + +out : Buffer +The output buffer. + +Returns +--- +stmt : Stmt +The result IR statement. +""" +ib = tvm.tir.ir_builder.create() +a0 = condition.shape[0] +a1 = condition.shape[1] + +condition = ib.buffer_ptr(condition) +out = ib.buffer_ptr(out) + +