This is an automated email from the ASF dual-hosted git repository.

mgrund pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark-connect-go.git


The following commit(s) were added to refs/heads/master by this push:
     new 2774c11  [#58] More DataFrame Features
2774c11 is described below

commit 2774c1128fe12bbb18af02260ca90275c174bd39
Author: Martin Grund <[email protected]>
AuthorDate: Thu Sep 26 09:05:58 2024 +0200

    [#58] More DataFrame Features
    
    ### What changes were proposed in this pull request?
    This patch adds support for the following DataFrame operations:
    
    * `ExceptALl`
    * `Head`
    * `Intersect`
    * `IntersectAll`
    * `Limit`
    * `Subtract`
    * `Tail`
    * `Take`
    * `ToArrow`
    * `Union`
    * `UnionAll`
    * `UnionByName`
    * `UnionByNameWithMissingColumns`
    * `Explain`
    * `StorageLevel`
    * `Persist`
    * `Unpersist`
    * `Cache`
    
    ### Why are the changes needed?
    Compatibility
    
    ### Does this PR introduce _any_ user-facing change?
    Adding the above new features.
    
    ### How was this patch tested?
    Added E2E tests.
    
    Closes #70 from grundprinzip/next_batch_v2.
    
    Authored-by: Martin Grund <[email protected]>
    Signed-off-by: Martin Grund <[email protected]>
---
 internal/tests/integration/dataframe_test.go | 120 ++++++++-
 spark/client/base/base.go                    |  10 +
 spark/client/client.go                       | 202 +++++++++++++++
 spark/mocks/mock_executor.go                 |  37 +++
 spark/sparkerrors/errors.go                  |   1 +
 spark/sql/dataframe.go                       | 365 +++++++++++++++++++++++++++
 spark/sql/utils/consts.go                    |  98 +++++++
 spark/sql/utils/consts_test.go               |  41 +++
 8 files changed, 872 insertions(+), 2 deletions(-)

diff --git a/internal/tests/integration/dataframe_test.go 
b/internal/tests/integration/dataframe_test.go
index a61d9ee..ec4469b 100644
--- a/internal/tests/integration/dataframe_test.go
+++ b/internal/tests/integration/dataframe_test.go
@@ -19,6 +19,8 @@ import (
        "context"
        "testing"
 
+       "github.com/apache/spark-connect-go/v35/spark/sql/utils"
+
        "github.com/apache/spark-connect-go/v35/spark/sql/types"
 
        "github.com/apache/spark-connect-go/v35/spark/sql/column"
@@ -248,8 +250,8 @@ func TestDataFrame_WithColumns(t *testing.T) {
                vals, err := row.Values()
                assert.NoError(t, err)
                assert.Equal(t, 3, len(vals))
-               assert.Equal(t, int64(1), vals[1])
-               assert.Equal(t, int64(2), vals[2])
+               assert.Equal(t, int64(1), vals[1], "%v", vals)
+               assert.Equal(t, int64(2), vals[2], "%v", vals)
        }
 }
 
@@ -409,3 +411,117 @@ func TestDataFrame_DropDuplicates(t *testing.T) {
        assert.NoError(t, err)
        assert.Equal(t, int64(1), res)
 }
+
+func TestDataFrame_Explain(t *testing.T) {
+       ctx, spark := connect()
+       df, err := spark.Sql(ctx, "select * from range(10)")
+       assert.NoError(t, err)
+       res, err := df.Explain(ctx, utils.ExplainModeSimple)
+       assert.NoError(t, err)
+       assert.Contains(t, res, "Physical Plan")
+
+       res, err = df.Explain(ctx, utils.ExplainModeExtended)
+       assert.NoError(t, err)
+       assert.Contains(t, res, "Physical Plan")
+
+       res, err = df.Explain(ctx, utils.ExplainModeCodegen)
+       assert.NoError(t, err)
+       assert.Contains(t, res, "WholeStageCodegen")
+
+       res, err = df.Explain(ctx, utils.ExplainModeCost)
+       assert.NoError(t, err)
+       assert.Contains(t, res, "Physical Plan")
+
+       res, err = df.Explain(ctx, utils.ExplainModeFormatted)
+       assert.NoError(t, err)
+       assert.Contains(t, res, "Physical Plan")
+}
+
+func TestDataFrame_CachingAndPersistence(t *testing.T) {
+       ctx, spark := connect()
+       levels := []utils.StorageLevel{
+               utils.StorageLevelDiskOnly,
+               utils.StorageLevelDiskOnly2,
+               utils.StorageLevelDiskOnly3,
+               utils.StorageLevelMemoryAndDisk,
+               utils.StorageLevelMemoryAndDisk2,
+               utils.StorageLevelMemoryOnly,
+               utils.StorageLevelMemoryOnly2,
+               utils.StorageLevelMemoyAndDiskDeser,
+               utils.StorageLevelOffHeap,
+       }
+
+       for _, lvl := range levels {
+               df, err := spark.Sql(ctx, "select * from range(10)")
+               assert.NoError(t, err)
+               err = df.Persist(ctx, lvl)
+               assert.NoError(t, err)
+               l, err := df.GetStorageLevel(ctx)
+               assert.NoError(t, err)
+
+               assert.Contains(t, []utils.StorageLevel{lvl, 
utils.StorageLevelMemoryOnly}, *l)
+
+               err = df.Unpersist(ctx)
+               assert.NoError(t, err)
+       }
+       df, err := spark.Sql(ctx, "select * from range(10)")
+       assert.NoError(t, err)
+       err = df.Cache(ctx)
+       assert.NoError(t, err)
+       l, err := df.GetStorageLevel(ctx)
+       assert.NoError(t, err)
+       assert.Equal(t, utils.StorageLevelMemoryOnly, *l, "%v != %v", 
utils.StorageLevelMemoryOnly, *l)
+}
+
+func TestDataFrame_SetOps(t *testing.T) {
+       ctx, spark := connect()
+       df1, err := spark.Sql(ctx, "select * from range(10)")
+       assert.NoError(t, err)
+       df2, err := spark.Sql(ctx, "select * from range(5)")
+       assert.NoError(t, err)
+
+       df := df1.Union(ctx, df2)
+       res, err := df.Collect(ctx)
+       assert.NoError(t, err)
+       assert.Equal(t, 15, len(res))
+
+       df = df1.Intersect(ctx, df2)
+       res, err = df.Collect(ctx)
+       assert.NoError(t, err)
+       assert.Equal(t, 5, len(res))
+
+       df = df1.ExceptAll(ctx, df2)
+       res, err = df.Collect(ctx)
+       assert.NoError(t, err)
+       assert.Equal(t, 5, len(res))
+}
+
+func TestDataFrame_ToArrow(t *testing.T) {
+       ctx, spark := connect()
+       df, err := spark.Sql(ctx, "select * from range(10)")
+       assert.NoError(t, err)
+       tbl, err := df.ToArrow(ctx)
+       assert.NoError(t, err)
+       assert.NotNil(t, tbl)
+}
+
+func TestDataFrame_LimitVersions(t *testing.T) {
+       ctx, spark := connect()
+       df, err := spark.Sql(ctx, "select * from range(10)")
+       assert.NoError(t, err)
+       rows, err := df.Limit(ctx, int32(5))
+       assert.NoError(t, err)
+       assert.Len(t, rows, 5)
+
+       rows, err = df.Tail(ctx, int32(3))
+       assert.NoError(t, err)
+       assert.Len(t, rows, 3)
+
+       rows, err = df.Head(ctx, int32(3))
+       assert.NoError(t, err)
+       assert.Len(t, rows, 3)
+
+       rows, err = df.Take(ctx, int32(3))
+       assert.NoError(t, err)
+       assert.Len(t, rows, 3)
+}
diff --git a/spark/client/base/base.go b/spark/client/base/base.go
index 2d4080a..25dce70 100644
--- a/spark/client/base/base.go
+++ b/spark/client/base/base.go
@@ -18,6 +18,8 @@ package base
 import (
        "context"
 
+       "github.com/apache/spark-connect-go/v35/spark/sql/utils"
+
        "github.com/apache/arrow/go/v17/arrow"
        "github.com/apache/spark-connect-go/v35/internal/generated"
        "github.com/apache/spark-connect-go/v35/spark/sql/types"
@@ -33,6 +35,14 @@ type SparkConnectClient interface {
        ExecutePlan(ctx context.Context, plan *generated.Plan) 
(ExecuteResponseStream, error)
        ExecuteCommand(ctx context.Context, plan *generated.Plan) (arrow.Table, 
*types.StructType, map[string]any, error)
        AnalyzePlan(ctx context.Context, plan *generated.Plan) 
(*generated.AnalyzePlanResponse, error)
+       Explain(ctx context.Context, plan *generated.Plan, explainMode 
utils.ExplainMode) (*generated.AnalyzePlanResponse, error)
+       Persist(ctx context.Context, plan *generated.Plan, storageLevel 
utils.StorageLevel) error
+       Unpersist(ctx context.Context, plan *generated.Plan) error
+       GetStorageLevel(ctx context.Context, plan *generated.Plan) 
(*utils.StorageLevel, error)
+       SparkVersion(ctx context.Context) (string, error)
+       DDLParse(ctx context.Context, sql string) (*types.StructType, error)
+       SameSemantics(ctx context.Context, plan1 *generated.Plan, plan2 
*generated.Plan) (bool, error)
+       SemanticHash(ctx context.Context, plan *generated.Plan) (int32, error)
 }
 
 type ExecuteResponseStream interface {
diff --git a/spark/client/client.go b/spark/client/client.go
index 75f3c80..d1834c8 100644
--- a/spark/client/client.go
+++ b/spark/client/client.go
@@ -21,6 +21,8 @@ import (
        "fmt"
        "io"
 
+       "github.com/apache/spark-connect-go/v35/spark/sql/utils"
+
        "google.golang.org/grpc"
        "google.golang.org/grpc/metadata"
 
@@ -129,6 +131,206 @@ func (s *sparkConnectClientImpl) AnalyzePlan(ctx 
context.Context, plan *proto.Pl
        return response, nil
 }
 
+func (s *sparkConnectClientImpl) Explain(ctx context.Context, plan *proto.Plan,
+       explainMode utils.ExplainMode,
+) (*proto.AnalyzePlanResponse, error) {
+       var mode proto.AnalyzePlanRequest_Explain_ExplainMode
+       if explainMode == utils.ExplainModeExtended {
+               mode = proto.AnalyzePlanRequest_Explain_EXPLAIN_MODE_EXTENDED
+       } else if explainMode == utils.ExplainModeSimple {
+               mode = proto.AnalyzePlanRequest_Explain_EXPLAIN_MODE_SIMPLE
+       } else if explainMode == utils.ExplainModeCost {
+               mode = proto.AnalyzePlanRequest_Explain_EXPLAIN_MODE_COST
+       } else if explainMode == utils.ExplainModeFormatted {
+               mode = proto.AnalyzePlanRequest_Explain_EXPLAIN_MODE_FORMATTED
+       } else if explainMode == utils.ExplainModeCodegen {
+               mode = proto.AnalyzePlanRequest_Explain_EXPLAIN_MODE_CODEGEN
+       } else {
+               return nil, sparkerrors.WithType(fmt.Errorf("unsupported 
explain mode %v",
+                       explainMode), sparkerrors.InvalidArgumentError)
+       }
+
+       request := proto.AnalyzePlanRequest{
+               SessionId: s.sessionId,
+               Analyze: &proto.AnalyzePlanRequest_Explain_{
+                       Explain: &proto.AnalyzePlanRequest_Explain{
+                               Plan:        plan,
+                               ExplainMode: mode,
+                       },
+               },
+               UserContext: &proto.UserContext{
+                       UserId: "na",
+               },
+       }
+       // Append the other items to the request.
+       ctx = metadata.NewOutgoingContext(ctx, s.metadata)
+
+       response, err := s.client.AnalyzePlan(ctx, &request)
+       if se := sparkerrors.FromRPCError(err); se != nil {
+               return nil, sparkerrors.WithType(se, sparkerrors.ExecutionError)
+       }
+       return response, nil
+}
+
+func (s *sparkConnectClientImpl) Persist(ctx context.Context, plan 
*proto.Plan, storageLevel utils.StorageLevel) error {
+       protoLevel := utils.ToProtoStorageLevel(storageLevel)
+
+       request := proto.AnalyzePlanRequest{
+               SessionId: s.sessionId,
+               Analyze: &proto.AnalyzePlanRequest_Persist_{
+                       Persist: &proto.AnalyzePlanRequest_Persist{
+                               Relation:     plan.GetRoot(),
+                               StorageLevel: protoLevel,
+                       },
+               },
+               UserContext: &proto.UserContext{
+                       UserId: "na",
+               },
+       }
+       // Append the other items to the request.
+       ctx = metadata.NewOutgoingContext(ctx, s.metadata)
+
+       _, err := s.client.AnalyzePlan(ctx, &request)
+       if se := sparkerrors.FromRPCError(err); se != nil {
+               return sparkerrors.WithType(se, sparkerrors.ExecutionError)
+       }
+       return nil
+}
+
+func (s *sparkConnectClientImpl) Unpersist(ctx context.Context, plan 
*proto.Plan) error {
+       request := proto.AnalyzePlanRequest{
+               SessionId: s.sessionId,
+               Analyze: &proto.AnalyzePlanRequest_Unpersist_{
+                       Unpersist: &proto.AnalyzePlanRequest_Unpersist{
+                               Relation: plan.GetRoot(),
+                       },
+               },
+               UserContext: &proto.UserContext{
+                       UserId: "na",
+               },
+       }
+       // Append the other items to the request.
+       ctx = metadata.NewOutgoingContext(ctx, s.metadata)
+
+       _, err := s.client.AnalyzePlan(ctx, &request)
+       if se := sparkerrors.FromRPCError(err); se != nil {
+               return sparkerrors.WithType(se, sparkerrors.ExecutionError)
+       }
+       return nil
+}
+
+func (s *sparkConnectClientImpl) GetStorageLevel(ctx context.Context, plan 
*proto.Plan) (*utils.StorageLevel, error) {
+       request := proto.AnalyzePlanRequest{
+               SessionId: s.sessionId,
+               Analyze: &proto.AnalyzePlanRequest_GetStorageLevel_{
+                       GetStorageLevel: 
&proto.AnalyzePlanRequest_GetStorageLevel{
+                               Relation: plan.GetRoot(),
+                       },
+               },
+               UserContext: &proto.UserContext{
+                       UserId: "na",
+               },
+       }
+       // Append the other items to the request.
+       ctx = metadata.NewOutgoingContext(ctx, s.metadata)
+
+       response, err := s.client.AnalyzePlan(ctx, &request)
+       if se := sparkerrors.FromRPCError(err); se != nil {
+               return nil, sparkerrors.WithType(se, sparkerrors.ExecutionError)
+       }
+
+       level := response.GetGetStorageLevel().StorageLevel
+       res := utils.FromProtoStorageLevel(level)
+       return &res, nil
+}
+
+func (s *sparkConnectClientImpl) SparkVersion(ctx context.Context) (string, 
error) {
+       request := proto.AnalyzePlanRequest{
+               SessionId: s.sessionId,
+               Analyze: &proto.AnalyzePlanRequest_SparkVersion_{
+                       SparkVersion: &proto.AnalyzePlanRequest_SparkVersion{},
+               },
+               UserContext: &proto.UserContext{
+                       UserId: "na",
+               },
+       }
+       // Append the other items to the request.
+       ctx = metadata.NewOutgoingContext(ctx, s.metadata)
+
+       response, err := s.client.AnalyzePlan(ctx, &request)
+       if se := sparkerrors.FromRPCError(err); se != nil {
+               return "", sparkerrors.WithType(se, sparkerrors.ExecutionError)
+       }
+       return response.GetSparkVersion().Version, nil
+}
+
+func (s *sparkConnectClientImpl) DDLParse(ctx context.Context, sql string) 
(*types.StructType, error) {
+       request := proto.AnalyzePlanRequest{
+               SessionId: s.sessionId,
+               Analyze: &proto.AnalyzePlanRequest_DdlParse{
+                       DdlParse: &proto.AnalyzePlanRequest_DDLParse{
+                               DdlString: sql,
+                       },
+               },
+               UserContext: &proto.UserContext{
+                       UserId: "na",
+               },
+       }
+       // Append the other items to the request.
+       ctx = metadata.NewOutgoingContext(ctx, s.metadata)
+
+       response, err := s.client.AnalyzePlan(ctx, &request)
+       if se := sparkerrors.FromRPCError(err); se != nil {
+               return nil, sparkerrors.WithType(se, sparkerrors.ExecutionError)
+       }
+       return 
types.ConvertProtoDataTypeToStructType(response.GetDdlParse().Parsed)
+}
+
+func (s *sparkConnectClientImpl) SameSemantics(ctx context.Context, plan1 
*proto.Plan, plan2 *proto.Plan) (bool, error) {
+       request := proto.AnalyzePlanRequest{
+               SessionId: s.sessionId,
+               Analyze: &proto.AnalyzePlanRequest_SameSemantics_{
+                       SameSemantics: &proto.AnalyzePlanRequest_SameSemantics{
+                               TargetPlan: plan1,
+                               OtherPlan:  plan2,
+                       },
+               },
+               UserContext: &proto.UserContext{
+                       UserId: "na",
+               },
+       }
+       // Append the other items to the request.
+       ctx = metadata.NewOutgoingContext(ctx, s.metadata)
+
+       response, err := s.client.AnalyzePlan(ctx, &request)
+       if se := sparkerrors.FromRPCError(err); se != nil {
+               return false, sparkerrors.WithType(se, 
sparkerrors.ExecutionError)
+       }
+       return response.GetSameSemantics().GetResult(), nil
+}
+
+func (s *sparkConnectClientImpl) SemanticHash(ctx context.Context, plan 
*proto.Plan) (int32, error) {
+       request := proto.AnalyzePlanRequest{
+               SessionId: s.sessionId,
+               Analyze: &proto.AnalyzePlanRequest_SemanticHash_{
+                       SemanticHash: &proto.AnalyzePlanRequest_SemanticHash{
+                               Plan: plan,
+                       },
+               },
+               UserContext: &proto.UserContext{
+                       UserId: "na",
+               },
+       }
+       // Append the other items to the request.
+       ctx = metadata.NewOutgoingContext(ctx, s.metadata)
+
+       response, err := s.client.AnalyzePlan(ctx, &request)
+       if se := sparkerrors.FromRPCError(err); se != nil {
+               return 0, sparkerrors.WithType(se, sparkerrors.ExecutionError)
+       }
+       return response.GetSemanticHash().GetResult(), nil
+}
+
 func NewSparkExecutor(conn *grpc.ClientConn, md metadata.MD, sessionId string, 
opts options.SparkClientOptions) base.SparkConnectClient {
        var client base.SparkConnectRPCClient
        if opts.ReattachExecution {
diff --git a/spark/mocks/mock_executor.go b/spark/mocks/mock_executor.go
index b70c4fa..c42e3e6 100644
--- a/spark/mocks/mock_executor.go
+++ b/spark/mocks/mock_executor.go
@@ -17,6 +17,9 @@ package mocks
 
 import (
        "context"
+       "errors"
+
+       "github.com/apache/spark-connect-go/v35/spark/sql/utils"
 
        "github.com/apache/spark-connect-go/v35/spark/client/base"
 
@@ -42,9 +45,43 @@ func (t *TestExecutor) AnalyzePlan(ctx context.Context, plan 
*generated.Plan) (*
        return t.response, nil
 }
 
+func (t *TestExecutor) Explain(ctx context.Context, plan *generated.Plan,
+       explainMode utils.ExplainMode,
+) (*generated.AnalyzePlanResponse, error) {
+       return nil, errors.New("not implemented")
+}
+
 func (t *TestExecutor) ExecuteCommand(ctx context.Context, plan 
*generated.Plan) (arrow.Table, *types.StructType, map[string]interface{}, 
error) {
        if t.Err != nil {
                return nil, nil, nil, t.Err
        }
        return nil, nil, nil, nil
 }
+
+func (t *TestExecutor) Persist(ctx context.Context, plan *generated.Plan, 
storageLevel utils.StorageLevel) error {
+       return errors.New("not implemented")
+}
+
+func (t *TestExecutor) Unpersist(ctx context.Context, plan *generated.Plan) 
error {
+       return errors.New("not implemented")
+}
+
+func (t *TestExecutor) GetStorageLevel(ctx context.Context, plan 
*generated.Plan) (*utils.StorageLevel, error) {
+       return nil, errors.New("not implemented")
+}
+
+func (t *TestExecutor) SparkVersion(ctx context.Context) (string, error) {
+       return "", errors.New("not implemented")
+}
+
+func (t *TestExecutor) DDLParse(ctx context.Context, sql string) 
(*types.StructType, error) {
+       return nil, errors.New("not implemented")
+}
+
+func (t *TestExecutor) SameSemantics(ctx context.Context, plan1 
*generated.Plan, plan2 *generated.Plan) (bool, error) {
+       return false, errors.New("not implemented")
+}
+
+func (t *TestExecutor) SemanticHash(ctx context.Context, plan *generated.Plan) 
(int32, error) {
+       return 0, errors.New("not implemented")
+}
diff --git a/spark/sparkerrors/errors.go b/spark/sparkerrors/errors.go
index 21616e1..4603c4c 100644
--- a/spark/sparkerrors/errors.go
+++ b/spark/sparkerrors/errors.go
@@ -66,6 +66,7 @@ var (
        TestSetupError                = errorType(errors.New("test setup 
error"))
        WriteError                    = errorType(errors.New("write error"))
        NotImplementedError           = errorType(errors.New("not implemented"))
+       InvalidArgumentError          = errorType(errors.New("invalid 
argument"))
 )
 
 // Format formats the error, supporting both short forms (v, s, q) and verbose 
form (+v)
diff --git a/spark/sql/dataframe.go b/spark/sql/dataframe.go
index 4f108c0..bcd0633 100644
--- a/spark/sql/dataframe.go
+++ b/spark/sql/dataframe.go
@@ -20,6 +20,9 @@ import (
        "context"
        "fmt"
 
+       "github.com/apache/arrow/go/v17/arrow"
+       "github.com/apache/spark-connect-go/v35/spark/sql/utils"
+
        "github.com/apache/spark-connect-go/v35/spark/sql/column"
 
        "github.com/apache/spark-connect-go/v35/spark/sql/types"
@@ -40,6 +43,8 @@ type DataFrame interface {
        PlanId() int64
        // Alias creates a new DataFrame with the specified subquery alias
        Alias(ctx context.Context, alias string) DataFrame
+       // Cache persists the DataFrame with the default storage level.
+       Cache(ctx context.Context) error
        // Coalesce returns a new DataFrame that has exactly numPartitions 
partitions.DataFrame
        //
        // Similar to coalesce defined on an :class:`RDD`, this operation 
results in a
@@ -72,20 +77,39 @@ type DataFrame interface {
        CreateTempView(ctx context.Context, viewName string, replace, global 
bool) error
        // CrossJoin joins the current DataFrame with another DataFrame using 
the cross product
        CrossJoin(ctx context.Context, other DataFrame) DataFrame
+       // Drop returns a new DataFrame that drops the specified list of 
columns.
        Drop(ctx context.Context, columns ...column.Convertible) (DataFrame, 
error)
+       // DropByName returns a new DataFrame that drops the specified list of 
columns by name.
        DropByName(ctx context.Context, columns ...string) (DataFrame, error)
+       // DropDuplicates returns a new DataFrame that contains only the unique 
rows from this DataFrame.
        DropDuplicates(ctx context.Context, columns ...string) (DataFrame, 
error)
+       // ExceptAll is similar to Substract but does not perform the distinct 
operation.
+       ExceptAll(ctx context.Context, other DataFrame) DataFrame
+       Explain(ctx context.Context, explainMode utils.ExplainMode) (string, 
error)
        // Filter filters the data frame by a column condition.
        Filter(ctx context.Context, condition column.Convertible) (DataFrame, 
error)
        // FilterByString filters the data frame by a string condition.
        FilterByString(ctx context.Context, condition string) (DataFrame, error)
+       // GetStorageLevel returns the storage level of the data frame.
+       GetStorageLevel(ctx context.Context) (*utils.StorageLevel, error)
        // GroupBy groups the DataFrame by the spcified columns so that the 
aggregation
        // can be performed on them. See GroupedData for all the available 
aggregate functions.
        GroupBy(cols ...column.Convertible) *GroupedData
+       // Head is an alias for Limit
+       Head(ctx context.Context, limit int32) ([]Row, error)
+       // Intersect performs the set intersection of two data frames and only 
returns distinct rows.
+       Intersect(ctx context.Context, other DataFrame) DataFrame
+       // IntersectAll performs the set intersection of two data frames and 
returns all rows.
+       IntersectAll(ctx context.Context, other DataFrame) DataFrame
+       // Limit returns the first `limit` rows as a list of Row.
+       Limit(ctx context.Context, limit int32) ([]Row, error)
+       Persist(ctx context.Context, storageLevel utils.StorageLevel) error
        // Repartition re-partitions a data frame.
        Repartition(ctx context.Context, numPartitions int, columns []string) 
(DataFrame, error)
        // RepartitionByRange re-partitions a data frame by range partition.
        RepartitionByRange(ctx context.Context, numPartitions int, columns 
...column.Convertible) (DataFrame, error)
+       // SameSemantics returns true if the other DataFrame has the same 
semantics.
+       SameSemantics(ctx context.Context, other DataFrame) (bool, error)
        // Show uses WriteResult to write the data frames to the console output.
        Show(ctx context.Context, numRows int, truncate bool) error
        // Schema returns the schema for the current data frame.
@@ -94,6 +118,31 @@ type DataFrame interface {
        Select(ctx context.Context, columns ...column.Convertible) (DataFrame, 
error)
        // SelectExpr projects a list of columns from the DataFrame by string 
expressions
        SelectExpr(ctx context.Context, exprs ...string) (DataFrame, error)
+       // SemanticHash returns the semantic hash of the data frame. The 
semantic hash can be used to
+       // understand of the semantic operations are similar.
+       SemanticHash(ctx context.Context) (int32, error)
+       // Subtract subtracts the other DataFrame from the current DataFrame. 
And only returns
+       // distinct rows.
+       Subtract(ctx context.Context, other DataFrame) DataFrame
+       // Tail returns the last `limit` rows as a list of Row.
+       Tail(ctx context.Context, limit int32) ([]Row, error)
+       // Take is an alias for Limit
+       Take(ctx context.Context, limit int32) ([]Row, error)
+       // ToArrow returns the Arrow representation of the DataFrame.
+       ToArrow(ctx context.Context) (*arrow.Table, error)
+       // Union is an alias for UnionAll
+       Union(ctx context.Context, other DataFrame) DataFrame
+       // UnionAll returns a new DataFrame containing union of rows in this 
and another DataFrame.
+       UnionAll(ctx context.Context, other DataFrame) DataFrame
+       // UnionByName performs a SQL union operation on two dataframes but 
reorders the schema
+       // according to the matching columns. If columns are missing, it will 
throw an eror.
+       UnionByName(ctx context.Context, other DataFrame) DataFrame
+       // UnionByNameWithMissingColumns performs a SQL union operation on two 
dataframes but reorders the schema
+       // according to the matching columns. Missing columns are supported.
+       UnionByNameWithMissingColumns(ctx context.Context, other DataFrame) 
DataFrame
+       // Unpersist resets the storage level for this data frame, and if 
necessary removes it
+       // from server-side caches.
+       Unpersist(ctx context.Context) error
        // WithColumn returns a new DataFrame by adding a column or replacing 
the
        // existing column that has the same name. The column expression must 
be an
        // expression over this DataFrame; attempting to add a column from some 
other
@@ -741,3 +790,319 @@ func (df *dataFrameImpl) DropDuplicates(ctx 
context.Context, columns ...string)
        }
        return NewDataFrame(df.session, rel), nil
 }
+
+func (df *dataFrameImpl) Tail(ctx context.Context, limit int32) ([]Row, error) 
{
+       rel := &proto.Relation{
+               Common: &proto.RelationCommon{
+                       PlanId: newPlanId(),
+               },
+               RelType: &proto.Relation_Tail{
+                       Tail: &proto.Tail{
+                               Input: df.relation,
+                               Limit: limit,
+                       },
+               },
+       }
+       data := NewDataFrame(df.session, rel)
+       return data.Collect(ctx)
+}
+
+func (df *dataFrameImpl) Limit(ctx context.Context, limit int32) ([]Row, 
error) {
+       rel := &proto.Relation{
+               Common: &proto.RelationCommon{
+                       PlanId: newPlanId(),
+               },
+               RelType: &proto.Relation_Limit{
+                       Limit: &proto.Limit{
+                               Input: df.relation,
+                               Limit: limit,
+                       },
+               },
+       }
+       data := NewDataFrame(df.session, rel)
+       return data.Collect(ctx)
+}
+
+func (df *dataFrameImpl) Head(ctx context.Context, limit int32) ([]Row, error) 
{
+       return df.Limit(ctx, limit)
+}
+
+func (df *dataFrameImpl) Take(ctx context.Context, limit int32) ([]Row, error) 
{
+       return df.Limit(ctx, limit)
+}
+
+func (df *dataFrameImpl) ToArrow(ctx context.Context) (*arrow.Table, error) {
+       responseClient, err := df.session.client.ExecutePlan(ctx, 
df.createPlan())
+       if err != nil {
+               return nil, sparkerrors.WithType(fmt.Errorf("failed to execute 
plan: %w", err), sparkerrors.ExecutionError)
+       }
+
+       _, table, err := responseClient.ToTable()
+       if err != nil {
+               return nil, err
+       }
+
+       return &table, nil
+}
+
+func (df *dataFrameImpl) UnionAll(ctx context.Context, other DataFrame) 
DataFrame {
+       otherDf := other.(*dataFrameImpl)
+       isAll := true
+       rel := &proto.Relation{
+               Common: &proto.RelationCommon{
+                       PlanId: newPlanId(),
+               },
+               RelType: &proto.Relation_SetOp{
+                       SetOp: &proto.SetOperation{
+                               LeftInput:  df.relation,
+                               RightInput: otherDf.relation,
+                               SetOpType:  
proto.SetOperation_SET_OP_TYPE_UNION,
+                               IsAll:      &isAll,
+                       },
+               },
+       }
+       return NewDataFrame(df.session, rel)
+}
+
+func (df *dataFrameImpl) Union(ctx context.Context, other DataFrame) DataFrame 
{
+       return df.UnionAll(ctx, other)
+}
+
+func (df *dataFrameImpl) UnionByName(ctx context.Context, other DataFrame) 
DataFrame {
+       otherDf := other.(*dataFrameImpl)
+       byName := true
+       allowMissingColumns := false
+       rel := &proto.Relation{
+               Common: &proto.RelationCommon{
+                       PlanId: newPlanId(),
+               },
+               RelType: &proto.Relation_SetOp{
+                       SetOp: &proto.SetOperation{
+                               LeftInput:           df.relation,
+                               RightInput:          otherDf.relation,
+                               SetOpType:           
proto.SetOperation_SET_OP_TYPE_UNION,
+                               ByName:              &byName,
+                               AllowMissingColumns: &allowMissingColumns,
+                       },
+               },
+       }
+       return NewDataFrame(df.session, rel)
+}
+
+func (df *dataFrameImpl) UnionByNameWithMissingColumns(ctx context.Context, 
other DataFrame) DataFrame {
+       otherDf := other.(*dataFrameImpl)
+       byName := true
+       allowMissingColumns := true
+       rel := &proto.Relation{
+               Common: &proto.RelationCommon{
+                       PlanId: newPlanId(),
+               },
+               RelType: &proto.Relation_SetOp{
+                       SetOp: &proto.SetOperation{
+                               LeftInput:           df.relation,
+                               RightInput:          otherDf.relation,
+                               SetOpType:           
proto.SetOperation_SET_OP_TYPE_UNION,
+                               ByName:              &byName,
+                               AllowMissingColumns: &allowMissingColumns,
+                       },
+               },
+       }
+       return NewDataFrame(df.session, rel)
+}
+
+func (df *dataFrameImpl) ExceptAll(ctx context.Context, other DataFrame) 
DataFrame {
+       otherDf := other.(*dataFrameImpl)
+       isAll := true
+       rel := &proto.Relation{
+               Common: &proto.RelationCommon{
+                       PlanId: newPlanId(),
+               },
+               RelType: &proto.Relation_SetOp{
+                       SetOp: &proto.SetOperation{
+                               LeftInput:  df.relation,
+                               RightInput: otherDf.relation,
+                               SetOpType:  
proto.SetOperation_SET_OP_TYPE_EXCEPT,
+                               IsAll:      &isAll,
+                       },
+               },
+       }
+       return NewDataFrame(df.session, rel)
+}
+
+func (df *dataFrameImpl) Subtract(ctx context.Context, other DataFrame) 
DataFrame {
+       otherDf := other.(*dataFrameImpl)
+       isAll := false
+       rel := &proto.Relation{
+               Common: &proto.RelationCommon{
+                       PlanId: newPlanId(),
+               },
+               RelType: &proto.Relation_SetOp{
+                       SetOp: &proto.SetOperation{
+                               LeftInput:  df.relation,
+                               RightInput: otherDf.relation,
+                               SetOpType:  
proto.SetOperation_SET_OP_TYPE_EXCEPT,
+                               IsAll:      &isAll,
+                       },
+               },
+       }
+       return NewDataFrame(df.session, rel)
+}
+
+func (df *dataFrameImpl) Intersect(ctx context.Context, other DataFrame) 
DataFrame {
+       otherDf := other.(*dataFrameImpl)
+       isAll := false
+       rel := &proto.Relation{
+               Common: &proto.RelationCommon{
+                       PlanId: newPlanId(),
+               },
+               RelType: &proto.Relation_SetOp{
+                       SetOp: &proto.SetOperation{
+                               LeftInput:  df.relation,
+                               RightInput: otherDf.relation,
+                               SetOpType:  
proto.SetOperation_SET_OP_TYPE_INTERSECT,
+                               IsAll:      &isAll,
+                       },
+               },
+       }
+       return NewDataFrame(df.session, rel)
+}
+
+func (df *dataFrameImpl) IntersectAll(ctx context.Context, other DataFrame) 
DataFrame {
+       otherDf := other.(*dataFrameImpl)
+       isAll := true
+       rel := &proto.Relation{
+               Common: &proto.RelationCommon{
+                       PlanId: newPlanId(),
+               },
+               RelType: &proto.Relation_SetOp{
+                       SetOp: &proto.SetOperation{
+                               LeftInput:  df.relation,
+                               RightInput: otherDf.relation,
+                               SetOpType:  
proto.SetOperation_SET_OP_TYPE_INTERSECT,
+                               IsAll:      &isAll,
+                       },
+               },
+       }
+       return NewDataFrame(df.session, rel)
+}
+
+func (df *dataFrameImpl) Sort(ctx context.Context, columns 
...column.Convertible) (DataFrame, error) {
+       globalSort := true
+       sortExprs := make([]*proto.Expression_SortOrder, 0, len(columns))
+       for _, c := range columns {
+               expr, err := c.ToProto(ctx)
+               if err != nil {
+                       return nil, err
+               }
+               sortExprs = append(sortExprs, expr.GetSortOrder())
+       }
+
+       rel := &proto.Relation{
+               Common: &proto.RelationCommon{
+                       PlanId: newPlanId(),
+               },
+               RelType: &proto.Relation_Sort{
+                       Sort: &proto.Sort{
+                               Input:    df.relation,
+                               Order:    sortExprs,
+                               IsGlobal: &globalSort,
+                       },
+               },
+       }
+       return NewDataFrame(df.session, rel), nil
+}
+
+func (df *dataFrameImpl) SortWithinPartitions(ctx context.Context, columns 
...column.Convertible) (DataFrame, error) {
+       globalSort := false
+       sortExprs := make([]*proto.Expression_SortOrder, 0, len(columns))
+       for _, c := range columns {
+               expr, err := c.ToProto(ctx)
+               if err != nil {
+                       return nil, err
+               }
+               sortExprs = append(sortExprs, expr.GetSortOrder())
+       }
+
+       rel := &proto.Relation{
+               Common: &proto.RelationCommon{
+                       PlanId: newPlanId(),
+               },
+               RelType: &proto.Relation_Sort{
+                       Sort: &proto.Sort{
+                               Input:    df.relation,
+                               Order:    sortExprs,
+                               IsGlobal: &globalSort,
+                       },
+               },
+       }
+       return NewDataFrame(df.session, rel), nil
+}
+
+func (df *dataFrameImpl) OrderBy(ctx context.Context, columns 
...column.Convertible) (DataFrame, error) {
+       return df.Sort(ctx, columns...)
+}
+
+func (df *dataFrameImpl) Explain(ctx context.Context, explainMode 
utils.ExplainMode) (string, error) {
+       plan := df.createPlan()
+
+       responseClient, err := df.session.client.Explain(ctx, plan, explainMode)
+       if err != nil {
+               return "", sparkerrors.WithType(fmt.Errorf("failed to execute 
plan: %w", err), sparkerrors.ExecutionError)
+       }
+       return responseClient.GetExplain().GetExplainString(), nil
+}
+
+func (df *dataFrameImpl) Persist(ctx context.Context, storageLevel 
utils.StorageLevel) error {
+       plan := &proto.Plan{
+               OpType: &proto.Plan_Root{
+                       Root: df.relation,
+               },
+       }
+       return df.session.client.Persist(ctx, plan, storageLevel)
+}
+
+func (df *dataFrameImpl) Cache(ctx context.Context) error {
+       return df.Persist(ctx, utils.StorageLevelMemoryOnly)
+}
+
+func (df *dataFrameImpl) Unpersist(ctx context.Context) error {
+       plan := &proto.Plan{
+               OpType: &proto.Plan_Root{
+                       Root: df.relation,
+               },
+       }
+       return df.session.client.Unpersist(ctx, plan)
+}
+
+func (df *dataFrameImpl) GetStorageLevel(ctx context.Context) 
(*utils.StorageLevel, error) {
+       plan := &proto.Plan{
+               OpType: &proto.Plan_Root{
+                       Root: df.relation,
+               },
+       }
+       return df.session.client.GetStorageLevel(ctx, plan)
+}
+
+func (df *dataFrameImpl) SameSemantics(ctx context.Context, other DataFrame) 
(bool, error) {
+       otherDf := other.(*dataFrameImpl)
+       plan := &proto.Plan{
+               OpType: &proto.Plan_Root{
+                       Root: df.relation,
+               },
+       }
+       otherPlan := &proto.Plan{
+               OpType: &proto.Plan_Root{
+                       Root: otherDf.relation,
+               },
+       }
+       return df.session.client.SameSemantics(ctx, plan, otherPlan)
+}
+
+func (df *dataFrameImpl) SemanticHash(ctx context.Context) (int32, error) {
+       plan := &proto.Plan{
+               OpType: &proto.Plan_Root{
+                       Root: df.relation,
+               },
+       }
+       return df.session.client.SemanticHash(ctx, plan)
+}
diff --git a/spark/sql/utils/consts.go b/spark/sql/utils/consts.go
new file mode 100644
index 0000000..d3f287f
--- /dev/null
+++ b/spark/sql/utils/consts.go
@@ -0,0 +1,98 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package utils
+
+import proto "github.com/apache/spark-connect-go/v35/internal/generated"
+
+type ExplainMode int
+
+const (
+       ExplainModeSimple    ExplainMode = iota
+       ExplainModeExtended  ExplainMode = iota
+       ExplainModeCodegen   ExplainMode = iota
+       ExplainModeCost      ExplainMode = iota
+       ExplainModeFormatted ExplainMode = iota
+)
+
+type StorageLevel int
+
+const (
+       StorageLevelDiskOnly          StorageLevel = iota
+       StorageLevelDiskOnly2         StorageLevel = iota
+       StorageLevelDiskOnly3         StorageLevel = iota
+       StorageLevelMemoryAndDisk     StorageLevel = iota
+       StorageLevelMemoryAndDisk2    StorageLevel = iota
+       StorageLevelMemoryOnly        StorageLevel = iota
+       StorageLevelMemoryOnly2       StorageLevel = iota
+       StorageLevelMemoyAndDiskDeser StorageLevel = iota
+       StorageLevelNone              StorageLevel = iota
+       StorageLevelOffHeap           StorageLevel = iota
+)
+
+func ToProtoStorageLevel(level StorageLevel) *proto.StorageLevel {
+       switch level {
+       case StorageLevelDiskOnly:
+               return &proto.StorageLevel{UseDisk: true, UseMemory: false, 
Replication: 1}
+       case StorageLevelDiskOnly2:
+               return &proto.StorageLevel{UseDisk: true, UseMemory: false, 
Replication: 2}
+       case StorageLevelDiskOnly3:
+               return &proto.StorageLevel{UseDisk: true, UseMemory: false, 
Replication: 3}
+       case StorageLevelMemoryAndDisk:
+               return &proto.StorageLevel{UseDisk: true, UseMemory: true, 
Replication: 1}
+       case StorageLevelMemoryAndDisk2:
+               return &proto.StorageLevel{UseDisk: true, UseMemory: true, 
Replication: 2}
+       case StorageLevelMemoryOnly:
+               return &proto.StorageLevel{UseDisk: false, UseMemory: true, 
Replication: 1}
+       case StorageLevelMemoryOnly2:
+               return &proto.StorageLevel{UseDisk: false, UseMemory: true, 
Replication: 2}
+       case StorageLevelMemoyAndDiskDeser:
+               return &proto.StorageLevel{UseDisk: true, UseMemory: true, 
Replication: 1, Deserialized: true}
+       case StorageLevelOffHeap:
+               return &proto.StorageLevel{UseDisk: true, UseMemory: true, 
UseOffHeap: true, Replication: 1}
+       default:
+               return &proto.StorageLevel{UseDisk: false, UseMemory: false, 
UseOffHeap: false, Replication: 1}
+       }
+}
+
+func FromProtoStorageLevel(level *proto.StorageLevel) StorageLevel {
+       if level.UseDisk && level.UseMemory && level.Replication <= 1 && 
!level.Deserialized && !level.UseOffHeap {
+               return StorageLevelMemoryAndDisk
+       } else if level.UseDisk && level.UseMemory && level.Replication == 2 && 
!level.Deserialized && !level.UseOffHeap {
+               return StorageLevelMemoryAndDisk2
+       } else if level.UseDisk && !level.UseMemory && level.Replication == 3 &&
+               !level.Deserialized && !level.UseOffHeap {
+               return StorageLevelDiskOnly3
+       } else if level.UseDisk && !level.UseMemory && level.Replication == 2 &&
+               !level.Deserialized && !level.UseOffHeap {
+               return StorageLevelDiskOnly2
+       } else if level.UseDisk && !level.UseMemory && level.Replication <= 1 &&
+               !level.Deserialized && !level.UseOffHeap {
+               return StorageLevelDiskOnly
+       } else if !level.UseDisk && level.UseMemory && level.Replication <= 1 &&
+               !level.Deserialized && !level.UseOffHeap {
+               return StorageLevelMemoryOnly
+       } else if !level.UseDisk && level.UseMemory && level.Replication == 2 &&
+               !level.Deserialized && !level.UseOffHeap {
+               return StorageLevelMemoryOnly2
+       } else if level.UseDisk && level.UseMemory && level.Replication <= 1 && 
level.Deserialized && !level.UseOffHeap {
+               return StorageLevelMemoyAndDiskDeser
+       } else if !level.UseDisk && !level.UseMemory && !level.Deserialized && 
!level.UseOffHeap {
+               return StorageLevelNone
+       } else if level.UseOffHeap && !level.Deserialized {
+               return StorageLevelOffHeap
+       }
+       return StorageLevelNone
+}
diff --git a/spark/sql/utils/consts_test.go b/spark/sql/utils/consts_test.go
new file mode 100644
index 0000000..fbb1fc7
--- /dev/null
+++ b/spark/sql/utils/consts_test.go
@@ -0,0 +1,41 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package utils
+
+import "testing"
+
+func TestStorageLevelConversion(t *testing.T) {
+       // Given a list of all storage levels, convert them to and from proto 
and
+       // check with the original value:
+       for _, level := range []StorageLevel{
+               StorageLevelDiskOnly,
+               StorageLevelDiskOnly2,
+               StorageLevelDiskOnly3,
+               StorageLevelMemoryAndDisk,
+               StorageLevelMemoryAndDisk2,
+               StorageLevelMemoryOnly,
+               StorageLevelMemoryOnly2,
+               StorageLevelMemoyAndDiskDeser,
+               StorageLevelNone,
+               StorageLevelOffHeap,
+       } {
+               protoLevel := ToProtoStorageLevel(level)
+               convertedLevel := FromProtoStorageLevel(protoLevel)
+               if level != convertedLevel {
+                       t.Errorf("Expected %v, got %v", level, convertedLevel)
+               }
+       }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to