AndrewZhaoLuo commented on a change in pull request #9184: URL: https://github.com/apache/tvm/pull/9184#discussion_r726673423
########## File path: python/tvm/topi/searchsorted.py ########## @@ -0,0 +1,121 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""searchsorted operator""" +from . import utils +from . import te +from ..tir import ir_builder +from .math import cast + + +def binary_search( + ib, sequence_offset, search_range, index, sorted_sequence, values, out_indices, side, out_dtype Review comment: might want to provide a brief docstring on these variables since they are not immediately obvious to me and this uses mutation which is kind of odd ########## File path: include/tvm/relay/attrs/algorithm.h ########## @@ -76,6 +76,19 @@ struct TopKAttrs : public tvm::AttrsNode<TopKAttrs> { } }; +struct SearchSortedAttrs : public tvm::AttrsNode<SearchSortedAttrs> { + std::string side; + DataType dtype; + + TVM_DECLARE_ATTRS(SearchSortedAttrs, "relay.attrs.SearchSortedAttrs") { + TVM_ATTR_FIELD(side).set_default("left").describe( + "Controls which index is returned if a value lands exactly on one of sorted values."); Review comment: I think it would be nice to give more detail, just copy the docstrings from the other frameworks. E.g. ``` If ‘left’, the index of the first suitable location found is given. If ‘right’, return the last such index. If there is no suitable index, return either 0 or N (where N is the length of a). ``` ########## File path: include/tvm/relay/attrs/algorithm.h ########## @@ -76,6 +76,19 @@ struct TopKAttrs : public tvm::AttrsNode<TopKAttrs> { } }; +struct SearchSortedAttrs : public tvm::AttrsNode<SearchSortedAttrs> { + std::string side; + DataType dtype; + + TVM_DECLARE_ATTRS(SearchSortedAttrs, "relay.attrs.SearchSortedAttrs") { + TVM_ATTR_FIELD(side).set_default("left").describe( + "Controls which index is returned if a value lands exactly on one of sorted values."); + TVM_ATTR_FIELD(dtype) + .set_default(DataType::Int(32)) Review comment: Hmmm, just curious if there is any convention on the dtype of indices, there is a lot of index code with dyn gather I believe has all the indices in Int(64). Int(64) might be a better default. The other attributes in this file have `NullValue<DataType>()` as the default value which is interesting. ########## File path: tests/python/relay/test_op_level6.py ########## @@ -149,5 +150,28 @@ def verify_topk(k, axis, ret_type, is_ascend, dtype, in_dtype="float32"): verify_topk(k, axis, ret_type, False, "int64", "float16") [email protected]_gpu +def test_searchsorted(): + def verify_searchsorted(side, dtype): + shape = (10, 20, 100) Review comment: can we make this smaller, e.g. (8, 9, 10) ########## File path: python/tvm/relay/op/algorithm.py ########## @@ -115,3 +115,36 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int32"): if ret_type == "both": return TupleWrapper(out, 2) return out + + +def searchsorted(sorted_sequence, values, side="left", dtype="int32"): + """Find indices where elements should be inserted to maintain order. + If `sorted_sequence` is N-dimensional, the innermost dimension of + `values` are searched in the corresponding dimension of `sorted_sequence`. + + Parameters + ---------- + sorted_sequence : relay.Expr + N-D or 1-D Tensor, containing monotonically increasing sequence + on the innermost dimension. + + values : relay.Expr + N-D Tensor containing the search values. When `sorted_sequence` is 1-D, + the shape of `values` can be arbitrary. Otherwise, ranks of `sorted_sequence` + and `values` must be the same, and outer N-1 axes must have the same size. + + side : string, optional + It can be `left` or `right`. If `left`, gets the lower bound index for each value + in `values` on the corresponding innermost dimension of the `sorted_sequence`. + If `right`, gets the upper bound index instead. + + dtype : string, optional + The data type of the output indices. + + Returns + ------- + indices : relay.Expr + Tensor with same shape as values, representing the indices of + elements of `values` if they are inserted in `sorted_sequence`. + """ Review comment: nit: check if side is valid ########## File path: include/tvm/relay/attrs/algorithm.h ########## @@ -76,6 +76,19 @@ struct TopKAttrs : public tvm::AttrsNode<TopKAttrs> { } }; +struct SearchSortedAttrs : public tvm::AttrsNode<SearchSortedAttrs> { + std::string side; Review comment: Can we make this a bool? Since it's either left or right. In general throughout the code (except maybe at the frontend level) can we make this bool? ########## File path: python/tvm/topi/searchsorted.py ########## @@ -0,0 +1,121 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""searchsorted operator""" +from . import utils +from . import te +from ..tir import ir_builder +from .math import cast + + +def binary_search( + ib, sequence_offset, search_range, index, sorted_sequence, values, out_indices, side, out_dtype +): + """Common IR generator for CPU and GPU searchsorted.""" + lo = ib.allocate(out_dtype, (1,), name="lo", scope="local") + hi = ib.allocate(out_dtype, (1,), name="hi", scope="local") + + v = values[index] + lo[0] = cast(0, out_dtype) + hi[0] = cast(search_range, out_dtype) + + # Reference: pytorch/aten/src/ATen/native/cuda/Bucketization.cu + def condition(current_val, target_val): + if side == "left": + return current_val < target_val + return current_val <= target_val + + with ib.while_loop(lo[0] < hi[0]): + mid = lo[0] + (hi[0] - lo[0] >> 1) + with ib.if_scope(condition(sorted_sequence[sequence_offset + mid], v)): + lo[0] = mid + 1 + with ib.else_scope(): + hi[0] = mid + + out_indices[index] = lo[0] Review comment: can we just return lo[0] and set out_indices[index] = ... below? Might be more reusable. ########## File path: src/relay/op/algorithm/searchsorted.cc ########## @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file searchsorted.cc + * \brief SearchSorted operators + */ +#include <tvm/relay/attrs/algorithm.h> +#include <tvm/relay/op.h> +#include <tvm/tir/op.h> + +namespace tvm { +namespace relay { + +TVM_REGISTER_NODE_TYPE(SearchSortedAttrs); + +bool SearchSortedRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + const SearchSortedAttrs* param = attrs.as<SearchSortedAttrs>(); + ICHECK_EQ(types.size(), 3); + const auto* sorted_sequence = types[0].as<TensorTypeNode>(); + const auto* values = types[1].as<TensorTypeNode>(); + ICHECK(sorted_sequence) << "Expects TensorType in the first input"; + ICHECK(values) << "Expects TensorType in the second input"; + ICHECK_GT(values->shape.size(), 0) << "The rank of `values` must be greater than one"; + ICHECK(param->side == "left" || param->side == "right") Review comment: In general I think it would be better to put these short of checks as far up near the frontend as possible and remove the need for the check by making the side a boolean. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
