WillAyd commented on code in PR #12963:
URL: https://github.com/apache/arrow/pull/12963#discussion_r864448320
##########
cpp/src/arrow/compute/kernels/vector_sort.cc:
##########
@@ -1909,6 +1909,110 @@ class SelectKUnstableMetaFunction : public MetaFunction
{
}
};
+// ----------------------------------------------------------------------
+// Rank implementation
+
+const FunctionDoc rank_doc(
+ "Returns the ranking of an array",
+ ("This function computes a rank of the input array.\n"
+ "By default, Null values are considered greater than any other value
and\n"
+ "are therefore sorted at the end of the input. For floating-point
types,\n"
+ "NaNs are considered greater than any other non-null value, but smaller\n"
+ "than null values. The default tiebreaker is to assign ranks in order
of\n"
+ "when ties appear in the input\n"
+ "\n"
+ "The handling of nulls, NaNs and tiebreakers can be changed in
RankOptions."),
+ {"input"}, "RankOptions");
+
+class RankMetaFunction : public MetaFunction {
+ public:
+ RankMetaFunction()
+ : MetaFunction("rank", Arity::Unary(), &rank_doc,
GetDefaultSortOptions()) {}
+
+ Result<Datum> ExecuteImpl(const std::vector<Datum>& args,
+ const FunctionOptions* options, ExecContext* ctx)
const {
+ const RankOptions& sort_options = static_cast<const
RankOptions&>(*options);
+ switch (args[0].kind()) {
+ case Datum::ARRAY: {
+ return Rank(*args[0].make_array(), sort_options, ctx);
+ } break;
+ default:
+ break;
+ }
+ return Status::NotImplemented(
+ "Unsupported types for rank operation: "
+ "values=",
+ args[0].ToString());
+ }
+
+ private:
+ Result<Datum> Rank(const Array& array, const RankOptions& options,
+ ExecContext* ctx) const {
+ ArraySortOptions array_options(options.order, options.null_placement);
+
+ ARROW_ASSIGN_OR_RAISE(auto sortIndices, CallFunction("array_sort_indices",
{array},
+ &array_options, ctx));
+
+ auto out_size = array.length();
+ ARROW_ASSIGN_OR_RAISE(auto rankings,
+ MakeMutableUInt64Array(uint64(), out_size,
ctx->memory_pool()));
+
+ auto* indices = sortIndices.make_array()->data()->GetValues<uint64_t>(1);
+ auto out_rankings = rankings->GetMutableValues<uint64_t>(1);
+ uint64_t rank = 0;
+ Datum prevValue, currValue;
+
+ if (options.tiebreaker == Tiebreaker::Dense) {
+ for (auto i = 0; i < out_size; i++) {
+ currValue = array.GetScalar(indices[i]).ValueOrDie();
+ if (i > 0 && currValue == prevValue) {
+ } else {
+ ++rank;
+ }
+ out_rankings[indices[i]] = rank;
+ prevValue = currValue;
+ }
+ } else if (options.tiebreaker == Tiebreaker::First) {
+ for (auto i = 0; i < out_size; i++) {
+ rank = i + 1;
+ out_rankings[indices[i]] = rank;
+ }
+ } else if (options.tiebreaker == Tiebreaker::Lowest) {
+ for (auto i = 0; i < out_size; i++) {
+ currValue = array.GetScalar(indices[i]).ValueOrDie();
+ if (i > 0 && currValue == prevValue) {
+ } else {
+ rank = i + 1;
+ }
+ out_rankings[indices[i]] = rank;
+ prevValue = currValue;
+ }
+ } else if (options.tiebreaker == Tiebreaker::Highest) {
+ auto currentTieCount = 0;
+ for (auto i = 0; i < out_size; i++) {
+ currValue = array.GetScalar(indices[i]).ValueOrDie();
+ if (i > 0 && currValue == prevValue) {
+ currentTieCount++;
+ } else {
+ currentTieCount = 0;
+ }
+ rank = i + 1;
+
+ // This can be inefficient when dealing many tied values
Review Comment:
Think this should be good now. Lastly looking at implementing the visitor
pattern to wrap up the inefficient array access you mentioned in another
comment then ready for another look
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]