This is an automated email from the ASF dual-hosted git repository.
mgrund pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark-connect-go.git
The following commit(s) were added to refs/heads/master by this push:
new 77eb4bb Add config support
77eb4bb is described below
commit 77eb4bb6fff8fa2f79162a3a2e91ede709be4f40
Author: Magnus Pierre <[email protected]>
AuthorDate: Fri Nov 1 16:35:57 2024 +0100
Add config support
### What changes were proposed in this pull request?
Add sparksession.Config() support (backlog #48)
Implemented:
Get
Set
GetAll
GetWithDefault
Unset
IsModifiable
### Why are the changes needed?
To be able to set and modify the configuration
### Does this PR introduce _any_ user-facing change?
sparkSession has an additional function Config() aligned to functionality
already existing in pyspark.connect
### How was this patch tested?
Simple application, test cases will be developed
Closes #82 from magpierre/Add-Config-support.
Lead-authored-by: Magnus Pierre <[email protected]>
Co-authored-by: Martin Grund <[email protected]>
Co-authored-by: Martin Grund <[email protected]>
Signed-off-by: Martin Grund <[email protected]>
---
internal/tests/integration/dataframe_test.go | 59 ++++++++++++++
spark/client/base/base.go | 1 +
spark/client/client.go | 10 +++
spark/client/conf.go | 112 +++++++++++++++++++++++++++
spark/mocks/mock_executor.go | 4 +
spark/sql/sparksession.go | 5 ++
6 files changed, 191 insertions(+)
diff --git a/internal/tests/integration/dataframe_test.go
b/internal/tests/integration/dataframe_test.go
index b04b7fe..5f67be8 100644
--- a/internal/tests/integration/dataframe_test.go
+++ b/internal/tests/integration/dataframe_test.go
@@ -697,6 +697,65 @@ func TestDataFrame_FreqItems(t *testing.T) {
assert.Len(t, res, 1)
}
+func TestDataFrame_Config_GetAll(t *testing.T) {
+ ctx, spark := connect()
+ result, err := spark.Config().GetAll(ctx)
+ assert.NoError(t, err)
+ assert.Equal(t, "driver", result["spark.executor.id"])
+}
+
+func TestDataFrame_Config_Get(t *testing.T) {
+ ctx, spark := connect()
+ result, err := spark.Config().Get(ctx, "spark.executor.id")
+ assert.NoError(t, err)
+ assert.Equal(t, "driver", result)
+}
+
+func TestDataFrame_Config_GetWithDefault(t *testing.T) {
+ ctx, spark := connect()
+
+ result, err := spark.Config().GetWithDefault(ctx, "spark.whatever",
"whatever_not_set")
+ assert.NoError(t, err)
+ assert.Equal(t, "whatever_not_set", result)
+}
+
+func TestDataFrame_Config_Set(t *testing.T) {
+ ctx, spark := connect()
+ err := spark.Config().Set(ctx, "spark.whatever", "whatever_set")
+ assert.NoError(t, err)
+}
+
+func TestDataFrame_Config_IsModifiable(t *testing.T) {
+ ctx, spark := connect()
+ result, err := spark.Config().IsModifiable(ctx, "spark.executor.id")
+ assert.NoError(t, err)
+ assert.Equal(t, false, result)
+}
+
+func TestDataFrame_Config_Unset(t *testing.T) {
+ ctx, spark := connect()
+ err := spark.Config().Set(ctx, "spark.whatever", "whatever_set")
+ assert.NoError(t, err)
+ err = spark.Config().Unset(ctx, "spark.whatever")
+ assert.NoError(t, err)
+}
+
+func TestDataFrame_Config_e2e_test(t *testing.T) {
+ ctx, spark := connect()
+ // add keys that we know is "modifiable"
+ key := "spark.sql.ansi.enabled"
+ result, err := spark.Config().IsModifiable(ctx, key)
+ assert.NoError(t, err)
+ assert.Equal(t, true, result)
+ _, err = spark.Config().Get(ctx, key)
+ assert.NoError(t, err)
+ err = spark.Config().Set(ctx, "spark.sql.ansi.enabled", "true")
+ assert.NoError(t, err)
+ m, err := spark.Config().Get(ctx, "spark.sql.ansi.enabled")
+ assert.NoError(t, err)
+ assert.Equal(t, "true", m)
+}
+
func TestDataFrame_WithOption(t *testing.T) {
ctx, spark := connect()
file, err := os.CreateTemp("", "example")
diff --git a/spark/client/base/base.go b/spark/client/base/base.go
index 1e34346..703ab66 100644
--- a/spark/client/base/base.go
+++ b/spark/client/base/base.go
@@ -43,6 +43,7 @@ type SparkConnectClient interface {
DDLParse(ctx context.Context, sql string) (*types.StructType, error)
SameSemantics(ctx context.Context, plan1 *generated.Plan, plan2
*generated.Plan) (bool, error)
SemanticHash(ctx context.Context, plan *generated.Plan) (int32, error)
+ Config(ctx context.Context, configRequest
*generated.ConfigRequest_Operation) (*generated.ConfigResponse, error)
}
type ExecuteResponseStream interface {
diff --git a/spark/client/client.go b/spark/client/client.go
index 6a60f67..0d24bb4 100644
--- a/spark/client/client.go
+++ b/spark/client/client.go
@@ -331,6 +331,16 @@ func (s *sparkConnectClientImpl) SemanticHash(ctx
context.Context, plan *proto.P
return response.GetSemanticHash().GetResult(), nil
}
+func (s *sparkConnectClientImpl) Config(ctx context.Context, operation
*proto.ConfigRequest_Operation) (*generated.ConfigResponse, error) {
+ request := &proto.ConfigRequest{Operation: operation}
+ request.SessionId = s.sessionId
+ resp, err := s.client.Config(ctx, request)
+ if err != nil {
+ return nil, err
+ }
+ return resp, nil
+}
+
func NewSparkExecutor(conn *grpc.ClientConn, md metadata.MD, sessionId string,
opts options.SparkClientOptions) base.SparkConnectClient {
var client base.SparkConnectRPCClient
if opts.ReattachExecution {
diff --git a/spark/client/conf.go b/spark/client/conf.go
new file mode 100644
index 0000000..ddfe8fc
--- /dev/null
+++ b/spark/client/conf.go
@@ -0,0 +1,112 @@
+package client
+
+import (
+ "context"
+
+ proto "github.com/apache/spark-connect-go/v35/internal/generated"
+ "github.com/apache/spark-connect-go/v35/spark/client/base"
+)
+
+// Public interface RuntimeConfig
+type RuntimeConfig interface {
+ GetAll(ctx context.Context) (map[string]string, error)
+ Set(ctx context.Context, key string, value string) error
+ Get(ctx context.Context, key string) (string, error)
+ Unset(ctx context.Context, key string) error
+ IsModifiable(ctx context.Context, key string) (bool, error)
+ GetWithDefault(ctx context.Context, key string, default_value string)
(string, error)
+}
+
+// private type with private member client
+type runtimeConfig struct {
+ client *base.SparkConnectClient
+}
+
+// GetAll returns all configured keys in a map of strings
+func (r runtimeConfig) GetAll(ctx context.Context) (map[string]string, error) {
+ req := &proto.ConfigRequest_GetAll{}
+ operation := &proto.ConfigRequest_Operation_GetAll{GetAll: req}
+ op := &proto.ConfigRequest_Operation{OpType: operation}
+ resp, err := (*r.client).Config(ctx, op)
+ if err != nil {
+ return nil, err
+ }
+ m := make(map[string]string, 0)
+ for _, k := range resp.GetPairs() {
+ if k.Value != nil {
+ m[k.Key] = *k.Value
+ }
+ }
+ return m, nil
+}
+
+// Set takes a key and a value and sets it in the config
+func (r runtimeConfig) Set(ctx context.Context, key string, value string)
error {
+ reqArr := []*proto.KeyValue{{Key: key, Value: &value}}
+ req := &proto.ConfigRequest_Set{
+ Pairs: reqArr,
+ }
+ op := &proto.ConfigRequest_Operation{OpType:
&proto.ConfigRequest_Operation_Set{Set: req}}
+ _, err := (*r.client).Config(ctx, op)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func (r runtimeConfig) Get(ctx context.Context, key string) (string, error) {
+ req := &proto.ConfigRequest_Get{Keys: []string{key}}
+ operation := &proto.ConfigRequest_Operation_Get{Get: req}
+ op := &proto.ConfigRequest_Operation{OpType: operation}
+ resp, err := (*r.client).Config(ctx, op)
+ if err != nil {
+ return "", err
+ }
+ return *resp.GetPairs()[0].Value, nil
+}
+
+func (r runtimeConfig) Unset(ctx context.Context, key string) error {
+ req := &proto.ConfigRequest_Unset{Keys: []string{key}}
+ operation := &proto.ConfigRequest_Operation_Unset{Unset: req}
+ op := &proto.ConfigRequest_Operation{OpType: operation}
+ _, err := (*r.client).Config(ctx, op)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func (r runtimeConfig) IsModifiable(ctx context.Context, key string) (bool,
error) {
+ req := &proto.ConfigRequest_IsModifiable{Keys: []string{key}}
+ operation := &proto.ConfigRequest_Operation_IsModifiable{IsModifiable:
req}
+ op := &proto.ConfigRequest_Operation{OpType: operation}
+ resp, err := (*r.client).Config(ctx, op)
+ if err != nil {
+ return false, err
+ }
+ re := *resp.GetPairs()[0].Value
+ if re == "true" {
+ return true, nil
+ } else {
+ return false, nil
+ }
+}
+
+func (r runtimeConfig) GetWithDefault(ctx context.Context, key string,
default_value string) (string, error) {
+ p := make([]*proto.KeyValue, 0)
+ p = append(p, &proto.KeyValue{Key: key, Value: &default_value})
+ req := &proto.ConfigRequest_GetWithDefault{Pairs: p}
+ operation :=
&proto.ConfigRequest_Operation_GetWithDefault{GetWithDefault: req}
+ op := &proto.ConfigRequest_Operation{OpType: operation}
+ resp, err := (*r.client).Config(ctx, op)
+ if err != nil {
+ return "", err
+ }
+
+ return *resp.GetPairs()[0].Value, nil
+}
+
+// Constructor for runtimeConfig used by SparkSession
+func NewRuntimeConfig(client *base.SparkConnectClient) *runtimeConfig {
+ return &runtimeConfig{client: client}
+}
diff --git a/spark/mocks/mock_executor.go b/spark/mocks/mock_executor.go
index a788e2c..415bf7e 100644
--- a/spark/mocks/mock_executor.go
+++ b/spark/mocks/mock_executor.go
@@ -85,3 +85,7 @@ func (t *TestExecutor) SameSemantics(ctx context.Context,
plan1 *generated.Plan,
func (t *TestExecutor) SemanticHash(ctx context.Context, plan *generated.Plan)
(int32, error) {
return 0, errors.New("not implemented")
}
+
+func (t *TestExecutor) Config(ctx context.Context, configRequest
*generated.ConfigRequest_Operation) (*generated.ConfigResponse, error) {
+ return nil, errors.New("not implemented")
+}
diff --git a/spark/sql/sparksession.go b/spark/sql/sparksession.go
index 3ae9300..58436be 100644
--- a/spark/sql/sparksession.go
+++ b/spark/sql/sparksession.go
@@ -47,6 +47,7 @@ type SparkSession interface {
Table(name string) (DataFrame, error)
CreateDataFrameFromArrow(ctx context.Context, data arrow.Table)
(DataFrame, error)
CreateDataFrame(ctx context.Context, data [][]any, schema
*types.StructType) (DataFrame, error)
+ Config() client.RuntimeConfig
}
// NewSessionBuilder creates a new session builder for starting a new spark
session
@@ -103,6 +104,10 @@ type sparkSessionImpl struct {
client base.SparkConnectClient
}
+func (s *sparkSessionImpl) Config() client.RuntimeConfig {
+ return client.NewRuntimeConfig(&s.client)
+}
+
func (s *sparkSessionImpl) Read() DataFrameReader {
return NewDataframeReader(s)
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]