This is an automated email from the ASF dual-hosted git repository.

mgrund pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark-connect-go.git


The following commit(s) were added to refs/heads/master by this push:
     new 77eb4bb  Add config support
77eb4bb is described below

commit 77eb4bb6fff8fa2f79162a3a2e91ede709be4f40
Author: Magnus Pierre <[email protected]>
AuthorDate: Fri Nov 1 16:35:57 2024 +0100

    Add config support
    
    ### What changes were proposed in this pull request?
    Add sparksession.Config() support   (backlog #48)
    Implemented:
    Get
    Set
    GetAll
    GetWithDefault
    Unset
    IsModifiable
    
    ### Why are the changes needed?
    To be able to set and modify the configuration
    
    ### Does this PR introduce _any_ user-facing change?
    sparkSession has an additional function Config() aligned to functionality 
already existing in pyspark.connect
    
    ### How was this patch tested?
    Simple application, test cases will be developed
    
    Closes #82 from magpierre/Add-Config-support.
    
    Lead-authored-by: Magnus Pierre <[email protected]>
    Co-authored-by: Martin Grund <[email protected]>
    Co-authored-by: Martin Grund <[email protected]>
    Signed-off-by: Martin Grund <[email protected]>
---
 internal/tests/integration/dataframe_test.go |  59 ++++++++++++++
 spark/client/base/base.go                    |   1 +
 spark/client/client.go                       |  10 +++
 spark/client/conf.go                         | 112 +++++++++++++++++++++++++++
 spark/mocks/mock_executor.go                 |   4 +
 spark/sql/sparksession.go                    |   5 ++
 6 files changed, 191 insertions(+)

diff --git a/internal/tests/integration/dataframe_test.go 
b/internal/tests/integration/dataframe_test.go
index b04b7fe..5f67be8 100644
--- a/internal/tests/integration/dataframe_test.go
+++ b/internal/tests/integration/dataframe_test.go
@@ -697,6 +697,65 @@ func TestDataFrame_FreqItems(t *testing.T) {
        assert.Len(t, res, 1)
 }
 
+func TestDataFrame_Config_GetAll(t *testing.T) {
+       ctx, spark := connect()
+       result, err := spark.Config().GetAll(ctx)
+       assert.NoError(t, err)
+       assert.Equal(t, "driver", result["spark.executor.id"])
+}
+
+func TestDataFrame_Config_Get(t *testing.T) {
+       ctx, spark := connect()
+       result, err := spark.Config().Get(ctx, "spark.executor.id")
+       assert.NoError(t, err)
+       assert.Equal(t, "driver", result)
+}
+
+func TestDataFrame_Config_GetWithDefault(t *testing.T) {
+       ctx, spark := connect()
+
+       result, err := spark.Config().GetWithDefault(ctx, "spark.whatever", 
"whatever_not_set")
+       assert.NoError(t, err)
+       assert.Equal(t, "whatever_not_set", result)
+}
+
+func TestDataFrame_Config_Set(t *testing.T) {
+       ctx, spark := connect()
+       err := spark.Config().Set(ctx, "spark.whatever", "whatever_set")
+       assert.NoError(t, err)
+}
+
+func TestDataFrame_Config_IsModifiable(t *testing.T) {
+       ctx, spark := connect()
+       result, err := spark.Config().IsModifiable(ctx, "spark.executor.id")
+       assert.NoError(t, err)
+       assert.Equal(t, false, result)
+}
+
+func TestDataFrame_Config_Unset(t *testing.T) {
+       ctx, spark := connect()
+       err := spark.Config().Set(ctx, "spark.whatever", "whatever_set")
+       assert.NoError(t, err)
+       err = spark.Config().Unset(ctx, "spark.whatever")
+       assert.NoError(t, err)
+}
+
+func TestDataFrame_Config_e2e_test(t *testing.T) {
+       ctx, spark := connect()
+       //  add keys that we know is "modifiable"
+       key := "spark.sql.ansi.enabled"
+       result, err := spark.Config().IsModifiable(ctx, key)
+       assert.NoError(t, err)
+       assert.Equal(t, true, result)
+       _, err = spark.Config().Get(ctx, key)
+       assert.NoError(t, err)
+       err = spark.Config().Set(ctx, "spark.sql.ansi.enabled", "true")
+       assert.NoError(t, err)
+       m, err := spark.Config().Get(ctx, "spark.sql.ansi.enabled")
+       assert.NoError(t, err)
+       assert.Equal(t, "true", m)
+}
+
 func TestDataFrame_WithOption(t *testing.T) {
        ctx, spark := connect()
        file, err := os.CreateTemp("", "example")
diff --git a/spark/client/base/base.go b/spark/client/base/base.go
index 1e34346..703ab66 100644
--- a/spark/client/base/base.go
+++ b/spark/client/base/base.go
@@ -43,6 +43,7 @@ type SparkConnectClient interface {
        DDLParse(ctx context.Context, sql string) (*types.StructType, error)
        SameSemantics(ctx context.Context, plan1 *generated.Plan, plan2 
*generated.Plan) (bool, error)
        SemanticHash(ctx context.Context, plan *generated.Plan) (int32, error)
+       Config(ctx context.Context, configRequest 
*generated.ConfigRequest_Operation) (*generated.ConfigResponse, error)
 }
 
 type ExecuteResponseStream interface {
diff --git a/spark/client/client.go b/spark/client/client.go
index 6a60f67..0d24bb4 100644
--- a/spark/client/client.go
+++ b/spark/client/client.go
@@ -331,6 +331,16 @@ func (s *sparkConnectClientImpl) SemanticHash(ctx 
context.Context, plan *proto.P
        return response.GetSemanticHash().GetResult(), nil
 }
 
+func (s *sparkConnectClientImpl) Config(ctx context.Context, operation 
*proto.ConfigRequest_Operation) (*generated.ConfigResponse, error) {
+       request := &proto.ConfigRequest{Operation: operation}
+       request.SessionId = s.sessionId
+       resp, err := s.client.Config(ctx, request)
+       if err != nil {
+               return nil, err
+       }
+       return resp, nil
+}
+
 func NewSparkExecutor(conn *grpc.ClientConn, md metadata.MD, sessionId string, 
opts options.SparkClientOptions) base.SparkConnectClient {
        var client base.SparkConnectRPCClient
        if opts.ReattachExecution {
diff --git a/spark/client/conf.go b/spark/client/conf.go
new file mode 100644
index 0000000..ddfe8fc
--- /dev/null
+++ b/spark/client/conf.go
@@ -0,0 +1,112 @@
+package client
+
+import (
+       "context"
+
+       proto "github.com/apache/spark-connect-go/v35/internal/generated"
+       "github.com/apache/spark-connect-go/v35/spark/client/base"
+)
+
+// Public interface RuntimeConfig
+type RuntimeConfig interface {
+       GetAll(ctx context.Context) (map[string]string, error)
+       Set(ctx context.Context, key string, value string) error
+       Get(ctx context.Context, key string) (string, error)
+       Unset(ctx context.Context, key string) error
+       IsModifiable(ctx context.Context, key string) (bool, error)
+       GetWithDefault(ctx context.Context, key string, default_value string) 
(string, error)
+}
+
+// private type with private member client
+type runtimeConfig struct {
+       client *base.SparkConnectClient
+}
+
+// GetAll returns all configured keys in a map of strings
+func (r runtimeConfig) GetAll(ctx context.Context) (map[string]string, error) {
+       req := &proto.ConfigRequest_GetAll{}
+       operation := &proto.ConfigRequest_Operation_GetAll{GetAll: req}
+       op := &proto.ConfigRequest_Operation{OpType: operation}
+       resp, err := (*r.client).Config(ctx, op)
+       if err != nil {
+               return nil, err
+       }
+       m := make(map[string]string, 0)
+       for _, k := range resp.GetPairs() {
+               if k.Value != nil {
+                       m[k.Key] = *k.Value
+               }
+       }
+       return m, nil
+}
+
+// Set takes a key and a value and sets it in the config
+func (r runtimeConfig) Set(ctx context.Context, key string, value string) 
error {
+       reqArr := []*proto.KeyValue{{Key: key, Value: &value}}
+       req := &proto.ConfigRequest_Set{
+               Pairs: reqArr,
+       }
+       op := &proto.ConfigRequest_Operation{OpType: 
&proto.ConfigRequest_Operation_Set{Set: req}}
+       _, err := (*r.client).Config(ctx, op)
+       if err != nil {
+               return err
+       }
+       return nil
+}
+
+func (r runtimeConfig) Get(ctx context.Context, key string) (string, error) {
+       req := &proto.ConfigRequest_Get{Keys: []string{key}}
+       operation := &proto.ConfigRequest_Operation_Get{Get: req}
+       op := &proto.ConfigRequest_Operation{OpType: operation}
+       resp, err := (*r.client).Config(ctx, op)
+       if err != nil {
+               return "", err
+       }
+       return *resp.GetPairs()[0].Value, nil
+}
+
+func (r runtimeConfig) Unset(ctx context.Context, key string) error {
+       req := &proto.ConfigRequest_Unset{Keys: []string{key}}
+       operation := &proto.ConfigRequest_Operation_Unset{Unset: req}
+       op := &proto.ConfigRequest_Operation{OpType: operation}
+       _, err := (*r.client).Config(ctx, op)
+       if err != nil {
+               return err
+       }
+       return nil
+}
+
+func (r runtimeConfig) IsModifiable(ctx context.Context, key string) (bool, 
error) {
+       req := &proto.ConfigRequest_IsModifiable{Keys: []string{key}}
+       operation := &proto.ConfigRequest_Operation_IsModifiable{IsModifiable: 
req}
+       op := &proto.ConfigRequest_Operation{OpType: operation}
+       resp, err := (*r.client).Config(ctx, op)
+       if err != nil {
+               return false, err
+       }
+       re := *resp.GetPairs()[0].Value
+       if re == "true" {
+               return true, nil
+       } else {
+               return false, nil
+       }
+}
+
+func (r runtimeConfig) GetWithDefault(ctx context.Context, key string, 
default_value string) (string, error) {
+       p := make([]*proto.KeyValue, 0)
+       p = append(p, &proto.KeyValue{Key: key, Value: &default_value})
+       req := &proto.ConfigRequest_GetWithDefault{Pairs: p}
+       operation := 
&proto.ConfigRequest_Operation_GetWithDefault{GetWithDefault: req}
+       op := &proto.ConfigRequest_Operation{OpType: operation}
+       resp, err := (*r.client).Config(ctx, op)
+       if err != nil {
+               return "", err
+       }
+
+       return *resp.GetPairs()[0].Value, nil
+}
+
+// Constructor for runtimeConfig used by SparkSession
+func NewRuntimeConfig(client *base.SparkConnectClient) *runtimeConfig {
+       return &runtimeConfig{client: client}
+}
diff --git a/spark/mocks/mock_executor.go b/spark/mocks/mock_executor.go
index a788e2c..415bf7e 100644
--- a/spark/mocks/mock_executor.go
+++ b/spark/mocks/mock_executor.go
@@ -85,3 +85,7 @@ func (t *TestExecutor) SameSemantics(ctx context.Context, 
plan1 *generated.Plan,
 func (t *TestExecutor) SemanticHash(ctx context.Context, plan *generated.Plan) 
(int32, error) {
        return 0, errors.New("not implemented")
 }
+
+func (t *TestExecutor) Config(ctx context.Context, configRequest 
*generated.ConfigRequest_Operation) (*generated.ConfigResponse, error) {
+       return nil, errors.New("not implemented")
+}
diff --git a/spark/sql/sparksession.go b/spark/sql/sparksession.go
index 3ae9300..58436be 100644
--- a/spark/sql/sparksession.go
+++ b/spark/sql/sparksession.go
@@ -47,6 +47,7 @@ type SparkSession interface {
        Table(name string) (DataFrame, error)
        CreateDataFrameFromArrow(ctx context.Context, data arrow.Table) 
(DataFrame, error)
        CreateDataFrame(ctx context.Context, data [][]any, schema 
*types.StructType) (DataFrame, error)
+       Config() client.RuntimeConfig
 }
 
 // NewSessionBuilder creates a new session builder for starting a new spark 
session
@@ -103,6 +104,10 @@ type sparkSessionImpl struct {
        client    base.SparkConnectClient
 }
 
+func (s *sparkSessionImpl) Config() client.RuntimeConfig {
+       return client.NewRuntimeConfig(&s.client)
+}
+
 func (s *sparkSessionImpl) Read() DataFrameReader {
        return NewDataframeReader(s)
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to