This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 3353c06183 Add Aggregation fuzzer framework (#12667)
3353c06183 is described below
commit 3353c061830dbd1ef912736fee8426ee0405a3f9
Author: kamille <[email protected]>
AuthorDate: Wed Oct 9 22:07:06 2024 +0800
Add Aggregation fuzzer framework (#12667)
* impl primitive arrays generator.
* sort out the test record batch generating codes.
* draft for `DataSetsGenerator`.
* tmp
* improve the data generator, and start to impl the session context
generator.
* impl context generator.
* tmp
* define the `AggregationFuzzer`.
* add ut for data generator.
* improve comments for `SessionContextGenerator`.
* define `GeneratedSessionContextBuilder` to reduce repeated codes.
* extract the check equality logic for reusing.
* add ut for `SessionContextGenerator`.
* tmp
* finish the main logic of `AggregationFuzzer`.
* try to rewrite some test using the fuzzer.
* fix header.
* expose table name through `AggregationFuzzerBuilder`.
* throw err to aggr fuzzer, and expect them then.
* switch to Arc<str> to slightly improve performance.
* throw more errors to fuzzer.
* print task informantion before panic.
* improve comments.
* support printing generated session context params in error reporting.
* add todo.
* add some new fuzz case based on `AggregationFuzzer`.
* fix lint.
* print more information in error report.
* fix clippy.
* improve comment of `SessionContextGenerator`.
* just use fixed `data_gen_rounds` and `ctx_gen_rounds` currently, because
we will hardly set them.
* improve comments for rounds constants.
* small improvements.
* select sql from some candidates ranther than fixed one.
* make `data_gen_rounds` able to set again, and add more tests.
* add no group cases.
* add fuzz test for basic string aggr.
* make `data_gen_rounds` smaller.
* add comments.
* fix typo.
* fix comment.
---
datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs | 302 ++++++++++++++
.../aggregation_fuzzer/context_generator.rs | 343 +++++++++++++++
.../aggregation_fuzzer/data_generator.rs | 459 +++++++++++++++++++++
.../tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs | 281 +++++++++++++
.../tests/fuzz_cases/aggregation_fuzzer/mod.rs | 69 ++++
datafusion/core/tests/fuzz_cases/mod.rs | 1 +
test-utils/Cargo.toml | 1 +
.../fuzz_cases => test-utils/src/array_gen}/mod.rs | 12 +-
test-utils/src/array_gen/primitive.rs | 80 ++++
.../src/{string_gen.rs => array_gen/string.rs} | 73 +---
test-utils/src/lib.rs | 1 +
test-utils/src/string_gen.rs | 72 +---
12 files changed, 1556 insertions(+), 138 deletions(-)
diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs
b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs
index 62e9be6398..5cc5157c3a 100644
--- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs
+++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs
@@ -44,6 +44,307 @@ use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use tokio::task::JoinSet;
+use crate::fuzz_cases::aggregation_fuzzer::{
+ AggregationFuzzerBuilder, ColumnDescr, DatasetGeneratorConfig,
+};
+
+// ========================================================================
+// The new aggregation fuzz tests based on [`AggregationFuzzer`]
+// ========================================================================
+
+// TODO: write more test case to cover more `group by`s and `aggregation
function`s
+// TODO: maybe we can use macro to simply the case creating
+
+/// Fuzz test for `basic prim aggr(sum/sum distinct/max/min/count/avg)` + `no
group by`
+#[tokio::test(flavor = "multi_thread")]
+async fn test_basic_prim_aggr_no_group() {
+ let builder = AggregationFuzzerBuilder::default();
+
+ // Define data generator config
+ let columns = vec![ColumnDescr::new("a", DataType::Int32)];
+
+ let data_gen_config = DatasetGeneratorConfig {
+ columns,
+ rows_num_range: (512, 1024),
+ sort_keys_set: Vec::new(),
+ };
+
+ // Build fuzzer
+ let fuzzer = builder
+ .data_gen_config(data_gen_config)
+ .data_gen_rounds(16)
+ .add_sql("SELECT sum(a) FROM fuzz_table")
+ .add_sql("SELECT sum(distinct a) FROM fuzz_table")
+ .add_sql("SELECT max(a) FROM fuzz_table")
+ .add_sql("SELECT min(a) FROM fuzz_table")
+ .add_sql("SELECT count(a) FROM fuzz_table")
+ .add_sql("SELECT count(distinct a) FROM fuzz_table")
+ .add_sql("SELECT avg(a) FROM fuzz_table")
+ .table_name("fuzz_table")
+ .build();
+
+ fuzzer.run().await;
+}
+
+/// Fuzz test for `basic prim aggr(sum/sum distinct/max/min/count/avg)` +
`group by single int64`
+#[tokio::test(flavor = "multi_thread")]
+async fn test_basic_prim_aggr_group_by_single_int64() {
+ let builder = AggregationFuzzerBuilder::default();
+
+ // Define data generator config
+ let columns = vec![
+ ColumnDescr::new("a", DataType::Int32),
+ ColumnDescr::new("b", DataType::Int64),
+ ColumnDescr::new("c", DataType::Int64),
+ ];
+ let sort_keys_set = vec![
+ vec!["b".to_string()],
+ vec!["c".to_string(), "b".to_string()],
+ ];
+ let data_gen_config = DatasetGeneratorConfig {
+ columns,
+ rows_num_range: (512, 1024),
+ sort_keys_set,
+ };
+
+ // Build fuzzer
+ let fuzzer = builder
+ .data_gen_config(data_gen_config)
+ .data_gen_rounds(16)
+ .add_sql("SELECT b, sum(a) FROM fuzz_table GROUP BY b")
+ .add_sql("SELECT b, sum(distinct a) FROM fuzz_table GROUP BY b")
+ .add_sql("SELECT b, max(a) FROM fuzz_table GROUP BY b")
+ .add_sql("SELECT b, min(a) FROM fuzz_table GROUP BY b")
+ .add_sql("SELECT b, count(a) FROM fuzz_table GROUP BY b")
+ .add_sql("SELECT b, count(distinct a) FROM fuzz_table GROUP BY b")
+ .add_sql("SELECT b, avg(a) FROM fuzz_table GROUP BY b")
+ .table_name("fuzz_table")
+ .build();
+
+ fuzzer.run().await;
+}
+
+/// Fuzz test for `basic prim aggr(sum/sum distinct/max/min/count/avg)` +
`group by single string`
+#[tokio::test(flavor = "multi_thread")]
+async fn test_basic_prim_aggr_group_by_single_string() {
+ let builder = AggregationFuzzerBuilder::default();
+
+ // Define data generator config
+ let columns = vec![
+ ColumnDescr::new("a", DataType::Int32),
+ ColumnDescr::new("b", DataType::Utf8),
+ ColumnDescr::new("c", DataType::Int64),
+ ];
+ let sort_keys_set = vec![
+ vec!["b".to_string()],
+ vec!["c".to_string(), "b".to_string()],
+ ];
+ let data_gen_config = DatasetGeneratorConfig {
+ columns,
+ rows_num_range: (512, 1024),
+ sort_keys_set,
+ };
+
+ // Build fuzzer
+ let fuzzer = builder
+ .data_gen_config(data_gen_config)
+ .data_gen_rounds(16)
+ .add_sql("SELECT b, sum(a) FROM fuzz_table GROUP BY b")
+ .add_sql("SELECT b, sum(distinct a) FROM fuzz_table GROUP BY b")
+ .add_sql("SELECT b, max(a) FROM fuzz_table GROUP BY b")
+ .add_sql("SELECT b, min(a) FROM fuzz_table GROUP BY b")
+ .add_sql("SELECT b, count(a) FROM fuzz_table GROUP BY b")
+ .add_sql("SELECT b, count(distinct a) FROM fuzz_table GROUP BY b")
+ .add_sql("SELECT b, avg(a) FROM fuzz_table GROUP BY b")
+ .table_name("fuzz_table")
+ .build();
+
+ fuzzer.run().await;
+}
+
+/// Fuzz test for `basic prim aggr(sum/sum distinct/max/min/count/avg)` +
`group by string + int64`
+#[tokio::test(flavor = "multi_thread")]
+async fn test_basic_prim_aggr_group_by_mixed_string_int64() {
+ let builder = AggregationFuzzerBuilder::default();
+
+ // Define data generator config
+ let columns = vec![
+ ColumnDescr::new("a", DataType::Int32),
+ ColumnDescr::new("b", DataType::Utf8),
+ ColumnDescr::new("c", DataType::Int64),
+ ColumnDescr::new("d", DataType::Int32),
+ ];
+ let sort_keys_set = vec![
+ vec!["b".to_string(), "c".to_string()],
+ vec!["d".to_string(), "b".to_string(), "c".to_string()],
+ ];
+ let data_gen_config = DatasetGeneratorConfig {
+ columns,
+ rows_num_range: (512, 1024),
+ sort_keys_set,
+ };
+
+ // Build fuzzer
+ let fuzzer = builder
+ .data_gen_config(data_gen_config)
+ .data_gen_rounds(16)
+ .add_sql("SELECT b, c, sum(a) FROM fuzz_table GROUP BY b, c")
+ .add_sql("SELECT b, c, sum(distinct a) FROM fuzz_table GROUP BY b,c")
+ .add_sql("SELECT b, c, max(a) FROM fuzz_table GROUP BY b, c")
+ .add_sql("SELECT b, c, min(a) FROM fuzz_table GROUP BY b, c")
+ .add_sql("SELECT b, c, count(a) FROM fuzz_table GROUP BY b, c")
+ .add_sql("SELECT b, c, count(distinct a) FROM fuzz_table GROUP BY b,
c")
+ .add_sql("SELECT b, c, avg(a) FROM fuzz_table GROUP BY b, c")
+ .table_name("fuzz_table")
+ .build();
+
+ fuzzer.run().await;
+}
+
+/// Fuzz test for `basic string aggr(count/count distinct/min/max)` + `no
group by`
+#[tokio::test(flavor = "multi_thread")]
+async fn test_basic_string_aggr_no_group() {
+ let builder = AggregationFuzzerBuilder::default();
+
+ // Define data generator config
+ let columns = vec![ColumnDescr::new("a", DataType::Utf8)];
+
+ let data_gen_config = DatasetGeneratorConfig {
+ columns,
+ rows_num_range: (512, 1024),
+ sort_keys_set: Vec::new(),
+ };
+
+ // Build fuzzer
+ let fuzzer = builder
+ .data_gen_config(data_gen_config)
+ .data_gen_rounds(8)
+ .add_sql("SELECT max(a) FROM fuzz_table")
+ .add_sql("SELECT min(a) FROM fuzz_table")
+ .add_sql("SELECT count(a) FROM fuzz_table")
+ .add_sql("SELECT count(distinct a) FROM fuzz_table")
+ .table_name("fuzz_table")
+ .build();
+
+ fuzzer.run().await;
+}
+
+/// Fuzz test for `basic string aggr(count/count distinct/min/max)` + `group
by single int64`
+#[tokio::test(flavor = "multi_thread")]
+async fn test_basic_string_aggr_group_by_single_int64() {
+ let builder = AggregationFuzzerBuilder::default();
+
+ // Define data generator config
+ let columns = vec![
+ ColumnDescr::new("a", DataType::Utf8),
+ ColumnDescr::new("b", DataType::Int64),
+ ColumnDescr::new("c", DataType::Int64),
+ ];
+ let sort_keys_set = vec![
+ vec!["b".to_string()],
+ vec!["c".to_string(), "b".to_string()],
+ ];
+ let data_gen_config = DatasetGeneratorConfig {
+ columns,
+ rows_num_range: (512, 1024),
+ sort_keys_set,
+ };
+
+ // Build fuzzer
+ let fuzzer = builder
+ .data_gen_config(data_gen_config)
+ .data_gen_rounds(8)
+ // FIXME: Encounter error in min/max
+ // ArrowError(InvalidArgumentError("number of columns(1) must match
number of fields(2) in schema"))
+ // .add_sql("SELECT b, max(a) FROM fuzz_table GROUP BY b")
+ // .add_sql("SELECT b, min(a) FROM fuzz_table GROUP BY b")
+ .add_sql("SELECT b, count(a) FROM fuzz_table GROUP BY b")
+ .add_sql("SELECT b, count(distinct a) FROM fuzz_table GROUP BY b")
+ .table_name("fuzz_table")
+ .build();
+
+ fuzzer.run().await;
+}
+
+/// Fuzz test for `basic string aggr(count/count distinct/min/max)` + `group
by single string`
+#[tokio::test(flavor = "multi_thread")]
+async fn test_basic_string_aggr_group_by_single_string() {
+ let builder = AggregationFuzzerBuilder::default();
+
+ // Define data generator config
+ let columns = vec![
+ ColumnDescr::new("a", DataType::Utf8),
+ ColumnDescr::new("b", DataType::Utf8),
+ ColumnDescr::new("c", DataType::Int64),
+ ];
+ let sort_keys_set = vec![
+ vec!["b".to_string()],
+ vec!["c".to_string(), "b".to_string()],
+ ];
+ let data_gen_config = DatasetGeneratorConfig {
+ columns,
+ rows_num_range: (512, 1024),
+ sort_keys_set,
+ };
+
+ // Build fuzzer
+ let fuzzer = builder
+ .data_gen_config(data_gen_config)
+ .data_gen_rounds(16)
+ // FIXME: Encounter error in min/max
+ // ArrowError(InvalidArgumentError("number of columns(1) must match
number of fields(2) in schema"))
+ // .add_sql("SELECT b, max(a) FROM fuzz_table GROUP BY b")
+ // .add_sql("SELECT b, min(a) FROM fuzz_table GROUP BY b")
+ .add_sql("SELECT b, count(a) FROM fuzz_table GROUP BY b")
+ .add_sql("SELECT b, count(distinct a) FROM fuzz_table GROUP BY b")
+ .table_name("fuzz_table")
+ .build();
+
+ fuzzer.run().await;
+}
+
+/// Fuzz test for `basic string aggr(count/count distinct/min/max)` + `group
by string + int64`
+#[tokio::test(flavor = "multi_thread")]
+async fn test_basic_string_aggr_group_by_mixed_string_int64() {
+ let builder = AggregationFuzzerBuilder::default();
+
+ // Define data generator config
+ let columns = vec![
+ ColumnDescr::new("a", DataType::Utf8),
+ ColumnDescr::new("b", DataType::Utf8),
+ ColumnDescr::new("c", DataType::Int64),
+ ColumnDescr::new("d", DataType::Int32),
+ ];
+ let sort_keys_set = vec![
+ vec!["b".to_string(), "c".to_string()],
+ vec!["d".to_string(), "b".to_string(), "c".to_string()],
+ ];
+ let data_gen_config = DatasetGeneratorConfig {
+ columns,
+ rows_num_range: (512, 1024),
+ sort_keys_set,
+ };
+
+ // Build fuzzer
+ let fuzzer = builder
+ .data_gen_config(data_gen_config)
+ .data_gen_rounds(16)
+ // FIXME: Encounter error in min/max
+ // ArrowError(InvalidArgumentError("number of columns(1) must match
number of fields(2) in schema"))
+ // .add_sql("SELECT b, c, max(a) FROM fuzz_table GROUP BY b, c")
+ // .add_sql("SELECT b, c, min(a) FROM fuzz_table GROUP BY b, c")
+ .add_sql("SELECT b, c, count(a) FROM fuzz_table GROUP BY b, c")
+ .add_sql("SELECT b, c, count(distinct a) FROM fuzz_table GROUP BY b,
c")
+ .table_name("fuzz_table")
+ .build();
+
+ fuzzer.run().await;
+}
+
+// ========================================================================
+// The old aggregation fuzz tests
+// ========================================================================
+/// Tracks if this stream is generating input or output
/// Tests that streaming aggregate and batch (non streaming) aggregate produce
/// same results
#[tokio::test(flavor = "multi_thread")]
@@ -311,6 +612,7 @@ async fn group_by_string_test(
let actual = extract_result_counts(results);
assert_eq!(expected, actual);
}
+
async fn verify_ordered_aggregate(frame: &DataFrame, expected_sort: bool) {
struct Visitor {
expected_sort: bool,
diff --git
a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs
b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs
new file mode 100644
index 0000000000..af454bee7c
--- /dev/null
+++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs
@@ -0,0 +1,343 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::{cmp, sync::Arc};
+
+use datafusion::{
+ datasource::MemTable,
+ prelude::{SessionConfig, SessionContext},
+};
+use datafusion_catalog::TableProvider;
+use datafusion_common::error::Result;
+use datafusion_common::ScalarValue;
+use datafusion_expr::col;
+use rand::{thread_rng, Rng};
+
+use crate::fuzz_cases::aggregation_fuzzer::data_generator::Dataset;
+
+/// SessionContext generator
+///
+/// During testing, `generate_baseline` will be called firstly to generate a
standard [`SessionContext`],
+/// and we will run `sql` on it to get the `expected result`. Then `generate`
will be called some times to
+/// generate some random [`SessionContext`]s, and we will run the same `sql`
on them to get `actual results`.
+/// Finally, we compare the `actual results` with `expected result`, the test
only success while all they are
+/// same with the expected.
+///
+/// Following parameters of [`SessionContext`] used in query running will be
generated randomly:
+/// - `batch_size`
+/// - `target_partitions`
+/// - `skip_partial parameters`
+/// - hint `sorted` or not
+/// - `spilling` or not (TODO, I think a special `MemoryPool` may be needed
+/// to support this)
+///
+pub struct SessionContextGenerator {
+ /// Current testing dataset
+ dataset: Arc<Dataset>,
+
+ /// Table name of the test table
+ table_name: String,
+
+ /// Used in generate the random `batch_size`
+ ///
+ /// The generated `batch_size` is between (0, total_rows_num]
+ max_batch_size: usize,
+
+ /// Candidate `SkipPartialParams` which will be picked randomly
+ candidate_skip_partial_params: Vec<SkipPartialParams>,
+
+ /// The upper bound of the randomly generated target partitions,
+ /// and the lower bound will be 1
+ max_target_partitions: usize,
+}
+
+impl SessionContextGenerator {
+ pub fn new(dataset_ref: Arc<Dataset>, table_name: &str) -> Self {
+ let candidate_skip_partial_params = vec![
+ SkipPartialParams::ensure_trigger(),
+ SkipPartialParams::ensure_not_trigger(),
+ ];
+
+ let max_batch_size = cmp::max(1, dataset_ref.total_rows_num);
+ let max_target_partitions = num_cpus::get();
+
+ Self {
+ dataset: dataset_ref,
+ table_name: table_name.to_string(),
+ max_batch_size,
+ candidate_skip_partial_params,
+ max_target_partitions,
+ }
+ }
+}
+
+impl SessionContextGenerator {
+ /// Generate the `SessionContext` for the baseline run
+ pub fn generate_baseline(&self) -> Result<SessionContextWithParams> {
+ let schema = self.dataset.batches[0].schema();
+ let batches = self.dataset.batches.clone();
+ let provider = MemTable::try_new(schema, vec![batches])?;
+
+ // The baseline context should try best to disable all optimizations,
+ // and pursuing the rightness.
+ let batch_size = self.max_batch_size;
+ let target_partitions = 1;
+ let skip_partial_params = SkipPartialParams::ensure_not_trigger();
+
+ let builder = GeneratedSessionContextBuilder {
+ batch_size,
+ target_partitions,
+ skip_partial_params,
+ sort_hint: false,
+ table_name: self.table_name.clone(),
+ table_provider: Arc::new(provider),
+ };
+
+ builder.build()
+ }
+
+ /// Randomly generate session context
+ pub fn generate(&self) -> Result<SessionContextWithParams> {
+ let mut rng = thread_rng();
+ let schema = self.dataset.batches[0].schema();
+ let batches = self.dataset.batches.clone();
+ let provider = MemTable::try_new(schema, vec![batches])?;
+
+ // We will randomly generate following options:
+ // - `batch_size`, from range: [1, `total_rows_num`]
+ // - `target_partitions`, from range: [1, cpu_num]
+ // - `skip_partial`, trigger or not trigger currently for simplicity
+ // - `sorted`, if found a sorted dataset, will or will not push down
this information
+ // - `spilling`(TODO)
+ let batch_size = rng.gen_range(1..=self.max_batch_size);
+
+ let target_partitions = rng.gen_range(1..=self.max_target_partitions);
+
+ let skip_partial_params_idx =
+ rng.gen_range(0..self.candidate_skip_partial_params.len());
+ let skip_partial_params =
+ self.candidate_skip_partial_params[skip_partial_params_idx];
+
+ let (provider, sort_hint) =
+ if rng.gen_bool(0.5) && !self.dataset.sort_keys.is_empty() {
+ // Sort keys exist and random to push down
+ let sort_exprs = self
+ .dataset
+ .sort_keys
+ .iter()
+ .map(|key| col(key).sort(true, true))
+ .collect::<Vec<_>>();
+ (provider.with_sort_order(vec![sort_exprs]), true)
+ } else {
+ (provider, false)
+ };
+
+ let builder = GeneratedSessionContextBuilder {
+ batch_size,
+ target_partitions,
+ sort_hint,
+ skip_partial_params,
+ table_name: self.table_name.clone(),
+ table_provider: Arc::new(provider),
+ };
+
+ builder.build()
+ }
+}
+
+/// The generated [`SessionContext`] with its params
+///
+/// Storing the generated `params` is necessary for
+/// reporting the broken test case.
+pub struct SessionContextWithParams {
+ pub ctx: SessionContext,
+ pub params: SessionContextParams,
+}
+
+/// Collect the generated params, and build the [`SessionContext`]
+struct GeneratedSessionContextBuilder {
+ batch_size: usize,
+ target_partitions: usize,
+ sort_hint: bool,
+ skip_partial_params: SkipPartialParams,
+ table_name: String,
+ table_provider: Arc<dyn TableProvider>,
+}
+
+impl GeneratedSessionContextBuilder {
+ fn build(self) -> Result<SessionContextWithParams> {
+ // Build session context
+ let mut session_config = SessionConfig::default();
+ session_config = session_config.set(
+ "datafusion.execution.batch_size",
+ &ScalarValue::UInt64(Some(self.batch_size as u64)),
+ );
+ session_config = session_config.set(
+ "datafusion.execution.target_partitions",
+ &ScalarValue::UInt64(Some(self.target_partitions as u64)),
+ );
+ session_config = session_config.set(
+
"datafusion.execution.skip_partial_aggregation_probe_rows_threshold",
+ &ScalarValue::UInt64(Some(self.skip_partial_params.rows_threshold
as u64)),
+ );
+ session_config = session_config.set(
+
"datafusion.execution.skip_partial_aggregation_probe_ratio_threshold",
+
&ScalarValue::Float64(Some(self.skip_partial_params.ratio_threshold)),
+ );
+
+ let ctx = SessionContext::new_with_config(session_config);
+ ctx.register_table(self.table_name, self.table_provider)?;
+
+ let params = SessionContextParams {
+ batch_size: self.batch_size,
+ target_partitions: self.target_partitions,
+ sort_hint: self.sort_hint,
+ skip_partial_params: self.skip_partial_params,
+ };
+
+ Ok(SessionContextWithParams { ctx, params })
+ }
+}
+
+/// The generated params for [`SessionContext`]
+#[derive(Debug)]
+#[allow(dead_code)]
+pub struct SessionContextParams {
+ batch_size: usize,
+ target_partitions: usize,
+ sort_hint: bool,
+ skip_partial_params: SkipPartialParams,
+}
+
+/// Partial skipping parameters
+#[derive(Debug, Clone, Copy)]
+pub struct SkipPartialParams {
+ /// Related to `skip_partial_aggregation_probe_ratio_threshold` in
`ExecutionOptions`
+ pub ratio_threshold: f64,
+
+ /// Related to `skip_partial_aggregation_probe_rows_threshold` in
`ExecutionOptions`
+ pub rows_threshold: usize,
+}
+
+impl SkipPartialParams {
+ /// Generate `SkipPartialParams` ensuring to trigger partial skipping
+ pub fn ensure_trigger() -> Self {
+ Self {
+ ratio_threshold: 0.0,
+ rows_threshold: 0,
+ }
+ }
+
+ /// Generate `SkipPartialParams` ensuring not to trigger partial skipping
+ pub fn ensure_not_trigger() -> Self {
+ Self {
+ ratio_threshold: 1.0,
+ rows_threshold: usize::MAX,
+ }
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use arrow_array::{RecordBatch, StringArray, UInt32Array};
+ use arrow_schema::{DataType, Field, Schema};
+
+ use crate::fuzz_cases::aggregation_fuzzer::check_equality_of_batches;
+
+ use super::*;
+
+ #[tokio::test]
+ async fn test_generated_context() {
+ // 1. Define a test dataset firstly
+ let a_col: StringArray = [
+ Some("rust"),
+ Some("java"),
+ Some("cpp"),
+ Some("go"),
+ Some("go1"),
+ Some("python"),
+ Some("python1"),
+ Some("python2"),
+ ]
+ .into_iter()
+ .collect();
+ // Sort by "b"
+ let b_col: UInt32Array = [
+ Some(1),
+ Some(2),
+ Some(4),
+ Some(8),
+ Some(8),
+ Some(16),
+ Some(16),
+ Some(16),
+ ]
+ .into_iter()
+ .collect();
+ let schema = Schema::new(vec![
+ Field::new("a", DataType::Utf8, true),
+ Field::new("b", DataType::UInt32, true),
+ ]);
+ let batch = RecordBatch::try_new(
+ Arc::new(schema),
+ vec![Arc::new(a_col), Arc::new(b_col)],
+ )
+ .unwrap();
+
+ // One row a group to create batches
+ let mut batches = Vec::with_capacity(batch.num_rows());
+ for start in 0..batch.num_rows() {
+ let sub_batch = batch.slice(start, 1);
+ batches.push(sub_batch);
+ }
+
+ let dataset = Dataset::new(batches, vec!["b".to_string()]);
+
+ // 2. Generate baseline context, and some randomly session contexts.
+ // Run the same query on them, and all randoms' results should equal
to baseline's
+ let ctx_generator = SessionContextGenerator::new(Arc::new(dataset),
"fuzz_table");
+
+ let query = "select b, count(a) from fuzz_table group by b";
+ let baseline_wrapped_ctx = ctx_generator.generate_baseline().unwrap();
+ let mut random_wrapped_ctxs = Vec::with_capacity(8);
+ for _ in 0..8 {
+ let ctx = ctx_generator.generate().unwrap();
+ random_wrapped_ctxs.push(ctx);
+ }
+
+ let base_result = baseline_wrapped_ctx
+ .ctx
+ .sql(query)
+ .await
+ .unwrap()
+ .collect()
+ .await
+ .unwrap();
+
+ for wrapped_ctx in random_wrapped_ctxs {
+ let random_result = wrapped_ctx
+ .ctx
+ .sql(query)
+ .await
+ .unwrap()
+ .collect()
+ .await
+ .unwrap();
+ check_equality_of_batches(&base_result, &random_result).unwrap();
+ }
+ }
+}
diff --git
a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs
b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs
new file mode 100644
index 0000000000..9d45779295
--- /dev/null
+++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs
@@ -0,0 +1,459 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::sync::Arc;
+
+use arrow_array::{ArrayRef, RecordBatch};
+use arrow_schema::{DataType, Field, Schema};
+use datafusion_common::{arrow_datafusion_err, DataFusionError, Result};
+use datafusion_physical_expr::{expressions::col, PhysicalSortExpr};
+use datafusion_physical_plan::sorts::sort::sort_batch;
+use rand::{
+ rngs::{StdRng, ThreadRng},
+ thread_rng, Rng, SeedableRng,
+};
+use test_utils::{
+ array_gen::{PrimitiveArrayGenerator, StringArrayGenerator},
+ stagger_batch,
+};
+
+/// Config for Data sets generator
+///
+/// # Parameters
+/// - `columns`, you just need to define `column name`s and `column data
type`s
+/// fot the test datasets, and then they will be randomly generated from
generator
+/// when you can `generate` function
+///
+/// - `rows_num_range`, the rows num of the datasets will be randomly
generated
+/// among this range
+///
+/// - `sort_keys`, if `sort_keys` are defined, when you can `generate`, the
generator
+/// will generate one `base dataset` firstly. Then the `base dataset`
will be sorted
+/// based on each `sort_key` respectively. And finally `len(sort_keys) +
1` datasets
+/// will be returned
+///
+#[derive(Debug, Clone)]
+pub struct DatasetGeneratorConfig {
+ // Descriptions of columns in datasets, it's `required`
+ pub columns: Vec<ColumnDescr>,
+
+ // Rows num range of the generated datasets, it's `required`
+ pub rows_num_range: (usize, usize),
+
+ // Sort keys used to generate the sorted data set, it's optional
+ pub sort_keys_set: Vec<Vec<String>>,
+}
+
+/// Dataset generator
+///
+/// It will generate one random [`Dataset`]s when `generate` function is
called.
+///
+/// The generation logic in `generate`:
+///
+/// - Randomly generate a base record from `batch_generator` firstly.
+/// And `columns`, `rows_num_range` in `config`(detail can see
`DataSetsGeneratorConfig`),
+/// will be used in generation.
+///
+/// - Sort the batch according to `sort_keys` in `config` to generator
another
+/// `len(sort_keys)` sorted batches.
+///
+/// - Split each batch to multiple batches which each sub-batch in has the
randomly `rows num`,
+/// and this multiple batches will be used to create the `Dataset`.
+///
+pub struct DatasetGenerator {
+ batch_generator: RecordBatchGenerator,
+ sort_keys_set: Vec<Vec<String>>,
+}
+
+impl DatasetGenerator {
+ pub fn new(config: DatasetGeneratorConfig) -> Self {
+ let batch_generator = RecordBatchGenerator::new(
+ config.rows_num_range.0,
+ config.rows_num_range.1,
+ config.columns,
+ );
+
+ Self {
+ batch_generator,
+ sort_keys_set: config.sort_keys_set,
+ }
+ }
+
+ pub fn generate(&self) -> Result<Vec<Dataset>> {
+ let mut datasets = Vec::with_capacity(self.sort_keys_set.len() + 1);
+
+ // Generate the base batch
+ let base_batch = self.batch_generator.generate()?;
+ let batches = stagger_batch(base_batch.clone());
+ let dataset = Dataset::new(batches, Vec::new());
+ datasets.push(dataset);
+
+ // Generate the related sorted batches
+ let schema = base_batch.schema_ref();
+ for sort_keys in self.sort_keys_set.clone() {
+ let sort_exprs = sort_keys
+ .iter()
+ .map(|key| {
+ let col_expr = col(key, schema)?;
+ Ok(PhysicalSortExpr::new_default(col_expr))
+ })
+ .collect::<Result<Vec<_>>>()?;
+ let sorted_batch = sort_batch(&base_batch, &sort_exprs, None)?;
+
+ let batches = stagger_batch(sorted_batch);
+ let dataset = Dataset::new(batches, sort_keys);
+ datasets.push(dataset);
+ }
+
+ Ok(datasets)
+ }
+}
+
+/// Single test data set
+#[derive(Debug)]
+pub struct Dataset {
+ pub batches: Vec<RecordBatch>,
+ pub total_rows_num: usize,
+ pub sort_keys: Vec<String>,
+}
+
+impl Dataset {
+ pub fn new(batches: Vec<RecordBatch>, sort_keys: Vec<String>) -> Self {
+ let total_rows_num = batches.iter().map(|batch|
batch.num_rows()).sum::<usize>();
+
+ Self {
+ batches,
+ total_rows_num,
+ sort_keys,
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct ColumnDescr {
+ // Column name
+ name: String,
+
+ // Data type of this column
+ column_type: DataType,
+}
+
+impl ColumnDescr {
+ #[inline]
+ pub fn new(name: &str, column_type: DataType) -> Self {
+ Self {
+ name: name.to_string(),
+ column_type,
+ }
+ }
+}
+
+/// Record batch generator
+struct RecordBatchGenerator {
+ min_rows_nun: usize,
+
+ max_rows_num: usize,
+
+ columns: Vec<ColumnDescr>,
+
+ candidate_null_pcts: Vec<f64>,
+}
+
+macro_rules! generate_string_array {
+ ($SELF:ident, $NUM_ROWS:ident, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident,
$OFFSET_TYPE:ty) => {{
+ let null_pct_idx =
$BATCH_GEN_RNG.gen_range(0..$SELF.candidate_null_pcts.len());
+ let null_pct = $SELF.candidate_null_pcts[null_pct_idx];
+ let max_len = $BATCH_GEN_RNG.gen_range(1..50);
+ let num_distinct_strings = if $NUM_ROWS > 1 {
+ $BATCH_GEN_RNG.gen_range(1..$NUM_ROWS)
+ } else {
+ $NUM_ROWS
+ };
+
+ let mut generator = StringArrayGenerator {
+ max_len,
+ num_strings: $NUM_ROWS,
+ num_distinct_strings,
+ null_pct,
+ rng: $ARRAY_GEN_RNG,
+ };
+
+ generator.gen_data::<$OFFSET_TYPE>()
+ }};
+}
+
+macro_rules! generate_primitive_array {
+ ($SELF:ident, $NUM_ROWS:ident, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident,
$DATA_TYPE:ident) => {
+ paste::paste! {{
+ let null_pct_idx =
$BATCH_GEN_RNG.gen_range(0..$SELF.candidate_null_pcts.len());
+ let null_pct = $SELF.candidate_null_pcts[null_pct_idx];
+ let num_distinct_primitives = if $NUM_ROWS > 1 {
+ $BATCH_GEN_RNG.gen_range(1..$NUM_ROWS)
+ } else {
+ $NUM_ROWS
+ };
+
+ let mut generator = PrimitiveArrayGenerator {
+ num_primitives: $NUM_ROWS,
+ num_distinct_primitives,
+ null_pct,
+ rng: $ARRAY_GEN_RNG,
+ };
+
+ generator.[< gen_data_ $DATA_TYPE >]()
+ }}}
+}
+
+impl RecordBatchGenerator {
+ fn new(min_rows_nun: usize, max_rows_num: usize, columns:
Vec<ColumnDescr>) -> Self {
+ let candidate_null_pcts = vec![0.0, 0.01, 0.1, 0.5];
+
+ Self {
+ min_rows_nun,
+ max_rows_num,
+ columns,
+ candidate_null_pcts,
+ }
+ }
+
+ fn generate(&self) -> Result<RecordBatch> {
+ let mut rng = thread_rng();
+ let num_rows = rng.gen_range(self.min_rows_nun..=self.max_rows_num);
+ let array_gen_rng = StdRng::from_seed(rng.gen());
+
+ // Build arrays
+ let mut arrays = Vec::with_capacity(self.columns.len());
+ for col in self.columns.iter() {
+ let array = self.generate_array_of_type(
+ col.column_type.clone(),
+ num_rows,
+ &mut rng,
+ array_gen_rng.clone(),
+ );
+ arrays.push(array);
+ }
+
+ // Build schema
+ let fields = self
+ .columns
+ .iter()
+ .map(|col| Field::new(col.name.clone(), col.column_type.clone(),
true))
+ .collect::<Vec<_>>();
+ let schema = Arc::new(Schema::new(fields));
+
+ RecordBatch::try_new(schema, arrays).map_err(|e|
arrow_datafusion_err!(e))
+ }
+
+ fn generate_array_of_type(
+ &self,
+ data_type: DataType,
+ num_rows: usize,
+ batch_gen_rng: &mut ThreadRng,
+ array_gen_rng: StdRng,
+ ) -> ArrayRef {
+ match data_type {
+ DataType::Int8 => {
+ generate_primitive_array!(
+ self,
+ num_rows,
+ batch_gen_rng,
+ array_gen_rng,
+ i8
+ )
+ }
+ DataType::Int16 => {
+ generate_primitive_array!(
+ self,
+ num_rows,
+ batch_gen_rng,
+ array_gen_rng,
+ i16
+ )
+ }
+ DataType::Int32 => {
+ generate_primitive_array!(
+ self,
+ num_rows,
+ batch_gen_rng,
+ array_gen_rng,
+ i32
+ )
+ }
+ DataType::Int64 => {
+ generate_primitive_array!(
+ self,
+ num_rows,
+ batch_gen_rng,
+ array_gen_rng,
+ i64
+ )
+ }
+ DataType::UInt8 => {
+ generate_primitive_array!(
+ self,
+ num_rows,
+ batch_gen_rng,
+ array_gen_rng,
+ u8
+ )
+ }
+ DataType::UInt16 => {
+ generate_primitive_array!(
+ self,
+ num_rows,
+ batch_gen_rng,
+ array_gen_rng,
+ u16
+ )
+ }
+ DataType::UInt32 => {
+ generate_primitive_array!(
+ self,
+ num_rows,
+ batch_gen_rng,
+ array_gen_rng,
+ u32
+ )
+ }
+ DataType::UInt64 => {
+ generate_primitive_array!(
+ self,
+ num_rows,
+ batch_gen_rng,
+ array_gen_rng,
+ u64
+ )
+ }
+ DataType::Float32 => {
+ generate_primitive_array!(
+ self,
+ num_rows,
+ batch_gen_rng,
+ array_gen_rng,
+ f32
+ )
+ }
+ DataType::Float64 => {
+ generate_primitive_array!(
+ self,
+ num_rows,
+ batch_gen_rng,
+ array_gen_rng,
+ f64
+ )
+ }
+ DataType::Utf8 => {
+ generate_string_array!(self, num_rows, batch_gen_rng,
array_gen_rng, i32)
+ }
+ DataType::LargeUtf8 => {
+ generate_string_array!(self, num_rows, batch_gen_rng,
array_gen_rng, i64)
+ }
+ _ => unreachable!(),
+ }
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use arrow_array::UInt32Array;
+
+ use crate::fuzz_cases::aggregation_fuzzer::check_equality_of_batches;
+
+ use super::*;
+
+ #[test]
+ fn test_generated_datasets() {
+ // The test datasets generation config
+ // We expect that after calling `generate`
+ // - Generate 2 datasets
+ // - They have 2 column "a" and "b",
+ // "a"'s type is `Utf8`, and "b"'s type is `UInt32`
+ // - One of them is unsorted, another is sorted by column "b"
+ // - Their rows num should be same and between [16, 32]
+ let config = DatasetGeneratorConfig {
+ columns: vec![
+ ColumnDescr {
+ name: "a".to_string(),
+ column_type: DataType::Utf8,
+ },
+ ColumnDescr {
+ name: "b".to_string(),
+ column_type: DataType::UInt32,
+ },
+ ],
+ rows_num_range: (16, 32),
+ sort_keys_set: vec![vec!["b".to_string()]],
+ };
+
+ let gen = DatasetGenerator::new(config);
+ let datasets = gen.generate().unwrap();
+
+ // Should Generate 2 datasets
+ assert_eq!(datasets.len(), 2);
+
+ // Should have 2 column "a" and "b",
+ // "a"'s type is `Utf8`, and "b"'s type is `UInt32`
+ let check_fields = |batch: &RecordBatch| {
+ assert_eq!(batch.num_columns(), 2);
+ let fields = batch.schema().fields().clone();
+ assert_eq!(fields[0].name(), "a");
+ assert_eq!(*fields[0].data_type(), DataType::Utf8);
+ assert_eq!(fields[1].name(), "b");
+ assert_eq!(*fields[1].data_type(), DataType::UInt32);
+ };
+
+ let batch = &datasets[0].batches[0];
+ check_fields(batch);
+ let batch = &datasets[1].batches[0];
+ check_fields(batch);
+
+ // One batches should be sort by "b"
+ let sorted_batches = &datasets[1].batches;
+ let b_vals = sorted_batches.iter().flat_map(|batch| {
+ let uint_array = batch
+ .column(1)
+ .as_any()
+ .downcast_ref::<UInt32Array>()
+ .unwrap();
+ uint_array.iter()
+ });
+ let mut prev_b_val = u32::MIN;
+ for b_val in b_vals {
+ let b_val = b_val.unwrap_or(u32::MIN);
+ assert!(b_val >= prev_b_val);
+ prev_b_val = b_val;
+ }
+
+ // Two batches should be same after sorting
+ check_equality_of_batches(&datasets[0].batches,
&datasets[1].batches).unwrap();
+
+ // Rows num should between [16, 32]
+ let rows_num0 = datasets[0]
+ .batches
+ .iter()
+ .map(|batch| batch.num_rows())
+ .sum::<usize>();
+ let rows_num1 = datasets[1]
+ .batches
+ .iter()
+ .map(|batch| batch.num_rows())
+ .sum::<usize>();
+ assert_eq!(rows_num0, rows_num1);
+ assert!(rows_num0 >= 16);
+ assert!(rows_num0 <= 32);
+ }
+}
diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs
b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs
new file mode 100644
index 0000000000..abb3404828
--- /dev/null
+++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs
@@ -0,0 +1,281 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::sync::Arc;
+
+use arrow::util::pretty::pretty_format_batches;
+use arrow_array::RecordBatch;
+use rand::{thread_rng, Rng};
+use tokio::task::JoinSet;
+
+use crate::fuzz_cases::aggregation_fuzzer::{
+ check_equality_of_batches,
+ context_generator::{SessionContextGenerator, SessionContextWithParams},
+ data_generator::{Dataset, DatasetGenerator, DatasetGeneratorConfig},
+ run_sql,
+};
+
+/// Rounds to call `generate` of [`SessionContextGenerator`]
+/// in [`AggregationFuzzer`], `ctx_gen_rounds` random [`SessionContext`]
+/// will generated for each dataset for testing.
+const CTX_GEN_ROUNDS: usize = 16;
+
+/// Aggregation fuzzer's builder
+pub struct AggregationFuzzerBuilder {
+ /// See `candidate_sqls` in [`AggregationFuzzer`], no default, and
required to set
+ candidate_sqls: Vec<Arc<str>>,
+
+ /// See `table_name` in [`AggregationFuzzer`], no default, and required to
set
+ table_name: Option<Arc<str>>,
+
+ /// Used to generate `dataset_generator` in [`AggregationFuzzer`],
+ /// no default, and required to set
+ data_gen_config: Option<DatasetGeneratorConfig>,
+
+ /// See `data_gen_rounds` in [`AggregationFuzzer`], default 16
+ data_gen_rounds: usize,
+}
+
+impl AggregationFuzzerBuilder {
+ fn new() -> Self {
+ Self {
+ candidate_sqls: Vec::new(),
+ table_name: None,
+ data_gen_config: None,
+ data_gen_rounds: 16,
+ }
+ }
+
+ pub fn add_sql(mut self, sql: &str) -> Self {
+ self.candidate_sqls.push(Arc::from(sql));
+ self
+ }
+
+ pub fn table_name(mut self, table_name: &str) -> Self {
+ self.table_name = Some(Arc::from(table_name));
+ self
+ }
+
+ pub fn data_gen_config(mut self, data_gen_config: DatasetGeneratorConfig)
-> Self {
+ self.data_gen_config = Some(data_gen_config);
+ self
+ }
+
+ pub fn data_gen_rounds(mut self, data_gen_rounds: usize) -> Self {
+ self.data_gen_rounds = data_gen_rounds;
+ self
+ }
+
+ pub fn build(self) -> AggregationFuzzer {
+ assert!(!self.candidate_sqls.is_empty());
+ let candidate_sqls = self.candidate_sqls;
+ let table_name = self.table_name.expect("table_name is required");
+ let data_gen_config = self.data_gen_config.expect("data_gen_config is
required");
+ let data_gen_rounds = self.data_gen_rounds;
+
+ let dataset_generator = DatasetGenerator::new(data_gen_config);
+
+ AggregationFuzzer {
+ candidate_sqls,
+ table_name,
+ dataset_generator,
+ data_gen_rounds,
+ }
+ }
+}
+
+impl Default for AggregationFuzzerBuilder {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+/// AggregationFuzzer randomly generating multiple [`AggregationFuzzTestTask`],
+/// and running them to check the correctness of the optimizations
+/// (e.g. sorted, partial skipping, spilling...)
+pub struct AggregationFuzzer {
+ /// Candidate test queries represented by sqls
+ candidate_sqls: Vec<Arc<str>>,
+
+ /// The queried table name
+ table_name: Arc<str>,
+
+ /// Dataset generator used to randomly generate datasets
+ dataset_generator: DatasetGenerator,
+
+ /// Rounds to call `generate` of [`DatasetGenerator`],
+ /// len(sort_keys_set) + 1` datasets will be generated for testing.
+ ///
+ /// It is suggested to set value 2x or more bigger than num of
+ /// `candidate_sqls` for better test coverage.
+ data_gen_rounds: usize,
+}
+
+/// Query group including the tested dataset and its sql query
+struct QueryGroup {
+ dataset: Dataset,
+ sql: Arc<str>,
+}
+
+impl AggregationFuzzer {
+ pub async fn run(&self) {
+ let mut join_set = JoinSet::new();
+ let mut rng = thread_rng();
+
+ // Loop to generate datasets and its query
+ for _ in 0..self.data_gen_rounds {
+ // Generate datasets first
+ let datasets = self
+ .dataset_generator
+ .generate()
+ .expect("should success to generate dataset");
+
+ // Then for each of them, we random select a test sql for it
+ let query_groups = datasets
+ .into_iter()
+ .map(|dataset| {
+ let sql_idx = rng.gen_range(0..self.candidate_sqls.len());
+ let sql = self.candidate_sqls[sql_idx].clone();
+
+ QueryGroup { dataset, sql }
+ })
+ .collect::<Vec<_>>();
+
+ let tasks = self.generate_fuzz_tasks(query_groups).await;
+ for task in tasks {
+ join_set.spawn(async move {
+ task.run().await;
+ });
+ }
+ }
+
+ while let Some(join_handle) = join_set.join_next().await {
+ // propagate errors
+ join_handle.unwrap();
+ }
+ }
+
+ async fn generate_fuzz_tasks(
+ &self,
+ query_groups: Vec<QueryGroup>,
+ ) -> Vec<AggregationFuzzTestTask> {
+ let mut tasks = Vec::with_capacity(query_groups.len() *
CTX_GEN_ROUNDS);
+ for QueryGroup { dataset, sql } in query_groups {
+ let dataset_ref = Arc::new(dataset);
+ let ctx_generator =
+ SessionContextGenerator::new(dataset_ref.clone(),
&self.table_name);
+
+ // Generate the baseline context, and get the baseline result
firstly
+ let baseline_ctx_with_params = ctx_generator
+ .generate_baseline()
+ .expect("should success to generate baseline session context");
+ let baseline_result = run_sql(&sql, &baseline_ctx_with_params.ctx)
+ .await
+ .expect("should success to run baseline sql");
+ let baseline_result = Arc::new(baseline_result);
+ // Generate test tasks
+ for _ in 0..CTX_GEN_ROUNDS {
+ let ctx_with_params = ctx_generator
+ .generate()
+ .expect("should success to generate session context");
+ let task = AggregationFuzzTestTask {
+ dataset_ref: dataset_ref.clone(),
+ expected_result: baseline_result.clone(),
+ sql: sql.clone(),
+ ctx_with_params,
+ };
+
+ tasks.push(task);
+ }
+ }
+ tasks
+ }
+}
+
+/// One test task generated by [`AggregationFuzzer`]
+///
+/// It includes:
+/// - `expected_result`, the expected result generated by baseline
[`SessionContext`]
+/// (disable all possible optimizations for ensuring correctness).
+///
+/// - `ctx`, a randomly generated [`SessionContext`], `sql` will be run
+/// on it after, and check if the result is equal to expected.
+///
+/// - `sql`, the selected test sql
+///
+/// - `dataset_ref`, the input dataset, store it for error reported when
found
+/// the inconsistency between the one for `ctx` and `expected results`.
+///
+struct AggregationFuzzTestTask {
+ /// Generated session context in current test case
+ ctx_with_params: SessionContextWithParams,
+
+ /// Expected result in current test case
+ /// It is generate from `query` + `baseline session context`
+ expected_result: Arc<Vec<RecordBatch>>,
+
+ /// The test query
+ /// Use sql to represent it currently.
+ sql: Arc<str>,
+
+ /// The test dataset for error reporting
+ dataset_ref: Arc<Dataset>,
+}
+
+impl AggregationFuzzTestTask {
+ async fn run(&self) {
+ let task_result = run_sql(&self.sql, &self.ctx_with_params.ctx)
+ .await
+ .expect("should success to run sql");
+ self.check_result(&task_result, &self.expected_result);
+ }
+
+ // TODO: maybe we should persist the `expected_result` and `task_result`,
+ // because the readability is not so good if we just print it.
+ fn check_result(&self, task_result: &[RecordBatch], expected_result:
&[RecordBatch]) {
+ let result = check_equality_of_batches(task_result, expected_result);
+ if let Err(e) = result {
+ // If we found inconsistent result, we print the test details for
reproducing at first
+ println!(
+ "##### AggregationFuzzer error report #####
+ ### Sql:\n{}\n\
+ ### Schema:\n{}\n\
+ ### Session context params:\n{:?}\n\
+ ### Inconsistent row:\n\
+ - row_idx:{}\n\
+ - task_row:{}\n\
+ - expected_row:{}\n\
+ ### Task total result:\n{}\n\
+ ### Expected total result:\n{}\n\
+ ### Input:\n{}\n\
+ ",
+ self.sql,
+ self.dataset_ref.batches[0].schema_ref(),
+ self.ctx_with_params.params,
+ e.row_idx,
+ e.lhs_row,
+ e.rhs_row,
+ pretty_format_batches(task_result).unwrap(),
+ pretty_format_batches(expected_result).unwrap(),
+ pretty_format_batches(&self.dataset_ref.batches).unwrap(),
+ );
+
+ // Then we just panic
+ panic!();
+ }
+ }
+}
diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/mod.rs
b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/mod.rs
new file mode 100644
index 0000000000..d93a5b7b93
--- /dev/null
+++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/mod.rs
@@ -0,0 +1,69 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::util::pretty::pretty_format_batches;
+use arrow_array::RecordBatch;
+use datafusion::prelude::SessionContext;
+use datafusion_common::error::Result;
+
+mod context_generator;
+mod data_generator;
+mod fuzzer;
+
+pub use data_generator::{ColumnDescr, DatasetGeneratorConfig};
+pub use fuzzer::*;
+
+#[derive(Debug)]
+pub(crate) struct InconsistentResult {
+ pub row_idx: usize,
+ pub lhs_row: String,
+ pub rhs_row: String,
+}
+
+pub(crate) fn check_equality_of_batches(
+ lhs: &[RecordBatch],
+ rhs: &[RecordBatch],
+) -> std::result::Result<(), InconsistentResult> {
+ let lhs_formatted_batches =
pretty_format_batches(lhs).unwrap().to_string();
+ let mut lhs_formatted_batches_sorted: Vec<&str> =
+ lhs_formatted_batches.trim().lines().collect();
+ lhs_formatted_batches_sorted.sort_unstable();
+ let rhs_formatted_batches =
pretty_format_batches(rhs).unwrap().to_string();
+ let mut rhs_formatted_batches_sorted: Vec<&str> =
+ rhs_formatted_batches.trim().lines().collect();
+ rhs_formatted_batches_sorted.sort_unstable();
+
+ for (row_idx, (lhs_row, rhs_row)) in lhs_formatted_batches_sorted
+ .iter()
+ .zip(&rhs_formatted_batches_sorted)
+ .enumerate()
+ {
+ if lhs_row != rhs_row {
+ return Err(InconsistentResult {
+ row_idx,
+ lhs_row: lhs_row.to_string(),
+ rhs_row: rhs_row.to_string(),
+ });
+ }
+ }
+
+ Ok(())
+}
+
+pub(crate) async fn run_sql(sql: &str, ctx: &SessionContext) ->
Result<Vec<RecordBatch>> {
+ ctx.sql(sql).await?.collect().await
+}
diff --git a/datafusion/core/tests/fuzz_cases/mod.rs
b/datafusion/core/tests/fuzz_cases/mod.rs
index 69241571b4..5bc36b963c 100644
--- a/datafusion/core/tests/fuzz_cases/mod.rs
+++ b/datafusion/core/tests/fuzz_cases/mod.rs
@@ -21,6 +21,7 @@ mod join_fuzz;
mod merge_fuzz;
mod sort_fuzz;
+mod aggregation_fuzzer;
mod limit_fuzz;
mod sort_preserving_repartition_fuzz;
mod window_fuzz;
diff --git a/test-utils/Cargo.toml b/test-utils/Cargo.toml
index 325a2cc2fc..414fa5569c 100644
--- a/test-utils/Cargo.toml
+++ b/test-utils/Cargo.toml
@@ -29,4 +29,5 @@ workspace = true
arrow = { workspace = true }
datafusion-common = { workspace = true, default-features = true }
env_logger = { workspace = true }
+paste = "1.0.15"
rand = { workspace = true }
diff --git a/datafusion/core/tests/fuzz_cases/mod.rs
b/test-utils/src/array_gen/mod.rs
similarity index 82%
copy from datafusion/core/tests/fuzz_cases/mod.rs
copy to test-utils/src/array_gen/mod.rs
index 69241571b4..4a799ae737 100644
--- a/datafusion/core/tests/fuzz_cases/mod.rs
+++ b/test-utils/src/array_gen/mod.rs
@@ -15,12 +15,8 @@
// specific language governing permissions and limitations
// under the License.
-mod aggregate_fuzz;
-mod distinct_count_string_fuzz;
-mod join_fuzz;
-mod merge_fuzz;
-mod sort_fuzz;
+mod primitive;
+mod string;
-mod limit_fuzz;
-mod sort_preserving_repartition_fuzz;
-mod window_fuzz;
+pub use primitive::PrimitiveArrayGenerator;
+pub use string::StringArrayGenerator;
diff --git a/test-utils/src/array_gen/primitive.rs
b/test-utils/src/array_gen/primitive.rs
new file mode 100644
index 0000000000..f70ebf6686
--- /dev/null
+++ b/test-utils/src/array_gen/primitive.rs
@@ -0,0 +1,80 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::array::{ArrayRef, PrimitiveArray, UInt32Array};
+use arrow::datatypes::{
+ Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
UInt16Type,
+ UInt32Type, UInt64Type, UInt8Type,
+};
+use rand::rngs::StdRng;
+use rand::Rng;
+
+/// Randomly generate primitive array
+pub struct PrimitiveArrayGenerator {
+ /// the total number of strings in the output
+ pub num_primitives: usize,
+ /// The number of distinct strings in the columns
+ pub num_distinct_primitives: usize,
+ /// The percentage of nulls in the columns
+ pub null_pct: f64,
+ /// Random number generator
+ pub rng: StdRng,
+}
+
+macro_rules! impl_gen_data {
+ ($NATIVE_TYPE:ty, $ARROW_TYPE:ident) => {
+ paste::paste! {
+ pub fn [< gen_data_ $NATIVE_TYPE >](&mut self) -> ArrayRef {
+ // table of strings from which to draw
+ let distinct_primitives: PrimitiveArray<$ARROW_TYPE> =
(0..self.num_distinct_primitives)
+ .map(|_| Some(self.rng.gen::<$NATIVE_TYPE>()))
+ .collect();
+
+ // pick num_strings randomly from the distinct string table
+ let indicies: UInt32Array = (0..self.num_primitives)
+ .map(|_| {
+ if self.rng.gen::<f64>() < self.null_pct {
+ None
+ } else if self.num_distinct_primitives > 1 {
+ let range = 1..(self.num_distinct_primitives as
u32);
+ Some(self.rng.gen_range(range))
+ } else {
+ Some(0)
+ }
+ })
+ .collect();
+
+ let options = None;
+ arrow::compute::take(&distinct_primitives, &indicies,
options).unwrap()
+ }
+ }
+ };
+}
+
+// TODO: support generating more primitive arrays
+impl PrimitiveArrayGenerator {
+ impl_gen_data!(i8, Int8Type);
+ impl_gen_data!(i16, Int16Type);
+ impl_gen_data!(i32, Int32Type);
+ impl_gen_data!(i64, Int64Type);
+ impl_gen_data!(u8, UInt8Type);
+ impl_gen_data!(u16, UInt16Type);
+ impl_gen_data!(u32, UInt32Type);
+ impl_gen_data!(u64, UInt64Type);
+ impl_gen_data!(f32, Float32Type);
+ impl_gen_data!(f64, Float64Type);
+}
diff --git a/test-utils/src/string_gen.rs b/test-utils/src/array_gen/string.rs
similarity index 54%
copy from test-utils/src/string_gen.rs
copy to test-utils/src/array_gen/string.rs
index 530fc15353..fbfa2bb941 100644
--- a/test-utils/src/string_gen.rs
+++ b/test-utils/src/array_gen/string.rs
@@ -14,16 +14,13 @@
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
-//
-// use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait,
RecordBatch, UInt32Array};
-use crate::stagger_batch;
+
use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait, UInt32Array};
-use arrow::record_batch::RecordBatch;
use rand::rngs::StdRng;
-use rand::{thread_rng, Rng, SeedableRng};
+use rand::Rng;
-/// Randomly generate strings
-pub struct StringBatchGenerator {
+/// Randomly generate string arrays
+pub struct StringArrayGenerator {
//// The maximum length of the strings
pub max_len: usize,
/// the total number of strings in the output
@@ -36,41 +33,10 @@ pub struct StringBatchGenerator {
pub rng: StdRng,
}
-impl StringBatchGenerator {
- /// Make batches of random strings with a random length columns "a" and
"b".
- ///
- /// * "a" is a StringArray
- /// * "b" is a LargeStringArray
- pub fn make_input_batches(&mut self) -> Vec<RecordBatch> {
- // use a random number generator to pick a random sized output
- let batch = RecordBatch::try_from_iter(vec![
- ("a", self.gen_data::<i32>()),
- ("b", self.gen_data::<i64>()),
- ])
- .unwrap();
- stagger_batch(batch)
- }
-
- /// Return a column sorted array of random strings, sorted by a
- ///
- /// if large is false, the array is a StringArray
- /// if large is true, the array is a LargeStringArray
- pub fn make_sorted_input_batches(&mut self, large: bool) ->
Vec<RecordBatch> {
- let array = if large {
- self.gen_data::<i32>()
- } else {
- self.gen_data::<i64>()
- };
-
- let array = arrow::compute::sort(&array, None).unwrap();
-
- let batch = RecordBatch::try_from_iter(vec![("a", array)]).unwrap();
- stagger_batch(batch)
- }
-
+impl StringArrayGenerator {
/// Creates a StringArray or LargeStringArray with random strings according
/// to the parameters of the BatchGenerator
- fn gen_data<O: OffsetSizeTrait>(&mut self) -> ArrayRef {
+ pub fn gen_data<O: OffsetSizeTrait>(&mut self) -> ArrayRef {
// table of strings from which to draw
let distinct_strings: GenericStringArray<O> =
(0..self.num_distinct_strings)
.map(|_| Some(random_string(&mut self.rng, self.max_len)))
@@ -93,33 +59,6 @@ impl StringBatchGenerator {
let options = None;
arrow::compute::take(&distinct_strings, &indicies, options).unwrap()
}
-
- /// Return an set of `BatchGenerator`s that cover a range of interesting
- /// cases
- pub fn interesting_cases() -> Vec<Self> {
- let mut cases = vec![];
- let mut rng = thread_rng();
- for null_pct in [0.0, 0.01, 0.1, 0.5] {
- for _ in 0..100 {
- // max length of generated strings
- let max_len = rng.gen_range(1..50);
- let num_strings = rng.gen_range(1..100);
- let num_distinct_strings = if num_strings > 1 {
- rng.gen_range(1..num_strings)
- } else {
- num_strings
- };
- cases.push(StringBatchGenerator {
- max_len,
- num_strings,
- num_distinct_strings,
- null_pct,
- rng: StdRng::from_seed(rng.gen()),
- })
- }
- }
- cases
- }
}
/// Return a string of random characters of length 1..=max_len
diff --git a/test-utils/src/lib.rs b/test-utils/src/lib.rs
index 3ddba2fec8..9db8920833 100644
--- a/test-utils/src/lib.rs
+++ b/test-utils/src/lib.rs
@@ -22,6 +22,7 @@ use datafusion_common::cast::as_int32_array;
use rand::prelude::StdRng;
use rand::{Rng, SeedableRng};
+pub mod array_gen;
mod data_gen;
mod string_gen;
pub mod tpcds;
diff --git a/test-utils/src/string_gen.rs b/test-utils/src/string_gen.rs
index 530fc15353..725eb22b85 100644
--- a/test-utils/src/string_gen.rs
+++ b/test-utils/src/string_gen.rs
@@ -1,3 +1,4 @@
+use crate::array_gen::StringArrayGenerator;
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
@@ -14,27 +15,14 @@
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
-//
-// use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait,
RecordBatch, UInt32Array};
+
use crate::stagger_batch;
-use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait, UInt32Array};
use arrow::record_batch::RecordBatch;
use rand::rngs::StdRng;
use rand::{thread_rng, Rng, SeedableRng};
/// Randomly generate strings
-pub struct StringBatchGenerator {
- //// The maximum length of the strings
- pub max_len: usize,
- /// the total number of strings in the output
- pub num_strings: usize,
- /// The number of distinct strings in the columns
- pub num_distinct_strings: usize,
- /// The percentage of nulls in the columns
- pub null_pct: f64,
- /// Random number generator
- pub rng: StdRng,
-}
+pub struct StringBatchGenerator(StringArrayGenerator);
impl StringBatchGenerator {
/// Make batches of random strings with a random length columns "a" and
"b".
@@ -44,8 +32,8 @@ impl StringBatchGenerator {
pub fn make_input_batches(&mut self) -> Vec<RecordBatch> {
// use a random number generator to pick a random sized output
let batch = RecordBatch::try_from_iter(vec![
- ("a", self.gen_data::<i32>()),
- ("b", self.gen_data::<i64>()),
+ ("a", self.0.gen_data::<i32>()),
+ ("b", self.0.gen_data::<i64>()),
])
.unwrap();
stagger_batch(batch)
@@ -57,9 +45,9 @@ impl StringBatchGenerator {
/// if large is true, the array is a LargeStringArray
pub fn make_sorted_input_batches(&mut self, large: bool) ->
Vec<RecordBatch> {
let array = if large {
- self.gen_data::<i32>()
+ self.0.gen_data::<i32>()
} else {
- self.gen_data::<i64>()
+ self.0.gen_data::<i64>()
};
let array = arrow::compute::sort(&array, None).unwrap();
@@ -68,32 +56,6 @@ impl StringBatchGenerator {
stagger_batch(batch)
}
- /// Creates a StringArray or LargeStringArray with random strings according
- /// to the parameters of the BatchGenerator
- fn gen_data<O: OffsetSizeTrait>(&mut self) -> ArrayRef {
- // table of strings from which to draw
- let distinct_strings: GenericStringArray<O> =
(0..self.num_distinct_strings)
- .map(|_| Some(random_string(&mut self.rng, self.max_len)))
- .collect();
-
- // pick num_strings randomly from the distinct string table
- let indicies: UInt32Array = (0..self.num_strings)
- .map(|_| {
- if self.rng.gen::<f64>() < self.null_pct {
- None
- } else if self.num_distinct_strings > 1 {
- let range = 1..(self.num_distinct_strings as u32);
- Some(self.rng.gen_range(range))
- } else {
- Some(0)
- }
- })
- .collect();
-
- let options = None;
- arrow::compute::take(&distinct_strings, &indicies, options).unwrap()
- }
-
/// Return an set of `BatchGenerator`s that cover a range of interesting
/// cases
pub fn interesting_cases() -> Vec<Self> {
@@ -109,31 +71,15 @@ impl StringBatchGenerator {
} else {
num_strings
};
- cases.push(StringBatchGenerator {
+ cases.push(StringBatchGenerator(StringArrayGenerator {
max_len,
num_strings,
num_distinct_strings,
null_pct,
rng: StdRng::from_seed(rng.gen()),
- })
+ }))
}
}
cases
}
}
-
-/// Return a string of random characters of length 1..=max_len
-fn random_string(rng: &mut StdRng, max_len: usize) -> String {
- // pick characters at random (not just ascii)
- match max_len {
- 0 => "".to_string(),
- 1 => String::from(rng.gen::<char>()),
- _ => {
- let len = rng.gen_range(1..=max_len);
- rng.sample_iter::<char, _>(rand::distributions::Standard)
- .take(len)
- .map(char::from)
- .collect::<String>()
- }
- }
-}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]