This is an automated email from the ASF dual-hosted git repository.

mgrund pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark-connect-go.git


The following commit(s) were added to refs/heads/master by this push:
     new 8c9f621  #52 Support for custom iterators on DataFrame
8c9f621 is described below

commit 8c9f62129461db5a8aad7145def6560cc3820505
Author: Martin Grund <[email protected]>
AuthorDate: Thu Jan 2 18:02:54 2025 +0100

    #52 Support for custom iterators on DataFrame
    
    ### What changes were proposed in this pull request?
    
    Based on the support for [iterators and ranging 
](https://pkg.go.dev/itermaster)over functions in Go 1.23, this patch adds 
support for such an operator for the DataFrame.
    
    Following the naming conventions of Golang the new function is called 
`All(context.Context)` and allows to write the following idiomatic code:
    
    ```golang
    df, err := spark.Sql(ctx, "select * from range(10)")
    if err != nil {
      panic(err)
    }
    for row, err := range df.All(ctx) {
      // ...
    }
    ```
    
    This avoids calling `Collect(ctx)` and then ranging over the result, but it 
is semantically equivalent
    
    ### Why are the changes needed?
    Simplicity
    
    ### Does this PR introduce _any_ user-facing change?
    Adds new custom range iterator
    
    ### How was this patch tested?
    Added tests
    
    Closes #103 from grundprinzip/df_custom_iterator.
    
    Authored-by: Martin Grund <[email protected]>
    Signed-off-by: Martin Grund <[email protected]>
---
 go.mod                                       |  4 +---
 internal/tests/integration/dataframe_test.go | 21 +++++++++++++++++++++
 spark/sql/dataframe.go                       | 17 +++++++++++++++++
 3 files changed, 39 insertions(+), 3 deletions(-)

diff --git a/go.mod b/go.mod
index e43afb5..2deb77f 100644
--- a/go.mod
+++ b/go.mod
@@ -15,9 +15,7 @@
 
 module github.com/apache/spark-connect-go/v35
 
-go 1.22.0
-
-toolchain go1.23.2
+go 1.23.2
 
 require (
        github.com/apache/arrow-go/v18 v18.0.0
diff --git a/internal/tests/integration/dataframe_test.go 
b/internal/tests/integration/dataframe_test.go
index 1290dc9..6220b8e 100644
--- a/internal/tests/integration/dataframe_test.go
+++ b/internal/tests/integration/dataframe_test.go
@@ -1039,3 +1039,24 @@ func TestDataFrame_DFNaFunctions(t *testing.T) {
 
        assert.Equal(t, "Bob", rows[0].At(2))
 }
+
+func TestDataFrame_RangeIter(t *testing.T) {
+       ctx, spark := connect()
+       df, err := spark.Sql(ctx, "select * from range(10)")
+       assert.NoError(t, err)
+       cnt := 0
+       for row, err := range df.All(ctx) {
+               assert.NoError(t, err)
+               assert.NotNil(t, row)
+               cnt++
+       }
+       assert.Equal(t, 10, cnt)
+
+       // Check that errors are properly propagated
+       df, err = spark.Sql(ctx, "select if(id = 5, raise_error('handle'), 
false) from range(10)")
+       assert.NoError(t, err)
+       for _, err := range df.All(ctx) {
+               // The error is immediately thrown:
+               assert.Error(t, err)
+       }
+}
diff --git a/spark/sql/dataframe.go b/spark/sql/dataframe.go
index 2612387..1d2b20b 100644
--- a/spark/sql/dataframe.go
+++ b/spark/sql/dataframe.go
@@ -18,6 +18,7 @@ package sql
 import (
        "context"
        "fmt"
+       "iter"
        "math/rand/v2"
 
        "github.com/apache/arrow-go/v18/arrow"
@@ -41,6 +42,7 @@ type ResultCollector interface {
 type DataFrame interface {
        // PlanId returns the plan id of the data frame.
        PlanId() int64
+       All(ctx context.Context) iter.Seq2[types.Row, error]
        Agg(ctx context.Context, exprs ...column.Convertible) (DataFrame, error)
        AggWithMap(ctx context.Context, exprs map[string]string) (DataFrame, 
error)
        // Alias creates a new DataFrame with the specified subquery alias
@@ -1648,3 +1650,18 @@ func (df *dataFrameImpl) DropNaWithThreshold(ctx 
context.Context, thresh int32,
 func (df *dataFrameImpl) Na() DataFrameNaFunctions {
        return &dataFrameNaFunctionsImpl{dataFrame: df}
 }
+
+func (df *dataFrameImpl) All(ctx context.Context) iter.Seq2[types.Row, error] {
+       data, err := df.Collect(ctx)
+       return func(yield func(types.Row, error) bool) {
+               if err != nil {
+                       yield(nil, err)
+                       return
+               }
+               for _, row := range data {
+                       if !yield(row, nil) {
+                               break
+                       }
+               }
+       }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to