This is an automated email from the ASF dual-hosted git repository.
mgrund pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark-connect-go.git
The following commit(s) were added to refs/heads/master by this push:
new 2774c11 [#58] More DataFrame Features
2774c11 is described below
commit 2774c1128fe12bbb18af02260ca90275c174bd39
Author: Martin Grund <[email protected]>
AuthorDate: Thu Sep 26 09:05:58 2024 +0200
[#58] More DataFrame Features
### What changes were proposed in this pull request?
This patch adds support for the following DataFrame operations:
* `ExceptALl`
* `Head`
* `Intersect`
* `IntersectAll`
* `Limit`
* `Subtract`
* `Tail`
* `Take`
* `ToArrow`
* `Union`
* `UnionAll`
* `UnionByName`
* `UnionByNameWithMissingColumns`
* `Explain`
* `StorageLevel`
* `Persist`
* `Unpersist`
* `Cache`
### Why are the changes needed?
Compatibility
### Does this PR introduce _any_ user-facing change?
Adding the above new features.
### How was this patch tested?
Added E2E tests.
Closes #70 from grundprinzip/next_batch_v2.
Authored-by: Martin Grund <[email protected]>
Signed-off-by: Martin Grund <[email protected]>
---
internal/tests/integration/dataframe_test.go | 120 ++++++++-
spark/client/base/base.go | 10 +
spark/client/client.go | 202 +++++++++++++++
spark/mocks/mock_executor.go | 37 +++
spark/sparkerrors/errors.go | 1 +
spark/sql/dataframe.go | 365 +++++++++++++++++++++++++++
spark/sql/utils/consts.go | 98 +++++++
spark/sql/utils/consts_test.go | 41 +++
8 files changed, 872 insertions(+), 2 deletions(-)
diff --git a/internal/tests/integration/dataframe_test.go
b/internal/tests/integration/dataframe_test.go
index a61d9ee..ec4469b 100644
--- a/internal/tests/integration/dataframe_test.go
+++ b/internal/tests/integration/dataframe_test.go
@@ -19,6 +19,8 @@ import (
"context"
"testing"
+ "github.com/apache/spark-connect-go/v35/spark/sql/utils"
+
"github.com/apache/spark-connect-go/v35/spark/sql/types"
"github.com/apache/spark-connect-go/v35/spark/sql/column"
@@ -248,8 +250,8 @@ func TestDataFrame_WithColumns(t *testing.T) {
vals, err := row.Values()
assert.NoError(t, err)
assert.Equal(t, 3, len(vals))
- assert.Equal(t, int64(1), vals[1])
- assert.Equal(t, int64(2), vals[2])
+ assert.Equal(t, int64(1), vals[1], "%v", vals)
+ assert.Equal(t, int64(2), vals[2], "%v", vals)
}
}
@@ -409,3 +411,117 @@ func TestDataFrame_DropDuplicates(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, int64(1), res)
}
+
+func TestDataFrame_Explain(t *testing.T) {
+ ctx, spark := connect()
+ df, err := spark.Sql(ctx, "select * from range(10)")
+ assert.NoError(t, err)
+ res, err := df.Explain(ctx, utils.ExplainModeSimple)
+ assert.NoError(t, err)
+ assert.Contains(t, res, "Physical Plan")
+
+ res, err = df.Explain(ctx, utils.ExplainModeExtended)
+ assert.NoError(t, err)
+ assert.Contains(t, res, "Physical Plan")
+
+ res, err = df.Explain(ctx, utils.ExplainModeCodegen)
+ assert.NoError(t, err)
+ assert.Contains(t, res, "WholeStageCodegen")
+
+ res, err = df.Explain(ctx, utils.ExplainModeCost)
+ assert.NoError(t, err)
+ assert.Contains(t, res, "Physical Plan")
+
+ res, err = df.Explain(ctx, utils.ExplainModeFormatted)
+ assert.NoError(t, err)
+ assert.Contains(t, res, "Physical Plan")
+}
+
+func TestDataFrame_CachingAndPersistence(t *testing.T) {
+ ctx, spark := connect()
+ levels := []utils.StorageLevel{
+ utils.StorageLevelDiskOnly,
+ utils.StorageLevelDiskOnly2,
+ utils.StorageLevelDiskOnly3,
+ utils.StorageLevelMemoryAndDisk,
+ utils.StorageLevelMemoryAndDisk2,
+ utils.StorageLevelMemoryOnly,
+ utils.StorageLevelMemoryOnly2,
+ utils.StorageLevelMemoyAndDiskDeser,
+ utils.StorageLevelOffHeap,
+ }
+
+ for _, lvl := range levels {
+ df, err := spark.Sql(ctx, "select * from range(10)")
+ assert.NoError(t, err)
+ err = df.Persist(ctx, lvl)
+ assert.NoError(t, err)
+ l, err := df.GetStorageLevel(ctx)
+ assert.NoError(t, err)
+
+ assert.Contains(t, []utils.StorageLevel{lvl,
utils.StorageLevelMemoryOnly}, *l)
+
+ err = df.Unpersist(ctx)
+ assert.NoError(t, err)
+ }
+ df, err := spark.Sql(ctx, "select * from range(10)")
+ assert.NoError(t, err)
+ err = df.Cache(ctx)
+ assert.NoError(t, err)
+ l, err := df.GetStorageLevel(ctx)
+ assert.NoError(t, err)
+ assert.Equal(t, utils.StorageLevelMemoryOnly, *l, "%v != %v",
utils.StorageLevelMemoryOnly, *l)
+}
+
+func TestDataFrame_SetOps(t *testing.T) {
+ ctx, spark := connect()
+ df1, err := spark.Sql(ctx, "select * from range(10)")
+ assert.NoError(t, err)
+ df2, err := spark.Sql(ctx, "select * from range(5)")
+ assert.NoError(t, err)
+
+ df := df1.Union(ctx, df2)
+ res, err := df.Collect(ctx)
+ assert.NoError(t, err)
+ assert.Equal(t, 15, len(res))
+
+ df = df1.Intersect(ctx, df2)
+ res, err = df.Collect(ctx)
+ assert.NoError(t, err)
+ assert.Equal(t, 5, len(res))
+
+ df = df1.ExceptAll(ctx, df2)
+ res, err = df.Collect(ctx)
+ assert.NoError(t, err)
+ assert.Equal(t, 5, len(res))
+}
+
+func TestDataFrame_ToArrow(t *testing.T) {
+ ctx, spark := connect()
+ df, err := spark.Sql(ctx, "select * from range(10)")
+ assert.NoError(t, err)
+ tbl, err := df.ToArrow(ctx)
+ assert.NoError(t, err)
+ assert.NotNil(t, tbl)
+}
+
+func TestDataFrame_LimitVersions(t *testing.T) {
+ ctx, spark := connect()
+ df, err := spark.Sql(ctx, "select * from range(10)")
+ assert.NoError(t, err)
+ rows, err := df.Limit(ctx, int32(5))
+ assert.NoError(t, err)
+ assert.Len(t, rows, 5)
+
+ rows, err = df.Tail(ctx, int32(3))
+ assert.NoError(t, err)
+ assert.Len(t, rows, 3)
+
+ rows, err = df.Head(ctx, int32(3))
+ assert.NoError(t, err)
+ assert.Len(t, rows, 3)
+
+ rows, err = df.Take(ctx, int32(3))
+ assert.NoError(t, err)
+ assert.Len(t, rows, 3)
+}
diff --git a/spark/client/base/base.go b/spark/client/base/base.go
index 2d4080a..25dce70 100644
--- a/spark/client/base/base.go
+++ b/spark/client/base/base.go
@@ -18,6 +18,8 @@ package base
import (
"context"
+ "github.com/apache/spark-connect-go/v35/spark/sql/utils"
+
"github.com/apache/arrow/go/v17/arrow"
"github.com/apache/spark-connect-go/v35/internal/generated"
"github.com/apache/spark-connect-go/v35/spark/sql/types"
@@ -33,6 +35,14 @@ type SparkConnectClient interface {
ExecutePlan(ctx context.Context, plan *generated.Plan)
(ExecuteResponseStream, error)
ExecuteCommand(ctx context.Context, plan *generated.Plan) (arrow.Table,
*types.StructType, map[string]any, error)
AnalyzePlan(ctx context.Context, plan *generated.Plan)
(*generated.AnalyzePlanResponse, error)
+ Explain(ctx context.Context, plan *generated.Plan, explainMode
utils.ExplainMode) (*generated.AnalyzePlanResponse, error)
+ Persist(ctx context.Context, plan *generated.Plan, storageLevel
utils.StorageLevel) error
+ Unpersist(ctx context.Context, plan *generated.Plan) error
+ GetStorageLevel(ctx context.Context, plan *generated.Plan)
(*utils.StorageLevel, error)
+ SparkVersion(ctx context.Context) (string, error)
+ DDLParse(ctx context.Context, sql string) (*types.StructType, error)
+ SameSemantics(ctx context.Context, plan1 *generated.Plan, plan2
*generated.Plan) (bool, error)
+ SemanticHash(ctx context.Context, plan *generated.Plan) (int32, error)
}
type ExecuteResponseStream interface {
diff --git a/spark/client/client.go b/spark/client/client.go
index 75f3c80..d1834c8 100644
--- a/spark/client/client.go
+++ b/spark/client/client.go
@@ -21,6 +21,8 @@ import (
"fmt"
"io"
+ "github.com/apache/spark-connect-go/v35/spark/sql/utils"
+
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
@@ -129,6 +131,206 @@ func (s *sparkConnectClientImpl) AnalyzePlan(ctx
context.Context, plan *proto.Pl
return response, nil
}
+func (s *sparkConnectClientImpl) Explain(ctx context.Context, plan *proto.Plan,
+ explainMode utils.ExplainMode,
+) (*proto.AnalyzePlanResponse, error) {
+ var mode proto.AnalyzePlanRequest_Explain_ExplainMode
+ if explainMode == utils.ExplainModeExtended {
+ mode = proto.AnalyzePlanRequest_Explain_EXPLAIN_MODE_EXTENDED
+ } else if explainMode == utils.ExplainModeSimple {
+ mode = proto.AnalyzePlanRequest_Explain_EXPLAIN_MODE_SIMPLE
+ } else if explainMode == utils.ExplainModeCost {
+ mode = proto.AnalyzePlanRequest_Explain_EXPLAIN_MODE_COST
+ } else if explainMode == utils.ExplainModeFormatted {
+ mode = proto.AnalyzePlanRequest_Explain_EXPLAIN_MODE_FORMATTED
+ } else if explainMode == utils.ExplainModeCodegen {
+ mode = proto.AnalyzePlanRequest_Explain_EXPLAIN_MODE_CODEGEN
+ } else {
+ return nil, sparkerrors.WithType(fmt.Errorf("unsupported
explain mode %v",
+ explainMode), sparkerrors.InvalidArgumentError)
+ }
+
+ request := proto.AnalyzePlanRequest{
+ SessionId: s.sessionId,
+ Analyze: &proto.AnalyzePlanRequest_Explain_{
+ Explain: &proto.AnalyzePlanRequest_Explain{
+ Plan: plan,
+ ExplainMode: mode,
+ },
+ },
+ UserContext: &proto.UserContext{
+ UserId: "na",
+ },
+ }
+ // Append the other items to the request.
+ ctx = metadata.NewOutgoingContext(ctx, s.metadata)
+
+ response, err := s.client.AnalyzePlan(ctx, &request)
+ if se := sparkerrors.FromRPCError(err); se != nil {
+ return nil, sparkerrors.WithType(se, sparkerrors.ExecutionError)
+ }
+ return response, nil
+}
+
+func (s *sparkConnectClientImpl) Persist(ctx context.Context, plan
*proto.Plan, storageLevel utils.StorageLevel) error {
+ protoLevel := utils.ToProtoStorageLevel(storageLevel)
+
+ request := proto.AnalyzePlanRequest{
+ SessionId: s.sessionId,
+ Analyze: &proto.AnalyzePlanRequest_Persist_{
+ Persist: &proto.AnalyzePlanRequest_Persist{
+ Relation: plan.GetRoot(),
+ StorageLevel: protoLevel,
+ },
+ },
+ UserContext: &proto.UserContext{
+ UserId: "na",
+ },
+ }
+ // Append the other items to the request.
+ ctx = metadata.NewOutgoingContext(ctx, s.metadata)
+
+ _, err := s.client.AnalyzePlan(ctx, &request)
+ if se := sparkerrors.FromRPCError(err); se != nil {
+ return sparkerrors.WithType(se, sparkerrors.ExecutionError)
+ }
+ return nil
+}
+
+func (s *sparkConnectClientImpl) Unpersist(ctx context.Context, plan
*proto.Plan) error {
+ request := proto.AnalyzePlanRequest{
+ SessionId: s.sessionId,
+ Analyze: &proto.AnalyzePlanRequest_Unpersist_{
+ Unpersist: &proto.AnalyzePlanRequest_Unpersist{
+ Relation: plan.GetRoot(),
+ },
+ },
+ UserContext: &proto.UserContext{
+ UserId: "na",
+ },
+ }
+ // Append the other items to the request.
+ ctx = metadata.NewOutgoingContext(ctx, s.metadata)
+
+ _, err := s.client.AnalyzePlan(ctx, &request)
+ if se := sparkerrors.FromRPCError(err); se != nil {
+ return sparkerrors.WithType(se, sparkerrors.ExecutionError)
+ }
+ return nil
+}
+
+func (s *sparkConnectClientImpl) GetStorageLevel(ctx context.Context, plan
*proto.Plan) (*utils.StorageLevel, error) {
+ request := proto.AnalyzePlanRequest{
+ SessionId: s.sessionId,
+ Analyze: &proto.AnalyzePlanRequest_GetStorageLevel_{
+ GetStorageLevel:
&proto.AnalyzePlanRequest_GetStorageLevel{
+ Relation: plan.GetRoot(),
+ },
+ },
+ UserContext: &proto.UserContext{
+ UserId: "na",
+ },
+ }
+ // Append the other items to the request.
+ ctx = metadata.NewOutgoingContext(ctx, s.metadata)
+
+ response, err := s.client.AnalyzePlan(ctx, &request)
+ if se := sparkerrors.FromRPCError(err); se != nil {
+ return nil, sparkerrors.WithType(se, sparkerrors.ExecutionError)
+ }
+
+ level := response.GetGetStorageLevel().StorageLevel
+ res := utils.FromProtoStorageLevel(level)
+ return &res, nil
+}
+
+func (s *sparkConnectClientImpl) SparkVersion(ctx context.Context) (string,
error) {
+ request := proto.AnalyzePlanRequest{
+ SessionId: s.sessionId,
+ Analyze: &proto.AnalyzePlanRequest_SparkVersion_{
+ SparkVersion: &proto.AnalyzePlanRequest_SparkVersion{},
+ },
+ UserContext: &proto.UserContext{
+ UserId: "na",
+ },
+ }
+ // Append the other items to the request.
+ ctx = metadata.NewOutgoingContext(ctx, s.metadata)
+
+ response, err := s.client.AnalyzePlan(ctx, &request)
+ if se := sparkerrors.FromRPCError(err); se != nil {
+ return "", sparkerrors.WithType(se, sparkerrors.ExecutionError)
+ }
+ return response.GetSparkVersion().Version, nil
+}
+
+func (s *sparkConnectClientImpl) DDLParse(ctx context.Context, sql string)
(*types.StructType, error) {
+ request := proto.AnalyzePlanRequest{
+ SessionId: s.sessionId,
+ Analyze: &proto.AnalyzePlanRequest_DdlParse{
+ DdlParse: &proto.AnalyzePlanRequest_DDLParse{
+ DdlString: sql,
+ },
+ },
+ UserContext: &proto.UserContext{
+ UserId: "na",
+ },
+ }
+ // Append the other items to the request.
+ ctx = metadata.NewOutgoingContext(ctx, s.metadata)
+
+ response, err := s.client.AnalyzePlan(ctx, &request)
+ if se := sparkerrors.FromRPCError(err); se != nil {
+ return nil, sparkerrors.WithType(se, sparkerrors.ExecutionError)
+ }
+ return
types.ConvertProtoDataTypeToStructType(response.GetDdlParse().Parsed)
+}
+
+func (s *sparkConnectClientImpl) SameSemantics(ctx context.Context, plan1
*proto.Plan, plan2 *proto.Plan) (bool, error) {
+ request := proto.AnalyzePlanRequest{
+ SessionId: s.sessionId,
+ Analyze: &proto.AnalyzePlanRequest_SameSemantics_{
+ SameSemantics: &proto.AnalyzePlanRequest_SameSemantics{
+ TargetPlan: plan1,
+ OtherPlan: plan2,
+ },
+ },
+ UserContext: &proto.UserContext{
+ UserId: "na",
+ },
+ }
+ // Append the other items to the request.
+ ctx = metadata.NewOutgoingContext(ctx, s.metadata)
+
+ response, err := s.client.AnalyzePlan(ctx, &request)
+ if se := sparkerrors.FromRPCError(err); se != nil {
+ return false, sparkerrors.WithType(se,
sparkerrors.ExecutionError)
+ }
+ return response.GetSameSemantics().GetResult(), nil
+}
+
+func (s *sparkConnectClientImpl) SemanticHash(ctx context.Context, plan
*proto.Plan) (int32, error) {
+ request := proto.AnalyzePlanRequest{
+ SessionId: s.sessionId,
+ Analyze: &proto.AnalyzePlanRequest_SemanticHash_{
+ SemanticHash: &proto.AnalyzePlanRequest_SemanticHash{
+ Plan: plan,
+ },
+ },
+ UserContext: &proto.UserContext{
+ UserId: "na",
+ },
+ }
+ // Append the other items to the request.
+ ctx = metadata.NewOutgoingContext(ctx, s.metadata)
+
+ response, err := s.client.AnalyzePlan(ctx, &request)
+ if se := sparkerrors.FromRPCError(err); se != nil {
+ return 0, sparkerrors.WithType(se, sparkerrors.ExecutionError)
+ }
+ return response.GetSemanticHash().GetResult(), nil
+}
+
func NewSparkExecutor(conn *grpc.ClientConn, md metadata.MD, sessionId string,
opts options.SparkClientOptions) base.SparkConnectClient {
var client base.SparkConnectRPCClient
if opts.ReattachExecution {
diff --git a/spark/mocks/mock_executor.go b/spark/mocks/mock_executor.go
index b70c4fa..c42e3e6 100644
--- a/spark/mocks/mock_executor.go
+++ b/spark/mocks/mock_executor.go
@@ -17,6 +17,9 @@ package mocks
import (
"context"
+ "errors"
+
+ "github.com/apache/spark-connect-go/v35/spark/sql/utils"
"github.com/apache/spark-connect-go/v35/spark/client/base"
@@ -42,9 +45,43 @@ func (t *TestExecutor) AnalyzePlan(ctx context.Context, plan
*generated.Plan) (*
return t.response, nil
}
+func (t *TestExecutor) Explain(ctx context.Context, plan *generated.Plan,
+ explainMode utils.ExplainMode,
+) (*generated.AnalyzePlanResponse, error) {
+ return nil, errors.New("not implemented")
+}
+
func (t *TestExecutor) ExecuteCommand(ctx context.Context, plan
*generated.Plan) (arrow.Table, *types.StructType, map[string]interface{},
error) {
if t.Err != nil {
return nil, nil, nil, t.Err
}
return nil, nil, nil, nil
}
+
+func (t *TestExecutor) Persist(ctx context.Context, plan *generated.Plan,
storageLevel utils.StorageLevel) error {
+ return errors.New("not implemented")
+}
+
+func (t *TestExecutor) Unpersist(ctx context.Context, plan *generated.Plan)
error {
+ return errors.New("not implemented")
+}
+
+func (t *TestExecutor) GetStorageLevel(ctx context.Context, plan
*generated.Plan) (*utils.StorageLevel, error) {
+ return nil, errors.New("not implemented")
+}
+
+func (t *TestExecutor) SparkVersion(ctx context.Context) (string, error) {
+ return "", errors.New("not implemented")
+}
+
+func (t *TestExecutor) DDLParse(ctx context.Context, sql string)
(*types.StructType, error) {
+ return nil, errors.New("not implemented")
+}
+
+func (t *TestExecutor) SameSemantics(ctx context.Context, plan1
*generated.Plan, plan2 *generated.Plan) (bool, error) {
+ return false, errors.New("not implemented")
+}
+
+func (t *TestExecutor) SemanticHash(ctx context.Context, plan *generated.Plan)
(int32, error) {
+ return 0, errors.New("not implemented")
+}
diff --git a/spark/sparkerrors/errors.go b/spark/sparkerrors/errors.go
index 21616e1..4603c4c 100644
--- a/spark/sparkerrors/errors.go
+++ b/spark/sparkerrors/errors.go
@@ -66,6 +66,7 @@ var (
TestSetupError = errorType(errors.New("test setup
error"))
WriteError = errorType(errors.New("write error"))
NotImplementedError = errorType(errors.New("not implemented"))
+ InvalidArgumentError = errorType(errors.New("invalid
argument"))
)
// Format formats the error, supporting both short forms (v, s, q) and verbose
form (+v)
diff --git a/spark/sql/dataframe.go b/spark/sql/dataframe.go
index 4f108c0..bcd0633 100644
--- a/spark/sql/dataframe.go
+++ b/spark/sql/dataframe.go
@@ -20,6 +20,9 @@ import (
"context"
"fmt"
+ "github.com/apache/arrow/go/v17/arrow"
+ "github.com/apache/spark-connect-go/v35/spark/sql/utils"
+
"github.com/apache/spark-connect-go/v35/spark/sql/column"
"github.com/apache/spark-connect-go/v35/spark/sql/types"
@@ -40,6 +43,8 @@ type DataFrame interface {
PlanId() int64
// Alias creates a new DataFrame with the specified subquery alias
Alias(ctx context.Context, alias string) DataFrame
+ // Cache persists the DataFrame with the default storage level.
+ Cache(ctx context.Context) error
// Coalesce returns a new DataFrame that has exactly numPartitions
partitions.DataFrame
//
// Similar to coalesce defined on an :class:`RDD`, this operation
results in a
@@ -72,20 +77,39 @@ type DataFrame interface {
CreateTempView(ctx context.Context, viewName string, replace, global
bool) error
// CrossJoin joins the current DataFrame with another DataFrame using
the cross product
CrossJoin(ctx context.Context, other DataFrame) DataFrame
+ // Drop returns a new DataFrame that drops the specified list of
columns.
Drop(ctx context.Context, columns ...column.Convertible) (DataFrame,
error)
+ // DropByName returns a new DataFrame that drops the specified list of
columns by name.
DropByName(ctx context.Context, columns ...string) (DataFrame, error)
+ // DropDuplicates returns a new DataFrame that contains only the unique
rows from this DataFrame.
DropDuplicates(ctx context.Context, columns ...string) (DataFrame,
error)
+ // ExceptAll is similar to Substract but does not perform the distinct
operation.
+ ExceptAll(ctx context.Context, other DataFrame) DataFrame
+ Explain(ctx context.Context, explainMode utils.ExplainMode) (string,
error)
// Filter filters the data frame by a column condition.
Filter(ctx context.Context, condition column.Convertible) (DataFrame,
error)
// FilterByString filters the data frame by a string condition.
FilterByString(ctx context.Context, condition string) (DataFrame, error)
+ // GetStorageLevel returns the storage level of the data frame.
+ GetStorageLevel(ctx context.Context) (*utils.StorageLevel, error)
// GroupBy groups the DataFrame by the spcified columns so that the
aggregation
// can be performed on them. See GroupedData for all the available
aggregate functions.
GroupBy(cols ...column.Convertible) *GroupedData
+ // Head is an alias for Limit
+ Head(ctx context.Context, limit int32) ([]Row, error)
+ // Intersect performs the set intersection of two data frames and only
returns distinct rows.
+ Intersect(ctx context.Context, other DataFrame) DataFrame
+ // IntersectAll performs the set intersection of two data frames and
returns all rows.
+ IntersectAll(ctx context.Context, other DataFrame) DataFrame
+ // Limit returns the first `limit` rows as a list of Row.
+ Limit(ctx context.Context, limit int32) ([]Row, error)
+ Persist(ctx context.Context, storageLevel utils.StorageLevel) error
// Repartition re-partitions a data frame.
Repartition(ctx context.Context, numPartitions int, columns []string)
(DataFrame, error)
// RepartitionByRange re-partitions a data frame by range partition.
RepartitionByRange(ctx context.Context, numPartitions int, columns
...column.Convertible) (DataFrame, error)
+ // SameSemantics returns true if the other DataFrame has the same
semantics.
+ SameSemantics(ctx context.Context, other DataFrame) (bool, error)
// Show uses WriteResult to write the data frames to the console output.
Show(ctx context.Context, numRows int, truncate bool) error
// Schema returns the schema for the current data frame.
@@ -94,6 +118,31 @@ type DataFrame interface {
Select(ctx context.Context, columns ...column.Convertible) (DataFrame,
error)
// SelectExpr projects a list of columns from the DataFrame by string
expressions
SelectExpr(ctx context.Context, exprs ...string) (DataFrame, error)
+ // SemanticHash returns the semantic hash of the data frame. The
semantic hash can be used to
+ // understand of the semantic operations are similar.
+ SemanticHash(ctx context.Context) (int32, error)
+ // Subtract subtracts the other DataFrame from the current DataFrame.
And only returns
+ // distinct rows.
+ Subtract(ctx context.Context, other DataFrame) DataFrame
+ // Tail returns the last `limit` rows as a list of Row.
+ Tail(ctx context.Context, limit int32) ([]Row, error)
+ // Take is an alias for Limit
+ Take(ctx context.Context, limit int32) ([]Row, error)
+ // ToArrow returns the Arrow representation of the DataFrame.
+ ToArrow(ctx context.Context) (*arrow.Table, error)
+ // Union is an alias for UnionAll
+ Union(ctx context.Context, other DataFrame) DataFrame
+ // UnionAll returns a new DataFrame containing union of rows in this
and another DataFrame.
+ UnionAll(ctx context.Context, other DataFrame) DataFrame
+ // UnionByName performs a SQL union operation on two dataframes but
reorders the schema
+ // according to the matching columns. If columns are missing, it will
throw an eror.
+ UnionByName(ctx context.Context, other DataFrame) DataFrame
+ // UnionByNameWithMissingColumns performs a SQL union operation on two
dataframes but reorders the schema
+ // according to the matching columns. Missing columns are supported.
+ UnionByNameWithMissingColumns(ctx context.Context, other DataFrame)
DataFrame
+ // Unpersist resets the storage level for this data frame, and if
necessary removes it
+ // from server-side caches.
+ Unpersist(ctx context.Context) error
// WithColumn returns a new DataFrame by adding a column or replacing
the
// existing column that has the same name. The column expression must
be an
// expression over this DataFrame; attempting to add a column from some
other
@@ -741,3 +790,319 @@ func (df *dataFrameImpl) DropDuplicates(ctx
context.Context, columns ...string)
}
return NewDataFrame(df.session, rel), nil
}
+
+func (df *dataFrameImpl) Tail(ctx context.Context, limit int32) ([]Row, error)
{
+ rel := &proto.Relation{
+ Common: &proto.RelationCommon{
+ PlanId: newPlanId(),
+ },
+ RelType: &proto.Relation_Tail{
+ Tail: &proto.Tail{
+ Input: df.relation,
+ Limit: limit,
+ },
+ },
+ }
+ data := NewDataFrame(df.session, rel)
+ return data.Collect(ctx)
+}
+
+func (df *dataFrameImpl) Limit(ctx context.Context, limit int32) ([]Row,
error) {
+ rel := &proto.Relation{
+ Common: &proto.RelationCommon{
+ PlanId: newPlanId(),
+ },
+ RelType: &proto.Relation_Limit{
+ Limit: &proto.Limit{
+ Input: df.relation,
+ Limit: limit,
+ },
+ },
+ }
+ data := NewDataFrame(df.session, rel)
+ return data.Collect(ctx)
+}
+
+func (df *dataFrameImpl) Head(ctx context.Context, limit int32) ([]Row, error)
{
+ return df.Limit(ctx, limit)
+}
+
+func (df *dataFrameImpl) Take(ctx context.Context, limit int32) ([]Row, error)
{
+ return df.Limit(ctx, limit)
+}
+
+func (df *dataFrameImpl) ToArrow(ctx context.Context) (*arrow.Table, error) {
+ responseClient, err := df.session.client.ExecutePlan(ctx,
df.createPlan())
+ if err != nil {
+ return nil, sparkerrors.WithType(fmt.Errorf("failed to execute
plan: %w", err), sparkerrors.ExecutionError)
+ }
+
+ _, table, err := responseClient.ToTable()
+ if err != nil {
+ return nil, err
+ }
+
+ return &table, nil
+}
+
+func (df *dataFrameImpl) UnionAll(ctx context.Context, other DataFrame)
DataFrame {
+ otherDf := other.(*dataFrameImpl)
+ isAll := true
+ rel := &proto.Relation{
+ Common: &proto.RelationCommon{
+ PlanId: newPlanId(),
+ },
+ RelType: &proto.Relation_SetOp{
+ SetOp: &proto.SetOperation{
+ LeftInput: df.relation,
+ RightInput: otherDf.relation,
+ SetOpType:
proto.SetOperation_SET_OP_TYPE_UNION,
+ IsAll: &isAll,
+ },
+ },
+ }
+ return NewDataFrame(df.session, rel)
+}
+
+func (df *dataFrameImpl) Union(ctx context.Context, other DataFrame) DataFrame
{
+ return df.UnionAll(ctx, other)
+}
+
+func (df *dataFrameImpl) UnionByName(ctx context.Context, other DataFrame)
DataFrame {
+ otherDf := other.(*dataFrameImpl)
+ byName := true
+ allowMissingColumns := false
+ rel := &proto.Relation{
+ Common: &proto.RelationCommon{
+ PlanId: newPlanId(),
+ },
+ RelType: &proto.Relation_SetOp{
+ SetOp: &proto.SetOperation{
+ LeftInput: df.relation,
+ RightInput: otherDf.relation,
+ SetOpType:
proto.SetOperation_SET_OP_TYPE_UNION,
+ ByName: &byName,
+ AllowMissingColumns: &allowMissingColumns,
+ },
+ },
+ }
+ return NewDataFrame(df.session, rel)
+}
+
+func (df *dataFrameImpl) UnionByNameWithMissingColumns(ctx context.Context,
other DataFrame) DataFrame {
+ otherDf := other.(*dataFrameImpl)
+ byName := true
+ allowMissingColumns := true
+ rel := &proto.Relation{
+ Common: &proto.RelationCommon{
+ PlanId: newPlanId(),
+ },
+ RelType: &proto.Relation_SetOp{
+ SetOp: &proto.SetOperation{
+ LeftInput: df.relation,
+ RightInput: otherDf.relation,
+ SetOpType:
proto.SetOperation_SET_OP_TYPE_UNION,
+ ByName: &byName,
+ AllowMissingColumns: &allowMissingColumns,
+ },
+ },
+ }
+ return NewDataFrame(df.session, rel)
+}
+
+func (df *dataFrameImpl) ExceptAll(ctx context.Context, other DataFrame)
DataFrame {
+ otherDf := other.(*dataFrameImpl)
+ isAll := true
+ rel := &proto.Relation{
+ Common: &proto.RelationCommon{
+ PlanId: newPlanId(),
+ },
+ RelType: &proto.Relation_SetOp{
+ SetOp: &proto.SetOperation{
+ LeftInput: df.relation,
+ RightInput: otherDf.relation,
+ SetOpType:
proto.SetOperation_SET_OP_TYPE_EXCEPT,
+ IsAll: &isAll,
+ },
+ },
+ }
+ return NewDataFrame(df.session, rel)
+}
+
+func (df *dataFrameImpl) Subtract(ctx context.Context, other DataFrame)
DataFrame {
+ otherDf := other.(*dataFrameImpl)
+ isAll := false
+ rel := &proto.Relation{
+ Common: &proto.RelationCommon{
+ PlanId: newPlanId(),
+ },
+ RelType: &proto.Relation_SetOp{
+ SetOp: &proto.SetOperation{
+ LeftInput: df.relation,
+ RightInput: otherDf.relation,
+ SetOpType:
proto.SetOperation_SET_OP_TYPE_EXCEPT,
+ IsAll: &isAll,
+ },
+ },
+ }
+ return NewDataFrame(df.session, rel)
+}
+
+func (df *dataFrameImpl) Intersect(ctx context.Context, other DataFrame)
DataFrame {
+ otherDf := other.(*dataFrameImpl)
+ isAll := false
+ rel := &proto.Relation{
+ Common: &proto.RelationCommon{
+ PlanId: newPlanId(),
+ },
+ RelType: &proto.Relation_SetOp{
+ SetOp: &proto.SetOperation{
+ LeftInput: df.relation,
+ RightInput: otherDf.relation,
+ SetOpType:
proto.SetOperation_SET_OP_TYPE_INTERSECT,
+ IsAll: &isAll,
+ },
+ },
+ }
+ return NewDataFrame(df.session, rel)
+}
+
+func (df *dataFrameImpl) IntersectAll(ctx context.Context, other DataFrame)
DataFrame {
+ otherDf := other.(*dataFrameImpl)
+ isAll := true
+ rel := &proto.Relation{
+ Common: &proto.RelationCommon{
+ PlanId: newPlanId(),
+ },
+ RelType: &proto.Relation_SetOp{
+ SetOp: &proto.SetOperation{
+ LeftInput: df.relation,
+ RightInput: otherDf.relation,
+ SetOpType:
proto.SetOperation_SET_OP_TYPE_INTERSECT,
+ IsAll: &isAll,
+ },
+ },
+ }
+ return NewDataFrame(df.session, rel)
+}
+
+func (df *dataFrameImpl) Sort(ctx context.Context, columns
...column.Convertible) (DataFrame, error) {
+ globalSort := true
+ sortExprs := make([]*proto.Expression_SortOrder, 0, len(columns))
+ for _, c := range columns {
+ expr, err := c.ToProto(ctx)
+ if err != nil {
+ return nil, err
+ }
+ sortExprs = append(sortExprs, expr.GetSortOrder())
+ }
+
+ rel := &proto.Relation{
+ Common: &proto.RelationCommon{
+ PlanId: newPlanId(),
+ },
+ RelType: &proto.Relation_Sort{
+ Sort: &proto.Sort{
+ Input: df.relation,
+ Order: sortExprs,
+ IsGlobal: &globalSort,
+ },
+ },
+ }
+ return NewDataFrame(df.session, rel), nil
+}
+
+func (df *dataFrameImpl) SortWithinPartitions(ctx context.Context, columns
...column.Convertible) (DataFrame, error) {
+ globalSort := false
+ sortExprs := make([]*proto.Expression_SortOrder, 0, len(columns))
+ for _, c := range columns {
+ expr, err := c.ToProto(ctx)
+ if err != nil {
+ return nil, err
+ }
+ sortExprs = append(sortExprs, expr.GetSortOrder())
+ }
+
+ rel := &proto.Relation{
+ Common: &proto.RelationCommon{
+ PlanId: newPlanId(),
+ },
+ RelType: &proto.Relation_Sort{
+ Sort: &proto.Sort{
+ Input: df.relation,
+ Order: sortExprs,
+ IsGlobal: &globalSort,
+ },
+ },
+ }
+ return NewDataFrame(df.session, rel), nil
+}
+
+func (df *dataFrameImpl) OrderBy(ctx context.Context, columns
...column.Convertible) (DataFrame, error) {
+ return df.Sort(ctx, columns...)
+}
+
+func (df *dataFrameImpl) Explain(ctx context.Context, explainMode
utils.ExplainMode) (string, error) {
+ plan := df.createPlan()
+
+ responseClient, err := df.session.client.Explain(ctx, plan, explainMode)
+ if err != nil {
+ return "", sparkerrors.WithType(fmt.Errorf("failed to execute
plan: %w", err), sparkerrors.ExecutionError)
+ }
+ return responseClient.GetExplain().GetExplainString(), nil
+}
+
+func (df *dataFrameImpl) Persist(ctx context.Context, storageLevel
utils.StorageLevel) error {
+ plan := &proto.Plan{
+ OpType: &proto.Plan_Root{
+ Root: df.relation,
+ },
+ }
+ return df.session.client.Persist(ctx, plan, storageLevel)
+}
+
+func (df *dataFrameImpl) Cache(ctx context.Context) error {
+ return df.Persist(ctx, utils.StorageLevelMemoryOnly)
+}
+
+func (df *dataFrameImpl) Unpersist(ctx context.Context) error {
+ plan := &proto.Plan{
+ OpType: &proto.Plan_Root{
+ Root: df.relation,
+ },
+ }
+ return df.session.client.Unpersist(ctx, plan)
+}
+
+func (df *dataFrameImpl) GetStorageLevel(ctx context.Context)
(*utils.StorageLevel, error) {
+ plan := &proto.Plan{
+ OpType: &proto.Plan_Root{
+ Root: df.relation,
+ },
+ }
+ return df.session.client.GetStorageLevel(ctx, plan)
+}
+
+func (df *dataFrameImpl) SameSemantics(ctx context.Context, other DataFrame)
(bool, error) {
+ otherDf := other.(*dataFrameImpl)
+ plan := &proto.Plan{
+ OpType: &proto.Plan_Root{
+ Root: df.relation,
+ },
+ }
+ otherPlan := &proto.Plan{
+ OpType: &proto.Plan_Root{
+ Root: otherDf.relation,
+ },
+ }
+ return df.session.client.SameSemantics(ctx, plan, otherPlan)
+}
+
+func (df *dataFrameImpl) SemanticHash(ctx context.Context) (int32, error) {
+ plan := &proto.Plan{
+ OpType: &proto.Plan_Root{
+ Root: df.relation,
+ },
+ }
+ return df.session.client.SemanticHash(ctx, plan)
+}
diff --git a/spark/sql/utils/consts.go b/spark/sql/utils/consts.go
new file mode 100644
index 0000000..d3f287f
--- /dev/null
+++ b/spark/sql/utils/consts.go
@@ -0,0 +1,98 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package utils
+
+import proto "github.com/apache/spark-connect-go/v35/internal/generated"
+
+type ExplainMode int
+
+const (
+ ExplainModeSimple ExplainMode = iota
+ ExplainModeExtended ExplainMode = iota
+ ExplainModeCodegen ExplainMode = iota
+ ExplainModeCost ExplainMode = iota
+ ExplainModeFormatted ExplainMode = iota
+)
+
+type StorageLevel int
+
+const (
+ StorageLevelDiskOnly StorageLevel = iota
+ StorageLevelDiskOnly2 StorageLevel = iota
+ StorageLevelDiskOnly3 StorageLevel = iota
+ StorageLevelMemoryAndDisk StorageLevel = iota
+ StorageLevelMemoryAndDisk2 StorageLevel = iota
+ StorageLevelMemoryOnly StorageLevel = iota
+ StorageLevelMemoryOnly2 StorageLevel = iota
+ StorageLevelMemoyAndDiskDeser StorageLevel = iota
+ StorageLevelNone StorageLevel = iota
+ StorageLevelOffHeap StorageLevel = iota
+)
+
+func ToProtoStorageLevel(level StorageLevel) *proto.StorageLevel {
+ switch level {
+ case StorageLevelDiskOnly:
+ return &proto.StorageLevel{UseDisk: true, UseMemory: false,
Replication: 1}
+ case StorageLevelDiskOnly2:
+ return &proto.StorageLevel{UseDisk: true, UseMemory: false,
Replication: 2}
+ case StorageLevelDiskOnly3:
+ return &proto.StorageLevel{UseDisk: true, UseMemory: false,
Replication: 3}
+ case StorageLevelMemoryAndDisk:
+ return &proto.StorageLevel{UseDisk: true, UseMemory: true,
Replication: 1}
+ case StorageLevelMemoryAndDisk2:
+ return &proto.StorageLevel{UseDisk: true, UseMemory: true,
Replication: 2}
+ case StorageLevelMemoryOnly:
+ return &proto.StorageLevel{UseDisk: false, UseMemory: true,
Replication: 1}
+ case StorageLevelMemoryOnly2:
+ return &proto.StorageLevel{UseDisk: false, UseMemory: true,
Replication: 2}
+ case StorageLevelMemoyAndDiskDeser:
+ return &proto.StorageLevel{UseDisk: true, UseMemory: true,
Replication: 1, Deserialized: true}
+ case StorageLevelOffHeap:
+ return &proto.StorageLevel{UseDisk: true, UseMemory: true,
UseOffHeap: true, Replication: 1}
+ default:
+ return &proto.StorageLevel{UseDisk: false, UseMemory: false,
UseOffHeap: false, Replication: 1}
+ }
+}
+
+func FromProtoStorageLevel(level *proto.StorageLevel) StorageLevel {
+ if level.UseDisk && level.UseMemory && level.Replication <= 1 &&
!level.Deserialized && !level.UseOffHeap {
+ return StorageLevelMemoryAndDisk
+ } else if level.UseDisk && level.UseMemory && level.Replication == 2 &&
!level.Deserialized && !level.UseOffHeap {
+ return StorageLevelMemoryAndDisk2
+ } else if level.UseDisk && !level.UseMemory && level.Replication == 3 &&
+ !level.Deserialized && !level.UseOffHeap {
+ return StorageLevelDiskOnly3
+ } else if level.UseDisk && !level.UseMemory && level.Replication == 2 &&
+ !level.Deserialized && !level.UseOffHeap {
+ return StorageLevelDiskOnly2
+ } else if level.UseDisk && !level.UseMemory && level.Replication <= 1 &&
+ !level.Deserialized && !level.UseOffHeap {
+ return StorageLevelDiskOnly
+ } else if !level.UseDisk && level.UseMemory && level.Replication <= 1 &&
+ !level.Deserialized && !level.UseOffHeap {
+ return StorageLevelMemoryOnly
+ } else if !level.UseDisk && level.UseMemory && level.Replication == 2 &&
+ !level.Deserialized && !level.UseOffHeap {
+ return StorageLevelMemoryOnly2
+ } else if level.UseDisk && level.UseMemory && level.Replication <= 1 &&
level.Deserialized && !level.UseOffHeap {
+ return StorageLevelMemoyAndDiskDeser
+ } else if !level.UseDisk && !level.UseMemory && !level.Deserialized &&
!level.UseOffHeap {
+ return StorageLevelNone
+ } else if level.UseOffHeap && !level.Deserialized {
+ return StorageLevelOffHeap
+ }
+ return StorageLevelNone
+}
diff --git a/spark/sql/utils/consts_test.go b/spark/sql/utils/consts_test.go
new file mode 100644
index 0000000..fbb1fc7
--- /dev/null
+++ b/spark/sql/utils/consts_test.go
@@ -0,0 +1,41 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package utils
+
+import "testing"
+
+func TestStorageLevelConversion(t *testing.T) {
+ // Given a list of all storage levels, convert them to and from proto
and
+ // check with the original value:
+ for _, level := range []StorageLevel{
+ StorageLevelDiskOnly,
+ StorageLevelDiskOnly2,
+ StorageLevelDiskOnly3,
+ StorageLevelMemoryAndDisk,
+ StorageLevelMemoryAndDisk2,
+ StorageLevelMemoryOnly,
+ StorageLevelMemoryOnly2,
+ StorageLevelMemoyAndDiskDeser,
+ StorageLevelNone,
+ StorageLevelOffHeap,
+ } {
+ protoLevel := ToProtoStorageLevel(level)
+ convertedLevel := FromProtoStorageLevel(protoLevel)
+ if level != convertedLevel {
+ t.Errorf("Expected %v, got %v", level, convertedLevel)
+ }
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]