masahi commented on a change in pull request #9184:
URL: https://github.com/apache/tvm/pull/9184#discussion_r727476180
##########
File path: include/tvm/relay/attrs/algorithm.h
##########
@@ -76,6 +76,19 @@ struct TopKAttrs : public tvm::AttrsNode<TopKAttrs> {
}
};
+struct SearchSortedAttrs : public tvm::AttrsNode<SearchSortedAttrs> {
+ std::string side;
+ DataType dtype;
+
+ TVM_DECLARE_ATTRS(SearchSortedAttrs, "relay.attrs.SearchSortedAttrs") {
+ TVM_ATTR_FIELD(side).set_default("left").describe(
+ "Controls which index is returned if a value lands exactly on one of
sorted values.");
+ TVM_ATTR_FIELD(dtype)
+ .set_default(DataType::Int(32))
Review comment:
Right, other ops have `NullValue<DataType>()` as the default here, but
if we look at the python definition at
https://github.com/apache/tvm/blob/main/python/tvm/relay/op/algorithm.py#L47,
they say the default is int32. So I thought we should make that explicit in
`attrs/algorithm.h` as well.
##########
File path: python/tvm/topi/searchsorted.py
##########
@@ -0,0 +1,121 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""searchsorted operator"""
+from . import utils
+from . import te
+from ..tir import ir_builder
+from .math import cast
+
+
+def binary_search(
+ ib, sequence_offset, search_range, index, sorted_sequence, values,
out_indices, side, out_dtype
+):
+ """Common IR generator for CPU and GPU searchsorted."""
+ lo = ib.allocate(out_dtype, (1,), name="lo", scope="local")
+ hi = ib.allocate(out_dtype, (1,), name="hi", scope="local")
+
+ v = values[index]
+ lo[0] = cast(0, out_dtype)
+ hi[0] = cast(search_range, out_dtype)
+
+ # Reference: pytorch/aten/src/ATen/native/cuda/Bucketization.cu
+ def condition(current_val, target_val):
+ if side == "left":
+ return current_val < target_val
+ return current_val <= target_val
+
+ with ib.while_loop(lo[0] < hi[0]):
+ mid = lo[0] + (hi[0] - lo[0] >> 1)
+ with ib.if_scope(condition(sorted_sequence[sequence_offset + mid], v)):
+ lo[0] = mid + 1
+ with ib.else_scope():
+ hi[0] = mid
+
+ out_indices[index] = lo[0]
Review comment:
Due to some peculiarity in the IR builder, that doesn't work on the
vulkan target. It works fine on llvm, but on vulkan I get all zero output:
```
Mismatched elements: 149857 / 150000 (99.9%)
Max absolute difference: 1024
Max relative difference: 1.
x: array([[[[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]],...
y: array([[[[ 116, 195, 291, ..., 338, 196, 890],
[ 609, 93, 659, ..., 977, 563, 693],
[ 675, 922, 53, ..., 1019, 429, 486]],...
```
I'll test on the cuda target too.
##########
File path: python/tvm/relay/op/algorithm.py
##########
@@ -115,3 +115,36 @@ def topk(data, k=1, axis=-1, ret_type="both",
is_ascend=False, dtype="int32"):
if ret_type == "both":
return TupleWrapper(out, 2)
return out
+
+
+def searchsorted(sorted_sequence, values, side="left", dtype="int32"):
+ """Find indices where elements should be inserted to maintain order.
+ If `sorted_sequence` is N-dimensional, the innermost dimension of
+ `values` are searched in the corresponding dimension of
`sorted_sequence`.
+
+ Parameters
+ ----------
+ sorted_sequence : relay.Expr
+ N-D or 1-D Tensor, containing monotonically increasing sequence
+ on the innermost dimension.
+
+ values : relay.Expr
+ N-D Tensor containing the search values. When `sorted_sequence` is 1-D,
+ the shape of `values` can be arbitrary. Otherwise, ranks of
`sorted_sequence`
+ and `values` must be the same, and outer N-1 axes must have the same
size.
+
+ side : string, optional
+ It can be `left` or `right`. If `left`, gets the lower bound index for
each value
+ in `values` on the corresponding innermost dimension of the
`sorted_sequence`.
+ If `right`, gets the upper bound index instead.
+
+ dtype : string, optional
+ The data type of the output indices.
+
+ Returns
+ -------
+ indices : relay.Expr
+ Tensor with same shape as values, representing the indices of
+ elements of `values` if they are inserted in `sorted_sequence`.
+ """
Review comment:
made it a boolean, no need to check
##########
File path: src/relay/op/algorithm/searchsorted.cc
##########
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file searchsorted.cc
+ * \brief SearchSorted operators
+ */
+#include <tvm/relay/attrs/algorithm.h>
+#include <tvm/relay/op.h>
+#include <tvm/tir/op.h>
+
+namespace tvm {
+namespace relay {
+
+TVM_REGISTER_NODE_TYPE(SearchSortedAttrs);
+
+bool SearchSortedRel(const Array<Type>& types, int num_inputs, const Attrs&
attrs,
+ const TypeReporter& reporter) {
+ const SearchSortedAttrs* param = attrs.as<SearchSortedAttrs>();
+ ICHECK_EQ(types.size(), 3);
+ const auto* sorted_sequence = types[0].as<TensorTypeNode>();
+ const auto* values = types[1].as<TensorTypeNode>();
+ ICHECK(sorted_sequence) << "Expects TensorType in the first input";
+ ICHECK(values) << "Expects TensorType in the second input";
+ ICHECK_GT(values->shape.size(), 0) << "The rank of `values` must be greater
than one";
+ ICHECK(param->side == "left" || param->side == "right")
Review comment:
made it a boolean
##########
File path: python/tvm/topi/searchsorted.py
##########
@@ -0,0 +1,121 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""searchsorted operator"""
+from . import utils
+from . import te
+from ..tir import ir_builder
+from .math import cast
+
+
+def binary_search(
+ ib, sequence_offset, search_range, index, sorted_sequence, values,
out_indices, side, out_dtype
Review comment:
Added some description, let me know if things are not clear
##########
File path: python/tvm/topi/searchsorted.py
##########
@@ -0,0 +1,121 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""searchsorted operator"""
+from . import utils
+from . import te
+from ..tir import ir_builder
+from .math import cast
+
+
+def binary_search(
+ ib, sequence_offset, search_range, index, sorted_sequence, values,
out_indices, side, out_dtype
+):
+ """Common IR generator for CPU and GPU searchsorted."""
+ lo = ib.allocate(out_dtype, (1,), name="lo", scope="local")
+ hi = ib.allocate(out_dtype, (1,), name="hi", scope="local")
+
+ v = values[index]
+ lo[0] = cast(0, out_dtype)
+ hi[0] = cast(search_range, out_dtype)
+
+ # Reference: pytorch/aten/src/ATen/native/cuda/Bucketization.cu
+ def condition(current_val, target_val):
+ if side == "left":
+ return current_val < target_val
+ return current_val <= target_val
+
+ with ib.while_loop(lo[0] < hi[0]):
+ mid = lo[0] + (hi[0] - lo[0] >> 1)
+ with ib.if_scope(condition(sorted_sequence[sequence_offset + mid], v)):
+ lo[0] = mid + 1
+ with ib.else_scope():
+ hi[0] = mid
+
+ out_indices[index] = lo[0]
Review comment:
My bad, when I tested the change above, I forget to update the GPU
definition in `topi/cuda/searchsorted.py` to assign the returned index to the
output buffer. No wonder I got all zero output!
I fixed my mistake and it works on vulkan as well.
##########
File path: python/tvm/topi/searchsorted.py
##########
@@ -0,0 +1,121 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""searchsorted operator"""
+from . import utils
+from . import te
+from ..tir import ir_builder
+from .math import cast
+
+
+def binary_search(
+ ib, sequence_offset, search_range, index, sorted_sequence, values,
out_indices, side, out_dtype
+):
+ """Common IR generator for CPU and GPU searchsorted."""
+ lo = ib.allocate(out_dtype, (1,), name="lo", scope="local")
+ hi = ib.allocate(out_dtype, (1,), name="hi", scope="local")
+
+ v = values[index]
+ lo[0] = cast(0, out_dtype)
+ hi[0] = cast(search_range, out_dtype)
+
+ # Reference: pytorch/aten/src/ATen/native/cuda/Bucketization.cu
+ def condition(current_val, target_val):
+ if side == "left":
+ return current_val < target_val
+ return current_val <= target_val
+
+ with ib.while_loop(lo[0] < hi[0]):
+ mid = lo[0] + (hi[0] - lo[0] >> 1)
+ with ib.if_scope(condition(sorted_sequence[sequence_offset + mid], v)):
+ lo[0] = mid + 1
+ with ib.else_scope():
+ hi[0] = mid
+
+ out_indices[index] = lo[0]
Review comment:
My bad, when I tested the change above, I forget to update the GPU
definition in `topi/cuda/searchsorted.py` to assign the returned index to the
output buffer. No wonder I got all zero output!
I fixed my mistake and it works on vulkan as well. Not only it removed
in-place mutation, it also removed some arguments from `binary_search`. Things
look much cleaner now, thanks!
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]