lidavidm commented on a change in pull request #11257:
URL: https://github.com/apache/arrow/pull/11257#discussion_r717611700



##########
File path: cpp/src/arrow/compute/kernels/aggregate_basic.cc
##########
@@ -121,6 +122,101 @@ Result<std::unique_ptr<KernelState>> 
CountInit(KernelContext*,
       static_cast<const CountOptions&>(*args.options));
 }
 
+// ----------------------------------------------------------------------
+// Distinct Count implementation
+
+template <typename Type>
+struct CountDistinctImpl : public ScalarAggregator {
+  using MemoTable = typename arrow::internal::HashTraits<Type>::MemoTableType;
+
+  Status consume_memo(const ArrayData& arr) {
+    auto visit_null = [&]() {
+      //return string_builder.AppendNull();
+      return Status::OK();
+    };
+    auto visit_value = [&](typename Type::c_type arg) {
+      //ARROW_ASSIGN_OR_RAISE(auto formatted, formatter(arg));
+      //return string_builder.Append(std::move(formatted));
+      int y;
+      RETURN_NOT_OK(memo_table_->GetOrInsert(arg, &y));
+      //std::cout << "AAAAAAAAAAAAAADFFFFFFFFFFFFFFFFFF\n" << arg << "\n";
+      return Status::OK();
+    };
+    RETURN_NOT_OK(VisitArrayDataInline<Type>(arr, visit_value, visit_null));
+    return Status::OK();
+  }
+
+  explicit CountDistinctImpl(CountOptions options) : 
options(std::move(options)), memo_table_(new MemoTable(default_memory_pool(), 
0)) {
+  }
+
+  Status Consume(KernelContext* ctx, const ExecBatch& batch) override {
+    auto a = batch[0].make_array();

Review comment:
       Note for later, we should also handle scalars.

##########
File path: cpp/src/arrow/compute/kernels/aggregate_basic.cc
##########
@@ -121,6 +122,101 @@ Result<std::unique_ptr<KernelState>> 
CountInit(KernelContext*,
       static_cast<const CountOptions&>(*args.options));
 }
 
+// ----------------------------------------------------------------------
+// Distinct Count implementation
+
+template <typename Type>
+struct CountDistinctImpl : public ScalarAggregator {
+  using MemoTable = typename arrow::internal::HashTraits<Type>::MemoTableType;
+
+  Status consume_memo(const ArrayData& arr) {
+    auto visit_null = [&]() {
+      //return string_builder.AppendNull();
+      return Status::OK();
+    };
+    auto visit_value = [&](typename Type::c_type arg) {
+      //ARROW_ASSIGN_OR_RAISE(auto formatted, formatter(arg));
+      //return string_builder.Append(std::move(formatted));
+      int y;
+      RETURN_NOT_OK(memo_table_->GetOrInsert(arg, &y));
+      //std::cout << "AAAAAAAAAAAAAADFFFFFFFFFFFFFFFFFF\n" << arg << "\n";
+      return Status::OK();
+    };
+    RETURN_NOT_OK(VisitArrayDataInline<Type>(arr, visit_value, visit_null));
+    return Status::OK();
+  }
+
+  explicit CountDistinctImpl(CountOptions options) : 
options(std::move(options)), memo_table_(new MemoTable(default_memory_pool(), 
0)) {
+  }
+
+  Status Consume(KernelContext* ctx, const ExecBatch& batch) override {
+    auto a = batch[0].make_array();
+    //std::cout << "LOTEEEEEEEEEEE\n" << a->ToString() << "\n";
+    
+    const ArrayData& arr = *batch[0].array();
+    
+    consume_memo(arr);
+    
+    //std::cout << "MEMOOOOOOOOOOOO\n" << this->memo_table_->size() << "\n";
+    this->result_countd += this->memo_table_->size();
+    return Status::OK();
+    //RETURN_NOT_OK(lookup_table->GetOrInsert());
+    
+    //ARROW_ASSIGN_OR_RAISE(auto grouper, 
internal::Grouper::Make(batch.GetDescriptors(), ctx->exec_context()));
+    //return grouper->Consume(batch).status();
+//    if (options.mode == CountOptions::ALL) {
+//      this->non_nulls += batch.length;
+//    } else if (batch[0].is_array()) {
+//      const ArrayData& input = *batch[0].array();
+//      const int64_t nulls = input.GetNullCount();
+//      this->nulls += nulls;
+//      this->non_nulls += input.length - nulls;
+//    } else {
+//      const Scalar& input = *batch[0].scalar();
+//      this->nulls += !input.is_valid * batch.length;
+//      this->non_nulls += input.is_valid * batch.length;
+//    }
+//    return Status::OK();
+  }
+
+  Status MergeFrom(KernelContext*, KernelState&& src) override {
+    const auto& other_state = checked_cast<const CountDistinctImpl&>(src);
+    //this->non_nulls += other_state.non_nulls;
+    //this->nulls += other_state.nulls;
+    this->result_countd += other_state.result_countd;
+    return Status::OK();
+  }
+
+  Status Finalize(KernelContext* ctx, Datum* out) override {
+    const auto& state = checked_cast<const CountDistinctImpl&>(*ctx->state());
+    switch (state.options.mode) {
+      case CountOptions::ONLY_VALID:
+      case CountOptions::ALL:
+        // ALL is equivalent since we don't count the null/non-null
+        // separately to avoid potentially computing null count
+        //*out = Datum(state.non_nulls);
+        *out = Datum(state.result_countd);
+        break;
+      case CountOptions::ONLY_NULL:
+        //*out = Datum(state.nulls);
+        *out = Datum(state.result_countd);
+        break;
+      default:
+        DCHECK(false) << "unreachable";
+    }
+    return Status::OK();
+  }
+
+  CountOptions options;
+  int64_t result_countd = 0;

Review comment:
       Is this a typo for result_count?

##########
File path: cpp/src/arrow/compute/kernels/aggregate_basic.cc
##########
@@ -121,6 +122,101 @@ Result<std::unique_ptr<KernelState>> 
CountInit(KernelContext*,
       static_cast<const CountOptions&>(*args.options));
 }
 
