This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark-connect-go.git


The following commit(s) were added to refs/heads/master by this push:
     new 9e254ba  [SPARK-48754] Adding tests, structure, best practices to 
productionize code base
9e254ba is described below

commit 9e254ba7583a06bfee2acc4766fd4dde00198f3c
Author: Mathias Schwarz <165780420+mathiasschw...@users.noreply.github.com>
AuthorDate: Tue Jul 2 17:28:39 2024 +0900

    [SPARK-48754] Adding tests, structure, best practices to productionize code 
base
    
    ### What changes were proposed in this pull request?
    This change contains several improvements that all aim to increase the code 
quality of the spark-connect-go repo. All in all, these changes push the repo 
much closer to best practice Go without major semantic changes.
    
    The changes fall in these categories:
    - Improve unit test coverage by about 30 percentage points
    - Decoupled the components in the sql package to make them individually 
testable and only depend on each others interfaces rather than implementation
    - Added context propagation to the code base. This allows users of the 
library to set connection timeouts, auth headers etc.
    - Added method/function level comments where they were missing for public 
functions
    - Removed the global var builder 'entry point' and replaced it by a normal 
constructor so that each builder is simply new instead of the previous copy 
semantics
    - Added a simple error hierarchy so that errors can be handled by looking 
at error types instead of just string values
    - Created constructors with required params for all structs instead of 
having the users create structs internally
    - Removed a strange case of panic'ing the the whole process if some input 
was invalid
    - Updated documentation and examples to reflect these changes
    
    ### Why are the changes needed?
    These changes aim (along with subsequent changes) to get this code base to 
a point where it will eventually be fit for production use, something that is 
strictly forbidden right now
    
    ### Does this PR introduce _any_ user-facing change?
    The PR as much as possible aims to not change the API but in a few cases 
this has not been possible. In particular, functions that eventually result in 
an outbound call to GRPC now take a context parameter. This is necessary and 
required for real production grade code. In addition, the builder is 
instantiated slightly differently (actually instantiated instead of being a 
global var) but the API for it otherwise remains.
    
    ### How was this patch tested?
    All the code that was touch, has gotten some degree of unit testing that at 
least ensures coverage as well as checking of output
    
    Closes #20 from mathiasschw-db/mathiasschw-db/productionize.
    
    Authored-by: Mathias Schwarz 
<165780420+mathiasschw...@users.noreply.github.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .github/workflows/build.yml                       |   4 +-
 client/channel/channel.go                         |  24 ++--
 client/channel/channel_test.go                    |  12 +-
 client/channel/compat.go                          |   6 +
 client/{sql/row.go => sparkerrors/errors.go}      |  36 ++++--
 client/sparkerrors/errors_test.go                 |  17 +++
 client/sql/dataframe.go                           | 134 +++++++++++++---------
 client/sql/dataframe_test.go                      |  78 ++++++++++++-
 client/sql/dataframereader.go                     |  34 +++---
 client/sql/dataframereader_test.go                |  26 +++++
 client/sql/dataframewriter.go                     |  35 ++++--
 client/sql/dataframewriter_test.go                |  55 ++++++++-
 client/sql/datatype.go                            |  17 +--
 client/sql/executeplanclient.go                   |  37 ++++++
 client/sql/mocks_test.go                          | 112 ++++++++++++++++++
 client/sql/plan_test.go                           |  13 +++
 client/sql/row.go                                 |  17 ++-
 client/sql/row_test.go                            |  25 ++++
 client/sql/sparksession.go                        |  82 +++++--------
 client/sql/sparksession_test.go                   | 111 ++++++++++++++++++
 client/sql/structtype.go                          |   2 +
 cmd/spark-connect-example-raw-grpc-client/main.go |  12 +-
 cmd/spark-connect-example-spark-session/main.go   |  82 +++++++------
 go.mod                                            |   2 +-
 quick-start.md                                    |  57 ++++++---
 25 files changed, 790 insertions(+), 240 deletions(-)

diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index 9d0c13b..b6b1c4c 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -45,8 +45,8 @@ jobs:
           ref: master
       - name: Install Golang
         run: |
-          curl -LO https://go.dev/dl/go1.19.9.linux-amd64.tar.gz
-          sudo tar -C /usr/local -xzf go1.19.9.linux-amd64.tar.gz
+          curl -LO https://go.dev/dl/go1.21.11.linux-amd64.tar.gz
+          sudo tar -C /usr/local -xzf go1.21.11.linux-amd64.tar.gz
       - name: Install Buf
         run: |
           # See more in "Installation" 
