This is an automated email from the ASF dual-hosted git repository.

goldmedal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new b4b267ae4b Support 1 or 3 arg in generate_series() UDTF (#13856)
b4b267ae4b is described below

commit b4b267ae4b2ad326d207609538add9f0f9ead506
Author: UBarney <[email protected]>
AuthorDate: Tue Dec 24 10:13:40 2024 +0800

    Support 1 or 3 arg in generate_series() UDTF (#13856)
    
    * Support 1 or 3 args in generate_series() UDTF
    
    * address comment
---
 datafusion/functions-table/src/generate_series.rs  | 168 ++++++++++++---------
 .../sqllogictest/test_files/table_functions.slt    |  63 +++++++-
 2 files changed, 154 insertions(+), 77 deletions(-)

diff --git a/datafusion/functions-table/src/generate_series.rs 
b/datafusion/functions-table/src/generate_series.rs
index ced43ea8f0..887daa71ec 100644
--- a/datafusion/functions-table/src/generate_series.rs
+++ b/datafusion/functions-table/src/generate_series.rs
@@ -22,7 +22,7 @@ use async_trait::async_trait;
 use datafusion_catalog::Session;
 use datafusion_catalog::TableFunctionImpl;
 use datafusion_catalog::TableProvider;
-use datafusion_common::{not_impl_err, plan_err, Result, ScalarValue};
+use datafusion_common::{plan_err, Result, ScalarValue};
 use datafusion_expr::{Expr, TableType};
 use datafusion_physical_plan::memory::{LazyBatchGenerator, LazyMemoryExec};
 use datafusion_physical_plan::ExecutionPlan;
@@ -30,28 +30,45 @@ use parking_lot::RwLock;
 use std::fmt;
 use std::sync::Arc;
 
-/// Table that generates a series of integers from `start`(inclusive) to 
`end`(inclusive)
+/// Indicates the arguments used for generating a series.
+#[derive(Debug, Clone)]
+enum GenSeriesArgs {
+    /// ContainsNull signifies that at least one argument(start, end, step) 
was null, thus no series will be generated.
+    ContainsNull,
+    /// AllNotNullArgs holds the start, end, and step values for generating 
the series when all arguments are not null.
+    AllNotNullArgs { start: i64, end: i64, step: i64 },
+}
+
+/// Table that generates a series of integers from `start`(inclusive) to 
`end`(inclusive), incrementing by step
 #[derive(Debug, Clone)]
 struct GenerateSeriesTable {
     schema: SchemaRef,
-    // None if input is Null
-    start: Option<i64>,
-    // None if input is Null
-    end: Option<i64>,
+    args: GenSeriesArgs,
 }
 
-/// Table state that generates a series of integers from `start`(inclusive) to 
`end`(inclusive)
+/// Table state that generates a series of integers from `start`(inclusive) to 
`end`(inclusive), incrementing by step
 #[derive(Debug, Clone)]
 struct GenerateSeriesState {
     schema: SchemaRef,
     start: i64, // Kept for display
     end: i64,
+    step: i64,
     batch_size: usize,
 
     /// Tracks current position when generating table
     current: i64,
 }
 
+impl GenerateSeriesState {
+    fn reach_end(&self, val: i64) -> bool {
+        if self.step > 0 {
+            return val > self.end;
+        }
+
+        val < self.end
+    }
+}
+
 /// Detail to display for 'Explain' plan
 impl fmt::Display for GenerateSeriesState {
     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
@@ -65,19 +82,19 @@ impl fmt::Display for GenerateSeriesState {
 
 impl LazyBatchGenerator for GenerateSeriesState {
     fn generate_next_batch(&mut self) -> Result<Option<RecordBatch>> {
-        // Check if we've reached the end
-        if self.current > self.end {
+        let mut buf = Vec::with_capacity(self.batch_size);
+        while buf.len() < self.batch_size && !self.reach_end(self.current) {
+            buf.push(self.current);
+            self.current += self.step;
+        }
+        let array = Int64Array::from(buf);
+
+        if array.is_empty() {
             return Ok(None);
         }
 
-        // Construct batch
-        let batch_end = (self.current + self.batch_size as i64 - 
1).min(self.end);
-        let array = Int64Array::from_iter_values(self.current..=batch_end);
         let batch = RecordBatch::try_new(self.schema.clone(), 
vec![Arc::new(array)])?;
 
-        // Update current position for next batch
-        self.current = batch_end + 1;
-
         Ok(Some(batch))
     }
 }
@@ -104,39 +121,31 @@ impl TableProvider for GenerateSeriesTable {
         _limit: Option<usize>,
     ) -> Result<Arc<dyn ExecutionPlan>> {
         let batch_size = state.config_options().execution.batch_size;
-        match (self.start, self.end) {
-            (Some(start), Some(end)) => {
-                if start > end {
-                    return plan_err!(
-                        "End value must be greater than or equal to start 
value"
-                    );
-                }
-
-                Ok(Arc::new(LazyMemoryExec::try_new(
-                    self.schema.clone(),
-                    vec![Arc::new(RwLock::new(GenerateSeriesState {
-                        schema: self.schema.clone(),
-                        start,
-                        end,
-                        current: start,
-                        batch_size,
-                    }))],
-                )?))
-            }
-            _ => {
-                // Either start or end is None, return a generator that 
outputs 0 rows
-                Ok(Arc::new(LazyMemoryExec::try_new(
-                    self.schema.clone(),
-                    vec![Arc::new(RwLock::new(GenerateSeriesState {
-                        schema: self.schema.clone(),
-                        start: 0,
-                        end: 0,
-                        current: 1,
-                        batch_size,
-                    }))],
-                )?))
-            }
-        }
+
+        let state = match self.args {
+            // if args have null, then return 0 row
+            GenSeriesArgs::ContainsNull => GenerateSeriesState {
+                schema: self.schema.clone(),
+                start: 0,
+                end: 0,
+                step: 1,
+                current: 1,
+                batch_size,
+            },
+            GenSeriesArgs::AllNotNullArgs { start, end, step } => 
GenerateSeriesState {
+                schema: self.schema.clone(),
+                start,
+                end,
+                step,
+                current: start,
+                batch_size,
+            },
+        };
+
+        Ok(Arc::new(LazyMemoryExec::try_new(
+            self.schema.clone(),
+            vec![Arc::new(RwLock::new(state))],
+        )?))
     }
 }
 
@@ -144,37 +153,58 @@ impl TableProvider for GenerateSeriesTable {
 pub struct GenerateSeriesFunc {}
 
 impl TableFunctionImpl for GenerateSeriesFunc {
-    // Check input `exprs` type and number. Input validity check (e.g. start 
<= end)
-    // will be performed in `TableProvider::scan`
     fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
-        // TODO: support 1 or 3 arguments following DuckDB:
-        // <https://duckdb.org/docs/sql/functions/list#generate_series>
-        if exprs.len() == 3 || exprs.len() == 1 {
-            return not_impl_err!("generate_series does not support 1 or 3 
arguments");
+        if exprs.is_empty() || exprs.len() > 3 {
+            return plan_err!("generate_series function requires 1 to 3 
arguments");
         }
 
-        if exprs.len() != 2 {
-            return plan_err!("generate_series expects 2 arguments");
+        let mut normalize_args = Vec::new();
+        for expr in exprs {
+            match expr {
+                Expr::Literal(ScalarValue::Null) => {}
+                Expr::Literal(ScalarValue::Int64(Some(n))) => 
normalize_args.push(*n),
+                _ => return plan_err!("First argument must be an integer 
literal"),
+            };
         }
 
-        let start = match &exprs[0] {
-            Expr::Literal(ScalarValue::Null) => None,
-            Expr::Literal(ScalarValue::Int64(Some(n))) => Some(*n),
-            _ => return plan_err!("First argument must be an integer literal"),
-        };
-
-        let end = match &exprs[1] {
-            Expr::Literal(ScalarValue::Null) => None,
-            Expr::Literal(ScalarValue::Int64(Some(n))) => Some(*n),
-            _ => return plan_err!("Second argument must be an integer 
literal"),
-        };
-
         let schema = Arc::new(Schema::new(vec![Field::new(
             "value",
             DataType::Int64,
             false,
         )]));
 
-        Ok(Arc::new(GenerateSeriesTable { schema, start, end }))
+        if normalize_args.len() != exprs.len() {
+            // contain null
+            return Ok(Arc::new(GenerateSeriesTable {
+                schema,
+                args: GenSeriesArgs::ContainsNull,
+            }));
+        }
+
+        let (start, end, step) = match &normalize_args[..] {
+            [end] => (0, *end, 1),
+            [start, end] => (*start, *end, 1),
+            [start, end, step] => (*start, *end, *step),
+            _ => {
+                return plan_err!("generate_series function requires 1 to 3 
arguments");
+            }
+        };
+
+        if start > end && step > 0 {
+            return plan_err!("start is bigger than end, but increment is 
positive: cannot generate infinite series");
+        }
+
+        if start < end && step < 0 {
+            return plan_err!("start is smaller than end, but increment is 
negative: cannot generate infinite series");
+        }
+
+        if step == 0 {
+            return plan_err!("step cannot be zero");
+        }
+
+        Ok(Arc::new(GenerateSeriesTable {
+            schema,
+            args: GenSeriesArgs::AllNotNullArgs { start, end, step },
+        }))
     }
 }
diff --git a/datafusion/sqllogictest/test_files/table_functions.slt 
b/datafusion/sqllogictest/test_files/table_functions.slt
index 79294993dd..2769da03b8 100644
--- a/datafusion/sqllogictest/test_files/table_functions.slt
+++ b/datafusion/sqllogictest/test_files/table_functions.slt
@@ -16,6 +16,18 @@
 # under the License.
 
 # Test generate_series table function
+query I
+SELECT * FROM generate_series(6)
+----
+0
+1
+2
+3
+4
+5
+6
+
+
 
 query I rowsort
 SELECT * FROM generate_series(1, 5)
@@ -39,11 +51,35 @@ SELECT * FROM generate_series(3, 6)
 5
 6
 
+# #generated_data > batch_size
+query I
+SELECT count(v1) FROM generate_series(-66666,66666) t1(v1)
+----
+133333
+
+
+
+
 query I rowsort
 SELECT SUM(v1) FROM generate_series(1, 5) t1(v1)
 ----
 15
 
+query I
+SELECT * FROM generate_series(6, -1, -2)
+----
+6 
+4 
+2 
+0 
+
+query I
+SELECT * FROM generate_series(6, 66, 666)
+----
+6 
+
+
+
 # Test generate_series with WHERE clause
 query I rowsort
 SELECT * FROM generate_series(1, 10) t1(v1) WHERE v1 % 2 = 0
@@ -93,6 +129,10 @@ ON a.v1 = b.v1 - 1
 2 3
 3 4
 
+#
+# Test generate_series with null arguments
+#
+
 query I
 SELECT * FROM generate_series(NULL, 5)
 ----
@@ -105,6 +145,11 @@ query I
 SELECT * FROM generate_series(NULL, NULL)
 ----
 
+query I
+SELECT * FROM generate_series(1, 5, NULL)
+----
+
+
 query TT
 EXPLAIN SELECT * FROM generate_series(1, 5)
 ----
@@ -115,20 +160,22 @@ physical_plan LazyMemoryExec: partitions=1, 
batch_generators=[generate_series: s
 # Test generate_series with invalid arguments
 #
 
-query error DataFusion error: Error during planning: End value must be greater 
than or equal to start value
+query error DataFusion error: Error during planning: start is bigger than end, 
but increment is positive: cannot generate infinite series
 SELECT * FROM generate_series(5, 1)
 
-statement error DataFusion error: This feature is not implemented: 
generate_series does not support 1 or 3 arguments
-SELECT * FROM generate_series(1, 5, NULL)
+query error DataFusion error: Error during planning: start is smaller than 
end, but increment is negative: cannot generate infinite series
+SELECT * FROM generate_series(-6, 6, -1)
+
+query error DataFusion error: Error during planning: step cannot be zero
+SELECT * FROM generate_series(-6, 6, 0)
+
+query error DataFusion error: Error during planning: start is bigger than end, 
but increment is positive: cannot generate infinite series
+SELECT * FROM generate_series(6, -6, 1)
 
-statement error DataFusion error: This feature is not implemented: 
generate_series does not support 1 or 3 arguments
-SELECT * FROM generate_series(1)
 
-statement error DataFusion error: Error during planning: generate_series 
expects 2 arguments
+statement error DataFusion error: Error during planning: generate_series 
function requires 1 to 3 arguments
 SELECT * FROM generate_series(1, 2, 3, 4)
 
-statement error DataFusion error: Error during planning: Second argument must 
be an integer literal
-SELECT * FROM generate_series(1, '2')
 
 statement error DataFusion error: Error during planning: First argument must 
be an integer literal
 SELECT * FROM generate_series('foo', 'bar')


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to