mbrookhart commented on a change in pull request #8816:
URL: https://github.com/apache/tvm/pull/8816#discussion_r698720804
##########
File path: src/relay/op/tensor/reduce.cc
##########
@@ -207,9 +208,29 @@ Array<te::Tensor> ReduceCompute(const Attrs& attrs, const
Array<te::Tensor>& inp
return {topi::identity(inputs[0])};
}
}
+
return {f(inputs[0], axes, param->keepdims, false)};
}
+template <typename F>
+Array<te::Tensor> OneElementReduceCompute(const Attrs& attrs, const
Array<te::Tensor>& inputs,
Review comment:
I don't understand what OneElement means here?
##########
File path: include/tvm/topi/reduction.h
##########
@@ -442,35 +481,49 @@ inline Tensor max(const Tensor& data, const
Array<Integer>& axis, bool keepdims
* left in the result as dimensions with size one. This enables the result
* to broadcast correctly against the input array.
* \param atleast1d Whether the output need to be atleast1d.
+ * \param select_last_index Whether to select the last index if the minimum
element
+ * appears multiple times, else select the first index.
*
* \return A Tensor whose op member is the argmin operation
*/
inline Tensor argmin(const Tensor& data, const Array<Integer>& axis, bool
keepdims = false,
- bool atleast1d = false) {
- auto fcombine = [](Array<Var> lhs, Array<Var> rhs) {
- Array<PrimExpr> result;
- result.push_back(tvm::tir::Select(lhs[1] <= rhs[1], lhs[0], rhs[0])); //
idx
- result.push_back(tvm::tir::Select(lhs[1] <= rhs[1], lhs[1], rhs[1])); //
val
- return result;
- };
- auto fidentity = [](std::vector<DataType> types) {
- Array<PrimExpr> result;
- result.push_back(tvm::tir::make_const(types[0], -1)); // idx
- result.push_back(tvm::max_value(types[1])); // val
- return result;
- };
- auto func = MakeCommReducer(fcombine, fidentity, "argmin");
- return CommReduceIdx(data, axis, func, keepdims, atleast1d);
+ bool atleast1d = false, bool select_last_index = false) {
+ auto reducer = MakeArgminReducer(select_last_index);
+ return CommReduceIdx(data, axis, reducer, keepdims, atleast1d);
}
-inline FCommReduce MakeArgmaxReducer() {
- auto fcombine = [](Array<Var> lhs, Array<Var> rhs) {
+inline FCommReduce MakeArgmaxReducer(bool select_last_index = false) {
+ // Create a Commutative Reducer with a comparison operation, and method to
get the initial value.
+ auto fcombine = [=](Array<Var> lhs, Array<Var> rhs) {
Array<PrimExpr> result;
- result.push_back(tvm::tir::Select(lhs[1] >= rhs[1], lhs[0], rhs[0])); //
idx
- result.push_back(tvm::tir::Select(lhs[1] >= rhs[1], lhs[1], rhs[1])); //
val
+
+ // Casting to avoid operator ambiguity
+ PrimExpr lhs_idx = static_cast<PrimExpr>(lhs[0]);
+ PrimExpr rhs_idx = static_cast<PrimExpr>(rhs[0]);
+ PrimExpr lhs_val = static_cast<PrimExpr>(lhs[1]);
+ PrimExpr rhs_val = static_cast<PrimExpr>(rhs[1]);
+
+ // These variables compare the actual values of the array
+ auto is_bigger = lhs_val > rhs_val;
+ auto is_same = lhs_val == rhs_val;
+
+ // This checks if the indices are correct for the reduction. E.g. for
select_last_index
+ // it gives precedence for later indices of the same element and
precedence for sooner
+ // indices if not select_last_index;
+ PrimExpr proper_index;
+ if (select_last_index) {
+ proper_index = lhs_idx > rhs_idx;
+ } else {
+ proper_index = lhs_idx < rhs_idx;
+ }
+
+ PrimExpr update_index = is_bigger || (is_same && proper_index);
+ result.push_back(tvm::tir::Select(update_index, lhs[0], rhs[0])); // idx
+ result.push_back(tvm::tir::Select(is_bigger, lhs[1], rhs[1])); // val
+ LOG(WARNING) << result;
Review comment:
Remove the log? I think this will end up being noisy.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]