This is an automated email from the ASF dual-hosted git repository.

mgrund pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark-connect-go.git


The following commit(s) were added to refs/heads/master by this push:
     new 78922ba  [#58] More DataFrame Functionality
78922ba is described below

commit 78922ba2bbe903e4d740ca15a969d4534d469e31
Author: Martin Grund <[email protected]>
AuthorDate: Fri Sep 27 11:17:48 2024 +0200

    [#58] More DataFrame Functionality
    
    ### What changes were proposed in this pull request?
    This PR adds support for the following DF functions:
    
    * CreateOrReplaceTempView
    * CreateGlobalTempView
    * CreateOrReplaceGlobalTempView
    * CrossTab
    * Cube
    * Describe
    * Distinct
    * First
    * FreqItems
    * FreqItemsWithSupport
    * GroupingSets
    * IsEmpty
    * Join
    * Offset
    * OrderBy
    * RandomSplit
    * Rollup
    * Sort
    * Summary
    
    ### Why are the changes needed?
    Compatibility
    
    ### Does this PR introduce _any_ user-facing change?
    new DF functions
    
    ### How was this patch tested?
    Added e2e tests
    
    Closes #71 from grundprinzip/next_batch_v3.
    
    Authored-by: Martin Grund <[email protected]>
    Signed-off-by: Martin Grund <[email protected]>
---
 internal/tests/integration/dataframe_test.go | 207 ++++++++++++++++++-
 spark/sql/dataframe.go                       | 290 ++++++++++++++++++++++++++-
 spark/sql/group.go                           |  26 ++-
 spark/sql/group_test.go                      |   2 +-
 spark/sql/utils/consts.go                    |  33 +++
 5 files changed, 543 insertions(+), 15 deletions(-)

diff --git a/internal/tests/integration/dataframe_test.go 
b/internal/tests/integration/dataframe_test.go
index ec4469b..61ebcdf 100644
--- a/internal/tests/integration/dataframe_test.go
+++ b/internal/tests/integration/dataframe_test.go
@@ -44,8 +44,8 @@ func TestDataFrame_Select(t *testing.T) {
        assert.NoError(t, err)
        assert.Equal(t, 100, len(res))
 
-       row_zero := res[0]
-       vals, err := row_zero.Values()
+       rowZero := res[0]
+       vals, err := rowZero.Values()
        assert.NoError(t, err)
        assert.Equal(t, 2, len(vals))
 
@@ -509,7 +509,9 @@ func TestDataFrame_LimitVersions(t *testing.T) {
        ctx, spark := connect()
        df, err := spark.Sql(ctx, "select * from range(10)")
        assert.NoError(t, err)
-       rows, err := df.Limit(ctx, int32(5))
+       df = df.Limit(ctx, int32(5))
+       assert.NoError(t, err)
+       rows, err := df.Collect(ctx)
        assert.NoError(t, err)
        assert.Len(t, rows, 5)
 
@@ -525,3 +527,202 @@ func TestDataFrame_LimitVersions(t *testing.T) {
        assert.NoError(t, err)
        assert.Len(t, rows, 3)
 }
+
+func TestDataFrame_Sort(t *testing.T) {
+       ctx, spark := connect()
+       src, err := spark.Sql(ctx, "select * from range(10)")
+       assert.NoError(t, err)
+       df, err := src.Sort(ctx, functions.Col("id").Desc())
+       assert.NoError(t, err)
+       res, err := df.Head(ctx, 1)
+       assert.NoError(t, err)
+       vals, err := res[0].Values()
+       assert.NoError(t, err)
+       assert.Equal(t, int64(9), vals[0])
+
+       df, err = src.Sort(ctx, functions.Col("id").Asc())
+       assert.NoError(t, err)
+       res, err = df.Head(ctx, 1)
+       assert.NoError(t, err)
+       vals, err = res[0].Values()
+       assert.NoError(t, err)
+       assert.Equal(t, int64(0), vals[0])
+}
+
+func TestDataFrame_Join(t *testing.T) {
+       ctx, spark := connect()
+       df1, err := spark.Sql(ctx, "select * from range(10)")
+       assert.NoError(t, err)
+       df2, err := spark.Sql(ctx, "select * from range(5)")
+       assert.NoError(t, err)
+
+       df, err := df1.Join(ctx, df2, column.OfDF(df1, 
"id").Eq(column.OfDF(df2, "id")), utils.JoinTypeInner)
+       assert.NoError(t, err)
+       res, err := df.Collect(ctx)
+       assert.NoError(t, err)
+       assert.Equal(t, 5, len(res))
+}
+
+func TestDataFrame_RandomSplits(t *testing.T) {
+       ctx, spark := connect()
+       df, err := spark.Sql(ctx, "select * from range(1000)")
+       assert.NoError(t, err)
+       dfs, err := df.RandomSplit(ctx, []float64{0.3, 0.7})
+       assert.NoError(t, err)
+       assert.Len(t, dfs, 2)
+       c1, err := dfs[0].Count(ctx)
+       assert.NoError(t, err)
+       c2, err := dfs[1].Count(ctx)
+       assert.NoError(t, err)
+       assert.Less(t, c1, c2)
+}
+
+func TestDataFrame_Describe(t *testing.T) {
+       ctx, spark := connect()
+       df, err := spark.Sql(ctx, "select * from range(10)")
+       assert.NoError(t, err)
+       res, err := df.Describe(ctx, "id").Collect(ctx)
+       assert.NoError(t, err)
+       assert.Len(t, res, 5)
+}
+
+func TestDataFrame_Summary(t *testing.T) {
+       ctx, spark := connect()
+       df, err := spark.Sql(ctx, "select id, 'a' as col, 2 as other from 
range(10)")
+       assert.NoError(t, err)
+       res, err := df.Summary(ctx, "count", "stddev").Collect(ctx)
+       assert.NoError(t, err)
+       assert.Len(t, res, 2)
+       v, err := res[0].Values()
+       assert.NoError(t, err)
+       assert.Equal(t, "count", v[0])
+       assert.Len(t, v, 4)
+}
+
+func TestDataFrame_Pivot(t *testing.T) {
+       ctx, spark := connect()
+
+       data := [][]any{
+               {"dotNET", 2012, 10000},
+               {"Java", 2012, 20000},
+               {"dotNET", 2012, 5000},
+               {"dotNET", 2013, 48000},
+               {"Java", 2013, 30000},
+       }
+       schema := types.StructOf(
+               types.NewStructField("course", types.STRING),
+               types.NewStructField("year", types.INTEGER),
+               types.NewStructField("earnings", types.INTEGER))
+
+       df, err := spark.CreateDataFrame(ctx, data, schema)
+       assert.NoError(t, err)
+       gd := df.GroupBy(functions.Col("year"))
+       gd, err = gd.Pivot(ctx, "course", []any{"Java", "dotNET"})
+       assert.NoError(t, err)
+       df, err = gd.Sum(ctx, "earnings")
+       assert.NoError(t, err)
+       res, err := df.Collect(ctx)
+       assert.NoError(t, err)
+       assert.Len(t, res, 2)
+}
+
+func TestDataFrame_Offset(t *testing.T) {
+       ctx, spark := connect()
+       df, err := spark.Sql(ctx, "select * from range(10)")
+       assert.NoError(t, err)
+       df = df.Offset(ctx, int32(5))
+       assert.NoError(t, err)
+       res, err := df.Collect(ctx)
+       assert.NoError(t, err)
+       assert.Len(t, res, 5)
+}
+
+func TestDataFrame_IsEmpty(t *testing.T) {
+       ctx, spark := connect()
+       df, err := spark.Sql(ctx, "select * from range(10)")
+       assert.NoError(t, err)
+       empty, err := df.IsEmpty(ctx)
+       assert.NoError(t, err)
+       assert.False(t, empty)
+
+       df, err = spark.Sql(ctx, "select * from range(0)")
+       assert.NoError(t, err)
+       empty, err = df.IsEmpty(ctx)
+       assert.NoError(t, err)
+       assert.True(t, empty)
+}
+
+func TestDataFrame_First(t *testing.T) {
+       ctx, spark := connect()
+       df, err := spark.Sql(ctx, "select * from range(10)")
+       assert.NoError(t, err)
+       row, err := df.First(ctx)
+       assert.NoError(t, err)
+       vals, err := row.Values()
+       assert.NoError(t, err)
+       assert.Equal(t, int64(0), vals[0])
+}
+
+func TestDataFrame_Distinct(t *testing.T) {
+       ctx, spark := connect()
+       df, err := spark.Sql(ctx, "select * from range(10)")
+       assert.NoError(t, err)
+       df = df.Distinct(ctx)
+       assert.NoError(t, err)
+       res, err := df.Collect(ctx)
+       assert.NoError(t, err)
+       assert.Len(t, res, 10)
+}
+
+func TestDataFrame_CrossTab(t *testing.T) {
+       ctx, spark := connect()
+       data := [][]any{{1, 11}, {1, 11}, {3, 10}, {4, 8}, {4, 8}}
+       schema := types.StructOf(
+               types.NewStructField("c1", types.INTEGER),
+               types.NewStructField("c2", types.INTEGER),
+       )
+
+       df, err := spark.CreateDataFrame(ctx, data, schema)
+       assert.NoError(t, err)
+       df = df.CrossTab(ctx, "c1", "c2")
+       df, err = df.Sort(ctx, column.OfDF(df, "c1_c2").Asc())
+       assert.NoError(t, err)
+       res, err := df.Collect(ctx)
+       assert.NoError(t, err)
+       assert.Len(t, res, 3)
+       v, err := res[0].Values()
+       assert.NoError(t, err)
+       assert.Equal(t, "1", v[0])
+       assert.Equal(t, int64(0), v[1])
+       assert.Equal(t, int64(2), v[2])
+       assert.Equal(t, int64(0), v[3])
+}
+
+func TestDataFrame_SameSemantics(t *testing.T) {
+       ctx, spark := connect()
+       df1, err := spark.Sql(ctx, "select * from range(10)")
+       assert.NoError(t, err)
+       df2, err := spark.Sql(ctx, "select * from range(10)")
+       assert.NoError(t, err)
+       res, _ := df1.SameSemantics(ctx, df2)
+       assert.True(t, res)
+}
+
+func TestDataFrame_SemanticHash(t *testing.T) {
+       ctx, spark := connect()
+       df, err := spark.Sql(ctx, "select * from range(10)")
+       assert.NoError(t, err)
+       hash, err := df.SemanticHash(ctx)
+       assert.NoError(t, err)
+       assert.NotEmpty(t, hash)
+}
+
+// DISABLED: Because we cannot parse LIST types
+//func TestDataFrame_FreqItems(t *testing.T) {
+//     ctx, spark := connect()
+//     df, err := spark.Sql(ctx, "select * from range(10)")
+//     assert.NoError(t, err)
+//     res, err := df.FreqItems(ctx, "id").Collect(ctx)
+//     assert.NoErrorf(t, err, "%+v", err)
+//     assert.Len(t, res, 1)
+//}
diff --git a/spark/sql/dataframe.go b/spark/sql/dataframe.go
index bcd0633..36da21f 100644
--- a/spark/sql/dataframe.go
+++ b/spark/sql/dataframe.go
@@ -19,6 +19,7 @@ package sql
 import (
        "context"
        "fmt"
+       "math/rand/v2"
 
        "github.com/apache/arrow/go/v17/arrow"
        "github.com/apache/spark-connect-go/v35/spark/sql/utils"
@@ -75,8 +76,28 @@ type DataFrame interface {
        Collect(ctx context.Context) ([]Row, error)
        // CreateTempView creates or replaces a temporary view.
        CreateTempView(ctx context.Context, viewName string, replace, global 
bool) error
+       // CreateOrReplaceTempView creates or replaces a temporary view and 
replaces the optional existing view.
+       CreateOrReplaceTempView(ctx context.Context, viewName string) error
+       // CreateGlobalTempView creates a global temporary view.
+       CreateGlobalTempView(ctx context.Context, viewName string) error
+       // CreateOrReplaceGlobalTempView creates or replaces a global temporary 
view and replaces the optional existing view.
+       CreateOrReplaceGlobalTempView(ctx context.Context, viewName string) 
error
        // CrossJoin joins the current DataFrame with another DataFrame using 
the cross product
        CrossJoin(ctx context.Context, other DataFrame) DataFrame
+       // CrossTab computes a pair-wise frequency table of the given columns. 
Also known as a
+       // contingency table.
+       // The first column of each row will be the distinct values of `col1` 
and the column names
+       // will be the distinct values of `col2`. The name of the first column 
will be `$col1_$col2`.
+       // Pairs that have no occurrences will have zero as their counts.
+       CrossTab(ctx context.Context, col1, col2 string) DataFrame
+       // Cube creates a multi-dimensional cube for the current DataFrame using
+       // the specified columns, so we can run aggregations on them.
+       Cube(ctx context.Context, cols ...column.Convertible) *GroupedData
+       // Describe omputes basic statistics for numeric and string columns.
+       // This includes count, mean, stddev, min, and max.
+       Describe(ctx context.Context, cols ...string) DataFrame
+       // Distinct returns a new DataFrame containing the distinct rows in 
this DataFrame.
+       Distinct(ctx context.Context) DataFrame
        // Drop returns a new DataFrame that drops the specified list of 
columns.
        Drop(ctx context.Context, columns ...column.Convertible) (DataFrame, 
error)
        // DropByName returns a new DataFrame that drops the specified list of 
columns by name.
@@ -85,11 +106,16 @@ type DataFrame interface {
        DropDuplicates(ctx context.Context, columns ...string) (DataFrame, 
error)
        // ExceptAll is similar to Substract but does not perform the distinct 
operation.
        ExceptAll(ctx context.Context, other DataFrame) DataFrame
+       // Explain returns the string explain plan for the current DataFrame 
according to the explainMode.
        Explain(ctx context.Context, explainMode utils.ExplainMode) (string, 
error)
        // Filter filters the data frame by a column condition.
        Filter(ctx context.Context, condition column.Convertible) (DataFrame, 
error)
        // FilterByString filters the data frame by a string condition.
        FilterByString(ctx context.Context, condition string) (DataFrame, error)
+       // Returns the first row of the DataFrame.
+       First(ctx context.Context) (Row, error)
+       FreqItems(ctx context.Context, cols ...string) DataFrame
+       FreqItemsWithSupport(ctx context.Context, support float64, cols 
...string) DataFrame
        // GetStorageLevel returns the storage level of the data frame.
        GetStorageLevel(ctx context.Context) (*utils.StorageLevel, error)
        // GroupBy groups the DataFrame by the spcified columns so that the 
aggregation
@@ -101,13 +127,25 @@ type DataFrame interface {
        Intersect(ctx context.Context, other DataFrame) DataFrame
        // IntersectAll performs the set intersection of two data frames and 
returns all rows.
        IntersectAll(ctx context.Context, other DataFrame) DataFrame
-       // Limit returns the first `limit` rows as a list of Row.
-       Limit(ctx context.Context, limit int32) ([]Row, error)
+       // IsEmpty returns true if the DataFrame is empty.
+       IsEmpty(ctx context.Context) (bool, error)
+       // Join joins the current DataFrame with another DataFrame using the 
specified column using the joinType specified.
+       Join(ctx context.Context, other DataFrame, on column.Convertible, 
joinType utils.JoinType) (DataFrame, error)
+       // Limit applies a limit on the DataFrame
+       Limit(ctx context.Context, limit int32) DataFrame
+       // Offset returns a new DataFrame by skipping the first `offset` rows.
+       Offset(ctx context.Context, offset int32) DataFrame
+       // OrderBy is an alias for Sort
+       OrderBy(ctx context.Context, columns ...column.Convertible) (DataFrame, 
error)
        Persist(ctx context.Context, storageLevel utils.StorageLevel) error
+       RandomSplit(ctx context.Context, weights []float64) ([]DataFrame, error)
        // Repartition re-partitions a data frame.
        Repartition(ctx context.Context, numPartitions int, columns []string) 
(DataFrame, error)
        // RepartitionByRange re-partitions a data frame by range partition.
        RepartitionByRange(ctx context.Context, numPartitions int, columns 
...column.Convertible) (DataFrame, error)
+       // Rollup creates a multi-dimensional rollup for the current DataFrame 
using
+       // the specified columns, so we can run aggregation on them.
+       Rollup(ctx context.Context, cols ...column.Convertible) *GroupedData
        // SameSemantics returns true if the other DataFrame has the same 
semantics.
        SameSemantics(ctx context.Context, other DataFrame) (bool, error)
        // Show uses WriteResult to write the data frames to the console output.
@@ -121,9 +159,16 @@ type DataFrame interface {
        // SemanticHash returns the semantic hash of the data frame. The 
semantic hash can be used to
        // understand of the semantic operations are similar.
        SemanticHash(ctx context.Context) (int32, error)
+       // Sort returns a new DataFrame sorted by the specified columns.
+       Sort(ctx context.Context, columns ...column.Convertible) (DataFrame, 
error)
        // Subtract subtracts the other DataFrame from the current DataFrame. 
And only returns
        // distinct rows.
        Subtract(ctx context.Context, other DataFrame) DataFrame
+       // Summary computes the specified statistics for the current DataFrame 
and returns it
+       // as a new DataFrame. Available statistics are: "count", "mean", 
"stddev", "min", "max" and
+       // arbitrary percentiles specified as a percentage (e.g., "75%"). If no 
statistics are given,
+       // this function computes "count", "mean", "stddev", "min", "25%", 
"50%", "75%", "max".
+       Summary(ctx context.Context, statistics ...string) DataFrame
        // Tail returns the last `limit` rows as a list of Row.
        Tail(ctx context.Context, limit int32) ([]Row, error)
        // Take is an alias for Limit
@@ -507,6 +552,18 @@ func (df *dataFrameImpl) CreateTempView(ctx 
context.Context, viewName string, re
        return err
 }
 
+func (df *dataFrameImpl) CreateOrReplaceTempView(ctx context.Context, viewName 
string) error {
+       return df.CreateTempView(ctx, viewName, true, false)
+}
+
+func (df *dataFrameImpl) CreateGlobalTempView(ctx context.Context, viewName 
string) error {
+       return df.CreateTempView(ctx, viewName, false, true)
+}
+
+func (df *dataFrameImpl) CreateOrReplaceGlobalTempView(ctx context.Context, 
viewName string) error {
+       return df.CreateTempView(ctx, viewName, true, true)
+}
+
 func (df *dataFrameImpl) Repartition(ctx context.Context, numPartitions int, 
columns []string) (DataFrame, error) {
        var partitionExpressions []*proto.Expression
        if columns != nil {
@@ -597,6 +654,10 @@ func (df *dataFrameImpl) FilterByString(ctx 
context.Context, condition string) (
 }
 
 func (df *dataFrameImpl) Select(ctx context.Context, columns 
...column.Convertible) (DataFrame, error) {
+       //
+       if len(columns) == 0 {
+               return df, nil
+       }
        exprs := make([]*proto.Expression, 0, len(columns))
        for _, c := range columns {
                expr, err := c.ToProto(ctx)
@@ -624,7 +685,7 @@ func (df *dataFrameImpl) Select(ctx context.Context, 
columns ...column.Convertib
 // can be performed on them. See GroupedData for all the available aggregate 
functions.
 func (df *dataFrameImpl) GroupBy(cols ...column.Convertible) *GroupedData {
        return &GroupedData{
-               df:           *df,
+               df:           df,
                groupingCols: cols,
                groupType:    "groupby",
        }
@@ -807,7 +868,7 @@ func (df *dataFrameImpl) Tail(ctx context.Context, limit 
int32) ([]Row, error) {
        return data.Collect(ctx)
 }
 
-func (df *dataFrameImpl) Limit(ctx context.Context, limit int32) ([]Row, 
error) {
+func (df *dataFrameImpl) Limit(ctx context.Context, limit int32) DataFrame {
        rel := &proto.Relation{
                Common: &proto.RelationCommon{
                        PlanId: newPlanId(),
@@ -819,16 +880,15 @@ func (df *dataFrameImpl) Limit(ctx context.Context, limit 
int32) ([]Row, error)
                        },
                },
        }
-       data := NewDataFrame(df.session, rel)
-       return data.Collect(ctx)
+       return NewDataFrame(df.session, rel)
 }
 
 func (df *dataFrameImpl) Head(ctx context.Context, limit int32) ([]Row, error) 
{
-       return df.Limit(ctx, limit)
+       return df.Limit(ctx, limit).Collect(ctx)
 }
 
 func (df *dataFrameImpl) Take(ctx context.Context, limit int32) ([]Row, error) 
{
-       return df.Limit(ctx, limit)
+       return df.Limit(ctx, limit).Collect(ctx)
 }
 
 func (df *dataFrameImpl) ToArrow(ctx context.Context) (*arrow.Table, error) {
@@ -994,7 +1054,12 @@ func (df *dataFrameImpl) Sort(ctx context.Context, 
columns ...column.Convertible
                if err != nil {
                        return nil, err
                }
-               sortExprs = append(sortExprs, expr.GetSortOrder())
+               so := expr.GetSortOrder()
+               if so == nil {
+                       return nil, sparkerrors.WithType(fmt.Errorf(
+                               "sort expression must not be nil"), 
sparkerrors.InvalidArgumentError)
+               }
+               sortExprs = append(sortExprs, so)
        }
 
        rel := &proto.Relation{
@@ -1106,3 +1171,210 @@ func (df *dataFrameImpl) SemanticHash(ctx 
context.Context) (int32, error) {
        }
        return df.session.client.SemanticHash(ctx, plan)
 }
+
+func (df *dataFrameImpl) Join(ctx context.Context, other DataFrame, onExpr 
column.Convertible, joinType utils.JoinType) (DataFrame, error) {
+       otherDf := other.(*dataFrameImpl)
+       onExpression, err := onExpr.ToProto(ctx)
+       if err != nil {
+               return nil, err
+       }
+       joinTypeProto := utils.ToProtoJoinType(joinType)
+       rel := &proto.Relation{
+               Common: &proto.RelationCommon{
+                       PlanId: newPlanId(),
+               },
+               RelType: &proto.Relation_Join{
+                       Join: &proto.Join{
+                               Left:          df.relation,
+                               Right:         otherDf.relation,
+                               JoinType:      joinTypeProto,
+                               JoinCondition: onExpression,
+                       },
+               },
+       }
+       return NewDataFrame(df.session, rel), nil
+}
+
+func (df *dataFrameImpl) CrossTab(ctx context.Context, col1, col2 string) 
DataFrame {
+       rel := &proto.Relation{
+               Common: &proto.RelationCommon{
+                       PlanId: newPlanId(),
+               },
+
+               RelType: &proto.Relation_Crosstab{
+                       Crosstab: &proto.StatCrosstab{
+                               Input: df.relation,
+                               Col1:  col1,
+                               Col2:  col2,
+                       },
+               },
+       }
+       return NewDataFrame(df.session, rel)
+}
+
+func (df *dataFrameImpl) Cube(ctx context.Context, cols ...column.Convertible) 
*GroupedData {
+       return &GroupedData{
+               df:           df,
+               groupingCols: cols,
+               groupType:    "cube",
+       }
+}
+
+func (df *dataFrameImpl) Rollup(ctx context.Context, cols 
...column.Convertible) *GroupedData {
+       return &GroupedData{
+               df:           df,
+               groupingCols: cols,
+               groupType:    "rollup",
+       }
+}
+
+func (df *dataFrameImpl) Describe(ctx context.Context, cols ...string) 
DataFrame {
+       rel := &proto.Relation{
+               Common: &proto.RelationCommon{
+                       PlanId: newPlanId(),
+               },
+
+               RelType: &proto.Relation_Describe{
+                       Describe: &proto.StatDescribe{
+                               Input: df.relation,
+                               Cols:  cols,
+                       },
+               },
+       }
+       return NewDataFrame(df.session, rel)
+}
+
+func (df *dataFrameImpl) Distinct(ctx context.Context) DataFrame {
+       allColumnsAsKeys := true
+       rel := &proto.Relation{
+               Common: &proto.RelationCommon{
+                       PlanId: newPlanId(),
+               },
+               RelType: &proto.Relation_Deduplicate{
+                       Deduplicate: &proto.Deduplicate{
+                               Input:            df.relation,
+                               AllColumnsAsKeys: &allColumnsAsKeys,
+                       },
+               },
+       }
+       return NewDataFrame(df.session, rel)
+}
+
+func (df *dataFrameImpl) First(ctx context.Context) (Row, error) {
+       rows, err := df.Head(ctx, 1)
+       if err != nil {
+               return nil, err
+       }
+       return rows[0], nil
+}
+
+func (df *dataFrameImpl) FreqItems(ctx context.Context, cols ...string) 
DataFrame {
+       return df.FreqItemsWithSupport(ctx, 0.01, cols...)
+}
+
+func (df *dataFrameImpl) FreqItemsWithSupport(ctx context.Context, support 
float64, cols ...string) DataFrame {
+       rel := &proto.Relation{
+               Common: &proto.RelationCommon{
+                       PlanId: newPlanId(),
+               },
+
+               RelType: &proto.Relation_FreqItems{
+                       FreqItems: &proto.StatFreqItems{
+                               Input:   df.relation,
+                               Cols:    cols,
+                               Support: &support,
+                       },
+               },
+       }
+       return NewDataFrame(df.session, rel)
+}
+
+func (df *dataFrameImpl) IsEmpty(ctx context.Context) (bool, error) {
+       d, err := df.Select(ctx)
+       if err != nil {
+               return false, err
+       }
+       rows, err := d.Take(ctx, int32(1))
+       if err != nil {
+               return false, err
+       }
+       return len(rows) == 0, nil
+}
+
+func (df *dataFrameImpl) Offset(ctx context.Context, offset int32) DataFrame {
+       rel := &proto.Relation{
+               Common: &proto.RelationCommon{
+                       PlanId: newPlanId(),
+               },
+
+               RelType: &proto.Relation_Offset{
+                       Offset: &proto.Offset{
+                               Input:  df.relation,
+                               Offset: offset,
+                       },
+               },
+       }
+       return NewDataFrame(df.session, rel)
+}
+
+func (df *dataFrameImpl) RandomSplit(ctx context.Context, weights []float64) 
([]DataFrame, error) {
+       // Check that we don't have negative weights:
+       total := 0.0
+       for _, w := range weights {
+               if w < 0.0 {
+                       return nil, sparkerrors.WithType(fmt.Errorf("weights 
must not be negative"), sparkerrors.InvalidArgumentError)
+               }
+               total += w
+       }
+       seed := rand.Int64()
+       normalizedWeights := make([]float64, len(weights))
+       for i, w := range weights {
+               normalizedWeights[i] = w / total
+       }
+
+       // Calculate the cumulative sum of the weights:
+       cumulativeWeights := make([]float64, len(weights)+1)
+       cumulativeWeights[0] = 0.0
+       for i := 0; i < len(normalizedWeights); i++ {
+               cumulativeWeights[i+1] = cumulativeWeights[i] + 
normalizedWeights[i]
+       }
+
+       // Iterate over cumulative weights as the boundaries of the interval 
and create the dataframes:
+       dataFrames := make([]DataFrame, len(weights))
+       withReplacement := false
+       for i := 1; i < len(cumulativeWeights); i++ {
+               sampleRelation := &proto.Relation{
+                       Common: &proto.RelationCommon{
+                               PlanId: newPlanId(),
+                       },
+                       RelType: &proto.Relation_Sample{
+                               Sample: &proto.Sample{
+                                       Input:              df.relation,
+                                       LowerBound:         
cumulativeWeights[i-1],
+                                       UpperBound:         
cumulativeWeights[i],
+                                       WithReplacement:    &withReplacement,
+                                       Seed:               &seed,
+                                       DeterministicOrder: true,
+                               },
+                       },
+               }
+               dataFrames[i-1] = NewDataFrame(df.session, sampleRelation)
+       }
+       return dataFrames, nil
+}
+
+func (df *dataFrameImpl) Summary(ctx context.Context, statistics ...string) 
DataFrame {
+       rel := &proto.Relation{
+               Common: &proto.RelationCommon{
+                       PlanId: newPlanId(),
+               },
+
+               RelType: &proto.Relation_Summary{
+                       Summary: &proto.StatSummary{
+                               Input:      df.relation,
+                               Statistics: statistics,
+                       },
+               },
+       }
+       return NewDataFrame(df.session, rel)
+}
diff --git a/spark/sql/group.go b/spark/sql/group.go
index 4734225..3f67e18 100644
--- a/spark/sql/group.go
+++ b/spark/sql/group.go
@@ -26,11 +26,11 @@ import (
 )
 
 type GroupedData struct {
-       df           dataFrameImpl
+       df           *dataFrameImpl
        groupType    string
        groupingCols []column.Convertible
        pivotValues  []any
-       // groupingSets [][]column.Column
+       pivotCol     column.Convertible
 }
 
 // Agg compute aggregates and returns the result as a DataFrame. The aggegrate 
expressions
@@ -71,8 +71,14 @@ func (gd *GroupedData) Agg(ctx context.Context, exprs 
...column.Column) (DataFra
                if len(gd.pivotValues) == 0 {
                        return nil, 
sparkerrors.WithString(sparkerrors.InvalidInputError, "pivotValues should not 
be empty")
                }
+               protoCol, err := gd.pivotCol.ToProto(ctx)
+               if err != nil {
+                       return nil, err
+               }
+
                agg.Pivot = &proto.Aggregate_Pivot{
                        Values: make([]*proto.Expression_Literal, 
len(gd.pivotValues)),
+                       Col:    protoCol,
                }
                for i, v := range gd.pivotValues {
                        exp, err := column.NewLiteral(v).ToProto(ctx)
@@ -172,3 +178,19 @@ func (gd *GroupedData) Count(ctx context.Context) 
(DataFrame, error) {
 func (gd *GroupedData) Mean(ctx context.Context, cols ...string) (DataFrame, 
error) {
        return gd.Avg(ctx, cols...)
 }
+
+func (gd *GroupedData) Pivot(ctx context.Context, pivotCol string, pivotValues 
[]any) (*GroupedData, error) {
+       if gd.groupType != "groupby" {
+               if gd.groupType == "pivot" {
+                       return nil, 
sparkerrors.WithString(sparkerrors.InvalidInputError, "pivot cannot be applied 
on pivot")
+               }
+               return nil, 
sparkerrors.WithString(sparkerrors.InvalidInputError, "pivot can only be 
applied on groupby")
+       }
+       return &GroupedData{
+               df:           gd.df,
+               groupType:    "pivot",
+               groupingCols: gd.groupingCols,
+               pivotValues:  pivotValues,
+               pivotCol:     column.NewColumnReferenceWithPlanId(pivotCol, 
gd.df.PlanId()),
+       }, nil
+}
diff --git a/spark/sql/group_test.go b/spark/sql/group_test.go
index 8cdd908..0fedaf1 100644
--- a/spark/sql/group_test.go
+++ b/spark/sql/group_test.go
@@ -26,7 +26,7 @@ import (
        "github.com/stretchr/testify/assert"
 )
 
-var sampleDataFrame = dataFrameImpl{session: nil, relation: &proto.Relation{
+var sampleDataFrame = &dataFrameImpl{session: nil, relation: &proto.Relation{
        RelType: &proto.Relation_Range{
                Range: &proto.Range{
                        End:  10,
diff --git a/spark/sql/utils/consts.go b/spark/sql/utils/consts.go
index d3f287f..e070aa6 100644
--- a/spark/sql/utils/consts.go
+++ b/spark/sql/utils/consts.go
@@ -96,3 +96,36 @@ func FromProtoStorageLevel(level *proto.StorageLevel) 
StorageLevel {
        }
        return StorageLevelNone
 }
+
+type JoinType int
+
+const (
+       JoinTypeInner      JoinType = iota
+       JoinTypeLeftOuter  JoinType = iota
+       JoinTypeRightOuter JoinType = iota
+       JoinTypeFullOuter  JoinType = iota
+       JoinTypeLeftSemi   JoinType = iota
+       JoinTypeLeftAnti   JoinType = iota
+       JoinTypeCross      JoinType = iota
+)
+
+func ToProtoJoinType(joinType JoinType) proto.Join_JoinType {
+       switch joinType {
+       case JoinTypeInner:
+               return proto.Join_JOIN_TYPE_INNER
+       case JoinTypeLeftOuter:
+               return proto.Join_JOIN_TYPE_LEFT_OUTER
+       case JoinTypeRightOuter:
+               return proto.Join_JOIN_TYPE_RIGHT_OUTER
+       case JoinTypeFullOuter:
+               return proto.Join_JOIN_TYPE_FULL_OUTER
+       case JoinTypeLeftSemi:
+               return proto.Join_JOIN_TYPE_LEFT_SEMI
+       case JoinTypeLeftAnti:
+               return proto.Join_JOIN_TYPE_LEFT_ANTI
+       case JoinTypeCross:
+               return proto.Join_JOIN_TYPE_CROSS
+       default:
+               return proto.Join_JOIN_TYPE_INNER
+       }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to