+// ----------------------------------------------------------------------
+// Distinct Count implementation
+
+template <typename Type>
+struct CountDistinctImpl : public ScalarAggregator {
+  using MemoTable = typename arrow::internal::HashTraits<Type>::MemoTableType;
+
+  Status consume_memo(const ArrayData& arr) {
+    auto visit_null = [&]() {
+      //return string_builder.AppendNull();
+      return Status::OK();
+    };
+    auto visit_value = [&](typename Type::c_type arg) {
+      //ARROW_ASSIGN_OR_RAISE(auto formatted, formatter(arg));
+      //return string_builder.Append(std::move(formatted));
+      int y;
+      RETURN_NOT_OK(memo_table_->GetOrInsert(arg, &y));
+      //std::cout << "AAAAAAAAAAAAAADFFFFFFFFFFFFFFFFFF\n" << arg << "\n";
+      return Status::OK();
+    };
+    RETURN_NOT_OK(VisitArrayDataInline<Type>(arr, visit_value, visit_null));
+    return Status::OK();
+  }
+
+  explicit CountDistinctImpl(CountOptions options) : 
options(std::move(options)), memo_table_(new MemoTable(default_memory_pool(), 
0)) {

Review comment:
       Hmm, we shouldn't use the default memory pool. CountDistinctInit gets a 
KernelContext which can be passed here; it contains the memory pool to use.

##########
File path: cpp/src/arrow/compute/kernels/aggregate_basic.cc
##########
@@ -121,6 +122,101 @@ Result<std::unique_ptr<KernelState>> 
CountInit(KernelContext*,
       static_cast<const CountOptions&>(*args.options));
 }
 
+// ----------------------------------------------------------------------
+// Distinct Count implementation
+
+template <typename Type>
+struct CountDistinctImpl : public ScalarAggregator {
+  using MemoTable = typename arrow::internal::HashTraits<Type>::MemoTableType;
+
+  Status consume_memo(const ArrayData& arr) {

Review comment:
       Just a note for later, but this should be ConsumeMemo not consume_memo

##########
File path: cpp/src/arrow/compute/kernels/aggregate_basic.cc
##########
@@ -121,6 +122,101 @@ Result<std::unique_ptr<KernelState>> 
CountInit(KernelContext*,
       static_cast<const CountOptions&>(*args.options));
 }
 
+// ----------------------------------------------------------------------
+// Distinct Count implementation
+
+template <typename Type>
+struct CountDistinctImpl : public ScalarAggregator {
+  using MemoTable = typename arrow::internal::HashTraits<Type>::MemoTableType;
+
+  Status consume_memo(const ArrayData& arr) {
+    auto visit_null = [&]() {
+      //return string_builder.AppendNull();
+      return Status::OK();
+    };
+    auto visit_value = [&](typename Type::c_type arg) {
+      //ARROW_ASSIGN_OR_RAISE(auto formatted, formatter(arg));
+      //return string_builder.Append(std::move(formatted));
+      int y;
+      RETURN_NOT_OK(memo_table_->GetOrInsert(arg, &y));
+      //std::cout << "AAAAAAAAAAAAAADFFFFFFFFFFFFFFFFFF\n" << arg << "\n";
+      return Status::OK();
+    };
+    RETURN_NOT_OK(VisitArrayDataInline<Type>(arr, visit_value, visit_null));
+    return Status::OK();
+  }
+
+  explicit CountDistinctImpl(CountOptions options) : 
options(std::move(options)), memo_table_(new MemoTable(default_memory_pool(), 
0)) {
+  }
+
+  Status Consume(KernelContext* ctx, const ExecBatch& batch) override {
+    auto a = batch[0].make_array();
+    //std::cout << "LOTEEEEEEEEEEE\n" << a->ToString() << "\n";
+    
+    const ArrayData& arr = *batch[0].array();
+    
+    consume_memo(arr);
+    
+    //std::cout << "MEMOOOOOOOOOOOO\n" << this->memo_table_->size() << "\n";
+    this->result_countd += this->memo_table_->size();
+    return Status::OK();
+    //RETURN_NOT_OK(lookup_table->GetOrInsert());
+    
+    //ARROW_ASSIGN_OR_RAISE(auto grouper, 
internal::Grouper::Make(batch.GetDescriptors(), ctx->exec_context()));
+    //return grouper->Consume(batch).status();
+//    if (options.mode == CountOptions::ALL) {
+//      this->non_nulls += batch.length;
+//    } else if (batch[0].is_array()) {
+//      const ArrayData& input = *batch[0].array();
+//      const int64_t nulls = input.GetNullCount();
+//      this->nulls += nulls;
+//      this->non_nulls += input.length - nulls;
+//    } else {
+//      const Scalar& input = *batch[0].scalar();
+//      this->nulls += !input.is_valid * batch.length;
+//      this->non_nulls += input.is_valid * batch.length;
+//    }
+//    return Status::OK();
+  }
+
+  Status MergeFrom(KernelContext*, KernelState&& src) override {
+    const auto& other_state = checked_cast<const CountDistinctImpl&>(src);
+    //this->non_nulls += other_state.non_nulls;
+    //this->nulls += other_state.nulls;
+    this->result_countd += other_state.result_countd;
+    return Status::OK();
+  }
+
+  Status Finalize(KernelContext* ctx, Datum* out) override {
+    const auto& state = checked_cast<const CountDistinctImpl&>(*ctx->state());
+    switch (state.options.mode) {
+      case CountOptions::ONLY_VALID:
+      case CountOptions::ALL:
+        // ALL is equivalent since we don't count the null/non-null
+        // separately to avoid potentially computing null count
+        //*out = Datum(state.non_nulls);
+        *out = Datum(state.result_countd);
+        break;
+      case CountOptions::ONLY_NULL:
+        //*out = Datum(state.nulls);
+        *out = Datum(state.result_countd);
+        break;
+      default:
+        DCHECK(false) << "unreachable";
+    }
+    return Status::OK();
+  }
+
+  CountOptions options;
+  int64_t result_countd = 0;
+  std::unique_ptr<MemoTable> memo_table_ = nullptr;

Review comment:
       nit, but if we're initializing this in the constructor anyways, there's 
no need to also initialize here.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to