pitrou commented on code in PR #14753:
URL: https://github.com/apache/arrow/pull/14753#discussion_r1048475757


##########
cpp/src/arrow/compute/kernels/vector_sort_internal.h:
##########
@@ -456,6 +457,243 @@ Status SortChunkedArray(ExecContext* ctx, uint64_t* 
indices_begin, uint64_t* ind
                         const ChunkedArray& values, SortOrder sort_order,
                         NullPlacement null_placement);
 
+// ----------------------------------------------------------------------
+// Helpers for Sort/SelectK/Rank implementations
+
+struct SortField {
+  int field_index;
+  SortOrder order;
+};
+
+inline Status CheckNonNested(const FieldRef& ref) {
+  if (ref.IsNested()) {
+    return Status::KeyError("Nested keys not supported for SortKeys");
+  }
+  return Status::OK();
+}
+
+template <typename T>
+Result<T> PrependInvalidColumn(Result<T> res) {
+  if (res.ok()) return res;
+  return res.status().WithMessage("Invalid sort key column: ", 
res.status().message());
+}
+
+// Return the field indices of the sort keys, deduplicating them along the way
+Result<std::vector<SortField>> FindSortKeys(const Schema& schema,
+                                            const std::vector<SortKey>& 
sort_keys);
+
+template <typename ResolvedSortKey, typename ResolvedSortKeyFactory>
+Result<std::vector<ResolvedSortKey>> ResolveSortKeys(
+    const Schema& schema, const std::vector<SortKey>& sort_keys,
+    ResolvedSortKeyFactory&& factory) {
+  ARROW_ASSIGN_OR_RAISE(const auto fields, FindSortKeys(schema, sort_keys));
+  std::vector<ResolvedSortKey> resolved;
+  resolved.reserve(fields.size());
+  std::transform(fields.begin(), fields.end(), std::back_inserter(resolved), 
factory);
+  return resolved;
+}
+
+template <typename ResolvedSortKey, typename TableOrBatch>
+Result<std::vector<ResolvedSortKey>> ResolveSortKeys(
+    const TableOrBatch& table_or_batch, const std::vector<SortKey>& sort_keys) 
{
+  return ResolveSortKeys<ResolvedSortKey>(
+      *table_or_batch.schema(), sort_keys, [&](const SortField& f) {
+        return ResolvedSortKey{table_or_batch.column(f.field_index), f.order};
+      });
+}
+
+// Returns nullptr if no column matching `ref` is found, or if the FieldRef is
+// a nested reference.
+inline std::shared_ptr<ChunkedArray> GetTableColumn(const Table& table,
+                                                    const FieldRef& ref) {
+  if (ref.IsNested()) return nullptr;
+
+  if (auto name = ref.name()) {
+    return table.GetColumnByName(*name);
+  }
+
+  auto index = ref.field_path()->indices()[0];
+  if (index >= table.num_columns()) return nullptr;
+  return table.column(index);
+}
+
+// We could try to reproduce the concrete Array classes' facilities
+// (such as cached raw values pointer) in a separate hierarchy of
+// physical accessors, but doing so ends up too cumbersome.
+// Instead, we simply create the desired concrete Array objects.
+inline std::shared_ptr<Array> GetPhysicalArray(
+    const Array& array, const std::shared_ptr<DataType>& physical_type) {
+  auto new_data = array.data()->Copy();
+  new_data->type = physical_type;
+  return MakeArray(std::move(new_data));
+}
+
+inline ArrayVector GetPhysicalChunks(const ArrayVector& chunks,
+                                     const std::shared_ptr<DataType>& 
physical_type) {
+  ArrayVector physical(chunks.size());
+  std::transform(chunks.begin(), chunks.end(), physical.begin(),
+                 [&](const std::shared_ptr<Array>& array) {
+                   return GetPhysicalArray(*array, physical_type);
+                 });
+  return physical;
+}
+
+inline ArrayVector GetPhysicalChunks(const ChunkedArray& chunked_array,
+                                     const std::shared_ptr<DataType>& 
physical_type) {
+  return GetPhysicalChunks(chunked_array.chunks(), physical_type);
+}
+
+// Compare two records in a single column (either from a batch or table)
+template <typename ResolvedSortKey>
+struct ColumnComparator {
+  using Location = typename ResolvedSortKey::LocationType;
+
+  ColumnComparator(const ResolvedSortKey& sort_key, NullPlacement 
null_placement)
+      : sort_key_(sort_key), null_placement_(null_placement) {}
+
+  virtual ~ColumnComparator() = default;
+
+  virtual int Compare(const Location& left, const Location& right) const = 0;
+
+  ResolvedSortKey sort_key_;
+  NullPlacement null_placement_;
+};
+
+template <typename ResolvedSortKey, typename Type>
+struct ConcreteColumnComparator : public ColumnComparator<ResolvedSortKey> {
+  using ArrayType = typename TypeTraits<Type>::ArrayType;
+  using Location = typename ResolvedSortKey::LocationType;
+
+  using ColumnComparator<ResolvedSortKey>::ColumnComparator;
+
+  int Compare(const Location& left, const Location& right) const override {
+    const auto& sort_key = this->sort_key_;
+
+    const auto chunk_left = sort_key.template GetChunk<ArrayType>(left);
+    const auto chunk_right = sort_key.template GetChunk<ArrayType>(right);
+    if (sort_key.null_count > 0) {
+      const bool is_null_left = chunk_left.IsNull();
+      const bool is_null_right = chunk_right.IsNull();
+      if (is_null_left && is_null_right) {
+        return 0;
+      } else if (is_null_left) {
+        return this->null_placement_ == NullPlacement::AtStart ? -1 : 1;
+      } else if (is_null_right) {
+        return this->null_placement_ == NullPlacement::AtStart ? 1 : -1;
+      }
+    }
+    return CompareTypeValues<Type>(chunk_left.Value(), chunk_right.Value(),
+                                   sort_key.order, this->null_placement_);
+  }
+};
+
+template <typename ResolvedSortKey>
+struct ConcreteColumnComparator<ResolvedSortKey, NullType>
+    : public ColumnComparator<ResolvedSortKey> {
+  using Location = typename ResolvedSortKey::LocationType;
+
+  using ColumnComparator<ResolvedSortKey>::ColumnComparator;
+
+  int Compare(const Location& left, const Location& right) const override { 
return 0; }
+};
+
+// Compare two records in the same RecordBatch or Table
+// (indexing is handled through ResolvedSortKey)
+template <typename ResolvedSortKey>
+class MultipleKeyComparator {
+ public:
+  using Location = typename ResolvedSortKey::LocationType;
+
+  MultipleKeyComparator(const std::vector<ResolvedSortKey>& sort_keys,
+                        NullPlacement null_placement)
+      : sort_keys_(sort_keys), null_placement_(null_placement) {
+    status_ &= MakeComparators();
+  }
+
+  Status status() const { return status_; }
+
+  // Returns true if the left-th value should be ordered before the
+  // right-th value, false otherwise. The start_sort_key_index-th
+  // sort key and subsequent sort keys are used for comparison.
+  bool Compare(const Location& left, const Location& right, size_t 
start_sort_key_index) {
+    return CompareInternal(left, right, start_sort_key_index) < 0;
+  }
+
+  bool Equals(const Location& left, const Location& right, size_t 
start_sort_key_index) {
+    return CompareInternal(left, right, start_sort_key_index) == 0;
+  }
+
+ private:
+  struct ColumnComparatorFactory {
+#define VISIT(TYPE) \
+  Status Visit(const TYPE& type) { return VisitGeneric(type); }
+
+    VISIT_SORTABLE_PHYSICAL_TYPES(VISIT)
+    VISIT(NullType)
+
+#undef VISIT
+
+    Status Visit(const DataType& type) {
+      return Status::TypeError("Unsupported type for batch or table sorting: ",
+                               type.ToString());
+    }
+
+    template <typename Type>
+    Status VisitGeneric(const Type& type) {
+      res.reset(
+          new ConcreteColumnComparator<ResolvedSortKey, Type>{sort_key, 
null_placement});
+      return Status::OK();
+    }
+
+    const ResolvedSortKey& sort_key;
+    NullPlacement null_placement;
+    std::unique_ptr<ColumnComparator<ResolvedSortKey>> res;
+  };
+
+  Status MakeComparators() {
+    column_comparators_.reserve(sort_keys_.size());
+
+    for (const auto& sort_key : sort_keys_) {
+      ColumnComparatorFactory factory{sort_key, null_placement_, nullptr};
+      RETURN_NOT_OK(VisitTypeInline(*sort_key.type, &factory));
+      column_comparators_.push_back(std::move(factory.res));
+    }
+    return Status::OK();
+  }
+
+  // Compare two records in the same table and return -1, 0 or 1.
+  //
+  // -1: The left is less than the right.
+  // 0: The left equals to the right.
+  // 1: The left is greater than the right.
+  //
+  // This supports null and NaN. Null is processed in this and NaN
+  // is processed in CompareTypeValue().
+  int CompareInternal(const Location& left, const Location& right,
+                      size_t start_sort_key_index) {
+    const auto num_sort_keys = sort_keys_.size();
+    for (size_t i = start_sort_key_index; i < num_sort_keys; ++i) {
+      const int r = column_comparators_[i]->Compare(left, right);
+      if (r != 0) {
+        return r;
+      }
+    }
+    return 0;
+  }
+
+  const std::vector<ResolvedSortKey>& sort_keys_;
+  const NullPlacement null_placement_;
+  std::vector<std::unique_ptr<ColumnComparator<ResolvedSortKey>>> 
column_comparators_;
+  Status status_;
+};
+
+inline Result<std::shared_ptr<ArrayData>> MakeMutableUInt64Array(
+    std::shared_ptr<DataType> out_type, int64_t length, MemoryPool* 
memory_pool) {

Review Comment:
   Is `out_type` unused here?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to