https://docs.buf.build/installation#tarball
diff --git a/client/channel/channel.go b/client/channel/channel.go
index 6cf7696..b5d459d 100644
--- a/client/channel/channel.go
+++ b/client/channel/channel.go
@@ -17,6 +17,7 @@
 package channel
 
 import (
+       "context"
        "crypto/tls"
        "crypto/x509"
        "errors"
@@ -26,6 +27,7 @@ import (
        "strconv"
        "strings"
 
+       "github.com/apache/spark-connect-go/v1/client/sparkerrors"
        "golang.org/x/oauth2"
        "google.golang.org/grpc"
        "google.golang.org/grpc/credentials"
@@ -35,11 +37,11 @@ import (
 // Reserved header parameters that must not be injected as variables.
 var reservedParams = []string{"user_id", "token", "use_ssl"}
 
-// The ChannelBuilder is used to parse the different parameters of the 
connection
+// Builder is used to parse the different parameters of the connection
 // string according to the specification documented here:
 //
 //     
https://github.com/apache/spark/blob/master/connector/connect/docs/client-connection-string.md
-type ChannelBuilder struct {
+type Builder struct {
        Host    string
        Port    int
        Token   string
@@ -47,10 +49,10 @@ type ChannelBuilder struct {
        Headers map[string]string
 }
 
-// Finalizes the creation of the gprc.ClientConn by creating a GRPC channel
+// Build finalizes the creation of the gprc.ClientConn by creating a GRPC 
channel
 // with the necessary options extracted from the connection string. For
 // TLS connections, this function will load the system certificates.
-func (cb *ChannelBuilder) Build() (*grpc.ClientConn, error) {
+func (cb *Builder) Build(ctx context.Context) (*grpc.ClientConn, error) {
        var opts []grpc.DialOption
 
        opts = append(opts, grpc.WithAuthority(cb.Host))
@@ -76,16 +78,16 @@ func (cb *ChannelBuilder) Build() (*grpc.ClientConn, error) 
{
        }
 
        remote := fmt.Sprintf("%v:%v", cb.Host, cb.Port)
-       conn, err := grpc.Dial(remote, opts...)
+       conn, err := grpc.DialContext(ctx, remote, opts...)
        if err != nil {
-               return nil, fmt.Errorf("failed to connect to remote %s: %w", 
remote, err)
+               return nil, sparkerrors.WithType(fmt.Errorf("failed to connect 
to remote %s: %w", remote, err), sparkerrors.ConnectionError)
        }
        return conn, nil
 }
 
-// Creates a new instance of the ChannelBuilder. This constructor effectively
+// NewBuilder creates a new instance of the Builder. This constructor 
effectively
 // parses the connection string and extracts the relevant parameters directly.
-func NewBuilder(connection string) (*ChannelBuilder, error) {
+func NewBuilder(connection string) (*Builder, error) {
 
        u, err := url.Parse(connection)
        if err != nil {
@@ -93,7 +95,7 @@ func NewBuilder(connection string) (*ChannelBuilder, error) {
        }
 
        if u.Scheme != "sc" {
-               return nil, errors.New("URL schema must be set to `sc`.")
+               return nil, sparkerrors.WithType(errors.New("URL schema must be 
set to `sc`"), sparkerrors.InvalidInputError)
        }
 
        var port = 15002
@@ -115,10 +117,10 @@ func NewBuilder(connection string) (*ChannelBuilder, 
error) {
 
        // Validate that the URL path is empty or follows the right format.
        if u.Path != "" && !strings.HasPrefix(u.Path, "/;") {
-               return nil, fmt.Errorf("The URL path (%v) must be empty or have 
a proper parameter syntax.", u.Path)
+               return nil, sparkerrors.WithType(fmt.Errorf("the URL path (%v) 
must be empty or have a proper parameter syntax", u.Path), 
sparkerrors.InvalidInputError)
        }
 
-       cb := &ChannelBuilder{
+       cb := &Builder{
                Host:    host,
                Port:    port,
                Headers: map[string]string{},
diff --git a/client/channel/channel_test.go b/client/channel/channel_test.go
index f4c4bc3..1395d99 100644
--- a/client/channel/channel_test.go
+++ b/client/channel/channel_test.go
@@ -17,10 +17,12 @@
 package channel_test
 
 import (
+       "context"
        "strings"
        "testing"
 
        "github.com/apache/spark-connect-go/v1/client/channel"
+       "github.com/apache/spark-connect-go/v1/client/sparkerrors"
        "github.com/stretchr/testify/assert"
 )
 
@@ -49,7 +51,8 @@ func TestBasicChannelParsing(t *testing.T) {
        assert.Nilf(t, err, "Port must be a valid number %v", err)
 
        _, err = channel.NewBuilder("sc://abcd/this")
-       assert.True(t, strings.Contains(err.Error(), "The URL path"), "URL path 
elements are not allowed")
+       assert.True(t, strings.Contains(err.Error(), "URL path"), "URL path 
elements are not allowed")
+       assert.ErrorIs(t, err, sparkerrors.InvalidInputError)
 
        cb, err = channel.NewBuilder(goodChannelURL)
        assert.Equal(t, "host", cb.Host)
@@ -60,7 +63,7 @@ func TestBasicChannelParsing(t *testing.T) {
        assert.Equal(t, "b", cb.Token)
 
        cb, err = 
channel.NewBuilder("sc://localhost:443/;token=token;user_id=user_id;cluster_id=a")
-       assert.Nilf(t, err, "Unexpected error: %v", err)
+       assert.NoError(t, err)
        assert.Equal(t, 443, cb.Port)
        assert.Equal(t, "localhost", cb.Host)
        assert.Equal(t, "token", cb.Token)
@@ -68,15 +71,16 @@ func TestBasicChannelParsing(t *testing.T) {
 }
 
 func TestChannelBuildConnect(t *testing.T) {
+       ctx := context.Background()
        cb, err := channel.NewBuilder("sc://localhost")
        assert.Nil(t, err, "Should not have an error for a proper URL.")
-       conn, err := cb.Build()
+       conn, err := cb.Build(ctx)
        assert.Nil(t, err, "no error for proper connection")
        assert.NotNil(t, conn)
 
        cb, err = channel.NewBuilder("sc://localhost:443/;token=abcd;user_id=a")
        assert.Nil(t, err, "Should not have an error for a proper URL.")
-       conn, err = cb.Build()
+       conn, err = cb.Build(ctx)
        assert.Nil(t, err, "no error for proper connection")
        assert.NotNil(t, conn)
 }
diff --git a/client/channel/compat.go b/client/channel/compat.go
new file mode 100644
index 0000000..33f8a3e
--- /dev/null
+++ b/client/channel/compat.go
@@ -0,0 +1,6 @@
+package channel
+
+// ChannelBuilder re-exports Builder as its previous name for compatibility.
+//
+// Deprecated: use Builder instead.
+type ChannelBuilder = Builder
diff --git a/client/sql/row.go b/client/sparkerrors/errors.go
similarity index 52%
copy from client/sql/row.go
copy to client/sparkerrors/errors.go
index 3bee2ac..770cdef 100644
--- a/client/sql/row.go
+++ b/client/sparkerrors/errors.go
@@ -14,22 +14,36 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-package sql
+package sparkerrors
 
-type Row interface {
-       Schema() (*StructType, error)
-       Values() ([]any, error)
+import (
+       "errors"
+       "fmt"
+)
+
+type wrappedError struct {
+       errorType error
+       cause     error
 }
 
-type GenericRowWithSchema struct {
-       values []any
-       schema *StructType
+func (w *wrappedError) Unwrap() []error {
+       return []error{w.errorType, w.cause}
 }
 
-func (r *GenericRowWithSchema) Schema() (*StructType, error) {
-       return r.schema, nil
+func (w *wrappedError) Error() string {
+       return fmt.Sprintf("%s: %s", w.errorType, w.cause)
 }
 
-func (r *GenericRowWithSchema) Values() ([]any, error) {
-       return r.values, nil
+// WithType wraps an error with a type that can later be checked using 
`errors.Is`
+func WithType(err error, errType errorType) error {
+       return &wrappedError{cause: err, errorType: errType}
 }
+
+type errorType error
+
+var (
+       ConnectionError   = errorType(errors.New("connection error"))
+       ReadError         = errorType(errors.New("read error"))
+       ExecutionError    = errorType(errors.New("execution error"))
+       InvalidInputError = errorType(errors.New("invalid input"))
+)
diff --git a/client/sparkerrors/errors_test.go 
b/client/sparkerrors/errors_test.go
new file mode 100644
index 0000000..f5857ec
--- /dev/null
+++ b/client/sparkerrors/errors_test.go
@@ -0,0 +1,17 @@
+package sparkerrors
+
+import (
+       "testing"
+
+       "github.com/stretchr/testify/assert"
+)
+
+func TestWithTypeGivesAndErrorThatIsOfThatType(t *testing.T) {
+       err := WithType(assert.AnError, ConnectionError)
+       assert.ErrorIs(t, err, ConnectionError)
+}
+
+func TestErrorStringContainsErrorType(t *testing.T) {
+       err := WithType(assert.AnError, ConnectionError)
+       assert.Contains(t, err.Error(), ConnectionError.Error())
+}
diff --git a/client/sql/dataframe.go b/client/sql/dataframe.go
index 6fa7515..a7f5f51 100644
--- a/client/sql/dataframe.go
+++ b/client/sql/dataframe.go
@@ -18,27 +18,41 @@ package sql
 
 import (
        "bytes"
+       "context"
        "errors"
        "fmt"
+       "io"
+
        "github.com/apache/arrow/go/v12/arrow"
        "github.com/apache/arrow/go/v12/arrow/array"
        "github.com/apache/arrow/go/v12/arrow/ipc"
+       "github.com/apache/spark-connect-go/v1/client/sparkerrors"
        proto "github.com/apache/spark-connect-go/v1/internal/generated"
-       "io"
 )
 
+// ResultCollector receives a stream of result rows
+type ResultCollector interface {
+       // WriteRow receives a single row from the data frame
+       WriteRow(values []any)
+}
+
 // DataFrame is a wrapper for data frame, representing a distributed 
collection of data row.
 type DataFrame interface {
-       // Show prints out data frame data.
-       Show(numRows int, truncate bool) error
+       // WriteResult streams the data frames to a result collector
+       WriteResult(ctx context.Context, collector ResultCollector, numRows 
int, truncate bool) error
+       // Show uses WriteResult to write the data frames to the console output.
+       Show(ctx context.Context, numRows int, truncate bool) error
        // Schema returns the schema for the current data frame.
-       Schema() (*StructType, error)
+       Schema(ctx context.Context) (*StructType, error)
        // Collect returns the data rows of the current data frame.
-       Collect() ([]Row, error)
-       // Write returns a data frame writer, which could be used to save data 
frame to supported storage.
+       Collect(ctx context.Context) ([]Row, error)
+       // Writer returns a data frame writer, which could be used to save data 
frame to supported storage.
+       Writer() DataFrameWriter
+       // Write is an alias for Writer
+       // Deprecated: Use Writer
        Write() DataFrameWriter
        // CreateTempView creates or replaces a temporary view.
-       CreateTempView(viewName string, replace bool, global bool) error
+       CreateTempView(ctx context.Context, viewName string, replace bool, 
global bool) error
        // Repartition re-partitions a data frame.
        Repartition(numPartitions int, columns []string) (DataFrame, error)
        // RepartitionByRange re-partitions a data frame by range partition.
@@ -52,11 +66,29 @@ type RangePartitionColumn struct {
 
 // dataFrameImpl is an implementation of DataFrame interface.
 type dataFrameImpl struct {
-       sparkSession *sparkSessionImpl
-       relation     *proto.Relation // TODO change to proto.Plan?
+       sparkExecutor sparkExecutor
+       relation      *proto.Relation // TODO change to proto.Plan?
+}
+
+func newDataFrame(sparkExecutor sparkExecutor, relation *proto.Relation) 
DataFrame {
+       return &dataFrameImpl{
+               sparkExecutor: sparkExecutor,
+               relation:      relation,
+       }
+}
+
+type consoleCollector struct {
+}
+
+func (c consoleCollector) WriteRow(values []any) {
+       fmt.Println(values...)
 }
 
-func (df *dataFrameImpl) Show(numRows int, truncate bool) error {
+func (df *dataFrameImpl) Show(ctx context.Context, numRows int, truncate bool) 
error {
+       return df.WriteResult(ctx, &consoleCollector{}, numRows, truncate)
+}
+
+func (df *dataFrameImpl) WriteResult(ctx context.Context, collector 
ResultCollector, numRows int, truncate bool) error {
        truncateValue := 0
        if truncate {
                truncateValue = 20
@@ -81,45 +113,42 @@ func (df *dataFrameImpl) Show(numRows int, truncate bool) 
error {
                },
        }
 
-       responseClient, err := df.sparkSession.executePlan(plan)
+       responseClient, err := df.sparkExecutor.executePlan(ctx, plan)
        if err != nil {
-               return fmt.Errorf("failed to show dataframe: %w", err)
+               return sparkerrors.WithType(fmt.Errorf("failed to show 
dataframe: %w", err), sparkerrors.ExecutionError)
        }
 
        for {
                response, err := responseClient.Recv()
                if err != nil {
-                       return fmt.Errorf("failed to receive show response: 
%w", err)
+                       return sparkerrors.WithType(fmt.Errorf("failed to 
receive show response: %w", err), sparkerrors.ReadError)
                }
                arrowBatch := response.GetArrowBatch()
                if arrowBatch == nil {
                        continue
                }
-               err = showArrowBatch(arrowBatch)
+               err = showArrowBatch(arrowBatch, collector)
                if err != nil {
                        return err
                }
                return nil
        }
-
-       return fmt.Errorf("did not get arrow batch in response")
 }
 
-func (df *dataFrameImpl) Schema() (*StructType, error) {
-       response, err := df.sparkSession.analyzePlan(df.createPlan())
+func (df *dataFrameImpl) Schema(ctx context.Context) (*StructType, error) {
+       response, err := df.sparkExecutor.analyzePlan(ctx, df.createPlan())
        if err != nil {
-               return nil, fmt.Errorf("failed to analyze plan: %w", err)
+               return nil, sparkerrors.WithType(fmt.Errorf("failed to analyze 
plan: %w", err), sparkerrors.ExecutionError)
        }
 
        responseSchema := response.GetSchema().Schema
-       result := convertProtoDataTypeToStructType(responseSchema)
-       return result, nil
+       return convertProtoDataTypeToStructType(responseSchema)
 }
 
-func (df *dataFrameImpl) Collect() ([]Row, error) {
-       responseClient, err := df.sparkSession.executePlan(df.createPlan())
+func (df *dataFrameImpl) Collect(ctx context.Context) ([]Row, error) {
+       responseClient, err := df.sparkExecutor.executePlan(ctx, 
df.createPlan())
        if err != nil {
-               return nil, fmt.Errorf("failed to execute plan: %w", err)
+               return nil, sparkerrors.WithType(fmt.Errorf("failed to execute 
plan: %w", err), sparkerrors.ExecutionError)
        }
 
        var schema *StructType
@@ -131,13 +160,16 @@ func (df *dataFrameImpl) Collect() ([]Row, error) {
                        if errors.Is(err, io.EOF) {
                                return allRows, nil
                        } else {
-                               return nil, fmt.Errorf("failed to receive plan 
execution response: %w", err)
+                               return nil, 
sparkerrors.WithType(fmt.Errorf("failed to receive plan execution response: 
%w", err), sparkerrors.ReadError)
                        }
                }
 
                dataType := response.GetSchema()
                if dataType != nil {
-                       schema = convertProtoDataTypeToStructType(dataType)
+                       schema, err = convertProtoDataTypeToStructType(dataType)
+                       if err != nil {
+                               return nil, err
+                       }
                        continue
                }
 
@@ -156,19 +188,17 @@ func (df *dataFrameImpl) Collect() ([]Row, error) {
                }
                allRows = append(allRows, rowBatch...)
        }
-
-       return allRows, nil
 }
 
 func (df *dataFrameImpl) Write() DataFrameWriter {
-       writer := dataFrameWriterImpl{
-               sparkSession: df.sparkSession,
-               relation:     df.relation,
-       }
-       return &writer
+       return df.Writer()
 }
 
-func (df *dataFrameImpl) CreateTempView(viewName string, replace bool, global 
bool) error {
+func (df *dataFrameImpl) Writer() DataFrameWriter {
+       return newDataFrameWriter(df.sparkExecutor, df.relation)
+}
+
+func (df *dataFrameImpl) CreateTempView(ctx context.Context, viewName string, 
replace bool, global bool) error {
        plan := &proto.Plan{
                OpType: &proto.Plan_Command{
                        Command: &proto.Command{
@@ -184,12 +214,12 @@ func (df *dataFrameImpl) CreateTempView(viewName string, 
replace bool, global bo
                },
        }
 
-       responseClient, err := df.sparkSession.executePlan(plan)
+       responseClient, err := df.sparkExecutor.executePlan(ctx, plan)
        if err != nil {
-               return fmt.Errorf("failed to create temp view %s: %w", 
viewName, err)
+               return sparkerrors.WithType(fmt.Errorf("failed to create temp 
view %s: %w", viewName, err), sparkerrors.ExecutionError)
        }
 
-       return consumeExecutePlanClient(responseClient)
+       return responseClient.consumeAll()
 }
 
 func (df *dataFrameImpl) Repartition(numPartitions int, columns []string) 
(DataFrame, error) {
@@ -272,17 +302,14 @@ func (df *dataFrameImpl) 
repartitionByExpressions(numPartitions int, partitionEx
                        },
                },
        }
-       return &dataFrameImpl{
-               sparkSession: df.sparkSession,
-               relation:     newRelation,
-       }, nil
+       return newDataFrame(df.sparkExecutor, newRelation), nil
 }
 
-func showArrowBatch(arrowBatch *proto.ExecutePlanResponse_ArrowBatch) error {
-       return showArrowBatchData(arrowBatch.Data)
+func showArrowBatch(arrowBatch *proto.ExecutePlanResponse_ArrowBatch, 
collector ResultCollector) error {
+       return showArrowBatchData(arrowBatch.Data, collector)
 }
 
-func showArrowBatchData(data []byte) error {
+func showArrowBatchData(data []byte, collector ResultCollector) error {
        rows, err := readArrowBatchData(data, nil)
        if err != nil {
                return err
@@ -290,9 +317,9 @@ func showArrowBatchData(data []byte) error {
        for _, row := range rows {
                values, err := row.Values()
                if err != nil {
-                       return fmt.Errorf("failed to get values in the row: 
%w", err)
+                       return sparkerrors.WithType(fmt.Errorf("failed to get 
values in the row: %w", err), sparkerrors.ReadError)
                }
-               fmt.Println(values...)
+               collector.WriteRow(values)
        }
        return nil
 }
@@ -301,7 +328,7 @@ func readArrowBatchData(data []byte, schema *StructType) 
([]Row, error) {
        reader := bytes.NewReader(data)
        arrowReader, err := ipc.NewReader(reader)
        if err != nil {
-               return nil, fmt.Errorf("failed to create arrow reader: %w", err)
+               return nil, sparkerrors.WithType(fmt.Errorf("failed to create 
arrow reader: %w", err), sparkerrors.ReadError)
        }
        defer arrowReader.Release()
 
@@ -313,7 +340,7 @@ func readArrowBatchData(data []byte, schema *StructType) 
([]Row, error) {
                        if errors.Is(err, io.EOF) {
                                return rows, nil
                        } else {
-                               return nil, fmt.Errorf("failed to read arrow: 
%w", err)
+                               return nil, 
sparkerrors.WithType(fmt.Errorf("failed to read arrow: %w", err), 
sparkerrors.ReadError)
                        }
                }
 
@@ -328,10 +355,7 @@ func readArrowBatchData(data []byte, schema *StructType) 
([]Row, error) {
                }
 
                for _, v := range values {
-                       row := &GenericRowWithSchema{
-                               schema: schema,
-                               values: v,
-                       }
+                       row := NewRowWithSchema(v, schema)
                        rows = append(rows, row)
                }
 
@@ -445,14 +469,14 @@ func readArrowRecordColumn(record arrow.Record, 
columnIndex int, values [][]any)
        return nil
 }
 
-func convertProtoDataTypeToStructType(input *proto.DataType) *StructType {
+func convertProtoDataTypeToStructType(input *proto.DataType) (*StructType, 
error) {
        dataTypeStruct := input.GetStruct()
        if dataTypeStruct == nil {
-               panic("dataType.GetStruct() is nil")
+               return nil, 
sparkerrors.WithType(errors.New("dataType.GetStruct() is nil"), 
sparkerrors.InvalidInputError)
        }
        return &StructType{
                Fields: convertProtoStructFields(dataTypeStruct.Fields),
-       }
+       }, nil
 }
 
 func convertProtoStructFields(input []*proto.DataType_StructField) 
[]StructField {
diff --git a/client/sql/dataframe_test.go b/client/sql/dataframe_test.go
index ac5b162..98675fd 100644
--- a/client/sql/dataframe_test.go
+++ b/client/sql/dataframe_test.go
@@ -18,6 +18,9 @@ package sql
 
 import (
        "bytes"
+       "context"
+       "testing"
+
        "github.com/apache/arrow/go/v12/arrow"
        "github.com/apache/arrow/go/v12/arrow/array"
        "github.com/apache/arrow/go/v12/arrow/decimal128"
@@ -28,7 +31,6 @@ import (
        proto "github.com/apache/spark-connect-go/v1/internal/generated"
        "github.com/stretchr/testify/assert"
        "github.com/stretchr/testify/require"
-       "testing"
 )
 
 func TestShowArrowBatchData(t *testing.T) {
@@ -56,8 +58,10 @@ func TestShowArrowBatchData(t *testing.T) {
        err := arrowWriter.Write(record)
        require.Nil(t, err)
 
-       err = showArrowBatchData(buf.Bytes())
+       collector := &testCollector{}
+       err = showArrowBatchData(buf.Bytes(), collector)
        assert.Nil(t, err)
+       assert.Equal(t, []any{"str2"}, collector.row)
 }
 
 func TestReadArrowRecord(t *testing.T) {
@@ -304,3 +308,73 @@ func TestConvertProtoDataTypeToDataType_UnsupportedType(t 
*testing.T) {
        }
        assert.Equal(t, "Unsupported", 
convertProtoDataTypeToDataType(unsupportedDataType).TypeName())
 }
+
+type testCollector struct {
+       row []any
+}
+
+func (t *testCollector) WriteRow(values []any) {
+       t.row = values
+}
+
+func TestWriteResultStreamsArrowResultToCollector(t *testing.T) {
+       ctx := context.Background()
+
+       arrowFields := []arrow.Field{
+               {
+                       Name: "show_string",
+                       Type: &arrow.StringType{},
+               },
+       }
+       arrowSchema := arrow.NewSchema(arrowFields, nil)
+       var buf bytes.Buffer
+       arrowWriter := ipc.NewWriter(&buf, ipc.WithSchema(arrowSchema))
+       defer arrowWriter.Close()
+
+       alloc := memory.NewGoAllocator()
+       recordBuilder := array.NewRecordBuilder(alloc, arrowSchema)
+       defer recordBuilder.Release()
+
+       recordBuilder.Field(0).(*array.StringBuilder).Append("str1a\nstr1b")
+       recordBuilder.Field(0).(*array.StringBuilder).Append("str2")
+
+       record := recordBuilder.NewRecord()
+       defer record.Release()
+
+       err := arrowWriter.Write(record)
+       require.Nil(t, err)
+
+       query := "select * from bla"
+
+       session := &sparkSessionImpl{
+               client: &connectServiceClient{
+                       executePlanClient: &executePlanClient{&protoClient{
+                               recvResponses: []*proto.ExecutePlanResponse{
+                                       {
+                                               ResponseType: 
&proto.ExecutePlanResponse_SqlCommandResult_{
+                                                       SqlCommandResult: 
&proto.ExecutePlanResponse_SqlCommandResult{},
+                                               },
+                                       },
+                                       {
+                                               ResponseType: 
&proto.ExecutePlanResponse_ArrowBatch_{
+                                                       ArrowBatch: 
&proto.ExecutePlanResponse_ArrowBatch{
+                                                               RowCount: 1,
+                                                               Data:     
buf.Bytes(),
+                                                       },
+                                               },
+                                       },
+                               }},
+                       },
+                       t: t,
+               },
+       }
+       resp, err := session.Sql(ctx, query)
+       assert.NoError(t, err)
+       assert.NotNil(t, resp)
+       writer, err := resp.Repartition(1, []string{"1"})
+       assert.NoError(t, err)
+       collector := &testCollector{}
+       err = writer.WriteResult(ctx, collector, 1, false)
+       assert.NoError(t, err)
+       assert.Equal(t, []any{"str2"}, collector.row)
+}
diff --git a/client/sql/dataframereader.go b/client/sql/dataframereader.go
index 2f0e215..59f21ff 100644
--- a/client/sql/dataframereader.go
+++ b/client/sql/dataframereader.go
@@ -14,34 +14,40 @@ type DataFrameReader interface {
 
 // dataFrameReaderImpl is an implementation of DataFrameReader interface.
 type dataFrameReaderImpl struct {
-       sparkSession *sparkSessionImpl
+       sparkSession sparkExecutor
        formatSource string
 }
 
+func newDataframeReader(session sparkExecutor) DataFrameReader {
+       return &dataFrameReaderImpl{
+               sparkSession: session,
+       }
+}
+
 func (w *dataFrameReaderImpl) Format(source string) DataFrameReader {
        w.formatSource = source
        return w
 }
 
 func (w *dataFrameReaderImpl) Load(path string) (DataFrame, error) {
-       var format *string
+       var format string
        if w.formatSource != "" {
-               format = &w.formatSource
+               format = w.formatSource
        }
-       df := &dataFrameImpl{
-               sparkSession: w.sparkSession,
-               relation: &proto.Relation{
-                       RelType: &proto.Relation_Read{
-                               Read: &proto.Read{
-                                       ReadType: &proto.Read_DataSource_{
-                                               DataSource: 
&proto.Read_DataSource{
-                                                       Format: format,
-                                                       Paths:  []string{path},
-                                               },
+       return newDataFrame(w.sparkSession, toRelation(path, format)), nil
+}
+
+func toRelation(path string, format string) *proto.Relation {
+       return &proto.Relation{
+               RelType: &proto.Relation_Read{
+                       Read: &proto.Read{
+                               ReadType: &proto.Read_DataSource_{
+                                       DataSource: &proto.Read_DataSource{
+                                               Format: &format,
+                                               Paths:  []string{path},
                                        },
                                },
                        },
                },
        }
-       return df, nil
 }
diff --git a/client/sql/dataframereader_test.go 
b/client/sql/dataframereader_test.go
new file mode 100644
index 0000000..c5e73c0
--- /dev/null
+++ b/client/sql/dataframereader_test.go
@@ -0,0 +1,26 @@
+package sql
+
+import (
+       "testing"
+
+       "github.com/stretchr/testify/assert"
+)
+
+func TestLoadCreatesADataFrame(t *testing.T) {
+       reader := newDataframeReader(nil)
+       source := "source"
+       path := "path"
+       reader.Format(source)
+       frame, err := reader.Load(path)
+       assert.NoError(t, err)
+       assert.NotNil(t, frame)
+}
+
+func TestRelationContainsPathAndFormat(t *testing.T) {
+       formatSource := "source"
+       path := "path"
+       relation := toRelation(path, formatSource)
+       assert.NotNil(t, relation)
+       assert.Equal(t, &formatSource, 
relation.GetRead().GetDataSource().Format)
+       assert.Equal(t, path, relation.GetRead().GetDataSource().Paths[0])
+}
diff --git a/client/sql/dataframewriter.go b/client/sql/dataframewriter.go
index a398ebd..dee74e8 100644
--- a/client/sql/dataframewriter.go
+++ b/client/sql/dataframewriter.go
@@ -1,9 +1,12 @@
 package sql
 
 import (
+       "context"
        "fmt"
-       proto "github.com/apache/spark-connect-go/v1/internal/generated"
        "strings"
+
+       "github.com/apache/spark-connect-go/v1/client/sparkerrors"
+       proto "github.com/apache/spark-connect-go/v1/internal/generated"
 )
 
 // DataFrameWriter supports writing data frame to storage.
@@ -13,15 +16,27 @@ type DataFrameWriter interface {
        // Format specifies data format (data source type) for the underlying 
data, e.g. parquet.
        Format(source string) DataFrameWriter
        // Save writes data frame to the given path.
-       Save(path string) error
+       Save(ctx context.Context, path string) error
+}
+
+type sparkExecutor interface {
+       executePlan(ctx context.Context, plan *proto.Plan) (*executePlanClient, 
error)
+       analyzePlan(ctx context.Context, plan *proto.Plan) 
(*proto.AnalyzePlanResponse, error)
+}
+
+func newDataFrameWriter(sparkExecutor sparkExecutor, relation *proto.Relation) 
DataFrameWriter {
+       return &dataFrameWriterImpl{
+               sparkExecutor: sparkExecutor,
+               relation:      relation,
+       }
 }
 
 // dataFrameWriterImpl is an implementation of DataFrameWriter interface.
 type dataFrameWriterImpl struct {
-       sparkSession *sparkSessionImpl
-       relation     *proto.Relation
-       saveMode     string
-       formatSource string
+       sparkExecutor sparkExecutor
+       relation      *proto.Relation
+       saveMode      string
+       formatSource  string
 }
 
 func (w *dataFrameWriterImpl) Mode(saveMode string) DataFrameWriter {
@@ -34,7 +49,7 @@ func (w *dataFrameWriterImpl) Format(source string) 
DataFrameWriter {
        return w
 }
 
-func (w *dataFrameWriterImpl) Save(path string) error {
+func (w *dataFrameWriterImpl) Save(ctx context.Context, path string) error {
        saveMode, err := getSaveMode(w.saveMode)
        if err != nil {
                return err
@@ -59,12 +74,12 @@ func (w *dataFrameWriterImpl) Save(path string) error {
                        },
                },
        }
-       responseClient, err := w.sparkSession.executePlan(plan)
+       responseClient, err := w.sparkExecutor.executePlan(ctx, plan)
        if err != nil {
                return err
        }
 
-       return consumeExecutePlanClient(responseClient)
+       return responseClient.consumeAll()
 }
 
 func getSaveMode(mode string) (proto.WriteOperation_SaveMode, error) {
@@ -79,6 +94,6 @@ func getSaveMode(mode string) (proto.WriteOperation_SaveMode, 
error) {
        } else if strings.EqualFold(mode, "Ignore") {
                return proto.WriteOperation_SAVE_MODE_IGNORE, nil
        } else {
-               return 0, fmt.Errorf("unsupported save mode: %s", mode)
+               return 0, sparkerrors.WithType(fmt.Errorf("unsupported save 
mode: %s", mode), sparkerrors.InvalidInputError)
        }
 }
diff --git a/client/sql/dataframewriter_test.go 
b/client/sql/dataframewriter_test.go
index d61266f..649e013 100644
--- a/client/sql/dataframewriter_test.go
+++ b/client/sql/dataframewriter_test.go
@@ -1,9 +1,12 @@
 package sql
 
 import (
+       "context"
+       "io"
+       "testing"
+
        proto "github.com/apache/spark-connect-go/v1/internal/generated"
        "github.com/stretchr/testify/assert"
-       "testing"
 )
 
 func TestGetSaveMode(t *testing.T) {
@@ -31,3 +34,53 @@ func TestGetSaveMode(t *testing.T) {
        assert.NotNil(t, err)
        assert.Equal(t, proto.WriteOperation_SAVE_MODE_UNSPECIFIED, mode)
 }
+
+func TestSaveExecutesWriteOperationUntilEOF(t *testing.T) {
+       relation := &proto.Relation{}
+       executor := &testExecutor{
+               client: newExecutePlanClient(&protoClient{
+                       err: io.EOF,
+               }),
+       }
+       ctx := context.Background()
+       path := "path"
+
+       writer := newDataFrameWriter(executor, relation)
+       writer.Format("format")
+       writer.Mode("append")
+       err := writer.Save(ctx, path)
+       assert.NoError(t, err)
+}
+
+func TestSaveFailsIfAnotherErrorHappensWhenReadingStream(t *testing.T) {
+       relation := &proto.Relation{}
+       executor := &testExecutor{
+               client: newExecutePlanClient(&protoClient{
+                       err: assert.AnError,
+               }),
+       }
+       ctx := context.Background()
+       path := "path"
+
+       writer := newDataFrameWriter(executor, relation)
+       writer.Format("format")
+       writer.Mode("append")
+       err := writer.Save(ctx, path)
+       assert.Error(t, err)
+}
+
+func TestSaveFailsIfAnotherErrorHappensWhenExecuting(t *testing.T) {
+       relation := &proto.Relation{}
+       executor := &testExecutor{
+               client: newExecutePlanClient(&protoClient{}),
+               err:    assert.AnError,
+       }
+       ctx := context.Background()
+       path := "path"
+
+       writer := newDataFrameWriter(executor, relation)
+       writer.Format("format")
+       writer.Mode("append")
+       err := writer.Save(ctx, path)
+       assert.Error(t, err)
+}
diff --git a/client/sql/datatype.go b/client/sql/datatype.go
index e201114..ecbadf3 100644
--- a/client/sql/datatype.go
+++ b/client/sql/datatype.go
@@ -17,7 +17,7 @@
 package sql
 
 import (
-       "reflect"
+       "fmt"
        "strings"
 )
 
@@ -125,16 +125,7 @@ func (t UnsupportedType) TypeName() string {
 }
 
 func getDataTypeName(dataType DataType) string {
-       t := reflect.TypeOf(dataType)
-       if t == nil {
-               return "(nil)"
-       }
-       var name string
-       if t.Kind() == reflect.Ptr {
-               name = t.Elem().Name()
-       } else {
-               name = t.Name()
-       }
-       name = strings.TrimSuffix(name, "Type")
-       return name
+       typeName := fmt.Sprintf("%T", dataType)
+       nonQualifiedTypeName := strings.Split(typeName, ".")[1]
+       return strings.TrimSuffix(nonQualifiedTypeName, "Type")
 }
diff --git a/client/sql/executeplanclient.go b/client/sql/executeplanclient.go
new file mode 100644
index 0000000..ece0a20
--- /dev/null
+++ b/client/sql/executeplanclient.go
@@ -0,0 +1,37 @@
+package sql
+
+import (
+       "errors"
+       "fmt"
+       "io"
+
+       "github.com/apache/spark-connect-go/v1/client/sparkerrors"
+       proto "github.com/apache/spark-connect-go/v1/internal/generated"
+)
+
+type executePlanClient struct {
+       proto.SparkConnectService_ExecutePlanClient
+}
+
+func newExecutePlanClient(responseClient 
proto.SparkConnectService_ExecutePlanClient) *executePlanClient {
+       return &executePlanClient{
+               responseClient,
+       }
+}
+
+// consumeAll reads through the returned GRPC stream from Spark Connect 
Driver. It will
+// discard the returned data if there is no error. This is necessary for 
handling GRPC response for
+// saving data frame, since such consuming will trigger Spark Connect Driver 
really saving data frame.
+// If we do not consume the returned GRPC stream, Spark Connect Driver will 
not really save data frame.
+func (c *executePlanClient) consumeAll() error {
+       for {
+               _, err := c.Recv()
+               if err != nil {
+                       if errors.Is(err, io.EOF) {
+                               return nil
+                       } else {
+                               return sparkerrors.WithType(fmt.Errorf("failed 
to receive plan execution response: %w", err), sparkerrors.ReadError)
+                       }
+               }
+       }
+}
diff --git a/client/sql/mocks_test.go b/client/sql/mocks_test.go
new file mode 100644
index 0000000..9d87d50
--- /dev/null
+++ b/client/sql/mocks_test.go
@@ -0,0 +1,112 @@
+package sql
+
+import (
+       "context"
+       "testing"
+
+       proto "github.com/apache/spark-connect-go/v1/internal/generated"
+       "github.com/stretchr/testify/assert"
+       "google.golang.org/grpc"
+       "google.golang.org/grpc/metadata"
+)
+
+type testExecutor struct {
+       client   *executePlanClient
+       response *proto.AnalyzePlanResponse
+       err      error
+}
+
+func (t *testExecutor) executePlan(ctx context.Context, plan *proto.Plan) 
(*executePlanClient, error) {
+       if t.err != nil {
+               return nil, t.err
+       }
+       return t.client, nil
+}
+
+func (t *testExecutor) analyzePlan(ctx context.Context, plan *proto.Plan) 
(*proto.AnalyzePlanResponse, error) {
+       return t.response, nil
+}
+
+type protoClient struct {
+       recvResponse  *proto.ExecutePlanResponse
+       recvResponses []*proto.ExecutePlanResponse
+
+       err error
+}
+
+func (p *protoClient) Recv() (*proto.ExecutePlanResponse, error) {
+       if len(p.recvResponses) != 0 {
+               p.recvResponse = p.recvResponses[0]
+               p.recvResponses = p.recvResponses[1:]
+       }
+       return p.recvResponse, p.err
+}
+
+func (p *protoClient) Header() (metadata.MD, error) {
+       return nil, p.err
+}
+
+func (p *protoClient) Trailer() metadata.MD {
+       return nil
+}
+
+func (p *protoClient) CloseSend() error {
+       return p.err
+}
+
+func (p *protoClient) Context() context.Context {
+       return nil
+}
+
+func (p *protoClient) SendMsg(m interface{}) error {
+       return p.err
+}
+
+func (p *protoClient) RecvMsg(m interface{}) error {
+       return p.err
+}
+
+type connectServiceClient struct {
+       t *testing.T
+
+       analysePlanResponse        *proto.AnalyzePlanResponse
+       executePlanClient          proto.SparkConnectService_ExecutePlanClient
+       expectedExecutePlanRequest *proto.ExecutePlanRequest
+
+       err error
+}
+
+func (c *connectServiceClient) ExecutePlan(ctx context.Context, in 
*proto.ExecutePlanRequest, opts ...grpc.CallOption) 
(proto.SparkConnectService_ExecutePlanClient, error) {
+       if c.expectedExecutePlanRequest != nil {
+               assert.Equal(c.t, c.expectedExecutePlanRequest, in)
+       }
+       return c.executePlanClient, c.err
+}
+
+func (c *connectServiceClient) AnalyzePlan(ctx context.Context, in 
*proto.AnalyzePlanRequest, opts ...grpc.CallOption) 
(*proto.AnalyzePlanResponse, error) {
+       return c.analysePlanResponse, c.err
+}
+
+func (c *connectServiceClient) Config(ctx context.Context, in 
*proto.ConfigRequest, opts ...grpc.CallOption) (*proto.ConfigResponse, error) {
+       return nil, c.err
+}
+
+func (c *connectServiceClient) AddArtifacts(ctx context.Context, opts 
...grpc.CallOption) (proto.SparkConnectService_AddArtifactsClient, error) {
+       return nil, c.err
+}
+
+func (c *connectServiceClient) ArtifactStatus(ctx context.Context, in 
*proto.ArtifactStatusesRequest, opts ...grpc.CallOption) 
(*proto.ArtifactStatusesResponse, error) {
+       return nil, c.err
+}
+
+func (c *connectServiceClient) Interrupt(ctx context.Context, in 
*proto.InterruptRequest, opts ...grpc.CallOption) (*proto.InterruptResponse, 
error) {
+       return nil, c.err
+}
+
+func (c *connectServiceClient) ReattachExecute(ctx context.Context, in 
*proto.ReattachExecuteRequest, opts ...grpc.CallOption) 
(proto.SparkConnectService_ReattachExecuteClient, error) {
+       return nil, c.err
+}
+
+func (c *connectServiceClient) ReleaseExecute(ctx context.Context, in 
*proto.ReleaseExecuteRequest, opts ...grpc.CallOption) 
(*proto.ReleaseExecuteResponse, error) {
+       return nil, c.err
+}
diff --git a/client/sql/plan_test.go b/client/sql/plan_test.go
new file mode 100644
index 0000000..c733862
--- /dev/null
+++ b/client/sql/plan_test.go
@@ -0,0 +1,13 @@
+package sql
+
+import (
+       "testing"
+
+       "github.com/stretchr/testify/assert"
+)
+
+func TestNewPlanIdGivesNewIDs(t *testing.T) {
+       id1 := newPlanId()
+       id2 := newPlanId()
+       assert.NotEqual(t, id1, id2)
+}
diff --git a/client/sql/row.go b/client/sql/row.go
index 3bee2ac..bea2ab7 100644
--- a/client/sql/row.go
+++ b/client/sql/row.go
@@ -16,20 +16,31 @@
 
 package sql
 
+// Row represents a row in a DataFrame.
 type Row interface {
+       // Schema returns the schema of the row.
        Schema() (*StructType, error)
+       // Values returns the values of the row.
        Values() ([]any, error)
 }
 
-type GenericRowWithSchema struct {
+// genericRowWithSchema represents a row in a DataFrame with schema.
+type genericRowWithSchema struct {
        values []any
        schema *StructType
 }
 
-func (r *GenericRowWithSchema) Schema() (*StructType, error) {
+func NewRowWithSchema(values []any, schema *StructType) Row {
+       return &genericRowWithSchema{
+               values: values,
+               schema: schema,
+       }
+}
+
+func (r *genericRowWithSchema) Schema() (*StructType, error) {
        return r.schema, nil
 }
 
-func (r *GenericRowWithSchema) Values() ([]any, error) {
+func (r *genericRowWithSchema) Values() ([]any, error) {
        return r.values, nil
 }
diff --git a/client/sql/row_test.go b/client/sql/row_test.go
new file mode 100644
index 0000000..7ae4f97
--- /dev/null
+++ b/client/sql/row_test.go
@@ -0,0 +1,25 @@
+package sql
+
+import (
+       "testing"
+
+       "github.com/stretchr/testify/assert"
+)
+
+func TestSchema(t *testing.T) {
+       values := []any{1}
+       schema := &StructType{}
+       row := NewRowWithSchema(values, schema)
+       schema2, err := row.Schema()
+       assert.NoError(t, err)
+       assert.Equal(t, schema, schema2)
+}
+
+func TestValues(t *testing.T) {
+       values := []any{1}
+       schema := &StructType{}
+       row := NewRowWithSchema(values, schema)
+       values2, err := row.Values()
+       assert.NoError(t, err)
+       assert.Equal(t, values, values2)
+}
diff --git a/client/sql/sparksession.go b/client/sql/sparksession.go
index a5ac466..e1ecef8 100644
--- a/client/sql/sparksession.go
+++ b/client/sql/sparksession.go
@@ -18,48 +18,46 @@ package sql
 
 import (
        "context"
-       "errors"
        "fmt"
 
        "github.com/apache/spark-connect-go/v1/client/channel"
+       "github.com/apache/spark-connect-go/v1/client/sparkerrors"
        proto "github.com/apache/spark-connect-go/v1/internal/generated"
        "github.com/google/uuid"
        "google.golang.org/grpc/metadata"
-       "io"
 )
 
-var SparkSession sparkSessionBuilderEntrypoint
-
-type sparkSession interface {
+type SparkSession interface {
        Read() DataFrameReader
-       Sql(query string) (DataFrame, error)
+       Sql(ctx context.Context, query string) (DataFrame, error)
        Stop() error
 }
 
-type sparkSessionBuilderEntrypoint struct {
-       Builder SparkSessionBuilder
+// NewSessionBuilder creates a new session builder for starting a new spark 
session
+func NewSessionBuilder() *SparkSessionBuilder {
+       return &SparkSessionBuilder{}
 }
 
 type SparkSessionBuilder struct {
        connectionString string
 }
 
-func (s SparkSessionBuilder) Remote(connectionString string) 
SparkSessionBuilder {
-       copy := s
-       copy.connectionString = connectionString
-       return copy
+// Remote sets the connection string for remote connection
+func (s *SparkSessionBuilder) Remote(connectionString string) 
*SparkSessionBuilder {
+       s.connectionString = connectionString
+       return s
 }
 
-func (s SparkSessionBuilder) Build() (sparkSession, error) {
+func (s *SparkSessionBuilder) Build(ctx context.Context) (SparkSession, error) 
{
 
        cb, err := channel.NewBuilder(s.connectionString)
        if err != nil {
-               return nil, fmt.Errorf("failed to connect to remote %s: %w", 
s.connectionString, err)
+               return nil, sparkerrors.WithType(fmt.Errorf("failed to connect 
to remote %s: %w", s.connectionString, err), sparkerrors.ConnectionError)
        }
 
-       conn, err := cb.Build()
+       conn, err := cb.Build(ctx)
        if err != nil {
-               return nil, fmt.Errorf("failed to connect to remote %s: %w", 
s.connectionString, err)
+               return nil, sparkerrors.WithType(fmt.Errorf("failed to connect 
to remote %s: %w", s.connectionString, err), sparkerrors.ConnectionError)
        }
 
        // Add metadata to the request.
@@ -83,12 +81,10 @@ type sparkSessionImpl struct {
 }
 
 func (s *sparkSessionImpl) Read() DataFrameReader {
-       return &dataFrameReaderImpl{
-               sparkSession: s,
-       }
+       return newDataframeReader(s)
 }
 
-func (s *sparkSessionImpl) Sql(query string) (DataFrame, error) {
+func (s *sparkSessionImpl) Sql(ctx context.Context, query string) (DataFrame, 
error) {
        plan := &proto.Plan{
                OpType: &proto.Plan_Command{
                        Command: &proto.Command{
@@ -100,32 +96,28 @@ func (s *sparkSessionImpl) Sql(query string) (DataFrame, 
error) {
                        },
                },
        }
-       responseClient, err := s.executePlan(plan)
+       responseClient, err := s.executePlan(ctx, plan)
        if err != nil {
-               return nil, fmt.Errorf("failed to execute sql: %s: %w", query, 
err)
+               return nil, sparkerrors.WithType(fmt.Errorf("failed to execute 
sql: %s: %w", query, err), sparkerrors.ExecutionError)
        }
        for {
                response, err := responseClient.Recv()
                if err != nil {
-                       return nil, fmt.Errorf("failed to receive ExecutePlan 
response: %w", err)
+                       return nil, sparkerrors.WithType(fmt.Errorf("failed to 
receive ExecutePlan response: %w", err), sparkerrors.ReadError)
                }
                sqlCommandResult := response.GetSqlCommandResult()
                if sqlCommandResult == nil {
                        continue
                }
-               return &dataFrameImpl{
-                       sparkSession: s,
-                       relation:     sqlCommandResult.GetRelation(),
-               }, nil
+               return newDataFrame(s, sqlCommandResult.GetRelation()), nil
        }
-       return nil, fmt.Errorf("failed to get SqlCommandResult in ExecutePlan 
response")
 }
 
 func (s *sparkSessionImpl) Stop() error {
        return nil
 }
 
-func (s *sparkSessionImpl) executePlan(plan *proto.Plan) 
(proto.SparkConnectService_ExecutePlanClient, error) {
+func (s *sparkSessionImpl) executePlan(ctx context.Context, plan *proto.Plan) 
(*executePlanClient, error) {
        request := proto.ExecutePlanRequest{
                SessionId: s.sessionId,
                Plan:      plan,
@@ -134,15 +126,15 @@ func (s *sparkSessionImpl) executePlan(plan *proto.Plan) 
(proto.SparkConnectServ
                },
        }
        // Append the other items to the request.
-       ctx := metadata.NewOutgoingContext(context.Background(), s.metadata)
-       executePlanClient, err := s.client.ExecutePlan(ctx, &request)
+       ctx = metadata.NewOutgoingContext(ctx, s.metadata)
+       client, err := s.client.ExecutePlan(ctx, &request)
        if err != nil {
-               return nil, fmt.Errorf("failed to call ExecutePlan in session 
%s: %w", s.sessionId, err)
+               return nil, sparkerrors.WithType(fmt.Errorf("failed to call 
ExecutePlan in session %s: %w", s.sessionId, err), sparkerrors.ExecutionError)
        }
-       return executePlanClient, nil
+       return newExecutePlanClient(client), nil
 }
 
-func (s *sparkSessionImpl) analyzePlan(plan *proto.Plan) 
(*proto.AnalyzePlanResponse, error) {
+func (s *sparkSessionImpl) analyzePlan(ctx context.Context, plan *proto.Plan) 
(*proto.AnalyzePlanResponse, error) {
        request := proto.AnalyzePlanRequest{
                SessionId: s.sessionId,
                Analyze: &proto.AnalyzePlanRequest_Schema_{
@@ -155,29 +147,11 @@ func (s *sparkSessionImpl) analyzePlan(plan *proto.Plan) 
(*proto.AnalyzePlanResp
                },
        }
        // Append the other items to the request.
-       ctx := metadata.NewOutgoingContext(context.Background(), s.metadata)
+       ctx = metadata.NewOutgoingContext(ctx, s.metadata)
 
        response, err := s.client.AnalyzePlan(ctx, &request)
        if err != nil {
-               return nil, fmt.Errorf("failed to call AnalyzePlan in session 
%s: %w", s.sessionId, err)
+               return nil, sparkerrors.WithType(fmt.Errorf("failed to call 
AnalyzePlan in session %s: %w", s.sessionId, err), sparkerrors.ExecutionError)
        }
        return response, nil
 }
-
-// consumeExecutePlanClient reads through the returned GRPC stream from Spark 
Connect Driver. It will
-// discard the returned data if there is no error. This is necessary for 
handling GRPC response for
-// saving data frame, since such consuming will trigger Spark Connect Driver 
really saving data frame.
-// If we do not consume the returned GRPC stream, Spark Connect Driver will 
not really save data frame.
-func consumeExecutePlanClient(responseClient 
proto.SparkConnectService_ExecutePlanClient) error {
-       for {
-               _, err := responseClient.Recv()
-               if err != nil {
-                       if errors.Is(err, io.EOF) {
-                               return nil
-                       } else {
-                               return fmt.Errorf("failed to receive plan 
execution response: %w", err)
-                       }
-               }
-       }
-       return nil
-}
diff --git a/client/sql/sparksession_test.go b/client/sql/sparksession_test.go
new file mode 100644
index 0000000..0a3b105
--- /dev/null
+++ b/client/sql/sparksession_test.go
@@ -0,0 +1,111 @@
+package sql
+
+import (
+       "context"
+       "testing"
+
+       "github.com/apache/spark-connect-go/v1/client/sparkerrors"
+       proto "github.com/apache/spark-connect-go/v1/internal/generated"
+       "github.com/stretchr/testify/assert"
+)
+
+func TestAnalyzePlanCallsAnalyzePlanOnClient(t *testing.T) {
+       ctx := context.Background()
+       reponse := &proto.AnalyzePlanResponse{}
+       session := &sparkSessionImpl{
+               client: &connectServiceClient{
+                       analysePlanResponse: reponse,
+               },
+       }
+       resp, err := session.analyzePlan(ctx, &proto.Plan{})
+       assert.NoError(t, err)
+       assert.NotNil(t, resp)
+}
+
+func TestAnalyzePlanFailsIfClientFails(t *testing.T) {
+       ctx := context.Background()
+       session := &sparkSessionImpl{
+               client: &connectServiceClient{
+                       err: assert.AnError,
+               },
+       }
+       resp, err := session.analyzePlan(ctx, &proto.Plan{})
+       assert.Nil(t, resp)
+       assert.Error(t, err)
+}
+
+func TestExecutePlanCallsExecutePlanOnClient(t *testing.T) {
+       ctx := context.Background()
+
+       plan := &proto.Plan{}
+       request := &proto.ExecutePlanRequest{
+               Plan: plan,
+               UserContext: &proto.UserContext{
+                       UserId: "na",
+               },
+       }
+       session := &sparkSessionImpl{
+               client: &connectServiceClient{
+                       executePlanClient:          &executePlanClient{},
+                       expectedExecutePlanRequest: request,
+                       t:                          t,
+               },
+       }
+       resp, err := session.executePlan(ctx, plan)
+       assert.NoError(t, err)
+       assert.NotNil(t, resp)
+}
+
+func TestSQLCallsExecutePlanWithSQLOnClient(t *testing.T) {
+       ctx := context.Background()
+
+       query := "select * from bla"
+       plan := &proto.Plan{
+               OpType: &proto.Plan_Command{
+                       Command: &proto.Command{
+                               CommandType: &proto.Command_SqlCommand{
+                                       SqlCommand: &proto.SqlCommand{
+                                               Sql: query,
+                                       },
+                               },
+                       },
+               },
+       }
+       request := &proto.ExecutePlanRequest{
+               Plan: plan,
+               UserContext: &proto.UserContext{
+                       UserId: "na",
+               },
+       }
+       session := &sparkSessionImpl{
+               client: &connectServiceClient{
+                       executePlanClient: &executePlanClient{&protoClient{
+                               recvResponse: &proto.ExecutePlanResponse{
+                                       ResponseType: 
&proto.ExecutePlanResponse_SqlCommandResult_{
+                                               SqlCommandResult: 
&proto.ExecutePlanResponse_SqlCommandResult{},
+                                       },
+                               },
+                       }},
+                       expectedExecutePlanRequest: request,
+                       t:                          t,
+               },
+       }
+       resp, err := session.Sql(ctx, query)
+       assert.NoError(t, err)
+       assert.NotNil(t, resp)
+}
+
+func TestNewSessionBuilderCreatesASession(t *testing.T) {
+       ctx := context.Background()
+       spark, err := NewSessionBuilder().Remote("sc:connection").Build(ctx)
+       assert.NoError(t, err)
+       assert.NotNil(t, spark)
+}
+
+func TestNewSessionBuilderFailsIfConnectionStringIsInvalid(t *testing.T) {
+       ctx := context.Background()
+       spark, err := NewSessionBuilder().Remote("invalid").Build(ctx)
+       assert.Error(t, err)
+       assert.ErrorIs(t, err, sparkerrors.InvalidInputError)
+       assert.Nil(t, spark)
+}
diff --git a/client/sql/structtype.go b/client/sql/structtype.go
index 2a59c30..fd75236 100644
--- a/client/sql/structtype.go
+++ b/client/sql/structtype.go
@@ -16,12 +16,14 @@
 
 package sql
 
+// StructField represents a field in a StructType.
 type StructField struct {
        Name     string
        DataType DataType
        Nullable bool // default should be true
 }
 
+// StructType represents a struct type.
 type StructType struct {
        TypeName string
        Fields   []StructField
diff --git a/cmd/spark-connect-example-raw-grpc-client/main.go 
b/cmd/spark-connect-example-raw-grpc-client/main.go
index f48500b..25e48d9 100644
--- a/cmd/spark-connect-example-raw-grpc-client/main.go
+++ b/cmd/spark-connect-example-raw-grpc-client/main.go
@@ -19,12 +19,13 @@ package main
 import (
        "context"
        "flag"
+       "log"
+       "time"
+
        proto "github.com/apache/spark-connect-go/v1/internal/generated"
        "github.com/google/uuid"
        "google.golang.org/grpc"
        "google.golang.org/grpc/credentials/insecure"
-       "log"
-       "time"
 )
 
 var (
@@ -32,13 +33,14 @@ var (
 )
 
 func main() {
+       ctx := context.Background()
        opts := []grpc.DialOption{
                grpc.WithTransportCredentials(insecure.NewCredentials()),
        }
 
-       conn, err := grpc.Dial(*remote, opts...)
+       conn, err := grpc.DialContext(ctx, *remote, opts...)
        if err != nil {
-               log.Fatalf("Failed: %s", err.Error())
+               log.Fatalf("Failed: %s", err)
        }
        defer conn.Close()
 
@@ -57,7 +59,7 @@ func main() {
        }
        configResponse, err := client.Config(ctx, &configRequest)
        if err != nil {
-               log.Fatalf("Failed: %s", err.Error())
+               log.Fatalf("Failed: %s", err)
        }
 
        log.Printf("configResponse: %v", configResponse)
diff --git a/cmd/spark-connect-example-spark-session/main.go 
b/cmd/spark-connect-example-spark-session/main.go
index aeb95b8..f19a199 100644
--- a/cmd/spark-connect-example-spark-session/main.go
+++ b/cmd/spark-connect-example-spark-session/main.go
@@ -17,6 +17,7 @@
 package main
 
 import (
+       "context"
        "flag"
        "log"
 
@@ -30,40 +31,41 @@ var (
 
 func main() {
        flag.Parse()
-       spark, err := sql.SparkSession.Builder.Remote(*remote).Build()
+       ctx := context.Background()
+       spark, err := sql.NewSessionBuilder().Remote(*remote).Build(ctx)
        if err != nil {
-               log.Fatalf("Failed: %s", err.Error())
+               log.Fatalf("Failed: %s", err)
        }
        defer spark.Stop()
 
-       df, err := spark.Sql("select 'apple' as word, 123 as count union all 
select 'orange' as word, 456 as count")
+       df, err := spark.Sql(ctx, "select 'apple' as word, 123 as count union 
all select 'orange' as word, 456 as count")
        if err != nil {
-               log.Fatalf("Failed: %s", err.Error())
+               log.Fatalf("Failed: %s", err)
        }
 
        log.Printf("DataFrame from sql: select 'apple' as word, 123 as count 
union all select 'orange' as word, 456 as count")
-       err = df.Show(100, false)
+       err = df.Show(ctx, 100, false)
        if err != nil {
-               log.Fatalf("Failed: %s", err.Error())
+               log.Fatalf("Failed: %s", err)
        }
 
-       schema, err := df.Schema()
+       schema, err := df.Schema(ctx)
        if err != nil {
-               log.Fatalf("Failed: %s", err.Error())
+               log.Fatalf("Failed: %s", err)
        }
 
        for _, f := range schema.Fields {
                log.Printf("Field in dataframe schema: %s - %s", f.Name, 
f.DataType.TypeName())
        }
 
-       rows, err := df.Collect()
+       rows, err := df.Collect(ctx)
        if err != nil {
-               log.Fatalf("Failed: %s", err.Error())
+               log.Fatalf("Failed: %s", err)
        }
 
        schema, err = rows[0].Schema()
        if err != nil {
-               log.Fatalf("Failed: %s", err.Error())
+               log.Fatalf("Failed: %s", err)
        }
 
        for _, f := range schema.Fields {
@@ -74,72 +76,78 @@ func main() {
                log.Printf("Row: %v", row)
        }
 
-       err = df.Write().Mode("overwrite").
+       err = df.Writer().Mode("overwrite").
                Format("parquet").
-               Save("file:///tmp/spark-connect-write-example-output.parquet")
+               Save(ctx, 
"file:///tmp/spark-connect-write-example-output.parquet")
        if err != nil {
-               log.Fatalf("Failed: %s", err.Error())
+               log.Fatalf("Failed: %s", err)
        }
 
        df, err = spark.Read().Format("parquet").
                Load("file:///tmp/spark-connect-write-example-output.parquet")
        if err != nil {
-               log.Fatalf("Failed: %s", err.Error())
+               log.Fatalf("Failed: %s", err)
        }
 
        log.Printf("DataFrame from reading parquet")
-       df.Show(100, false)
+       err = df.Show(ctx, 100, false)
+       if err != nil {
+               log.Fatalf("Failed: %s", err)
+       }
 
-       err = df.CreateTempView("view1", true, false)
+       err = df.CreateTempView(ctx, "view1", true, false)
        if err != nil {
-               log.Fatalf("Failed: %s", err.Error())
+               log.Fatalf("Failed: %s", err)
        }
 
-       df, err = spark.Sql("select count, word from view1 order by count")
+       df, err = spark.Sql(ctx, "select count, word from view1 order by count")
        if err != nil {
-               log.Fatalf("Failed: %s", err.Error())
+               log.Fatalf("Failed: %s", err)
        }
 
        log.Printf("DataFrame from sql: select count, word from view1 order by 
count")
-       df.Show(100, false)
+       err = df.Show(ctx, 100, false)
+       if err != nil {
+               log.Fatalf("Failed: %s", err)
+       }
 
        log.Printf("Repartition with one partition")
        df, err = df.Repartition(1, nil)
        if err != nil {
-               log.Fatalf("Failed: %s", err.Error())
+               log.Fatalf("Failed: %s", err)
        }
 
-       err = df.Write().Mode("overwrite").
+       err = df.Writer().Mode("overwrite").
                Format("parquet").
-               
Save("file:///tmp/spark-connect-write-example-output-one-partition.parquet")
+               Save(ctx, 
"file:///tmp/spark-connect-write-example-output-one-partition.parquet")
        if err != nil {
-               log.Fatalf("Failed: %s", err.Error())
+               log.Fatalf("Failed: %s", err)
        }
 
        log.Printf("Repartition with two partitions")
        df, err = df.Repartition(2, nil)
        if err != nil {
-               log.Fatalf("Failed: %s", err.Error())
+               log.Fatalf("Failed: %s", err)
        }
 
-       err = df.Write().Mode("overwrite").
+       err = df.Writer().Mode("overwrite").
                Format("parquet").
-               
Save("file:///tmp/spark-connect-write-example-output-two-partition.parquet")
+               Save(ctx, 
"file:///tmp/spark-connect-write-example-output-two-partition.parquet")
        if err != nil {
-               log.Fatalf("Failed: %s", err.Error())
+               log.Fatalf("Failed: %s", err)
        }
 
        log.Printf("Repartition with columns")
        df, err = df.Repartition(0, []string{"word", "count"})
        if err != nil {
-               log.Fatalf("Failed: %s", err.Error())
+               log.Fatalf("Failed: %s", err)
        }
 
-       err = df.Write().Mode("overwrite").
+       err = df.Writer().Mode("overwrite").
                Format("parquet").
-               
Save("file:///tmp/spark-connect-write-example-output-repartition-with-column.parquet")
+               Save(ctx, 
"file:///tmp/spark-connect-write-example-output-repartition-with-column.parquet")
        if err != nil {
-               log.Fatalf("Failed: %s", err.Error())
+               log.Fatalf("Failed: %s", err)
        }
 
        log.Printf("Repartition by range with columns")
@@ -150,13 +158,13 @@ func main() {
                },
        })
        if err != nil {
-               log.Fatalf("Failed: %s", err.Error())
+               log.Fatalf("Failed: %s", err)
        }
 
-       err = df.Write().Mode("overwrite").
+       err = df.Writer().Mode("overwrite").
                Format("parquet").
-               
Save("file:///tmp/spark-connect-write-example-output-repartition-by-range-with-column.parquet")
+               Save(ctx, 
"file:///tmp/spark-connect-write-example-output-repartition-by-range-with-column.parquet")
        if err != nil {
-               log.Fatalf("Failed: %s", err.Error())
+               log.Fatalf("Failed: %s", err)
        }
 }
diff --git a/go.mod b/go.mod
index 72f627c..ffdb132 100644
--- a/go.mod
+++ b/go.mod
@@ -15,7 +15,7 @@
 
 module github.com/apache/spark-connect-go/v1
 
-go 1.19
+go 1.21
 
 require (
        github.com/apache/arrow/go/v12 v12.0.0
diff --git a/quick-start.md b/quick-start.md
index 8e1b1e1..2cf83a4 100644
--- a/quick-start.md
+++ b/quick-start.md
@@ -18,6 +18,7 @@ Create `main.go` file with following code:
 package main
 
 import (
+       "context"
        "flag"
        "log"
 
@@ -31,56 +32,78 @@ var (
 
 func main() {
        flag.Parse()
-       spark, err := sql.SparkSession.Builder.Remote(*remote).Build()
+       ctx := context.Background()
+       spark, err := sql.NewSessionBuilder().Remote(*remote).Build(ctx)
        if err != nil {
-               log.Fatalf("Failed: %s", err.Error())
+               log.Fatalf("Failed: %s", err)
        }
        defer spark.Stop()
 
-       df, err := spark.Sql("select 'apple' as word, 123 as count union all 
select 'orange' as word, 456 as count")
+       df, err := spark.Sql(ctx, "select 'apple' as word, 123 as count union 
all select 'orange' as word, 456 as count")
        if err != nil {
-               log.Fatalf("Failed: %s", err.Error())
+               log.Fatalf("Failed: %s", err)
        }
 
        log.Printf("DataFrame from sql: select 'apple' as word, 123 as count 
union all select 'orange' as word, 456 as count")
-       err = df.Show(100, false)
+       err = df.Show(ctx, 100, false)
        if err != nil {
-               log.Fatalf("Failed: %s", err.Error())
+               log.Fatalf("Failed: %s", err)
        }
 
-       rows, err := df.Collect()
+       schema, err := df.Schema(ctx)
        if err != nil {
-               log.Fatalf("Failed: %s", err.Error())
+               log.Fatalf("Failed: %s", err)
+       }
+
+       for _, f := range schema.Fields {
+               log.Printf("Field in dataframe schema: %s - %s", f.Name, 
f.DataType.TypeName())
+       }
+
+       rows, err := df.Collect(ctx)
+       if err != nil {
+               log.Fatalf("Failed: %s", err)
+       }
+
+       schema, err = rows[0].Schema()
+       if err != nil {
+               log.Fatalf("Failed: %s", err)
+       }
+
+       for _, f := range schema.Fields {
+               log.Printf("Field in row: %s - %s", f.Name, 
f.DataType.TypeName())
        }
 
        for _, row := range rows {
                log.Printf("Row: %v", row)
        }
 
-       err = df.Write().Mode("overwrite").
+       err = df.Writer().Mode("overwrite").
                Format("parquet").
-               Save("file:///tmp/spark-connect-write-example-output.parquet")
+               Save(ctx, 
"file:///tmp/spark-connect-write-example-output.parquet")
        if err != nil {
-               log.Fatalf("Failed: %s", err.Error())
+               log.Fatalf("Failed: %s", err)
        }
 
        df, err = spark.Read().Format("parquet").
                Load("file:///tmp/spark-connect-write-example-output.parquet")
        if err != nil {
-               log.Fatalf("Failed: %s", err.Error())
+               log.Fatalf("Failed: %s", err)
        }
 
        log.Printf("DataFrame from reading parquet")
-       df.Show(100, false)
+       err = df.Show(ctx, 100, false)
+       if err != nil {
+               log.Fatalf("Failed: %s", err)
+       }
 
-       err = df.CreateTempView("view1", true, false)
+       err = df.CreateTempView(ctx, "view1", true, false)
        if err != nil {
-               log.Fatalf("Failed: %s", err.Error())
+               log.Fatalf("Failed: %s", err)
        }
 
-       df, err = spark.Sql("select count, word from view1 order by count")
+       df, err = spark.Sql(ctx, "select count, word from view1 order by count")
        if err != nil {
-               log.Fatalf("Failed: %s", err.Error())
+               log.Fatalf("Failed: %s", err)
        }
 
        log.Printf("DataFrame from sql: select count, word from view1 order by 
count")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to