This is an automated email from the ASF dual-hosted git repository.
mgrund pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark-connect-go.git
The following commit(s) were added to refs/heads/master by this push:
new 78922ba [#58] More DataFrame Functionality
78922ba is described below
commit 78922ba2bbe903e4d740ca15a969d4534d469e31
Author: Martin Grund <[email protected]>
AuthorDate: Fri Sep 27 11:17:48 2024 +0200
[#58] More DataFrame Functionality
### What changes were proposed in this pull request?
This PR adds support for the following DF functions:
* CreateOrReplaceTempView
* CreateGlobalTempView
* CreateOrReplaceGlobalTempView
* CrossTab
* Cube
* Describe
* Distinct
* First
* FreqItems
* FreqItemsWithSupport
* GroupingSets
* IsEmpty
* Join
* Offset
* OrderBy
* RandomSplit
* Rollup
* Sort
* Summary
### Why are the changes needed?
Compatibility
### Does this PR introduce _any_ user-facing change?
new DF functions
### How was this patch tested?
Added e2e tests
Closes #71 from grundprinzip/next_batch_v3.
Authored-by: Martin Grund <[email protected]>
Signed-off-by: Martin Grund <[email protected]>
---
internal/tests/integration/dataframe_test.go | 207 ++++++++++++++++++-
spark/sql/dataframe.go | 290 ++++++++++++++++++++++++++-
spark/sql/group.go | 26 ++-
spark/sql/group_test.go | 2 +-
spark/sql/utils/consts.go | 33 +++
5 files changed, 543 insertions(+), 15 deletions(-)
diff --git a/internal/tests/integration/dataframe_test.go
b/internal/tests/integration/dataframe_test.go
index ec4469b..61ebcdf 100644
--- a/internal/tests/integration/dataframe_test.go
+++ b/internal/tests/integration/dataframe_test.go
@@ -44,8 +44,8 @@ func TestDataFrame_Select(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, 100, len(res))
- row_zero := res[0]
- vals, err := row_zero.Values()
+ rowZero := res[0]
+ vals, err := rowZero.Values()
assert.NoError(t, err)
assert.Equal(t, 2, len(vals))
@@ -509,7 +509,9 @@ func TestDataFrame_LimitVersions(t *testing.T) {
ctx, spark := connect()
df, err := spark.Sql(ctx, "select * from range(10)")
assert.NoError(t, err)
- rows, err := df.Limit(ctx, int32(5))
+ df = df.Limit(ctx, int32(5))
+ assert.NoError(t, err)
+ rows, err := df.Collect(ctx)
assert.NoError(t, err)
assert.Len(t, rows, 5)
@@ -525,3 +527,202 @@ func TestDataFrame_LimitVersions(t *testing.T) {
assert.NoError(t, err)
assert.Len(t, rows, 3)
}
+
+func TestDataFrame_Sort(t *testing.T) {
+ ctx, spark := connect()
+ src, err := spark.Sql(ctx, "select * from range(10)")
+ assert.NoError(t, err)
+ df, err := src.Sort(ctx, functions.Col("id").Desc())
+ assert.NoError(t, err)
+ res, err := df.Head(ctx, 1)
+ assert.NoError(t, err)
+ vals, err := res[0].Values()
+ assert.NoError(t, err)
+ assert.Equal(t, int64(9), vals[0])
+
+ df, err = src.Sort(ctx, functions.Col("id").Asc())
+ assert.NoError(t, err)
+ res, err = df.Head(ctx, 1)
+ assert.NoError(t, err)
+ vals, err = res[0].Values()
+ assert.NoError(t, err)
+ assert.Equal(t, int64(0), vals[0])
+}
+
+func TestDataFrame_Join(t *testing.T) {
+ ctx, spark := connect()
+ df1, err := spark.Sql(ctx, "select * from range(10)")
+ assert.NoError(t, err)
+ df2, err := spark.Sql(ctx, "select * from range(5)")
+ assert.NoError(t, err)
+
+ df, err := df1.Join(ctx, df2, column.OfDF(df1,
"id").Eq(column.OfDF(df2, "id")), utils.JoinTypeInner)
+ assert.NoError(t, err)
+ res, err := df.Collect(ctx)
+ assert.NoError(t, err)
+ assert.Equal(t, 5, len(res))
+}
+
+func TestDataFrame_RandomSplits(t *testing.T) {
+ ctx, spark := connect()
+ df, err := spark.Sql(ctx, "select * from range(1000)")
+ assert.NoError(t, err)
+ dfs, err := df.RandomSplit(ctx, []float64{0.3, 0.7})
+ assert.NoError(t, err)
+ assert.Len(t, dfs, 2)
+ c1, err := dfs[0].Count(ctx)
+ assert.NoError(t, err)
+ c2, err := dfs[1].Count(ctx)
+ assert.NoError(t, err)
+ assert.Less(t, c1, c2)
+}
+
+func TestDataFrame_Describe(t *testing.T) {
+ ctx, spark := connect()
+ df, err := spark.Sql(ctx, "select * from range(10)")
+ assert.NoError(t, err)
+ res, err := df.Describe(ctx, "id").Collect(ctx)
+ assert.NoError(t, err)
+ assert.Len(t, res, 5)
+}
+
+func TestDataFrame_Summary(t *testing.T) {
+ ctx, spark := connect()
+ df, err := spark.Sql(ctx, "select id, 'a' as col, 2 as other from
range(10)")
+ assert.NoError(t, err)
+ res, err := df.Summary(ctx, "count", "stddev").Collect(ctx)
+ assert.NoError(t, err)
+ assert.Len(t, res, 2)
+ v, err := res[0].Values()
+ assert.NoError(t, err)
+ assert.Equal(t, "count", v[0])
+ assert.Len(t, v, 4)
+}
+
+func TestDataFrame_Pivot(t *testing.T) {
+ ctx, spark := connect()
+
+ data := [][]any{
+ {"dotNET", 2012, 10000},
+ {"Java", 2012, 20000},
+ {"dotNET", 2012, 5000},
+ {"dotNET", 2013, 48000},
+ {"Java", 2013, 30000},
+ }
+ schema := types.StructOf(
+ types.NewStructField("course", types.STRING),
+ types.NewStructField("year", types.INTEGER),
+ types.NewStructField("earnings", types.INTEGER))
+
+ df, err := spark.CreateDataFrame(ctx, data, schema)
+ assert.NoError(t, err)
+ gd := df.GroupBy(functions.Col("year"))
+ gd, err = gd.Pivot(ctx, "course", []any{"Java", "dotNET"})
+ assert.NoError(t, err)
+ df, err = gd.Sum(ctx, "earnings")
+ assert.NoError(t, err)
+ res, err := df.Collect(ctx)
+ assert.NoError(t, err)
+ assert.Len(t, res, 2)
+}
+
+func TestDataFrame_Offset(t *testing.T) {
+ ctx, spark := connect()
+ df, err := spark.Sql(ctx, "select * from range(10)")
+ assert.NoError(t, err)
+ df = df.Offset(ctx, int32(5))
+ assert.NoError(t, err)
+ res, err := df.Collect(ctx)
+ assert.NoError(t, err)
+ assert.Len(t, res, 5)
+}
+
+func TestDataFrame_IsEmpty(t *testing.T) {
+ ctx, spark := connect()
+ df, err := spark.Sql(ctx, "select * from range(10)")
+ assert.NoError(t, err)
+ empty, err := df.IsEmpty(ctx)
+ assert.NoError(t, err)
+ assert.False(t, empty)
+
+ df, err = spark.Sql(ctx, "select * from range(0)")
+ assert.NoError(t, err)
+ empty, err = df.IsEmpty(ctx)
+ assert.NoError(t, err)
+ assert.True(t, empty)
+}
+
+func TestDataFrame_First(t *testing.T) {
+ ctx, spark := connect()
+ df, err := spark.Sql(ctx, "select * from range(10)")
+ assert.NoError(t, err)
+ row, err := df.First(ctx)
+ assert.NoError(t, err)
+ vals, err := row.Values()
+ assert.NoError(t, err)
+ assert.Equal(t, int64(0), vals[0])
+}
+
+func TestDataFrame_Distinct(t *testing.T) {
+ ctx, spark := connect()
+ df, err := spark.Sql(ctx, "select * from range(10)")
+ assert.NoError(t, err)
+ df = df.Distinct(ctx)
+ assert.NoError(t, err)
+ res, err := df.Collect(ctx)
+ assert.NoError(t, err)
+ assert.Len(t, res, 10)
+}
+
+func TestDataFrame_CrossTab(t *testing.T) {
+ ctx, spark := connect()
+ data := [][]any{{1, 11}, {1, 11}, {3, 10}, {4, 8}, {4, 8}}
+ schema := types.StructOf(
+ types.NewStructField("c1", types.INTEGER),
+ types.NewStructField("c2", types.INTEGER),
+ )
+
+ df, err := spark.CreateDataFrame(ctx, data, schema)
+ assert.NoError(t, err)
+ df = df.CrossTab(ctx, "c1", "c2")
+ df, err = df.Sort(ctx, column.OfDF(df, "c1_c2").Asc())
+ assert.NoError(t, err)
+ res, err := df.Collect(ctx)
+ assert.NoError(t, err)
+ assert.Len(t, res, 3)
+ v, err := res[0].Values()
+ assert.NoError(t, err)
+ assert.Equal(t, "1", v[0])
+ assert.Equal(t, int64(0), v[1])
+ assert.Equal(t, int64(2), v[2])
+ assert.Equal(t, int64(0), v[3])
+}
+
+func TestDataFrame_SameSemantics(t *testing.T) {
+ ctx, spark := connect()
+ df1, err := spark.Sql(ctx, "select * from range(10)")
+ assert.NoError(t, err)
+ df2, err := spark.Sql(ctx, "select * from range(10)")
+ assert.NoError(t, err)
+ res, _ := df1.SameSemantics(ctx, df2)
+ assert.True(t, res)
+}
+
+func TestDataFrame_SemanticHash(t *testing.T) {
+ ctx, spark := connect()
+ df, err := spark.Sql(ctx, "select * from range(10)")
+ assert.NoError(t, err)
+ hash, err := df.SemanticHash(ctx)
+ assert.NoError(t, err)
+ assert.NotEmpty(t, hash)
+}
+
+// DISABLED: Because we cannot parse LIST types
+//func TestDataFrame_FreqItems(t *testing.T) {
+// ctx, spark := connect()
+// df, err := spark.Sql(ctx, "select * from range(10)")
+// assert.NoError(t, err)
+// res, err := df.FreqItems(ctx, "id").Collect(ctx)
+// assert.NoErrorf(t, err, "%+v", err)
+// assert.Len(t, res, 1)
+//}
diff --git a/spark/sql/dataframe.go b/spark/sql/dataframe.go
index bcd0633..36da21f 100644
--- a/spark/sql/dataframe.go
+++ b/spark/sql/dataframe.go
@@ -19,6 +19,7 @@ package sql
import (
"context"
"fmt"
+ "math/rand/v2"
"github.com/apache/arrow/go/v17/arrow"
"github.com/apache/spark-connect-go/v35/spark/sql/utils"
@@ -75,8 +76,28 @@ type DataFrame interface {
Collect(ctx context.Context) ([]Row, error)
// CreateTempView creates or replaces a temporary view.
CreateTempView(ctx context.Context, viewName string, replace, global
bool) error
+ // CreateOrReplaceTempView creates or replaces a temporary view and
replaces the optional existing view.
+ CreateOrReplaceTempView(ctx context.Context, viewName string) error
+ // CreateGlobalTempView creates a global temporary view.
+ CreateGlobalTempView(ctx context.Context, viewName string) error
+ // CreateOrReplaceGlobalTempView creates or replaces a global temporary
view and replaces the optional existing view.
+ CreateOrReplaceGlobalTempView(ctx context.Context, viewName string)
error
// CrossJoin joins the current DataFrame with another DataFrame using
the cross product
CrossJoin(ctx context.Context, other DataFrame) DataFrame
+ // CrossTab computes a pair-wise frequency table of the given columns.
Also known as a
+ // contingency table.
+ // The first column of each row will be the distinct values of `col1`
and the column names
+ // will be the distinct values of `col2`. The name of the first column
will be `$col1_$col2`.
+ // Pairs that have no occurrences will have zero as their counts.
+ CrossTab(ctx context.Context, col1, col2 string) DataFrame
+ // Cube creates a multi-dimensional cube for the current DataFrame using
+ // the specified columns, so we can run aggregations on them.
+ Cube(ctx context.Context, cols ...column.Convertible) *GroupedData
+ // Describe omputes basic statistics for numeric and string columns.
+ // This includes count, mean, stddev, min, and max.
+ Describe(ctx context.Context, cols ...string) DataFrame
+ // Distinct returns a new DataFrame containing the distinct rows in
this DataFrame.
+ Distinct(ctx context.Context) DataFrame
// Drop returns a new DataFrame that drops the specified list of
columns.
Drop(ctx context.Context, columns ...column.Convertible) (DataFrame,
error)
// DropByName returns a new DataFrame that drops the specified list of
columns by name.
@@ -85,11 +106,16 @@ type DataFrame interface {
DropDuplicates(ctx context.Context, columns ...string) (DataFrame,
error)
// ExceptAll is similar to Substract but does not perform the distinct
operation.
ExceptAll(ctx context.Context, other DataFrame) DataFrame
+ // Explain returns the string explain plan for the current DataFrame
according to the explainMode.
Explain(ctx context.Context, explainMode utils.ExplainMode) (string,
error)
// Filter filters the data frame by a column condition.
Filter(ctx context.Context, condition column.Convertible) (DataFrame,
error)
// FilterByString filters the data frame by a string condition.
FilterByString(ctx context.Context, condition string) (DataFrame, error)
+ // Returns the first row of the DataFrame.
+ First(ctx context.Context) (Row, error)
+ FreqItems(ctx context.Context, cols ...string) DataFrame
+ FreqItemsWithSupport(ctx context.Context, support float64, cols
...string) DataFrame
// GetStorageLevel returns the storage level of the data frame.
GetStorageLevel(ctx context.Context) (*utils.StorageLevel, error)
// GroupBy groups the DataFrame by the spcified columns so that the
aggregation
@@ -101,13 +127,25 @@ type DataFrame interface {
Intersect(ctx context.Context, other DataFrame) DataFrame
// IntersectAll performs the set intersection of two data frames and
returns all rows.
IntersectAll(ctx context.Context, other DataFrame) DataFrame
- // Limit returns the first `limit` rows as a list of Row.
- Limit(ctx context.Context, limit int32) ([]Row, error)
+ // IsEmpty returns true if the DataFrame is empty.
+ IsEmpty(ctx context.Context) (bool, error)
+ // Join joins the current DataFrame with another DataFrame using the
specified column using the joinType specified.
+ Join(ctx context.Context, other DataFrame, on column.Convertible,
joinType utils.JoinType) (DataFrame, error)
+ // Limit applies a limit on the DataFrame
+ Limit(ctx context.Context, limit int32) DataFrame
+ // Offset returns a new DataFrame by skipping the first `offset` rows.
+ Offset(ctx context.Context, offset int32) DataFrame
+ // OrderBy is an alias for Sort
+ OrderBy(ctx context.Context, columns ...column.Convertible) (DataFrame,
error)
Persist(ctx context.Context, storageLevel utils.StorageLevel) error
+ RandomSplit(ctx context.Context, weights []float64) ([]DataFrame, error)
// Repartition re-partitions a data frame.
Repartition(ctx context.Context, numPartitions int, columns []string)
(DataFrame, error)
// RepartitionByRange re-partitions a data frame by range partition.
RepartitionByRange(ctx context.Context, numPartitions int, columns
...column.Convertible) (DataFrame, error)
+ // Rollup creates a multi-dimensional rollup for the current DataFrame
using
+ // the specified columns, so we can run aggregation on them.
+ Rollup(ctx context.Context, cols ...column.Convertible) *GroupedData
// SameSemantics returns true if the other DataFrame has the same
semantics.
SameSemantics(ctx context.Context, other DataFrame) (bool, error)
// Show uses WriteResult to write the data frames to the console output.
@@ -121,9 +159,16 @@ type DataFrame interface {
// SemanticHash returns the semantic hash of the data frame. The
semantic hash can be used to
// understand of the semantic operations are similar.
SemanticHash(ctx context.Context) (int32, error)
+ // Sort returns a new DataFrame sorted by the specified columns.
+ Sort(ctx context.Context, columns ...column.Convertible) (DataFrame,
error)
// Subtract subtracts the other DataFrame from the current DataFrame.
And only returns
// distinct rows.
Subtract(ctx context.Context, other DataFrame) DataFrame
+ // Summary computes the specified statistics for the current DataFrame
and returns it
+ // as a new DataFrame. Available statistics are: "count", "mean",
"stddev", "min", "max" and
+ // arbitrary percentiles specified as a percentage (e.g., "75%"). If no
statistics are given,
+ // this function computes "count", "mean", "stddev", "min", "25%",
"50%", "75%", "max".
+ Summary(ctx context.Context, statistics ...string) DataFrame
// Tail returns the last `limit` rows as a list of Row.
Tail(ctx context.Context, limit int32) ([]Row, error)
// Take is an alias for Limit
@@ -507,6 +552,18 @@ func (df *dataFrameImpl) CreateTempView(ctx
context.Context, viewName string, re
return err
}
+func (df *dataFrameImpl) CreateOrReplaceTempView(ctx context.Context, viewName
string) error {
+ return df.CreateTempView(ctx, viewName, true, false)
+}
+
+func (df *dataFrameImpl) CreateGlobalTempView(ctx context.Context, viewName
string) error {
+ return df.CreateTempView(ctx, viewName, false, true)
+}
+
+func (df *dataFrameImpl) CreateOrReplaceGlobalTempView(ctx context.Context,
viewName string) error {
+ return df.CreateTempView(ctx, viewName, true, true)
+}
+
func (df *dataFrameImpl) Repartition(ctx context.Context, numPartitions int,
columns []string) (DataFrame, error) {
var partitionExpressions []*proto.Expression
if columns != nil {
@@ -597,6 +654,10 @@ func (df *dataFrameImpl) FilterByString(ctx
context.Context, condition string) (
}
func (df *dataFrameImpl) Select(ctx context.Context, columns
...column.Convertible) (DataFrame, error) {
+ //
+ if len(columns) == 0 {
+ return df, nil
+ }
exprs := make([]*proto.Expression, 0, len(columns))
for _, c := range columns {
expr, err := c.ToProto(ctx)
@@ -624,7 +685,7 @@ func (df *dataFrameImpl) Select(ctx context.Context,
columns ...column.Convertib
// can be performed on them. See GroupedData for all the available aggregate
functions.
func (df *dataFrameImpl) GroupBy(cols ...column.Convertible) *GroupedData {
return &GroupedData{
- df: *df,
+ df: df,
groupingCols: cols,
groupType: "groupby",
}
@@ -807,7 +868,7 @@ func (df *dataFrameImpl) Tail(ctx context.Context, limit
int32) ([]Row, error) {
return data.Collect(ctx)
}
-func (df *dataFrameImpl) Limit(ctx context.Context, limit int32) ([]Row,
error) {
+func (df *dataFrameImpl) Limit(ctx context.Context, limit int32) DataFrame {
rel := &proto.Relation{
Common: &proto.RelationCommon{
PlanId: newPlanId(),
@@ -819,16 +880,15 @@ func (df *dataFrameImpl) Limit(ctx context.Context, limit
int32) ([]Row, error)
},
},
}
- data := NewDataFrame(df.session, rel)
- return data.Collect(ctx)
+ return NewDataFrame(df.session, rel)
}
func (df *dataFrameImpl) Head(ctx context.Context, limit int32) ([]Row, error)
{
- return df.Limit(ctx, limit)
+ return df.Limit(ctx, limit).Collect(ctx)
}
func (df *dataFrameImpl) Take(ctx context.Context, limit int32) ([]Row, error)
{
- return df.Limit(ctx, limit)
+ return df.Limit(ctx, limit).Collect(ctx)
}
func (df *dataFrameImpl) ToArrow(ctx context.Context) (*arrow.Table, error) {
@@ -994,7 +1054,12 @@ func (df *dataFrameImpl) Sort(ctx context.Context,
columns ...column.Convertible
if err != nil {
return nil, err
}
- sortExprs = append(sortExprs, expr.GetSortOrder())
+ so := expr.GetSortOrder()
+ if so == nil {
+ return nil, sparkerrors.WithType(fmt.Errorf(
+ "sort expression must not be nil"),
sparkerrors.InvalidArgumentError)
+ }
+ sortExprs = append(sortExprs, so)
}
rel := &proto.Relation{
@@ -1106,3 +1171,210 @@ func (df *dataFrameImpl) SemanticHash(ctx
context.Context) (int32, error) {
}
return df.session.client.SemanticHash(ctx, plan)
}
+
+func (df *dataFrameImpl) Join(ctx context.Context, other DataFrame, onExpr
column.Convertible, joinType utils.JoinType) (DataFrame, error) {
+ otherDf := other.(*dataFrameImpl)
+ onExpression, err := onExpr.ToProto(ctx)
+ if err != nil {
+ return nil, err
+ }
+ joinTypeProto := utils.ToProtoJoinType(joinType)
+ rel := &proto.Relation{
+ Common: &proto.RelationCommon{
+ PlanId: newPlanId(),
+ },
+ RelType: &proto.Relation_Join{
+ Join: &proto.Join{
+ Left: df.relation,
+ Right: otherDf.relation,
+ JoinType: joinTypeProto,
+ JoinCondition: onExpression,
+ },
+ },
+ }
+ return NewDataFrame(df.session, rel), nil
+}
+
+func (df *dataFrameImpl) CrossTab(ctx context.Context, col1, col2 string)
DataFrame {
+ rel := &proto.Relation{
+ Common: &proto.RelationCommon{
+ PlanId: newPlanId(),
+ },
+
+ RelType: &proto.Relation_Crosstab{
+ Crosstab: &proto.StatCrosstab{
+ Input: df.relation,
+ Col1: col1,
+ Col2: col2,
+ },
+ },
+ }
+ return NewDataFrame(df.session, rel)
+}
+
+func (df *dataFrameImpl) Cube(ctx context.Context, cols ...column.Convertible)
*GroupedData {
+ return &GroupedData{
+ df: df,
+ groupingCols: cols,
+ groupType: "cube",
+ }
+}
+
+func (df *dataFrameImpl) Rollup(ctx context.Context, cols
...column.Convertible) *GroupedData {
+ return &GroupedData{
+ df: df,
+ groupingCols: cols,
+ groupType: "rollup",
+ }
+}
+
+func (df *dataFrameImpl) Describe(ctx context.Context, cols ...string)
DataFrame {
+ rel := &proto.Relation{
+ Common: &proto.RelationCommon{
+ PlanId: newPlanId(),
+ },
+
+ RelType: &proto.Relation_Describe{
+ Describe: &proto.StatDescribe{
+ Input: df.relation,
+ Cols: cols,
+ },
+ },
+ }
+ return NewDataFrame(df.session, rel)
+}
+
+func (df *dataFrameImpl) Distinct(ctx context.Context) DataFrame {
+ allColumnsAsKeys := true
+ rel := &proto.Relation{
+ Common: &proto.RelationCommon{
+ PlanId: newPlanId(),
+ },
+ RelType: &proto.Relation_Deduplicate{
+ Deduplicate: &proto.Deduplicate{
+ Input: df.relation,
+ AllColumnsAsKeys: &allColumnsAsKeys,
+ },
+ },
+ }
+ return NewDataFrame(df.session, rel)
+}
+
+func (df *dataFrameImpl) First(ctx context.Context) (Row, error) {
+ rows, err := df.Head(ctx, 1)
+ if err != nil {
+ return nil, err
+ }
+ return rows[0], nil
+}
+
+func (df *dataFrameImpl) FreqItems(ctx context.Context, cols ...string)
DataFrame {
+ return df.FreqItemsWithSupport(ctx, 0.01, cols...)
+}
+
+func (df *dataFrameImpl) FreqItemsWithSupport(ctx context.Context, support
float64, cols ...string) DataFrame {
+ rel := &proto.Relation{
+ Common: &proto.RelationCommon{
+ PlanId: newPlanId(),
+ },
+
+ RelType: &proto.Relation_FreqItems{
+ FreqItems: &proto.StatFreqItems{
+ Input: df.relation,
+ Cols: cols,
+ Support: &support,
+ },
+ },
+ }
+ return NewDataFrame(df.session, rel)
+}
+
+func (df *dataFrameImpl) IsEmpty(ctx context.Context) (bool, error) {
+ d, err := df.Select(ctx)
+ if err != nil {
+ return false, err
+ }
+ rows, err := d.Take(ctx, int32(1))
+ if err != nil {
+ return false, err
+ }
+ return len(rows) == 0, nil
+}
+
+func (df *dataFrameImpl) Offset(ctx context.Context, offset int32) DataFrame {
+ rel := &proto.Relation{
+ Common: &proto.RelationCommon{
+ PlanId: newPlanId(),
+ },
+
+ RelType: &proto.Relation_Offset{
+ Offset: &proto.Offset{
+ Input: df.relation,
+ Offset: offset,
+ },
+ },
+ }
+ return NewDataFrame(df.session, rel)
+}
+
+func (df *dataFrameImpl) RandomSplit(ctx context.Context, weights []float64)
([]DataFrame, error) {
+ // Check that we don't have negative weights:
+ total := 0.0
+ for _, w := range weights {
+ if w < 0.0 {
+ return nil, sparkerrors.WithType(fmt.Errorf("weights
must not be negative"), sparkerrors.InvalidArgumentError)
+ }
+ total += w
+ }
+ seed := rand.Int64()
+ normalizedWeights := make([]float64, len(weights))
+ for i, w := range weights {
+ normalizedWeights[i] = w / total
+ }
+
+ // Calculate the cumulative sum of the weights:
+ cumulativeWeights := make([]float64, len(weights)+1)
+ cumulativeWeights[0] = 0.0
+ for i := 0; i < len(normalizedWeights); i++ {
+ cumulativeWeights[i+1] = cumulativeWeights[i] +
normalizedWeights[i]
+ }
+
+ // Iterate over cumulative weights as the boundaries of the interval
and create the dataframes:
+ dataFrames := make([]DataFrame, len(weights))
+ withReplacement := false
+ for i := 1; i < len(cumulativeWeights); i++ {
+ sampleRelation := &proto.Relation{
+ Common: &proto.RelationCommon{
+ PlanId: newPlanId(),
+ },
+ RelType: &proto.Relation_Sample{
+ Sample: &proto.Sample{
+ Input: df.relation,
+ LowerBound:
cumulativeWeights[i-1],
+ UpperBound:
cumulativeWeights[i],
+ WithReplacement: &withReplacement,
+ Seed: &seed,
+ DeterministicOrder: true,
+ },
+ },
+ }
+ dataFrames[i-1] = NewDataFrame(df.session, sampleRelation)
+ }
+ return dataFrames, nil
+}
+
+func (df *dataFrameImpl) Summary(ctx context.Context, statistics ...string)
DataFrame {
+ rel := &proto.Relation{
+ Common: &proto.RelationCommon{
+ PlanId: newPlanId(),
+ },
+
+ RelType: &proto.Relation_Summary{
+ Summary: &proto.StatSummary{
+ Input: df.relation,
+ Statistics: statistics,
+ },
+ },
+ }
+ return NewDataFrame(df.session, rel)
+}
diff --git a/spark/sql/group.go b/spark/sql/group.go
index 4734225..3f67e18 100644
--- a/spark/sql/group.go
+++ b/spark/sql/group.go
@@ -26,11 +26,11 @@ import (
)
type GroupedData struct {
- df dataFrameImpl
+ df *dataFrameImpl
groupType string
groupingCols []column.Convertible
pivotValues []any
- // groupingSets [][]column.Column
+ pivotCol column.Convertible
}
// Agg compute aggregates and returns the result as a DataFrame. The aggegrate
expressions
@@ -71,8 +71,14 @@ func (gd *GroupedData) Agg(ctx context.Context, exprs
...column.Column) (DataFra
if len(gd.pivotValues) == 0 {
return nil,
sparkerrors.WithString(sparkerrors.InvalidInputError, "pivotValues should not
be empty")
}
+ protoCol, err := gd.pivotCol.ToProto(ctx)
+ if err != nil {
+ return nil, err
+ }
+
agg.Pivot = &proto.Aggregate_Pivot{
Values: make([]*proto.Expression_Literal,
len(gd.pivotValues)),
+ Col: protoCol,
}
for i, v := range gd.pivotValues {
exp, err := column.NewLiteral(v).ToProto(ctx)
@@ -172,3 +178,19 @@ func (gd *GroupedData) Count(ctx context.Context)
(DataFrame, error) {
func (gd *GroupedData) Mean(ctx context.Context, cols ...string) (DataFrame,
error) {
return gd.Avg(ctx, cols...)
}
+
+func (gd *GroupedData) Pivot(ctx context.Context, pivotCol string, pivotValues
[]any) (*GroupedData, error) {
+ if gd.groupType != "groupby" {
+ if gd.groupType == "pivot" {
+ return nil,
sparkerrors.WithString(sparkerrors.InvalidInputError, "pivot cannot be applied
on pivot")
+ }
+ return nil,
sparkerrors.WithString(sparkerrors.InvalidInputError, "pivot can only be
applied on groupby")
+ }
+ return &GroupedData{
+ df: gd.df,
+ groupType: "pivot",
+ groupingCols: gd.groupingCols,
+ pivotValues: pivotValues,
+ pivotCol: column.NewColumnReferenceWithPlanId(pivotCol,
gd.df.PlanId()),
+ }, nil
+}
diff --git a/spark/sql/group_test.go b/spark/sql/group_test.go
index 8cdd908..0fedaf1 100644
--- a/spark/sql/group_test.go
+++ b/spark/sql/group_test.go
@@ -26,7 +26,7 @@ import (
"github.com/stretchr/testify/assert"
)
-var sampleDataFrame = dataFrameImpl{session: nil, relation: &proto.Relation{
+var sampleDataFrame = &dataFrameImpl{session: nil, relation: &proto.Relation{
RelType: &proto.Relation_Range{
Range: &proto.Range{
End: 10,
diff --git a/spark/sql/utils/consts.go b/spark/sql/utils/consts.go
index d3f287f..e070aa6 100644
--- a/spark/sql/utils/consts.go
+++ b/spark/sql/utils/consts.go
@@ -96,3 +96,36 @@ func FromProtoStorageLevel(level *proto.StorageLevel)
StorageLevel {
}
return StorageLevelNone
}
+
+type JoinType int
+
+const (
+ JoinTypeInner JoinType = iota
+ JoinTypeLeftOuter JoinType = iota
+ JoinTypeRightOuter JoinType = iota
+ JoinTypeFullOuter JoinType = iota
+ JoinTypeLeftSemi JoinType = iota
+ JoinTypeLeftAnti JoinType = iota
+ JoinTypeCross JoinType = iota
+)
+
+func ToProtoJoinType(joinType JoinType) proto.Join_JoinType {
+ switch joinType {
+ case JoinTypeInner:
+ return proto.Join_JOIN_TYPE_INNER
+ case JoinTypeLeftOuter:
+ return proto.Join_JOIN_TYPE_LEFT_OUTER
+ case JoinTypeRightOuter:
+ return proto.Join_JOIN_TYPE_RIGHT_OUTER
+ case JoinTypeFullOuter:
+ return proto.Join_JOIN_TYPE_FULL_OUTER
+ case JoinTypeLeftSemi:
+ return proto.Join_JOIN_TYPE_LEFT_SEMI
+ case JoinTypeLeftAnti:
+ return proto.Join_JOIN_TYPE_LEFT_ANTI
+ case JoinTypeCross:
+ return proto.Join_JOIN_TYPE_CROSS
+ default:
+ return proto.Join_JOIN_TYPE_INNER
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]